Coverage for gpkit/tools/autosweep.py: 84%

190 statements

, created at 2022-07-28 12:35 -0400

1"Tools for optimal fits to GP sweeps"

2from time import time

3import numpy as np

4from ..small_classes import Count

5from ..small_scripts import mag

6from ..solution_array import SolutionArray

7from ..exceptions import InvalidGPConstraint

10class BinarySweepTree: # pylint: disable=too-many-instance-attributes

11 """Spans a line segment. May contain two subtrees that divide the segment.

13 Attributes

14 ----------

16 bounds : two-element list

17 The left and right boundaries of the segment

19 sols : two-element list

20 The left and right solutions of the segment

22 costs : array

23 The left and right logcosts of the segment

25 splits : None or two-element list

26 If not None, contains the left and right subtrees

28 splitval : None or float

29 The worst-error point, where the split will be if tolerance is too low

31 splitlb : None or float

32 The cost lower bound at splitval

34 splitub : None or float

35 The cost upper bound at splitval

36 """

38 def __init__(self, bounds, sols, sweptvar, costposy):

39 if len(bounds) != 2:

40 raise ValueError("bounds must be of length 2")

41 if bounds[1] <= bounds[0]:

42 raise ValueError("bounds[0] must be smaller than bounds[1].")

43 self.bounds = bounds

44 self.sols = sols

45 self.costs = np.log([mag(sol["cost"]) for sol in sols])

46 self.splits = None

47 self.splitval = None

48 self.splitlb = None

49 self.splitub = None

50 self.sweptvar = sweptvar

51 self.costposy = costposy

54 "Creates subtrees from bounds[0] to splitval and splitval to bounds[1]"

55 if self.splitval:

57 if splitval <= self.bounds[0] or splitval >= self.bounds[1]:

58 raise ValueError("split value is at or outside bounds.")

59 self.splitval = splitval

60 self.splits = [BinarySweepTree([self.bounds[0], splitval],

61 [self.sols[0], splitsol],

62 self.sweptvar, self.costposy),

63 BinarySweepTree([splitval, self.bounds[1]],

64 [splitsol, self.sols[1]],

65 self.sweptvar, self.costposy)]

67 def add_splitcost(self, splitval, splitlb, splitub):

68 "Adds a splitval, lower bound, and upper bound"

69 if self.splitval:

71 if splitval <= self.bounds[0] or splitval >= self.bounds[1]:

72 raise ValueError("split value is at or outside bounds.")

73 self.splitval = splitval

74 self.splitlb, self.splitub = splitlb, splitub

76 def posy_at(self, posy, value):

77 """Logspace interpolates between sols to get posynomial values.

79 No guarantees, just like a regular sweep.

80 """

81 if value < self.bounds[0] or value > self.bounds[1]:

82 raise ValueError("query value is outside bounds.")

83 bst = self.min_bst(value)

84 lo, hi = bst.bounds

85 loval, hival = [sol(posy) for sol in bst.sols]

86 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival])))

87 interp = (hi-np.log(value))/float(hi-lo)

88 return np.exp(interp*loval + (1-interp)*hival)

90 def cost_at(self, _, value, bound=None):

91 "Logspace interpolates between split and costs. Guaranteed bounded."

92 if value < self.bounds[0] or value > self.bounds[1]:

93 raise ValueError("query value is outside bounds.")

94 bst = self.min_bst(value)

95 if bst.splitlb:

96 if bound:

97 if bound == "lb":

98 splitcost = np.exp(bst.splitlb)

99 elif bound == "ub":

100 splitcost = np.exp(bst.splitub)

101 else:

102 splitcost = np.exp((bst.splitlb + bst.splitub)/2)

103 if value <= bst.splitval:

104 lo, hi = bst.bounds[0], bst.splitval

105 loval, hival = bst.sols[0]["cost"], splitcost

106 else:

107 lo, hi = bst.splitval, bst.bounds[1]

108 loval, hival = splitcost, bst.sols[1]["cost"]

109 else:

110 lo, hi = bst.bounds

111 loval, hival = [sol["cost"] for sol in bst.sols]

112 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival])))

113 interp = (hi-np.log(value))/float(hi-lo)

114 return np.exp(interp*loval + (1-interp)*hival)

116 def min_bst(self, value):

117 "Returns smallest bst around value."

118 if not self.splits:

119 return self

120 choice = self.splits[0] if value <= self.splitval else self.splits[1]

121 return choice.min_bst(value)

123 def sample_at(self, values):

124 "Creates a SolutionOracle at a given range of values"

125 return SolutionOracle(self, values)

127 @property

128 def sollist(self):

129 "Returns a list of all the solutions in an autosweep"

130 sollist = [self.sols[0]]

131 if self.splits:

132 sollist.extend(self.splits[0].sollist[1:])

133 sollist.extend(self.splits[1].sollist[1:-1])

134 sollist.append(self.sols[1])

135 return sollist

137 @property

138 def solarray(self):

139 "Returns a solution array of all the solutions in an autosweep"

140 solution = SolutionArray()

141 for sol in self.sollist:

142 solution.append(sol)

143 solution.to_arrays()

144 return solution

146 def save(self, filename="autosweep.p"):

147 """Pickles the autosweep and saves it to a file.

149 The saved autosweep is identical except for two things:

150 - the cost is made unitless

151 - each solution's 'program' attribute is removed

153 Solution can then be loaded with e.g.:

154 >>> import cPickle as pickle

156 """

157 import pickle

158 pickle.dump(self, open(filename, "wb"))

161class SolutionOracle:

