Coverage for gpkit/tools/autosweep.py: 85%

191 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 22:15 -0500

1"Tools for optimal fits to GP sweeps" 

2from time import time 

3import pickle 

4import numpy as np 

5from ..small_classes import Count 

6from ..small_scripts import mag 

7from ..solution_array import SolutionArray 

8from ..exceptions import InvalidGPConstraint 

9 

10 

11class BinarySweepTree: # pylint: disable=too-many-instance-attributes 

12 """Spans a line segment. May contain two subtrees that divide the segment. 

13 

14 Attributes 

15 ---------- 

16 

17 bounds : two-element list 

18 The left and right boundaries of the segment 

19 

20 sols : two-element list 

21 The left and right solutions of the segment 

22 

23 costs : array 

24 The left and right logcosts of the segment 

25 

26 splits : None or two-element list 

27 If not None, contains the left and right subtrees 

28 

29 splitval : None or float 

30 The worst-error point, where the split will be if tolerance is too low 

31 

32 splitlb : None or float 

33 The cost lower bound at splitval 

34 

35 splitub : None or float 

36 The cost upper bound at splitval 

37 """ 

38 

39 def __init__(self, bounds, sols, sweptvar, costposy): 

40 if len(bounds) != 2: 

41 raise ValueError("bounds must be of length 2") 

42 if bounds[1] <= bounds[0]: 

43 raise ValueError("bounds[0] must be smaller than bounds[1].") 

44 self.bounds = bounds 

45 self.sols = sols 

46 self.costs = np.log([mag(sol["cost"]) for sol in sols]) 

47 self.splits = None 

48 self.splitval = None 

49 self.splitlb = None 

50 self.splitub = None 

51 self.sweptvar = sweptvar 

52 self.costposy = costposy 

53 

54 def add_split(self, splitval, splitsol): 

55 "Creates subtrees from bounds[0] to splitval and splitval to bounds[1]" 

56 if self.splitval: 

57 raise ValueError("split already exists!") 

58 if splitval <= self.bounds[0] or splitval >= self.bounds[1]: 

59 raise ValueError("split value is at or outside bounds.") 

60 self.splitval = splitval 

61 self.splits = [BinarySweepTree([self.bounds[0], splitval], 

62 [self.sols[0], splitsol], 

63 self.sweptvar, self.costposy), 

64 BinarySweepTree([splitval, self.bounds[1]], 

65 [splitsol, self.sols[1]], 

66 self.sweptvar, self.costposy)] 

67 

68 def add_splitcost(self, splitval, splitlb, splitub): 

69 "Adds a splitval, lower bound, and upper bound" 

70 if self.splitval: 

71 raise ValueError("split already exists!") 

72 if splitval <= self.bounds[0] or splitval >= self.bounds[1]: 

73 raise ValueError("split value is at or outside bounds.") 

74 self.splitval = splitval 

75 self.splitlb, self.splitub = splitlb, splitub 

76 

77 def posy_at(self, posy, value): 

78 """Logspace interpolates between sols to get posynomial values. 

79 

80 No guarantees, just like a regular sweep. 

81 """ 

82 if value < self.bounds[0] or value > self.bounds[1]: 

83 raise ValueError("query value is outside bounds.") 

84 bst = self.min_bst(value) 

85 lo, hi = bst.bounds 

86 loval, hival = [sol(posy) for sol in bst.sols] 

87 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival]))) 

88 interp = (hi-np.log(value))/float(hi-lo) 

89 return np.exp(interp*loval + (1-interp)*hival) 

90 

91 def cost_at(self, _, value, bound=None): 

92 "Logspace interpolates between split and costs. Guaranteed bounded." 

93 if value < self.bounds[0] or value > self.bounds[1]: 

94 raise ValueError("query value is outside bounds.") 

95 bst = self.min_bst(value) 

96 if bst.splitlb: 

97 if bound: 

98 if bound == "lb": 

99 splitcost = np.exp(bst.splitlb) 

100 elif bound == "ub": 

101 splitcost = np.exp(bst.splitub) 

102 else: 

103 splitcost = np.exp((bst.splitlb + bst.splitub)/2) 

104 if value <= bst.splitval: 

