Coverage for gpkit/tools/autosweep.py: 85%
190 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-28 11:50 -0400
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-28 11:50 -0400
1"Tools for optimal fits to GP sweeps"
2from time import time
3import numpy as np
4from ..small_classes import Count
5from ..small_scripts import mag
6from ..solution_array import SolutionArray
7from ..exceptions import InvalidGPConstraint
10class BinarySweepTree: # pylint: disable=too-many-instance-attributes
11 """Spans a line segment. May contain two subtrees that divide the segment.
13 Attributes
14 ----------
16 bounds : two-element list
17 The left and right boundaries of the segment
19 sols : two-element list
20 The left and right solutions of the segment
22 costs : array
23 The left and right logcosts of the segment
25 splits : None or two-element list
26 If not None, contains the left and right subtrees
28 splitval : None or float
29 The worst-error point, where the split will be if tolerance is too low
31 splitlb : None or float
32 The cost lower bound at splitval
34 splitub : None or float
35 The cost upper bound at splitval
36 """
38 def __init__(self, bounds, sols, sweptvar, costposy):
39 if len(bounds) != 2:
40 raise ValueError("bounds must be of length 2")
41 if bounds[1] <= bounds[0]:
42 raise ValueError("bounds[0] must be smaller than bounds[1].")
43 self.bounds = bounds
44 self.sols = sols
45 self.costs = np.log([mag(sol["cost"]) for sol in sols])
46 self.splits = None
47 self.splitval = None
48 self.splitlb = None
49 self.splitub = None
50 self.sweptvar = sweptvar
51 self.costposy = costposy
53 def add_split(self, splitval, splitsol):
54 "Creates subtrees from bounds[0] to splitval and splitval to bounds[1]"
55 if self.splitval:
56 raise ValueError("split already exists!")
57 if splitval <= self.bounds[0] or splitval >= self.bounds[1]:
58 raise ValueError("split value is at or outside bounds.")
59 self.splitval = splitval
60 self.splits = [BinarySweepTree([self.bounds[0], splitval],
61 [self.sols[0], splitsol],
62 self.sweptvar, self.costposy),
63 BinarySweepTree([splitval, self.bounds[1]],
64 [splitsol, self.sols[1]],
65 self.sweptvar, self.costposy)]
67 def add_splitcost(self, splitval, splitlb, splitub):
68 "Adds a splitval, lower bound, and upper bound"
69 if self.splitval:
70 raise ValueError("split already exists!")
71 if splitval <= self.bounds[0] or splitval >= self.bounds[1]:
72 raise ValueError("split value is at or outside bounds.")
73 self.splitval = splitval
74 self.splitlb, self.splitub = splitlb, splitub
76 def posy_at(self, posy, value):
77 """Logspace interpolates between sols to get posynomial values.
79 No guarantees, just like a regular sweep.
80 """
81 if value < self.bounds[0] or value > self.bounds[1]:
82 raise ValueError("query value is outside bounds.")
83 bst = self.min_bst(value)
84 lo, hi = bst.bounds
85 loval, hival = [sol(posy) for sol in bst.sols]
86 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival])))
87 interp = (hi-np.log(value))/float(hi-lo)
88 return np.exp(interp*loval + (1-interp)*hival)
90 def cost_at(self, _, value, bound=None):
91 "Logspace interpolates between split and costs. Guaranteed bounded."
92 if value < self.bounds[0] or value > self.bounds[1]:
93 raise ValueError("query value is outside bounds.")
94 bst = self.min_bst(value)
95 if bst.splitlb:
96 if bound:
97 if bound == "lb":
98 splitcost = np.exp(bst.splitlb)
99 elif bound == "ub":
100 splitcost = np.exp(bst.splitub)
101 else:
102 splitcost = np.exp((bst.splitlb + bst.splitub)/2)
103 if value <= bst.splitval:
104 lo, hi = bst.bounds[0], bst.splitval
105 loval, hival = bst.sols[0]["cost"], splitcost
106 else:
107 lo, hi = bst.splitval, bst.bounds[1]
108 loval, hival = splitcost, bst.sols[1]["cost"]
109 else:
110 lo, hi = bst.bounds
111 loval, hival = [sol["cost"] for sol in bst.sols]
112 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival])))
113 interp = (hi-np.log(value))/float(hi-lo)
114 return np.exp(interp*loval + (1-interp)*hival)
116 def min_bst(self, value):
117 "Returns smallest bst around value."
118 if not self.splits:
119 return self
120 choice = self.splits[0] if value <= self.splitval else self.splits[1]
121 return choice.min_bst(value)
123 def sample_at(self, values):
124 "Creates a SolutionOracle at a given range of values"
125 return SolutionOracle(self, values)
127 @property
128 def sollist(self):
129 "Returns a list of all the solutions in an autosweep"
130 sollist = [self.sols[0]]
131 if self.splits:
132 sollist.extend(self.splits[0].sollist[1:])
133 sollist.extend(self.splits[1].sollist[1:-1])
134 sollist.append(self.sols[1])
135 return sollist
137 @property
138 def solarray(self):
139 "Returns a solution array of all the solutions in an autosweep"
140 solution = SolutionArray()
141 for sol in self.sollist:
142 solution.append(sol)
143 solution.to_arrays()
144 return solution
146 def save(self, filename="autosweep.p"):
147 """Pickles the autosweep and saves it to a file.
149 The saved autosweep is identical except for two things:
150 - the cost is made unitless
151 - each solution's 'program' attribute is removed
153 Solution can then be loaded with e.g.:
154 >>> import cPickle as pickle
155 >>> pickle.load(open("autosweep.p"))
156 """
157 import pickle
158 pickle.dump(self, open(filename, "wb"))
161class SolutionOracle:
162 "Acts like a SolutionArray for autosweeps"
163 def __init__(self, bst, sampled_at):
164 self.sampled_at = sampled_at
165 self.bst = bst
167 def __call__(self, key):
168 return self.__getval(key)
170 def __getitem__(self, key):
171 return self.__getval(key)
173 def _is_cost(self, key):
174 if hasattr(key, "hmap") and key.hmap == self.bst.costposy.hmap:
175 return True
176 return key == "cost"
178 def __getval(self, key):
179 "Gets values from the BST and units them"
180 if self._is_cost(key):
181 key_at = self.bst.cost_at
182 v0 = self.bst.sols[0]["cost"]
183 else:
184 key_at = self.bst.posy_at
185 v0 = self.bst.sols[0](key)
186 units = getattr(v0, "units", None)
187 fit = [key_at(key, x) for x in self.sampled_at]
188 return fit*units if units else np.array(fit)
190 def cost_lb(self):
191 "Gets cost lower bounds from the BST and units them"
192 units = getattr(self.bst.sols[0]["cost"], "units", None)
193 fit = [self.bst.cost_at("cost", x, "lb") for x in self.sampled_at]
194 return fit*units if units else np.array(fit)
196 def cost_ub(self):
197 "Gets cost upper bounds from the BST and units them"
198 units = getattr(self.bst.sols[0]["cost"], "units", None)
199 fit = [self.bst.cost_at("cost", x, "ub") for x in self.sampled_at]
200 return fit*units if units else np.array(fit)
202 def plot(self, posys=None, axes=None):
203 "Plots the sweep for each posy"
204 import matplotlib.pyplot as plt
205 from ..interactive.plot_sweep import assign_axes
206 from .. import GPBLU
207 if not hasattr(posys, "__len__"):
208 posys = [posys]
209 for i, posy in enumerate(posys):
210 if posy in [None, "cost"]:
211 posys[i] = self.bst.costposy
212 posys, axes = assign_axes(self.bst.sweptvar, posys, axes)
213 for posy, ax in zip(posys, axes):
214 if self._is_cost(posy): # with small tol should look like a line
215 ax.fill_between(self.sampled_at,
216 self.cost_lb(), self.cost_ub(),
217 facecolor=GPBLU, edgecolor=GPBLU,
218 linewidth=0.75)
219 else:
220 ax.plot(self.sampled_at, self(posy), color=GPBLU)
221 if len(axes) == 1:
222 axes, = axes
223 return plt.gcf(), axes
226def autosweep_1d(model, logtol, sweepvar, bounds, **solvekwargs):
227 "Autosweep a model over one sweepvar"
228 original_val = model.substitutions.get(sweepvar, None)
229 start_time = time()
230 solvekwargs.setdefault("verbosity", 1)
231 solvekwargs["verbosity"] -= 1
232 sols = Count().next
233 firstsols = []
234 for bound in bounds:
235 model.substitutions.update({sweepvar: bound})
236 try:
237 model.solve(**solvekwargs)
238 firstsols.append(model.program.result)
239 except InvalidGPConstraint:
240 raise InvalidGPConstraint("only GPs can be autoswept.")
241 sols()
242 bst = BinarySweepTree(bounds, firstsols, sweepvar, model.cost)
243 tol = recurse_splits(model, bst, sweepvar, logtol, solvekwargs, sols)
244 bst.nsols = sols() # pylint: disable=attribute-defined-outside-init
245 if solvekwargs["verbosity"] > -1:
246 print("Solved in %2i passes, cost logtol +/-%.3g" % (bst.nsols, tol))
247 print("Autosweeping took %.3g seconds." % (time() - start_time))
248 if original_val:
249 model.substitutions[sweepvar] = original_val
250 else:
251 del model.substitutions[sweepvar]
252 return bst
255def recurse_splits(model, bst, variable, logtol, solvekwargs, sols):
256 "Recursively splits a BST until logtol is reached"
257 x, lb, ub = get_tol(bst.costs, bst.bounds, bst.sols, variable)
258 tol = (ub-lb)/2.0
259 if tol >= logtol:
260 model.substitutions.update({variable: x})
261 model.solve(**solvekwargs)
262 bst.add_split(x, model.program.result)
263 sols()
264 tols = [recurse_splits(model, split, variable, logtol, solvekwargs,
265 sols)
266 for split in bst.splits]
267 bst.tol = max(tols)
268 return bst.tol
269 bst.add_splitcost(x, lb, ub)
270 return tol
273def get_tol(costs, bounds, sols, variable): # pylint: disable=too-many-locals
274 "Gets the intersection point and corresponding bounds from two solutions."
275 y0, y1 = costs
276 x0, x1 = np.log(bounds)
277 s0, s1 = [sol["sensitivities"]["variables"][variable] for sol in sols]
278 # y0 + s0*(x - x0) == y1 + s1*(x - x1)
279 num = y1-y0 + x0*s0-x1*s1
280 denom = s0-s1
281 # NOTE: several branches below deal with straight lines, where lower
282 # and upper bounds are identical and so x is undefined
283 if denom == 0:
284 # mosek runs into this on perfect straight lines, num also equal to 0
285 # mosek_cli also runs into this on near-straight lines, num ~= 0
286 interp = -1 # flag interp as out-of bounds
287 else:
288 x = num/denom
289 lb = y0 + s0*(x-x0)
290 interp = (x1-x)/(x1-x0)
291 ub = y0*interp + y1*(1-interp)
292 if interp < 1e-7 or interp > 1 - 1e-7: # cvxopt on straight lines
293 x = (x0 + x1)/2 # x is undefined? stick it in the middle!
294 lb = ub = (y0 + y1)/2
295 return np.exp(x), lb, ub