Coverage for gpkit/tools/autosweep.py: 85%

190 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 16:49 -0500

1"Tools for optimal fits to GP sweeps" 

2from time import time 

3import numpy as np 

4from ..small_classes import Count 

5from ..small_scripts import mag 

6from ..solution_array import SolutionArray 

7from ..exceptions import InvalidGPConstraint 

8 

9 

10class BinarySweepTree: # pylint: disable=too-many-instance-attributes 

11 """Spans a line segment. May contain two subtrees that divide the segment. 

12 

13 Attributes 

14 ---------- 

15 

16 bounds : two-element list 

17 The left and right boundaries of the segment 

18 

19 sols : two-element list 

20 The left and right solutions of the segment 

21 

22 costs : array 

23 The left and right logcosts of the segment 

24 

25 splits : None or two-element list 

26 If not None, contains the left and right subtrees 

27 

28 splitval : None or float 

29 The worst-error point, where the split will be if tolerance is too low 

30 

31 splitlb : None or float 

32 The cost lower bound at splitval 

33 

34 splitub : None or float 

35 The cost upper bound at splitval 

36 """ 

37 

38 def __init__(self, bounds, sols, sweptvar, costposy): 

39 if len(bounds) != 2: 

40 raise ValueError("bounds must be of length 2") 

41 if bounds[1] <= bounds[0]: 

42 raise ValueError("bounds[0] must be smaller than bounds[1].") 

43 self.bounds = bounds 

44 self.sols = sols 

45 self.costs = np.log([mag(sol["cost"]) for sol in sols]) 

46 self.splits = None 

47 self.splitval = None 

48 self.splitlb = None 

49 self.splitub = None 

50 self.sweptvar = sweptvar 

51 self.costposy = costposy 

52 

53 def add_split(self, splitval, splitsol): 

54 "Creates subtrees from bounds[0] to splitval and splitval to bounds[1]" 

55 if self.splitval: 

56 raise ValueError("split already exists!") 

57 if splitval <= self.bounds[0] or splitval >= self.bounds[1]: 

58 raise ValueError("split value is at or outside bounds.") 

59 self.splitval = splitval 

60 self.splits = [BinarySweepTree([self.bounds[0], splitval], 

61 [self.sols[0], splitsol], 

62 self.sweptvar, self.costposy), 

63 BinarySweepTree([splitval, self.bounds[1]], 

64 [splitsol, self.sols[1]], 

65 self.sweptvar, self.costposy)] 

66 

67 def add_splitcost(self, splitval, splitlb, splitub): 

68 "Adds a splitval, lower bound, and upper bound" 

69 if self.splitval: 

70 raise ValueError("split already exists!") 

71 if splitval <= self.bounds[0] or splitval >= self.bounds[1]: 

72 raise ValueError("split value is at or outside bounds.") 

73 self.splitval = splitval 

74 self.splitlb, self.splitub = splitlb, splitub 

75 

76 def posy_at(self, posy, value): 

77 """Logspace interpolates between sols to get posynomial values. 

78 

79 No guarantees, just like a regular sweep. 

80 """ 

81 if value < self.bounds[0] or value > self.bounds[1]: 

82 raise ValueError("query value is outside bounds.") 

83 bst = self.min_bst(value) 

84 lo, hi = bst.bounds 

85 loval, hival = [sol(posy) for sol in bst.sols] 

86 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival]))) 

87 interp = (hi-np.log(value))/float(hi-lo) 

88 return np.exp(interp*loval + (1-interp)*hival) 

89 

90 def cost_at(self, _, value, bound=None): 

91 "Logspace interpolates between split and costs. Guaranteed bounded." 

92 if value < self.bounds[0] or value > self.bounds[1]: 

93 raise ValueError("query value is outside bounds.") 

94 bst = self.min_bst(value) 

95 if bst.splitlb: 

96 if bound: 

97 if bound == "lb": 

98 splitcost = np.exp(bst.splitlb) 

99 elif bound == "ub": 

100 splitcost = np.exp(bst.splitub) 

101 else: 

102 splitcost = np.exp((bst.splitlb + bst.splitub)/2) 

103 if value <= bst.splitval: 