162 "Acts like a SolutionArray for autosweeps"

163 def __init__(self, bst, sampled_at):

164 self.sampled_at = sampled_at

165 self.bst = bst

167 def __call__(self, key):

168 return self.__getval(key)

170 def __getitem__(self, key):

171 return self.__getval(key)

173 def _is_cost(self, key):

174 if hasattr(key, "hmap") and key.hmap == self.bst.costposy.hmap:

175 return True

176 return key == "cost"

178 def __getval(self, key):

179 "Gets values from the BST and units them"

180 if self._is_cost(key):

181 key_at = self.bst.cost_at

182 v0 = self.bst.sols[0]["cost"]

183 else:

184 key_at = self.bst.posy_at

185 v0 = self.bst.sols[0](key)

186 units = getattr(v0, "units", None)

187 fit = [key_at(key, x) for x in self.sampled_at]

188 return fit*units if units else np.array(fit)

190 def cost_lb(self):

191 "Gets cost lower bounds from the BST and units them"

192 units = getattr(self.bst.sols[0]["cost"], "units", None)

193 fit = [self.bst.cost_at("cost", x, "lb") for x in self.sampled_at]

194 return fit*units if units else np.array(fit)

196 def cost_ub(self):

197 "Gets cost upper bounds from the BST and units them"

198 units = getattr(self.bst.sols[0]["cost"], "units", None)

199 fit = [self.bst.cost_at("cost", x, "ub") for x in self.sampled_at]

200 return fit*units if units else np.array(fit)

202 def plot(self, posys=None, axes=None):

203 "Plots the sweep for each posy"

204 import matplotlib.pyplot as plt

205 from ..interactive.plot_sweep import assign_axes

206 from .. import GPBLU

207 if not hasattr(posys, "__len__"):

208 posys = [posys]

209 for i, posy in enumerate(posys):

210 if posy in [None, "cost"]:

211 posys[i] = self.bst.costposy

212 posys, axes = assign_axes(self.bst.sweptvar, posys, axes)

213 for posy, ax in zip(posys, axes):

214 if self._is_cost(posy): # with small tol should look like a line

215 ax.fill_between(self.sampled_at,

216 self.cost_lb(), self.cost_ub(),

217 facecolor=GPBLU, edgecolor=GPBLU,

218 linewidth=0.75)

219 else:

220 ax.plot(self.sampled_at, self(posy), color=GPBLU)

221 if len(axes) == 1:

222 axes, = axes

223 return plt.gcf(), axes

226def autosweep_1d(model, logtol, sweepvar, bounds, **solvekwargs):

227 "Autosweep a model over one sweepvar"

228 original_val = model.substitutions.get(sweepvar, None)

229 start_time = time()

230 solvekwargs.setdefault("verbosity", 1)

231 solvekwargs["verbosity"] -= 1

232 sols = Count().next

233 firstsols = []

234 for bound in bounds:

235 model.substitutions.update({sweepvar: bound})

236 try:

237 model.solve(**solvekwargs)

238 firstsols.append(model.program.result)

239 except InvalidGPConstraint:

240 raise InvalidGPConstraint("only GPs can be autoswept.")

241 sols()

242 bst = BinarySweepTree(bounds, firstsols, sweepvar, model.cost)

243 tol = recurse_splits(model, bst, sweepvar, logtol, solvekwargs, sols)

244 bst.nsols = sols() # pylint: disable=attribute-defined-outside-init

245 if solvekwargs["verbosity"] > -1:

246 print("Solved in %2i passes, cost logtol +/-%.3g" % (bst.nsols, tol))

247 print("Autosweeping took %.3g seconds." % (time() - start_time))

248 if original_val:

249 model.substitutions[sweepvar] = original_val

250 else:

251 del model.substitutions[sweepvar]

252 return bst

255def recurse_splits(model, bst, variable, logtol, solvekwargs, sols):

256 "Recursively splits a BST until logtol is reached"

257 x, lb, ub = get_tol(bst.costs, bst.bounds, bst.sols, variable)

258 tol = (ub-lb)/2.0

259 if tol >= logtol:

260 model.substitutions.update({variable: x})

261 model.solve(**solvekwargs)

263 sols()

264 tols = [recurse_splits(model, split, variable, logtol, solvekwargs,

265 sols)

266 for split in bst.splits]

267 bst.tol = max(tols)

268 return bst.tol

273def get_tol(costs, bounds, sols, variable): # pylint: disable=too-many-locals

274 "Gets the intersection point and corresponding bounds from two solutions."

275 y0, y1 = costs

276 x0, x1 = np.log(bounds)

277 s0, s1 = [sol["sensitivities"]["variables"][variable] for sol in sols]

278 # y0 + s0*(x - x0) == y1 + s1*(x - x1)

279 num = y1-y0 + x0*s0-x1*s1

280 denom = s0-s1

281 # NOTE: several branches below deal with straight lines, where lower

282 # and upper bounds are identical and so x is undefined

283 if denom == 0:

284 # mosek runs into this on perfect straight lines, num also equal to 0

285 # mosek_cli also runs into this on near-straight lines, num ~= 0

286 interp = -1 # flag interp as out-of bounds

287 else:

288 x = num/denom

289 lb = y0 + s0*(x-x0)

290 interp = (x1-x)/(x1-x0)

291 ub = y0*interp + y1*(1-interp)

292 if interp < 1e-7 or interp > 1 - 1e-7: # cvxopt on straight lines

293 x = (x0 + x1)/2 # x is undefined? stick it in the middle!

294 lb = ub = (y0 + y1)/2

295 return np.exp(x), lb, ub