105 lo, hi = bst.bounds[0], bst.splitval 

106 loval, hival = bst.sols[0]["cost"], splitcost 

107 else: 

108 lo, hi = bst.splitval, bst.bounds[1] 

109 loval, hival = splitcost, bst.sols[1]["cost"] 

110 else: 

111 lo, hi = bst.bounds 

112 loval, hival = [sol["cost"] for sol in bst.sols] 

113 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival]))) 

114 interp = (hi-np.log(value))/float(hi-lo) 

115 return np.exp(interp*loval + (1-interp)*hival) 

116 

117 def min_bst(self, value): 

118 "Returns smallest bst around value." 

119 if not self.splits: 

120 return self 

121 choice = self.splits[0] if value <= self.splitval else self.splits[1] 

122 return choice.min_bst(value) 

123 

124 def sample_at(self, values): 

125 "Creates a SolutionOracle at a given range of values" 

126 return SolutionOracle(self, values) 

127 

128 @property 

129 def sollist(self): 

130 "Returns a list of all the solutions in an autosweep" 

131 sollist = [self.sols[0]] 

132 if self.splits: 

133 sollist.extend(self.splits[0].sollist[1:]) 

134 sollist.extend(self.splits[1].sollist[1:-1]) 

135 sollist.append(self.sols[1]) 

136 return sollist 

137 

138 @property 

139 def solarray(self): 

140 "Returns a solution array of all the solutions in an autosweep" 

141 solution = SolutionArray() 

142 for sol in self.sollist: 

143 solution.append(sol) 

144 solution.to_arrays() 

145 return solution 

146 

147 def save(self, filename="autosweep.p"): 

148 """Pickles the autosweep and saves it to a file. 

149 

150 The saved autosweep is identical except for two things: 

151 - the cost is made unitless 

152 - each solution's 'program' attribute is removed 

153 

154 Solution can then be loaded with e.g.: 

155 >>> import cPickle as pickle 

156 >>> pickle.load(open("autosweep.p")) 

157 """ 

158 with open(filename, "wb") as f: 

159 pickle.dump(self, f) 

160 

161 

162class SolutionOracle: 

163 "Acts like a SolutionArray for autosweeps" 

164 def __init__(self, bst, sampled_at): 

165 self.sampled_at = sampled_at 

166 self.bst = bst 

167 

168 def __call__(self, key): 

169 return self.__getval(key) 

170 

171 def __getitem__(self, key): 

172 return self.__getval(key) 

173 

174 def _is_cost(self, key): 

175 if hasattr(key, "hmap") and key.hmap == self.bst.costposy.hmap: 

176 return True 

177 return key == "cost" 

178 

179 def __getval(self, key): 

180 "Gets values from the BST and units them" 

181 if self._is_cost(key): 

182 key_at = self.bst.cost_at 

183 v0 = self.bst.sols[0]["cost"] 

184 else: 

185 key_at = self.bst.posy_at 

186 v0 = self.bst.sols[0](key) 

187 units = getattr(v0, "units", None) 

188 fit = [key_at(key, x) for x in self.sampled_at] 

189 return fit*units if units else np.array(fit) 

190 

191 def cost_lb(self): 

192 "Gets cost lower bounds from the BST and units them" 

193 units = getattr(self.bst.sols[0]["cost"], "units", None) 

194 fit = [self.bst.cost_at("cost", x, "lb") for x in self.sampled_at] 

195 return fit*units if units else np.array(fit) 

196 

197 def cost_ub(self): 

198 "Gets cost upper bounds from the BST and units them" 

199 units = getattr(self.bst.sols[0]["cost"], "units", None) 

200 fit = [self.bst.cost_at("cost", x, "ub") for x in self.sampled_at] 

201 return fit*units if units else np.array(fit) 

202 

203 def plot(self, posys=None, axes=None): 

204 "Plots the sweep for each posy" 

205 #pylint: disable=import-outside-toplevel 

206 import matplotlib.pyplot as plt 

207 from ..interactive.plot_sweep import assign_axes 

208 from .. import GPBLU 

209 if not hasattr(posys, "__len__"): 

210 posys = [posys] 

211 for i, posy in enumerate(posys): 

212 if posy in [None, "cost"]: 