104 lo, hi = bst.bounds[0], bst.splitval 

105 loval, hival = bst.sols[0]["cost"], splitcost 

106 else: 

107 lo, hi = bst.splitval, bst.bounds[1] 

108 loval, hival = splitcost, bst.sols[1]["cost"] 

109 else: 

110 lo, hi = bst.bounds 

111 loval, hival = [sol["cost"] for sol in bst.sols] 

112 lo, hi, loval, hival = np.log(list(map(mag, [lo, hi, loval, hival]))) 

113 interp = (hi-np.log(value))/float(hi-lo) 

114 return np.exp(interp*loval + (1-interp)*hival) 

115 

116 def min_bst(self, value): 

117 "Returns smallest bst around value." 

118 if not self.splits: 

119 return self 

120 choice = self.splits[0] if value <= self.splitval else self.splits[1] 

121 return choice.min_bst(value) 

122 

123 def sample_at(self, values): 

124 "Creates a SolutionOracle at a given range of values" 

125 return SolutionOracle(self, values) 

126 

127 @property 

128 def sollist(self): 

129 "Returns a list of all the solutions in an autosweep" 

130 sollist = [self.sols[0]] 

131 if self.splits: 

132 sollist.extend(self.splits[0].sollist[1:]) 

133 sollist.extend(self.splits[1].sollist[1:-1]) 

134 sollist.append(self.sols[1]) 

135 return sollist 

136 

137 @property 

138 def solarray(self): 

139 "Returns a solution array of all the solutions in an autosweep" 

140 solution = SolutionArray() 

141 for sol in self.sollist: 

142 solution.append(sol) 

143 solution.to_arrays() 

144 return solution 

145 

146 def save(self, filename="autosweep.p"): 

147 """Pickles the autosweep and saves it to a file. 

148 

149 The saved autosweep is identical except for two things: 

150 - the cost is made unitless 

151 - each solution's 'program' attribute is removed 

152 

153 Solution can then be loaded with e.g.: 

154 >>> import cPickle as pickle 

155 >>> pickle.load(open("autosweep.p")) 

156 """ 

157 import pickle 

158 pickle.dump(self, open(filename, "wb")) 

159 

160 

161class SolutionOracle: 

162 "Acts like a SolutionArray for autosweeps" 

163 def __init__(self, bst, sampled_at): 

164 self.sampled_at = sampled_at 

165 self.bst = bst 

166 

167 def __call__(self, key): 

168 return self.__getval(key) 

169 

170 def __getitem__(self, key): 

171 return self.__getval(key) 

172 

173 def _is_cost(self, key): 

174 if hasattr(key, "hmap") and key.hmap == self.bst.costposy.hmap: 

175 return True 

176 return key == "cost" 

177 

178 def __getval(self, key): 

179 "Gets values from the BST and units them" 

180 if self._is_cost(key): 

181 key_at = self.bst.cost_at 

182 v0 = self.bst.sols[0]["cost"] 

183 else: 

184 key_at = self.bst.posy_at 

185 v0 = self.bst.sols[0](key) 

186 units = getattr(v0, "units", None) 

187 fit = [key_at(key, x) for x in self.sampled_at] 

188 return fit*units if units else np.array(fit) 

189 

190 def cost_lb(self): 

191 "Gets cost lower bounds from the BST and units them" 

192 units = getattr(self.bst.sols[0]["cost"], "units", None) 

193 fit = [self.bst.cost_at("cost", x, "lb") for x in self.sampled_at] 

194 return fit*units if units else np.array(fit) 

195 

196 def cost_ub(self): 

197 "Gets cost upper bounds from the BST and units them" 

198 units = getattr(self.bst.sols[0]["cost"], "units", None) 

199 fit = [self.bst.cost_at("cost", x, "ub") for x in self.sampled_at] 

200 return fit*units if units else np.array(fit) 

201 

202 def plot(self, posys=None, axes=None): 

203 "Plots the sweep for each posy" 

204 import matplotlib.pyplot as plt 

205 from ..interactive.plot_sweep import assign_axes 

206 from .. import GPBLU 

207 if not hasattr(posys, "__len__"): 

208 posys = [posys] 

209 for i, posy in enumerate(posys): 

