Coverage for gpkit/tools/autosweep.py: 85%
191 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 22:33 -0500
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 22:33 -0500
1"Tools for optimal fits to GP sweeps"
2from time import time
3import pickle
4import numpy as np
5from ..small_classes import Count
6from ..small_scripts import mag
7from ..solution_array import SolutionArray
8from ..exceptions import InvalidGPConstraint
11class BinarySweepTree: # pylint: disable=too-many-instance-attributes
12 """Spans a line segment. May contain two subtrees that divide the segment.
14 Attributes
15 ----------
17 bounds : two-element list
18 The left and right boundaries of the segment
20 sols : two-element list
21 The left and right solutions of the segment
23 costs : array
24 The left and right logcosts of the segment
26 splits : None or two-element list
27 If not None, contains the left and right subtrees
29 splitval : None or float
30 The worst-error point, where the split will be if tolerance is too low
32 splitlb : None or float
33 The cost lower bound at splitval
35 splitub : None or float
36 The cost upper bound at splitval
37 """
39 def __init__(self, bounds, sols, sweptvar, costposy):
40 if len(bounds) != 2:
41 raise ValueError("bounds must be of length 2")
42 if bounds[1] <= bounds[0]:
43 raise ValueError("bounds[0] must be smaller than bounds[1].")
44 self.bounds = bounds
45 self.sols = sols
46 self.costs = np.log([mag(sol["cost"]) for sol in sols])
47 self.splits = None
48 self.splitval = None
49 self.splitlb = None
50 self.splitub = None
51 self.sweptvar = sweptvar
52 self.costposy = costposy
54 def add_split(self, splitval, splitsol):
55 "Creates subtrees from bounds[0] to splitval and splitval to bounds[1]"
56 if self.splitval:
57 raise ValueError("split already exists!")
58 if splitval <= self.bounds[0] or splitval >= self.bounds[1]:
59 raise ValueError("split value is at or outside bounds.")
60 self.splitval = splitval
61 self.splits = [BinarySweepTree([self.bounds[0], splitval],
62 [self.sols[0], splitsol],
63 self.sweptvar, self.costposy),
64 BinarySweepTree([splitval, self.bounds[1]],
65 [splitsol, self.sols[1]],
66 self.sweptvar, self.costposy)]
68 def add_splitcost(self, splitval, splitlb, splitub):
69 "Adds a splitval, lower bound, and upper bound"
70 if self.splitval:
71 raise ValueError("split already exists!")
72 if splitval <= self.bounds[0] or splitval >= self.bounds[1]:
73 raise ValueError("split value is at or outside bounds.")
74 self.splitval = splitval
75 self.splitlb, self.splitub = splitlb, splitub
77 def posy_at(self, posy, value):
78 """Logspace interpolates between sols to get posynomial values.
80 No guarantees, just like a regular sweep.
81 """
82 if value < self.bounds[0] or value > self.bounds[1]:
83 raise ValueError("query value is outside bounds.")
84 bst = self.min_bst(value)
85 lo, hi = bst.bounds
86 loval, hival = [sol(posy) for sol in bst.sols]
87 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival])))
88 interp = (hi-np.log(value))/float(hi-lo)
89 return np.exp(interp*loval + (1-interp)*hival)
91 def cost_at(self, _, value, bound=None):
92 "Logspace interpolates between split and costs. Guaranteed bounded."
93 if value < self.bounds[0] or value > self.bounds[1]:
94 raise ValueError("query value is outside bounds.")
95 bst = self.min_bst(value)
96 if bst.splitlb:
97 if bound:
98 if bound == "lb":
99 splitcost = np.exp(bst.splitlb)
100 elif bound == "ub":
101 splitcost = np.exp(bst.splitub)
102 else:
103 splitcost = np.exp((bst.splitlb + bst.splitub)/2)
104 if value <= bst.splitval:
105 lo, hi = bst.bounds[0], bst.splitval
106 loval, hival = bst.sols[0]["cost"], splitcost
107 else:
108 lo, hi = bst.splitval, bst.bounds[1]
109 loval, hival = splitcost, bst.sols[1]["cost"]
110 else:
111 lo, hi = bst.bounds
112 loval, hival = [sol["cost"] for sol in bst.sols]
113 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival])))
114 interp = (hi-np.log(value))/float(hi-lo)
115 return np.exp(interp*loval + (1-interp)*hival)
117 def min_bst(self, value):
118 "Returns smallest bst around value."
119 if not self.splits:
120 return self
121 choice = self.splits[0] if value <= self.splitval else self.splits[1]
122 return choice.min_bst(value)
124 def sample_at(self, values):
125 "Creates a SolutionOracle at a given range of values"
126 return SolutionOracle(self, values)
128 @property
129 def sollist(self):
130 "Returns a list of all the solutions in an autosweep"
131 sollist = [self.sols[0]]
132 if self.splits:
133 sollist.extend(self.splits[0].sollist[1:])
134 sollist.extend(self.splits[1].sollist[1:-1])
135 sollist.append(self.sols[1])
136 return sollist
138 @property
139 def solarray(self):
140 "Returns a solution array of all the solutions in an autosweep"
141 solution = SolutionArray()
142 for sol in self.sollist:
143 solution.append(sol)
144 solution.to_arrays()
145 return solution
147 def save(self, filename="autosweep.p"):
148 """Pickles the autosweep and saves it to a file.
150 The saved autosweep is identical except for two things:
151 - the cost is made unitless
152 - each solution's 'program' attribute is removed
154 Solution can then be loaded with e.g.:
155 >>> import cPickle as pickle
156 >>> pickle.load(open("autosweep.p"))
157 """
158 with open(filename, "wb") as f:
159 pickle.dump(self, f)
162class SolutionOracle:
163 "Acts like a SolutionArray for autosweeps"
164 def __init__(self, bst, sampled_at):
165 self.sampled_at = sampled_at
166 self.bst = bst
168 def __call__(self, key):
169 return self.__getval(key)
171 def __getitem__(self, key):
172 return self.__getval(key)
174 def _is_cost(self, key):
175 if hasattr(key, "hmap") and key.hmap == self.bst.costposy.hmap:
176 return True
177 return key == "cost"
179 def __getval(self, key):
180 "Gets values from the BST and units them"
181 if self._is_cost(key):
182 key_at = self.bst.cost_at
183 v0 = self.bst.sols[0]["cost"]
184 else:
185 key_at = self.bst.posy_at
186 v0 = self.bst.sols[0](key)
187 units = getattr(v0, "units", None)
188 fit = [key_at(key, x) for x in self.sampled_at]
189 return fit*units if units else np.array(fit)
191 def cost_lb(self):
192 "Gets cost lower bounds from the BST and units them"
193 units = getattr(self.bst.sols[0]["cost"], "units", None)
194 fit = [self.bst.cost_at("cost", x, "lb") for x in self.sampled_at]
195 return fit*units if units else np.array(fit)
197 def cost_ub(self):
198 "Gets cost upper bounds from the BST and units them"
199 units = getattr(self.bst.sols[0]["cost"], "units", None)
200 fit = [self.bst.cost_at("cost", x, "ub") for x in self.sampled_at]
201 return fit*units if units else np.array(fit)
203 def plot(self, posys=None, axes=None):
204 "Plots the sweep for each posy"
205 #pylint: disable=import-outside-toplevel
206 import matplotlib.pyplot as plt
207 from ..interactive.plot_sweep import assign_axes
208 from .. import GPBLU
209 if not hasattr(posys, "__len__"):
210 posys = [posys]
211 for i, posy in enumerate(posys):
212 if posy in [None, "cost"]:
213 posys[i] = self.bst.costposy
214 posys, axes = assign_axes(self.bst.sweptvar, posys, axes)
215 for posy, ax in zip(posys, axes):
216 if self._is_cost(posy): # with small tol should look like a line
217 ax.fill_between(self.sampled_at,
218 self.cost_lb(), self.cost_ub(),
219 facecolor=GPBLU, edgecolor=GPBLU,
220 linewidth=0.75)
221 else:
222 ax.plot(self.sampled_at, self(posy), color=GPBLU)
223 if len(axes) == 1:
224 axes, = axes
225 return plt.gcf(), axes
228def autosweep_1d(model, logtol, sweepvar, bounds, **solvekwargs):
229 "Autosweep a model over one sweepvar"
230 original_val = model.substitutions.get(sweepvar, None)
231 start_time = time()
232 solvekwargs.setdefault("verbosity", 1)
233 solvekwargs["verbosity"] -= 1
234 sols = Count().next
235 firstsols = []
236 for bound in bounds:
237 model.substitutions.update({sweepvar: bound})
238 try:
239 model.solve(**solvekwargs)
240 firstsols.append(model.program.result)
241 except InvalidGPConstraint as exc:
242 raise InvalidGPConstraint("only GPs can be autoswept.") from exc
243 sols()
244 bst = BinarySweepTree(bounds, firstsols, sweepvar, model.cost)
245 tol = recurse_splits(model, bst, sweepvar, logtol, solvekwargs, sols)
246 bst.nsols = sols() # pylint: disable=attribute-defined-outside-init
247 if solvekwargs["verbosity"] > -1:
248 print(f"Solved in {bst.nsols:2} passes, cost logtol +/-{tol:.3g}")
249 print(f"Autosweeping took {time() - start_time:.3g} seconds.")
250 if original_val:
251 model.substitutions[sweepvar] = original_val
252 else:
253 del model.substitutions[sweepvar]
254 return bst
257def recurse_splits(model, bst, variable, logtol, solvekwargs, sols):
258 # pylint: disable=too-many-arguments
259 "Recursively splits a BST until logtol is reached"
260 x, lb, ub = get_tol(bst.costs, bst.bounds, bst.sols, variable)
261 tol = (ub-lb)/2.0
262 if tol >= logtol:
263 model.substitutions.update({variable: x})
264 model.solve(**solvekwargs)
265 bst.add_split(x, model.program.result)
266 sols()
267 tols = [recurse_splits(model, split, variable, logtol, solvekwargs,
268 sols)
269 for split in bst.splits]
270 bst.tol = max(tols)
271 return bst.tol
272 bst.add_splitcost(x, lb, ub)
273 return tol
276def get_tol(costs, bounds, sols, variable): # pylint: disable=too-many-locals
277 "Gets the intersection point and corresponding bounds from two solutions."
278 y0, y1 = costs
279 x0, x1 = np.log(bounds)
280 s0, s1 = [sol["sensitivities"]["variables"][variable] for sol in sols]
281 # y0 + s0*(x - x0) == y1 + s1*(x - x1)
282 num = y1-y0 + x0*s0-x1*s1
283 denom = s0-s1
284 # NOTE: several branches below deal with straight lines, where lower
285 # and upper bounds are identical and so x is undefined
286 if denom == 0:
287 # mosek runs into this on perfect straight lines, num also equal to 0
288 # mosek_cli also runs into this on near-straight lines, num ~= 0
289 interp = -1 # flag interp as out-of bounds
290 else:
291 x = num/denom
292 lb = y0 + s0*(x-x0)
293 interp = (x1-x)/(x1-x0)
294 ub = y0*interp + y1*(1-interp)
295 if interp < 1e-7 or interp > 1 - 1e-7: # cvxopt on straight lines
296 x = (x0 + x1)/2 # x is undefined? stick it in the middle!
297 lb = ub = (y0 + y1)/2
298 return np.exp(x), lb, ub