213 posys[i] = self.bst.costposy 

214 posys, axes = assign_axes(self.bst.sweptvar, posys, axes) 

215 for posy, ax in zip(posys, axes): 

216 if self._is_cost(posy): # with small tol should look like a line 

217 ax.fill_between(self.sampled_at, 

218 self.cost_lb(), self.cost_ub(), 

219 facecolor=GPBLU, edgecolor=GPBLU, 

220 linewidth=0.75) 

221 else: 

222 ax.plot(self.sampled_at, self(posy), color=GPBLU) 

223 if len(axes) == 1: 

224 axes, = axes 

225 return plt.gcf(), axes 

226 

227 

228def autosweep_1d(model, logtol, sweepvar, bounds, **solvekwargs): 

229 "Autosweep a model over one sweepvar" 

230 original_val = model.substitutions.get(sweepvar, None) 

231 start_time = time() 

232 solvekwargs.setdefault("verbosity", 1) 

233 solvekwargs["verbosity"] -= 1 

234 sols = Count().next 

235 firstsols = [] 

236 for bound in bounds: 

237 model.substitutions.update({sweepvar: bound}) 

238 try: 

239 model.solve(**solvekwargs) 

240 firstsols.append(model.program.result) 

241 except InvalidGPConstraint as exc: 

242 raise InvalidGPConstraint("only GPs can be autoswept.") from exc 

243 sols() 

244 bst = BinarySweepTree(bounds, firstsols, sweepvar, model.cost) 

245 tol = recurse_splits(model, bst, sweepvar, logtol, solvekwargs, sols) 

246 bst.nsols = sols() # pylint: disable=attribute-defined-outside-init 

247 if solvekwargs["verbosity"] > -1: 

248 print(f"Solved in {bst.nsols:2} passes, cost logtol +/-{tol:.3g}") 

249 print(f"Autosweeping took {time() - start_time:.3g} seconds.") 

250 if original_val: 

251 model.substitutions[sweepvar] = original_val 

252 else: 

253 del model.substitutions[sweepvar] 

254 return bst 

255 

256 

257def recurse_splits(model, bst, variable, logtol, solvekwargs, sols): 

258 # pylint: disable=too-many-arguments 

259 "Recursively splits a BST until logtol is reached" 

260 x, lb, ub = get_tol(bst.costs, bst.bounds, bst.sols, variable) 

261 tol = (ub-lb)/2.0 

262 if tol >= logtol: 

263 model.substitutions.update({variable: x}) 

264 model.solve(**solvekwargs) 

265 bst.add_split(x, model.program.result) 

266 sols() 

267 tols = [recurse_splits(model, split, variable, logtol, solvekwargs, 

268 sols) 

269 for split in bst.splits] 

270 bst.tol = max(tols) 

271 return bst.tol 

272 bst.add_splitcost(x, lb, ub) 

273 return tol 

274 

275 

276def get_tol(costs, bounds, sols, variable): # pylint: disable=too-many-locals 

277 "Gets the intersection point and corresponding bounds from two solutions." 

278 y0, y1 = costs 

279 x0, x1 = np.log(bounds) 

280 s0, s1 = [sol["sensitivities"]["variables"][variable] for sol in sols] 

281 # y0 + s0*(x - x0) == y1 + s1*(x - x1) 

282 num = y1-y0 + x0*s0-x1*s1 

283 denom = s0-s1 

284 # NOTE: several branches below deal with straight lines, where lower 

285 # and upper bounds are identical and so x is undefined 

286 if denom == 0: 

287 # mosek runs into this on perfect straight lines, num also equal to 0 

288 # mosek_cli also runs into this on near-straight lines, num ~= 0 

289 interp = -1 # flag interp as out-of bounds 

290 else: 

291 x = num/denom 

292 lb = y0 + s0*(x-x0) 

293 interp = (x1-x)/(x1-x0) 

294 ub = y0*interp + y1*(1-interp) 

295 if interp < 1e-7 or interp > 1 - 1e-7: # cvxopt on straight lines 

296 x = (x0 + x1)/2 # x is undefined? stick it in the middle! 

297 lb = ub = (y0 + y1)/2 

298 return np.exp(x), lb, ub