Coverage for gpkit/tools/autosweep.py: 84%

191 statements

, created at 2024-01-07 22:56 -0500

1"Tools for optimal fits to GP sweeps"

2from time import time

3import pickle

4import numpy as np

5from ..small_classes import Count

6from ..small_scripts import mag

7from ..solution_array import SolutionArray

8from ..exceptions import InvalidGPConstraint

11class BinarySweepTree: # pylint: disable=too-many-instance-attributes

12 """Spans a line segment. May contain two subtrees that divide the segment.

14 Attributes

15 ----------

17 bounds : two-element list

18 The left and right boundaries of the segment

20 sols : two-element list

21 The left and right solutions of the segment

23 costs : array

24 The left and right logcosts of the segment

26 splits : None or two-element list

27 If not None, contains the left and right subtrees

29 splitval : None or float

30 The worst-error point, where the split will be if tolerance is too low

32 splitlb : None or float

33 The cost lower bound at splitval

35 splitub : None or float

36 The cost upper bound at splitval

37 """

39 def __init__(self, bounds, sols, sweptvar, costposy):

40 if len(bounds) != 2:

41 raise ValueError("bounds must be of length 2")

42 if bounds[1] <= bounds[0]:

43 raise ValueError("bounds[0] must be smaller than bounds[1].")

44 self.bounds = bounds

45 self.sols = sols

46 self.costs = np.log([mag(sol["cost"]) for sol in sols])

47 self.splits = None

48 self.splitval = None

49 self.splitlb = None

50 self.splitub = None

51 self.sweptvar = sweptvar

52 self.costposy = costposy

55 "Creates subtrees from bounds[0] to splitval and splitval to bounds[1]"

56 if self.splitval:

58 if splitval <= self.bounds[0] or splitval >= self.bounds[1]:

59 raise ValueError("split value is at or outside bounds.")

60 self.splitval = splitval

61 self.splits = [BinarySweepTree([self.bounds[0], splitval],

62 [self.sols[0], splitsol],

63 self.sweptvar, self.costposy),

64 BinarySweepTree([splitval, self.bounds[1]],

65 [splitsol, self.sols[1]],

66 self.sweptvar, self.costposy)]

68 def add_splitcost(self, splitval, splitlb, splitub):

69 "Adds a splitval, lower bound, and upper bound"

70 if self.splitval:

72 if splitval <= self.bounds[0] or splitval >= self.bounds[1]:

73 raise ValueError("split value is at or outside bounds.")

74 self.splitval = splitval

75 self.splitlb, self.splitub = splitlb, splitub

77 def posy_at(self, posy, value):

78 """Logspace interpolates between sols to get posynomial values.

80 No guarantees, just like a regular sweep.

81 """

82 if value < self.bounds[0] or value > self.bounds[1]:

83 raise ValueError("query value is outside bounds.")

84 bst = self.min_bst(value)

85 lo, hi = bst.bounds

86 loval, hival = [sol(posy) for sol in bst.sols]

87 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival])))

88 interp = (hi-np.log(value))/float(hi-lo)

89 return np.exp(interp*loval + (1-interp)*hival)

91 def cost_at(self, _, value, bound=None):

92 "Logspace interpolates between split and costs. Guaranteed bounded."

93 if value < self.bounds[0] or value > self.bounds[1]:

94 raise ValueError("query value is outside bounds.")

95 bst = self.min_bst(value)

96 if bst.splitlb:

97 if bound:

98 if bound == "lb":

99 splitcost = np.exp(bst.splitlb)

100 elif bound == "ub":

101 splitcost = np.exp(bst.splitub)

102 else:

103 splitcost = np.exp((bst.splitlb + bst.splitub)/2)

104 if value <= bst.splitval:

105 lo, hi = bst.bounds[0], bst.splitval

106 loval, hival = bst.sols[0]["cost"], splitcost

107 else:

108 lo, hi = bst.splitval, bst.bounds[1]

109 loval, hival = splitcost, bst.sols[1]["cost"]

110 else:

111 lo, hi = bst.bounds

112 loval, hival = [sol["cost"] for sol in bst.sols]

113 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival])))

114 interp = (hi-np.log(value))/float(hi-lo)

115 return np.exp(interp*loval + (1-interp)*hival)

117 def min_bst(self, value):

118 "Returns smallest bst around value."

119 if not self.splits:

120 return self

121 choice = self.splits[0] if value <= self.splitval else self.splits[1]

122 return choice.min_bst(value)

124 def sample_at(self, values):

125 "Creates a SolutionOracle at a given range of values"

126 return SolutionOracle(self, values)

128 @property

129 def sollist(self):

130 "Returns a list of all the solutions in an autosweep"

131 sollist = [self.sols[0]]

132 if self.splits:

133 sollist.extend(self.splits[0].sollist[1:])

134 sollist.extend(self.splits[1].sollist[1:-1])

135 sollist.append(self.sols[1])

136 return sollist

138 @property

139 def solarray(self):

140 "Returns a solution array of all the solutions in an autosweep"

141 solution = SolutionArray()

142 for sol in self.sollist:

143 solution.append(sol)

144 solution.to_arrays()

145 return solution

147 def save(self, filename="autosweep.p"):

148 """Pickles the autosweep and saves it to a file.

150 The saved autosweep is identical except for two things:

151 - the cost is made unitless

152 - each solution's 'program' attribute is removed

154 Solution can then be loaded with e.g.:

155 >>> import cPickle as pickle

157 """

158 with open(filename, "wb") as f:

159 pickle.dump(self, f)

162class SolutionOracle:

163 "Acts like a SolutionArray for autosweeps"