210 if posy in [None, "cost"]: 

211 posys[i] = self.bst.costposy 

212 posys, axes = assign_axes(self.bst.sweptvar, posys, axes) 

213 for posy, ax in zip(posys, axes): 

214 if self._is_cost(posy): # with small tol should look like a line 

215 ax.fill_between(self.sampled_at, 

216 self.cost_lb(), self.cost_ub(), 

217 facecolor=GPBLU, edgecolor=GPBLU, 

218 linewidth=0.75) 

219 else: 

220 ax.plot(self.sampled_at, self(posy), color=GPBLU) 

221 if len(axes) == 1: 

222 axes, = axes 

223 return plt.gcf(), axes 

224 

225 

226def autosweep_1d(model, logtol, sweepvar, bounds, **solvekwargs): 

227 "Autosweep a model over one sweepvar" 

228 original_val = model.substitutions.get(sweepvar, None) 

229 start_time = time() 

230 solvekwargs.setdefault("verbosity", 1) 

231 solvekwargs["verbosity"] -= 1 

232 sols = Count().next 

233 firstsols = [] 

234 for bound in bounds: 

235 model.substitutions.update({sweepvar: bound}) 

236 try: 

237 model.solve(**solvekwargs) 

238 firstsols.append(model.program.result) 

239 except InvalidGPConstraint: 

240 raise InvalidGPConstraint("only GPs can be autoswept.") 

241 sols() 

242 bst = BinarySweepTree(bounds, firstsols, sweepvar, model.cost) 

243 tol = recurse_splits(model, bst, sweepvar, logtol, solvekwargs, sols) 

244 bst.nsols = sols() # pylint: disable=attribute-defined-outside-init 

245 if solvekwargs["verbosity"] > -1: 

246 print("Solved in %2i passes, cost logtol +/-%.3g" % (bst.nsols, tol)) 

247 print("Autosweeping took %.3g seconds." % (time() - start_time)) 

248 if original_val: 

249 model.substitutions[sweepvar] = original_val 

250 else: 

251 del model.substitutions[sweepvar] 

252 return bst 

253 

254 

255def recurse_splits(model, bst, variable, logtol, solvekwargs, sols): 

256 "Recursively splits a BST until logtol is reached" 

257 x, lb, ub = get_tol(bst.costs, bst.bounds, bst.sols, variable) 

258 tol = (ub-lb)/2.0 

259 if tol >= logtol: 

260 model.substitutions.update({variable: x}) 

261 model.solve(**solvekwargs) 

262 bst.add_split(x, model.program.result) 

263 sols() 

264 tols = [recurse_splits(model, split, variable, logtol, solvekwargs, 

265 sols) 

266 for split in bst.splits] 

267 bst.tol = max(tols) 

268 return bst.tol 

269 bst.add_splitcost(x, lb, ub) 

270 return tol 

271 

272 

273def get_tol(costs, bounds, sols, variable): # pylint: disable=too-many-locals 

274 "Gets the intersection point and corresponding bounds from two solutions." 

275 y0, y1 = costs 

276 x0, x1 = np.log(bounds) 

277 s0, s1 = [sol["sensitivities"]["variables"][variable] for sol in sols] 

278 # y0 + s0*(x - x0) == y1 + s1*(x - x1) 

279 num = y1-y0 + x0*s0-x1*s1 

280 denom = s0-s1 

281 # NOTE: several branches below deal with straight lines, where lower 

282 # and upper bounds are identical and so x is undefined 

283 if denom == 0: 

284 # mosek runs into this on perfect straight lines, num also equal to 0 

285 # mosek_cli also runs into this on near-straight lines, num ~= 0 

286 interp = -1 # flag interp as out-of bounds 

287 else: 

288 x = num/denom 

289 lb = y0 + s0*(x-x0) 

290 interp = (x1-x)/(x1-x0) 

291 ub = y0*interp + y1*(1-interp) 

292 if interp < 1e-7 or interp > 1 - 1e-7: # cvxopt on straight lines 

293 x = (x0 + x1)/2 # x is undefined? stick it in the middle! 

294 lb = ub = (y0 + y1)/2 

295 return np.exp(x), lb, ub