164 def __init__(self, bst, sampled_at):

165 self.sampled_at = sampled_at

166 self.bst = bst

168 def __call__(self, key):

169 return self.__getval(key)

171 def __getitem__(self, key):

172 return self.__getval(key)

174 def _is_cost(self, key):

175 if hasattr(key, "hmap") and key.hmap == self.bst.costposy.hmap:

176 return True

177 return key == "cost"

179 def __getval(self, key):

180 "Gets values from the BST and units them"

181 if self._is_cost(key):

182 key_at = self.bst.cost_at

183 v0 = self.bst.sols[0]["cost"]

184 else:

185 key_at = self.bst.posy_at

186 v0 = self.bst.sols[0](key)

187 units = getattr(v0, "units", None)

188 fit = [key_at(key, x) for x in self.sampled_at]

189 return fit*units if units else np.array(fit)

191 def cost_lb(self):

192 "Gets cost lower bounds from the BST and units them"

193 units = getattr(self.bst.sols[0]["cost"], "units", None)

194 fit = [self.bst.cost_at("cost", x, "lb") for x in self.sampled_at]

195 return fit*units if units else np.array(fit)

197 def cost_ub(self):

198 "Gets cost upper bounds from the BST and units them"

199 units = getattr(self.bst.sols[0]["cost"], "units", None)

200 fit = [self.bst.cost_at("cost", x, "ub") for x in self.sampled_at]

201 return fit*units if units else np.array(fit)

203 def plot(self, posys=None, axes=None):

204 "Plots the sweep for each posy"

205 #pylint: disable=import-outside-toplevel

206 import matplotlib.pyplot as plt

207 from ..interactive.plot_sweep import assign_axes

208 from .. import GPBLU

209 if not hasattr(posys, "__len__"):

210 posys = [posys]

211 for i, posy in enumerate(posys):

212 if posy in [None, "cost"]:

213 posys[i] = self.bst.costposy

214 posys, axes = assign_axes(self.bst.sweptvar, posys, axes)

215 for posy, ax in zip(posys, axes):

216 if self._is_cost(posy): # with small tol should look like a line

217 ax.fill_between(self.sampled_at,

218 self.cost_lb(), self.cost_ub(),

219 facecolor=GPBLU, edgecolor=GPBLU,

220 linewidth=0.75)

221 else:

222 ax.plot(self.sampled_at, self(posy), color=GPBLU)

223 if len(axes) == 1:

224 axes, = axes

225 return plt.gcf(), axes

228def autosweep_1d(model, logtol, sweepvar, bounds, **solvekwargs):

229 "Autosweep a model over one sweepvar"

230 original_val = model.substitutions.get(sweepvar, None)

231 start_time = time()

232 solvekwargs.setdefault("verbosity", 1)

233 solvekwargs["verbosity"] -= 1

234 sols = Count().next

235 firstsols = []

236 for bound in bounds:

237 model.substitutions.update({sweepvar: bound})

238 try:

239 model.solve(**solvekwargs)

240 firstsols.append(model.program.result)

241 except InvalidGPConstraint as exc:

242 raise InvalidGPConstraint("only GPs can be autoswept.") from exc

243 sols()

244 bst = BinarySweepTree(bounds, firstsols, sweepvar, model.cost)

245 tol = recurse_splits(model, bst, sweepvar, logtol, solvekwargs, sols)

246 bst.nsols = sols() # pylint: disable=attribute-defined-outside-init

247 if solvekwargs["verbosity"] > -1:

248 print(f"Solved in {bst.nsols:2} passes, cost logtol +/-{tol:.3g}")

249 print(f"Autosweeping took {time() - start_time:.3g} seconds.")

250 if original_val:

251 model.substitutions[sweepvar] = original_val

252 else:

253 del model.substitutions[sweepvar]

254 return bst

257def recurse_splits(model, bst, variable, logtol, solvekwargs, sols):

258 # pylint: disable=too-many-arguments

259 "Recursively splits a BST until logtol is reached"

260 x, lb, ub = get_tol(bst.costs, bst.bounds, bst.sols, variable)

261 tol = (ub-lb)/2.0

262 if tol >= logtol:

263 model.substitutions.update({variable: x})

264 model.solve(**solvekwargs)

266 sols()

267 tols = [recurse_splits(model, split, variable, logtol, solvekwargs,

268 sols)

269 for split in bst.splits]

270 bst.tol = max(tols)

271 return bst.tol

276def get_tol(costs, bounds, sols, variable): # pylint: disable=too-many-locals

277 "Gets the intersection point and corresponding bounds from two solutions."

278 y0, y1 = costs

279 x0, x1 = np.log(bounds)

280 s0, s1 = [sol["sensitivities"]["variables"][variable] for sol in sols]

281 # y0 + s0*(x - x0) == y1 + s1*(x - x1)

282 num = y1-y0 + x0*s0-x1*s1

283 denom = s0-s1

284 # NOTE: several branches below deal with straight lines, where lower

285 # and upper bounds are identical and so x is undefined

286 if denom == 0:

287 # mosek runs into this on perfect straight lines, num also equal to 0

288 # mosek_cli also runs into this on near-straight lines, num ~= 0

289 interp = -1 # flag interp as out-of bounds

290 else:

291 x = num/denom

292 lb = y0 + s0*(x-x0)

293 interp = (x1-x)/(x1-x0)

294 ub = y0*interp + y1*(1-interp)

295 if interp < 1e-7 or interp > 1 - 1e-7: # cvxopt on straight lines

296 x = (x0 + x1)/2 # x is undefined? stick it in the middle!

297 lb = ub = (y0 + y1)/2

298 return np.exp(x), lb, ub