Coverage for gpkit/breakdowns.py: 81%

754 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 22:56 -0500

1#TODO: cleanup weird conditionals 

2# add conversions to plotly/sankey 

3 

4# pylint: skip-file 

5import string 

6from collections import defaultdict, namedtuple, Counter 

7from gpkit.nomials import Monomial, Posynomial, Variable 

8from gpkit.nomials.map import NomialMap 

9from gpkit.small_scripts import mag, try_str_without 

10from gpkit.small_classes import FixedScalar, HashVector 

11from gpkit.exceptions import DimensionalityError 

12from gpkit.repr_conventions import unitstr as get_unitstr 

13from gpkit.repr_conventions import lineagestr 

14from gpkit.varkey import VarKey 

15import numpy as np 

16 

17Tree = namedtuple("Tree", ["key", "value", "branches"]) 

18Transform = namedtuple("Transform", ["factor", "power", "origkey"]) 

19def is_factor(key): 

20 return (isinstance(key, Transform) and key.power == 1) 

21def is_power(key): 

22 return (isinstance(key, Transform) and key.power != 1) 

23 

24def get_free_vks(posy, solution): 

25 "Returns all free vks of a given posynomial for a given solution" 

26 return set(vk for vk in posy.vks if vk not in solution["constants"]) 

27 

28def get_model_breakdown(solution): 

29 breakdowns = {"|sensitivity|": 0} 

30 for constraint, senss in solution["sensitivities"]["constraints"].items(): 

31 senss = abs(senss) # for those monomial 

32 if senss <= 1e-5: 

33 continue 

34 subbd = breakdowns 

35 subbd["|sensitivity|"] += senss 

36 for parent in lineagestr(constraint).split("."): 

37 if parent == "": 

38 continue 

39 if parent not in subbd: 

40 subbd[parent] = {} 

41 subbd = subbd[parent] 

42 if "|sensitivity|" not in subbd: 

43 subbd["|sensitivity|"] = 0 

44 subbd["|sensitivity|"] += senss 

45 # treat vectors as namespace 

46 constrstr = try_str_without(constraint, {"units", ":MAGIC:"+lineagestr(constraint)}) 

47 if " at 0x" in constrstr: # don't print memory addresses 

48 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

49 subbd[constrstr] = {"|sensitivity|": senss} 

50 for vk in solution["sensitivities"]["variables"].keymap: # could this be done away with for backwards compatibility? 

51 if not isinstance(vk, VarKey) or (vk.shape and not vk.index): 

52 continue 

53 senss = abs(solution["sensitivities"]["variables"][vk]) 

54 if hasattr(senss, "shape"): 

55 senss = np.nansum(senss) 

56 if senss <= 1e-5: 

57 continue 

58 subbd = breakdowns 

59 subbd["|sensitivity|"] += senss 

60 for parent in vk.lineagestr().split("."): 

61 if parent == "": 

62 continue 

63 if parent not in subbd: 

64 subbd[parent] = {} 

65 subbd = subbd[parent] 

66 if "|sensitivity|" not in subbd: 

67 subbd["|sensitivity|"] = 0 

68 subbd["|sensitivity|"] += senss 

69 # treat vectors as namespace (indexing vectors above) 

70 vk = vk.str_without({"lineage"}) + get_valstr(vk, solution, " = %s").replace(", fixed", "") 

71 subbd[vk] = {"|sensitivity|": senss} 

72 # TODO: track down in a live-solve environment why this isn't the same 

73 # print(breakdowns["HyperloopSystem"]["|sensitivity|"]) 

74 return breakdowns 

75 

76def crawl_modelbd(bd, lookup, name="Model"): 

77 tree = Tree(name, bd.pop("|sensitivity|"), []) 

78 if bd: 

79 lookup[name] = tree 

80 for subname, subtree in sorted(bd.items(), 

81 key=lambda kv: (-float("%.2g" % kv[1]["|sensitivity|"]), kv[0])): 

82 tree.branches.append(crawl_modelbd(subtree, lookup, subname)) 

83 return tree 

84 

85def divide_out_vk(vk, pow, lt, gt): 

86 hmap = NomialMap({HashVector({vk: 1}): 1.0}) 

87 hmap.units = vk.units 

88 var = Monomial(hmap)**pow 

89 lt, gt = lt/var, gt/var 

90 lt.ast = gt.ast = None 

91 return lt, gt 

92 

93# @profile 

94def get_breakdowns(basically_fixed_variables, solution): 

95 """Returns {key: (lt, gt, constraint)} for breakdown constrain in solution. 

96 

97 A breakdown constraint is any whose "gt" contains a single free variable. 

98 

99 (At present, monomial constraints check both sides as "gt") 

100 """ 

101 breakdowns = defaultdict(list) 

102 beatout = defaultdict(set) 

103 for constraint, senss in sorted(solution["sensitivities"]["constraints"].items(), key=lambda kv: (-abs(float("%.2g" % kv[1])), str(kv[0]))): 

104 while getattr(constraint, "child", None): 

105 constraint = constraint.child 

106 while getattr(constraint, "generated", None): 

107 constraint = constraint.generated 

108 if abs(senss) <= 1e-5: # only tight-ish ones 

109 continue 

110 if constraint.oper == ">=": 

111 gt, lt = (constraint.left, constraint.right) 

112 elif constraint.oper == "<=": 

113 lt, gt = (constraint.left, constraint.right) 

114 elif constraint.oper == "=": 

115 if senss > 0: # l_over_r is more sensitive - see nomials/math.py 

116 lt, gt = (constraint.left, constraint.right) 

117 else: # r_over_l is more sensitive - see nomials/math.py 

118 gt, lt = (constraint.left, constraint.right) 

119 for gtvk in gt.vks: # remove RelaxPCCP.C 

120 if (gtvk.name == "C" and gtvk.lineage[0][0] == "RelaxPCCP" 

121 and gtvk not in solution["variables"]): 

122 lt, gt = lt.sub({gtvk: 1}), gt.sub({gtvk: 1}) 

123 if len(gt.hmap) > 1: 

124 continue 

125 pos_gtvks = {vk for vk, pow in gt.exp.items() if pow > 0} 

126 if len(pos_gtvks) > 1: 

127 pos_gtvks &= get_free_vks(gt, solution) # remove constants 

128 if len(pos_gtvks) == 1: 

129 chosenvk, = pos_gtvks 

130 while getattr(constraint, "parent", None): 

131 constraint = constraint.parent 

132 while getattr(constraint, "generated_by", None): 

133 constraint = constraint.generated_by 

134 breakdowns[chosenvk].append((lt, gt, constraint)) 

135 for constraint, senss in sorted(solution["sensitivities"]["constraints"].items(), key=lambda kv: (-abs(float("%.2g" % kv[1])), str(kv[0]))): 

136 if abs(senss) <= 1e-5: # only tight-ish ones 

137 continue 

138 while getattr(constraint, "child", None): 

139 constraint = constraint.child 

140 while getattr(constraint, "generated", None): 

141 constraint = constraint.generated 

142 if constraint.oper == ">=": 

143 gt, lt = (constraint.left, constraint.right) 

144 elif constraint.oper == "<=": 

145 lt, gt = (constraint.left, constraint.right) 

146 elif constraint.oper == "=": 

147 if senss > 0: # l_over_r is more sensitive - see nomials/math.py 

148 lt, gt = (constraint.left, constraint.right) 

149 else: # r_over_l is more sensitive - see nomials/math.py 

150 gt, lt = (constraint.left, constraint.right) 

151 for gtvk in gt.vks: 

152 if (gtvk.name == "C" and gtvk.lineage[0][0] == "RelaxPCCP" 

153 and gtvk not in solution["variables"]): 

154 lt, gt = lt.sub({gtvk: 1}), gt.sub({gtvk: 1}) 

155 if len(gt.hmap) > 1: 

156 continue 

157 pos_gtvks = {vk for vk, pow in gt.exp.items() if pow > 0} 

158 if len(pos_gtvks) > 1: 

159 pos_gtvks &= get_free_vks(gt, solution) # remove constants 

160 if len(pos_gtvks) != 1: # we'll choose our favorite vk 

161 for vk, pow in gt.exp.items(): 

162 if pow < 0: # remove all non-positive 

163 lt, gt = divide_out_vk(vk, pow, lt, gt) 

164 # bring over common factors from lt 

165 lt_pows = defaultdict(set) 

166 for exp in lt.hmap: 

167 for vk, pow in exp.items(): 

168 lt_pows[vk].add(pow) 

169 for vk, pows in lt_pows.items(): 

170 if len(pows) == 1: 

171 pow, = pows 

172 if pow < 0: # ...but only if they're positive 

173 lt, gt = divide_out_vk(vk, pow, lt, gt) 

174 # don't choose something that's already been broken down 

175 candidatevks = {vk for vk in gt.vks if vk not in breakdowns} 

176 if candidatevks: 

177 vrisk = solution["sensitivities"]["variablerisk"] 

178 chosenvk, *_ = sorted( 

179 candidatevks, 

180 key=lambda vk: (-float("%.2g" % (gt.exp[vk]*vrisk.get(vk, 0))), str(vk)) 

181 ) 

182 for vk, pow in gt.exp.items(): 

183 if vk is not chosenvk: 

184 lt, gt = divide_out_vk(vk, pow, lt, gt) 

185 while getattr(constraint, "parent", None): 

186 constraint = constraint.parent 

187 while getattr(constraint, "generated_by", None): 

188 constraint = constraint.generated_by 

189 breakdowns[chosenvk].append((lt, gt, constraint)) 

190 breakdowns = dict(breakdowns) # remove the defaultdict-ness 

191 

192 prevlen = None 

193 while len(basically_fixed_variables) != prevlen: 

194 prevlen = len(basically_fixed_variables) 

195 for key in breakdowns: 

196 if key not in basically_fixed_variables: 

197 get_fixity(basically_fixed_variables, key, breakdowns, solution) 

198 return breakdowns 

199 

200 

201def get_fixity(basically_fixed, key, bd, solution, visited=set()): 

202 lt, gt, _ = bd[key][0] 

203 free_vks = get_free_vks(lt, solution).union(get_free_vks(gt, solution)) 

204 for vk in free_vks: 

205 if vk is key or vk in basically_fixed: 

206 continue # currently checking or already checked 

207 if vk not in bd: 

208 return # a very free variable, can't even be broken down 

209 if vk in visited: 

210 return # tried it before, implicitly it didn't work out 

211 # maybe it's basically fixed? 

212 visited.add(key) 

213 get_fixity(basically_fixed, vk, bd, solution, visited) 

214 if vk not in basically_fixed: 

215 return # ...well, we tried 

216 basically_fixed.add(key) 

217 

218# @profile # ~84% of total last check # TODO: remove 

219def crawl(basically_fixed_variables, key, bd, solution, basescale=1, permissivity=2, verbosity=0, 

220 visited_bdkeys=None, gone_negative=False, all_visited_bdkeys=None): 

221 "Returns the tree of breakdowns of key in bd, sorting by solution's values" 

222 if key != solution["cost function"] and hasattr(key, "key"): 

223 key = key.key # clear up Variables 

224 if key in bd: 

225 # TODO: do multiple if sensitivities are quite close? 

226 composition, keymon, constraint = bd[key][0] 

227 elif isinstance(key, Posynomial): 

228 composition = key 

229 keymon = None 

230 else: 

231 raise TypeError("the `key` argument must be a VarKey or Posynomial.") 

232 

233 if visited_bdkeys is None: 

234 visited_bdkeys = set() 

235 all_visited_bdkeys = set() 

236 if verbosity == 1: 

237 already_set = False #not solution._lineageset TODO 

238 if not already_set: 

239 solution.set_necessarylineage() 

240 if verbosity: 

241 indent = verbosity-1 # HACK: a bit of overloading, here 

242 kvstr = "%s (%s)" % (key, get_valstr(key, solution)) 

243 if key in all_visited_bdkeys: 

244 print(" "*indent + kvstr + " [as broken down above]") 

245 verbosity = 0 

246 else: 

247 print(" "*indent + kvstr) 

248 indent += 1 

249 orig_subtree = subtree = [] 

250 tree = Tree(key, basescale, subtree) 

251 visited_bdkeys.add(key) 

252 all_visited_bdkeys.add(key) 

253 if keymon is None: 

254 scale = solution(key)/basescale 

255 else: 

256 if verbosity: 

257 print(" "*indent + "which in: " 

258 + constraint.str_without(["units", "lineage"]) 

259 + " (sensitivity %+.2g)" % solution["sensitivities"]["constraints"][constraint]) 

260 interesting_vks = {key} 

261 subkey, = interesting_vks 

262 power = keymon.exp[subkey] 

263 boring_vks = set(keymon.vks) - interesting_vks 

264 scale = solution(key)**power/basescale 

265 # TODO: make method that can handle both kinds of transforms 

266 if (power != 1 or boring_vks or mag(keymon.c) != 1 

267 or keymon.units != key.units): # some kind of transform here 

268 units = 1 

269 exp = HashVector() 

270 for vk in interesting_vks: 

271 exp[vk] = keymon.exp[vk] 

272 if vk.units: 

273 units *= vk.units**keymon.exp[vk] 

274 subhmap = NomialMap({exp: 1}) 

275 try: 

276 subhmap.units = None if units == 1 else units 

277 except DimensionalityError: 

278 # pints was unable to divide a unit by itself bc 

279 # it has terrible floating-point errors. 

280 # so let's assume it isn't dimensionless 

281 # even though it probably is 

282 subhmap.units = units 

283 freemon = Monomial(subhmap) 

284 factor = Monomial(keymon/freemon) 

285 scale = scale * solution(factor) 

286 if factor != 1: 

287 factor = factor**(-1/power) # invert the transform 

288 factor.ast = None 

289 if verbosity: 

290 print(" "*indent + "{ through a factor of %s (%s) }" % 

291 (factor.str_without(["units"]), 

292 get_valstr(factor, solution))) 

293 subsubtree = [] 

294 transform = Transform(factor, 1, keymon) 

295 orig_subtree.append(Tree(transform, basescale, subsubtree)) 

296 orig_subtree = subsubtree 

297 if power != 1: 

298 if verbosity: 

299 print(" "*indent + "{ through a power of %.2g }" % power) 

300 subsubtree = [] 

301 transform = Transform(1, 1/power, keymon) # inverted bc it's on the gt side 

302 orig_subtree.append(Tree(transform, basescale, subsubtree)) 

303 orig_subtree = subsubtree 

304 

305 # TODO: use ast_parsing instead of chop? 

306 mons = composition.chop() 

307 monsols = [solution(mon) for mon in mons] # ~20% of total last check # TODO: remove 

308 parsed_monsols = [getattr(mon, "value", mon) for mon in monsols] 

309 monvals = [float(mon/scale) for mon in parsed_monsols] # ~10% of total last check # TODO: remove 

310 # sort by value, preserving order in case of value tie 

311 sortedmonvals = sorted(zip([-float("%.2g" % mv) for mv in monvals], range(len(mons)), monvals, mons)) 

312 # print([m.str_without({"units", "lineage"}) for m in mons]) 

313 if verbosity: 

314 if len(monsols) == 1: 

315 print(" "*indent + "breaks down into:") 

316 else: 

317 print(" "*indent + "breaks down into %i monomials:" % len(monsols)) 

318 indent += 1 

319 indent += 1 

320 for i, (_, _, scaledmonval, mon) in enumerate(sortedmonvals): 

321 if not scaledmonval: 

322 continue 

323 subtree = orig_subtree # return to the original subtree 

324 # time for some filtering 

325 interesting_vks = mon.vks 

326 potential_filters = [ 

327 {vk for vk in interesting_vks if vk not in bd}, 

328 mon.vks - get_free_vks(mon, solution), 

329 {vk for vk in interesting_vks if vk in basically_fixed_variables} 

330 ] 

331 if scaledmonval < 1 - permissivity: # skip breakdown filter 

332 potential_filters = potential_filters[1:] 

333 potential_filters.insert(0, visited_bdkeys) 

334 for filter in potential_filters: 

335 if interesting_vks - filter: # don't remove the last one 

336 interesting_vks = interesting_vks - filter 

337 # if filters weren't enough and permissivity is high enough, sort! 

338 if len(interesting_vks) > 1 and permissivity > 1: 

339 csenss = solution["sensitivities"]["constraints"] 

340 best_vks = sorted((vk for vk in interesting_vks if vk in bd), 

341 key=lambda vk: (-abs(float("%.2g" % (mon.exp[vk]*csenss[bd[vk][0][2]]))), 

342 -float("%.2g" % solution["variables"][vk]), 

343 str(bd[vk][0][0]))) # ~5% of total last check # TODO: remove 

344 # TODO: changing to str(vk) above does some odd stuff, why? 

345 if best_vks: 

346 interesting_vks = set([best_vks[0]]) 

347 boring_vks = mon.vks - interesting_vks 

348 

349 subkey = None 

350 if len(interesting_vks) == 1: 

351 subkey, = interesting_vks 

352 if subkey in visited_bdkeys and len(sortedmonvals) == 1: 

353 continue # don't even go there 

354 if subkey not in bd: 

355 power = 1 # no need for a transform 

356 else: 

357 power = mon.exp[subkey] 

358 if power < 0 and gone_negative: 

359 subkey = None # don't breakdown another negative 

360 

361 if len(monsols) > 1 and verbosity: 

362 indent -= 1 

363 print(" "*indent + "%s) forming %i%% of the RHS and %i%% of the total:" % (i+1, scaledmonval/basescale*100, scaledmonval*100)) 

364 indent += 1 

365 

366 if subkey is None: 

367 power = 1 

368 if scaledmonval > 1 - permissivity and not boring_vks: 

369 boring_vks = interesting_vks 

370 interesting_vks = set() 

371 if not interesting_vks: 

372 # prioritize showing some boring_vks as if they were "free" 

373 if len(boring_vks) == 1: 

374 interesting_vks = boring_vks 

375 boring_vks = set() 

376 else: 

377 for vk in list(boring_vks): 

378 if vk.units and not vk.units.dimensionless: 

379 interesting_vks.add(vk) 

380 boring_vks.remove(vk) 

381 

382 if interesting_vks and (boring_vks or mag(mon.c) != 1): 

383 units = 1 

384 exp = HashVector() 

385 for vk in interesting_vks: 

386 exp[vk] = mon.exp[vk] 

387 if vk.units: 

388 units *= vk.units**mon.exp[vk] 

389 subhmap = NomialMap({exp: 1}) 

390 subhmap.units = None if units is 1 else units 

391 freemon = Monomial(subhmap) 

392 factor = mon/freemon # autoconvert... 

393 if (factor.units is None and isinstance(factor, FixedScalar) 

394 and abs(factor.value - 1) <= 1e-4): 

395 factor = 1 # minor fudge to clear numerical inaccuracies 

396 if factor != 1 : 

397 factor.ast = None 

398 if verbosity: 

399 keyvalstr = "%s (%s)" % (factor.str_without(["units"]), 

400 get_valstr(factor, solution)) 

401 print(" "*indent + "{ through a factor of %s }" % keyvalstr) 

402 subsubtree = [] 

403 transform = Transform(factor, 1, mon) 

404 subtree.append(Tree(transform, scaledmonval, subsubtree)) 

405 subtree = subsubtree 

406 mon = freemon # simplifies units 

407 if power != 1: 

408 if verbosity: 

409 print(" "*indent + "{ through a power of %.2g }" % power) 

410 subsubtree = [] 

411 transform = Transform(1, power, mon) 

412 subtree.append(Tree(transform, scaledmonval, subsubtree)) 

413 subtree = subsubtree 

414 mon = mon**(1/power) 

415 mon.ast = None 

416 # TODO: make minscale an argument - currently an arbitrary 0.01 

417 if (subkey is not None and subkey not in visited_bdkeys 

418 and subkey in bd and scaledmonval > 0.05): 

419 subverbosity = indent + 1 if verbosity else 0 # slight hack 

420 subsubtree = crawl(basically_fixed_variables, subkey, bd, solution, scaledmonval, 

421 permissivity, subverbosity, set(visited_bdkeys), 

422 gone_negative, all_visited_bdkeys) 

423 subtree.append(subsubtree) 

424 else: 

425 if verbosity: 

426 keyvalstr = "%s (%s)" % (mon.str_without(["units"]), 

427 get_valstr(mon, solution)) 

428 print(" "*indent + keyvalstr) 

429 subtree.append(Tree(mon, scaledmonval, [])) 

430 if verbosity == 1: 

431 if not already_set: 

432 solution.set_necessarylineage(clear=True) 

433 return tree 

434 

435SYMBOLS = string.ascii_uppercase + string.ascii_lowercase 

436for ambiguous_symbol in "lILT": 

437 SYMBOLS = SYMBOLS.replace(ambiguous_symbol, "") 

438 

439def get_spanstr(legend, length, label, leftwards, solution): 

440 "Returns span visualization, collapsing labels to symbols" 

441 if label is None: 

442 return " "*length 

443 spacer, lend, rend = "│", "┯", "┷" 

444 if isinstance(label, Transform): 

445 spacer, lend, rend = "╎", "╤", "╧" 

446 if label.power != 1: 

447 spacer = " " 

448 lend = rend = "^" if label.power > 0 else "/" 

449 # remove origkeys so they collide in the legends dictionary 

450 label = Transform(label.factor, label.power, None) 

451 if label.power == 1 and len(str(label.factor)) == 1: 

452 legend[label] = str(label.factor) 

453 

454 if label not in legend: 

455 legend[label] = SYMBOLS[len(legend)] 

456 shortname = legend[label] 

457 

458 if length <= 1: 

459 return shortname 

460 shortside = int(max(0, length - 2)/2) 

461 longside = int(max(0, length - 3)/2) 

462 if leftwards: 

463 if length == 2: 

464 return lend + shortname 

465 return lend + spacer*shortside + shortname + spacer*longside + rend 

466 else: 

467 if length == 2: 

468 return shortname + rend 

469 # HACK: no corners on long rightwards - only used for depth 0 

470 return "┃"*(longside+1) + shortname + "┃"*(shortside+1) 

471 

472def discretize(tree, extent, solution, collapse, depth=0, justsplit=False): 

473 # TODO: add vertical simplification? 

474 key, val, branches = tree 

475 if collapse: # collapse Transforms with power 1 

476 while any(isinstance(branch.key, Transform) and branch.key.power > 0 for branch in branches): 

477 newbranches = [] 

478 for branch in branches: 

479 # isinstance(branch.key, Transform) and branch.key.power > 0 

480 if isinstance(branch.key, Transform) and branch.key.power > 0: 

481 newbranches.extend(branch.branches) 

482 else: 

483 newbranches.append(branch) 

484 branches = newbranches 

485 

486 scale = extent/val 

487 values = [b.value for b in branches] 

488 bkey_indexs = {} 

489 for i, b in enumerate(branches): 

490 k = get_keystr(b.key, solution) 

491 if isinstance(b.key, Transform): 

492 if len(b.branches) == 1: 

493 k = get_keystr(b.branches[0].key, solution) 

494 if k in bkey_indexs: 

495 values[bkey_indexs[k]] += values[i] 

496 values[i] = None 

497 else: 

498 bkey_indexs[k] = i 

499 if any(v is None for v in values): 

500 bvs = zip(*sorted(((-float("%.2g" % v), i, b, v) for i, (b, v) in enumerate(zip(branches, values)) if v is not None))) 

501 _, _, branches, values = bvs 

502 branches = list(branches) 

503 values = list(values) 

504 extents = [int(round(scale*v)) for v in values] 

505 surplus = extent - sum(extents) 

506 for i, b in enumerate(branches): 

507 if isinstance(b.key, Transform): 

508 subscale = extents[i]/b.value 

509 if not any(round(subscale*subv) for _, subv, _ in b.branches): 

510 extents[i] = 0 # transform with no worthy heirs: misc it 

511 if not any(extents): 

512 return Tree(key, extent, []) 

513 if not all(extents): # create a catch-all 

514 branches = branches.copy() 

515 miscvkeys, miscval = [], 0 

516 for subextent in reversed(extents): 

517 if not subextent or (branches[-1].value < miscval and surplus < 0): 

518 extents.pop() 

519 k, v, _ = branches.pop() 

520 if isinstance(k, Transform): 

521 k = k.origkey # TODO: this is the only use of origkey - remove it 

522 if isinstance(k, tuple): 

523 vkeys = [(-kv[1], str(kv[0]), kv[0]) for kv in k] 

524 if not isinstance(k, tuple): 

525 vkeys = [(-float("%.2g" % v), str(k), k)] 

526 miscvkeys += vkeys 

527 surplus -= (round(scale*(miscval + v)) 

528 - round(scale*miscval) - subextent) 

529 miscval += v 

530 misckeys = tuple(k for _, _, k in sorted(miscvkeys)) 

531 branches.append(Tree(misckeys, miscval, [])) 

532 extents.append(int(round(scale*miscval))) 

533 if surplus: 

534 sign = int(np.sign(surplus)) 

535 bump_priority = sorted((ext, sign*float("%.2g" % b.value), i) for i, (b, ext) 

536 in enumerate(zip(branches, extents))) 

537 # print(key, surplus, bump_priority) 

538 while surplus: 

539 try: 

540 extents[bump_priority.pop()[-1]] += sign 

541 surplus -= sign 

542 except IndexError: 

543 raise ValueError(val, [b.value for b in branches]) 

544 

545 tree = Tree(key, extent, []) 

546 # simplify based on how we're branching 

547 branchfactor = len([ext for ext in extents if ext]) - 1 

548 if depth and not isinstance(key, Transform): 

549 if extent == 1 or branchfactor >= max(extent-2, 1): 

550 # if we'd branch too much, stop 

551 return tree 

552 if collapse and not branchfactor and not justsplit: 

553 # if we didn't just split and aren't about to, skip through 

554 return discretize(branches[0], extent, solution, collapse, 

555 depth=depth+1, justsplit=False) 

556 if branchfactor: 

557 justsplit = True 

558 elif not isinstance(key, Transform): # justsplit passes through transforms 

559 justsplit = False 

560 

561 for branch, subextent in zip(branches, extents): 

562 if subextent: 

563 branch = discretize(branch, subextent, solution, collapse, 

564 depth=depth+1, justsplit=justsplit) 

565 if (collapse and is_power(branch.key) 

566 and all(is_power(b.key) for b in branch.branches)): 

567 # combine stacked powers 

568 power = branch.key.power 

569 for b in branch.branches: 

570 key = Transform(1, power*b.key.power, None) 

571 if key.power == 1: # powers canceled, collapse both 

572 tree.branches.extend(b.branches) 

573 else: # collapse this level 

574 tree.branches.append(Tree(key, b.value, b.branches)) 

575 else: 

576 tree.branches.append(branch) 

577 return tree 

578 

579def layer(map, tree, maxdepth, depth=0): 

580 "Turns the tree into a 2D-array" 

581 key, extent, branches = tree 

582 if depth <= maxdepth: 

583 if len(map) <= depth: 

584 map.append([]) 

585 map[depth].append((key, extent)) 

586 if not branches: 

587 branches = [Tree(None, extent, [])] # pad it out 

588 for branch in branches: 

589 layer(map, branch, maxdepth, depth+1) 

590 return map 

591 

592def plumb(tree, depth=0): 

593 "Finds maximum depth of a tree" 

594 maxdepth = depth 

595 for branch in tree.branches: 

596 maxdepth = max(maxdepth, plumb(branch, depth+1)) 

597 return maxdepth 

598 

599def prune(tree, solution, maxlength, length=-1, prefix=""): 

600 "Prune branches that are longer than a certain number of characters" 

601 key, extent, branches = tree 

602 keylength = max(len(get_valstr(key, solution, into="(%s)")), 

603 len(get_keystr(key, solution, prefix))) 

604 if length == -1 and isinstance(key, VarKey) and key.necessarylineage: 

605 prefix = key.lineagestr() 

606 length += keylength + 3 

607 for branch in branches: 

608 keylength = max(len(get_valstr(branch.key, solution, into="(%s)")), 

609 len(get_keystr(branch.key, solution, prefix))) 

610 branchlength = length + keylength + 3 

611 if branchlength > maxlength: 

612 return Tree(key, extent, []) 

613 return Tree(key, extent, [prune(b, solution, maxlength, length, prefix) 

614 for b in branches]) 

615 

616def simplify(tree, solution, extent, maxdepth, maxlength, collapse): 

617 "Discretize, prune, and layer a tree to prepare for printing" 

618 subtree = discretize(tree, extent, solution, collapse) 

619 if collapse and maxlength: 

620 subtree = prune(subtree, solution, maxlength) 

621 return layer([], subtree, maxdepth) 

622 

623# @profile # ~16% of total last check # TODO: remove 

624def graph(tree, breakdowns, solution, basically_fixed_variables, *, 

625 height=None, maxdepth=None, maxwidth=81, showlegend=False): 

626 "Prints breakdown" 

627 already_set = solution._lineageset 

628 if not already_set: 

629 solution.set_necessarylineage() 

630 collapse = (not showlegend) # TODO: set to True while showlegend is True for first approx of receipts; autoinclude with trace? 

631 if maxdepth is None: 

632 maxdepth = plumb(tree) 

633 if height is not None: 

634 mt = simplify(tree, solution, height, maxdepth, maxwidth, collapse) 

635 else: # zoom in from a default height of 20 to a height of 4 per branch 

636 prev_height = None 

637 height = 20 

638 while prev_height != height: 

639 mt = simplify(tree, solution, height, maxdepth, maxwidth, collapse) 

640 prev_height = height 

641 height = min(height, max(*(4*len(at_depth) for at_depth in mt))) 

642 

643 legend = {} 

644 chararray = np.full((len(mt), height), "", "object") 

645 for depth, elements_at_depth in enumerate(mt): 

646 row = "" 

647 for i, (element, length) in enumerate(elements_at_depth): 

648 leftwards = depth > 0 and length > 2 

649 row += get_spanstr(legend, length, element, leftwards, solution) 

650 chararray[depth, :] = list(row) 

651 

652 # Format depth=0 

653 A_key, = [key for key, value in legend.items() if value == "A"] 

654 prefix = "" 

655 if A_key is solution["cost function"]: 

656 A_str = "Cost" 

657 else: 

658 A_str = get_keystr(A_key, solution) 

659 if isinstance(A_key, VarKey) and A_key.necessarylineage: 

660 prefix = A_key.lineagestr() 

661 A_valstr = get_valstr(A_key, solution, into="(%s)") 

662 fmt = "{0:>%s}" % (max(len(A_str), len(A_valstr)) + 3) 

663 for j, entry in enumerate(chararray[0,:]): 

664 if entry == "A": 

665 chararray[0,j] = fmt.format(A_str + "╺┫") 

666 chararray[0,j+1] = fmt.format(A_valstr + " ┃") 

667 else: 

668 chararray[0,j] = fmt.format(entry) 

669 # Format depths 1+ 

670 labeled = set() 

671 reverse_legend = {v: k for k, v in legend.items()} 

672 legend = {} 

673 for pos in range(height): 

674 for depth in reversed(range(1,len(mt))): 

675 char = chararray[depth, pos] 

676 if char not in reverse_legend: 

677 continue 

678 key = reverse_legend[char] 

679 if key not in legend and (isinstance(key, tuple) or (depth != len(mt) - 1 and chararray[depth+1, pos] != " ")): 

680 legend[key] = SYMBOLS[len(legend)] 

681 if collapse and is_power(key): 

682 chararray[depth, pos] = "^" if key.power > 0 else "/" 

683 del legend[key] 

684 continue 

685 if key in legend: 

686 chararray[depth, pos] = legend[key] 

687 if isinstance(key, tuple) and not isinstance(key, Transform): 

688 chararray[depth, pos] = "*" + chararray[depth, pos] 

689 del legend[key] 

690 if showlegend: 

691 continue 

692 

693 keystr = get_keystr(key, solution, prefix) 

694 if keystr in labeled: 

695 valuestr = "" 

696 else: 

697 valuestr = get_valstr(key, solution, into=" (%s)") 

698 if collapse: 

699 fmt = "{0:<%s}" % max(len(keystr) + 3, len(valuestr) + 2) 

700 else: 

701 fmt = "{0:<1}" 

702 span = 0 

703 tryup, trydn = True, True 

704 while tryup or trydn: 

705 span += 1 

706 if tryup: 

707 if pos - span < 0: 

708 tryup = False 

709 else: 

710 upchar = chararray[depth, pos-span] 

711 if upchar == "│": 

712 chararray[depth, pos-span] = fmt.format("┃") 

713 elif upchar == "┯": 

714 chararray[depth, pos-span] = fmt.format("┓") 

715 else: 

716 tryup = False 

717 if trydn: 

718 if pos + span >= height: 

719 trydn = False 

720 else: 

721 dnchar = chararray[depth, pos+span] 

722 if dnchar == "│": 

723 chararray[depth, pos+span] = fmt.format("┃") 

724 elif dnchar == "┷": 

725 chararray[depth, pos+span] = fmt.format("┛") 

726 else: 

727 trydn = False 

728 linkstr = "┣╸" 

729 if not isinstance(key, FixedScalar): 

730 labeled.add(keystr) 

731 if span > 1 and (collapse or pos + 2 >= height 

732 or chararray[depth, pos+1] == "┃"): 

733 vallabel = chararray[depth, pos+1].rstrip() + valuestr 

734 chararray[depth, pos+1] = fmt.format(vallabel) 

735 elif showlegend: 

736 keystr += valuestr 

737 if (key in breakdowns and not chararray[depth+1, pos].strip() 

738 and (depth >= len(mt)-2 

739 or not chararray[depth+2, pos].strip())): 

740 keystr = keystr + "╶⎨" 

741 chararray[depth, pos] = fmt.format(linkstr + keystr) 

742 # Rotate and print 

743 rowstrs = ["".join(row).rstrip() for row in chararray.T.tolist()] 

744 print("\n" + "\n".join(rowstrs) + "\n") 

745 

746 if showlegend: # create and print legend 

747 legend_lines = [] 

748 for key, shortname in sorted(legend.items(), key=lambda kv: kv[1]): 

749 legend_lines.append(legend_entry(key, shortname, solution, prefix, 

750 basically_fixed_variables)) 

751 maxlens = [max(len(el) for el in col) for col in zip(*legend_lines)] 

752 fmts = ["{0:<%s}" % L for L in maxlens] 

753 for line in legend_lines: 

754 line = "".join(fmt.format(cell) 

755 for fmt, cell in zip(fmts, line) if cell).rstrip() 

756 print(" " + line) 

757 

758 if not already_set: 

759 solution.set_necessarylineage(clear=True) 

760 

761def legend_entry(key, shortname, solution, prefix, basically_fixed_variables): 

762 "Returns list of legend elements" 

763 operator = note = "" 

764 keystr = valuestr = " " 

765 operator = "= " if shortname else " + " 

766 if is_factor(key): 

767 operator = " ×" 

768 key = key.factor 

769 free, quasifixed = False, False 

770 if any(vk not in basically_fixed_variables 

771 for vk in get_free_vks(key, solution)): 

772 note = " [free factor]" 

773 if is_power(key): 

774 valuestr = " ^%.3g" % key.power 

775 else: 

776 valuestr = get_valstr(key, solution, into=" "+operator+"%s") 

777 if not isinstance(key, FixedScalar): 

778 keystr = get_keystr(key, solution, prefix) 

779 return ["%-4s" % shortname, keystr, valuestr, note] 

780 

781def get_keystr(key, solution, prefix=""): 

782 "Returns formatted string of the key in solution." 

783 if hasattr(key, "str_without"): 

784 out = key.str_without({"units", ":MAGIC:"+prefix}) 

785 elif isinstance(key, tuple): 

786 out = "[%i terms]" % len(key) 

787 else: 

788 out = str(key) 

789 return out if len(out) <= 67 else out[:66]+"…" 

790 

791def get_valstr(key, solution, into="%s"): 

792 "Returns formatted string of the value of key in solution." 

793 # get valuestr 

794 try: 

795 value = solution(key) 

796 except (ValueError, TypeError): 

797 try: 

798 value = sum(solution(subkey) for subkey in key) 

799 except (ValueError, TypeError): 

800 return " " 

801 if isinstance(value, FixedScalar): 

802 value = value.value 

803 if 1e3 <= mag(value) < 1e6: 

804 valuestr = "{:,.0f}".format(mag(value)) 

805 else: 

806 valuestr = "%-.3g" % mag(value) 

807 # get unitstr 

808 if hasattr(key, "unitstr"): 

809 unitstr = key.unitstr() 

810 else: 

811 try: 

812 if hasattr(value, "units"): 

813 value.ito_reduced_units() 

814 except DimensionalityError: 

815 pass 

816 unitstr = get_unitstr(value) 

817 if unitstr[:2] == "1/": 

818 unitstr = "/" + unitstr[2:] 

819 if key in solution["constants"] or ( 

820 hasattr(key, "vks") and key.vks 

821 and all(vk in solution["constants"] for vk in key.vks)): 

822 unitstr += ", fixed" 

823 return into % (valuestr + unitstr) 

824 

825 

826import plotly.graph_objects as go 

827def plotlyify(tree, solution, minval=None): 

828 """Plots model structure as Plotly TreeMap 

829 

830 Arguments 

831 --------- 

832 model: Model 

833 GPkit model object 

834 

835 itemize (optional): string, either "variables" or "constraints" 

836 Specify whether to iterate over the model varkeys or constraints 

837 

838 sizebycount (optional): bool 

839 Whether to size blocks by number of variables/constraints or use 

840 default sizing 

841 

842 Returns 

843 ------- 

844 plotly.graph_objects.Figure 

845 Plot of model hierarchy 

846 

847 """ 

848 ids = [] 

849 labels = [] 

850 parents = [] 

851 values = [] 

852 

853 key, value, branches = tree 

854 if isinstance(key, VarKey) and key.necessarylineage: 

855 prefix = key.lineagestr() 

856 else: 

857 prefix = "" 

858 

859 if minval is None: 

860 minval = value/1000 

861 

862 parent_budgets = {} 

863 

864 def crawl(tree, parent_id=None): 

865 key, value, branches = tree 

866 if value > minval: 

867 if isinstance(key, Transform): 

868 id = parent_id 

869 else: 

870 id = len(ids)+1 

871 ids.append(id) 

872 labels.append(get_keystr(key, solution, prefix)) 

873 if not isinstance(key, str): 

874 labels[-1] = labels[-1] + "<br>" + get_valstr(key, solution) 

875 parents.append(parent_id) 

876 if parent_id is not None: # make sure there's no overflow 

877 if parent_budgets[parent_id] < value: 

878 value = parent_budgets[parent_id] # take remainder 

879 parent_budgets[parent_id] -= value 

880 values.append(value) 

881 parent_budgets[id] = value 

882 for branch in branches: 

883 crawl(branch, id) 

884 

885 crawl(tree) 

886 return ids, labels, parents, values 

887 

888def treemap(ids, labels, parents, values): 

889 return go.Figure(go.Treemap( 

890 ids=ids, 

891 labels=labels, 

892 parents=parents, 

893 values=values, 

894 branchvalues="total" 

895 )) 

896 

897def icicle(ids, labels, parents, values): 

898 return go.Figure(go.Icicle( 

899 ids=ids, 

900 labels=labels, 

901 parents=parents, 

902 values=values, 

903 branchvalues="total" 

904 )) 

905 

906 

907import functools 

908 

909class Breakdowns(object): 

910 def __init__(self, sol): 

911 self.sol = sol 

912 self.mlookup = {} 

913 self.mtree = crawl_modelbd(get_model_breakdown(sol), self.mlookup) 

914 self.basically_fixed_variables = set() 

915 self.bd = get_breakdowns(self.basically_fixed_variables, self.sol) 

916 

917 def trace(self, key, *, permissivity=2): 

918 print("") # a little padding to start 

919 self.get_tree(key, permissivity=permissivity, verbosity=1) 

920 

921 def get_tree(self, key, *, permissivity=2, verbosity=0): 

922 tree = None 

923 kind = "variable" 

924 if isinstance(key, str): 

925 if key == "model sensitivities": 

926 tree = self.mtree 

927 kind = "constraint" 

928 elif key == "cost": 

929 key = self.sol["cost function"] 

930 elif key in self.mlookup: 

931 tree = self.mlookup[key] 

932 kind = "constraint" 

933 else: 

934 # TODO: support submodels 

935 keys = [vk for vk in self.bd 

936 if key in str(vk) and key[-1] == str(vk)[-1]] 

937 if not keys: 

938 raise KeyError(key) 

939 elif len(keys) > 1: 

940 raise KeyError("There are %i keys containing '%s'." % (len(keys), key)) 

941 key, = keys 

942 if tree is None: 

943 tree = crawl(self.basically_fixed_variables, key, self.bd, self.sol, 

944 permissivity=permissivity, verbosity=verbosity) 

945 return tree, kind 

946 

947 def plot(self, key, *, height=None, permissivity=2, showlegend=False, 

948 maxwidth=85): 

949 tree, kind = self.get_tree(key, permissivity=permissivity) 

950 lookup = self.bd if kind == "variable" else self.mlookup 

951 graph(tree, lookup, self.sol, self.basically_fixed_variables, 

952 height=height, showlegend=showlegend, maxwidth=maxwidth) 

953 

954 def treemap(self, key, *, permissivity=2, returnfig=False, filename=None): 

955 tree, _ = self.get_tree(key) 

956 fig = treemap(*plotlyify(tree, self.sol)) 

957 if returnfig: 

958 return fig 

959 if filename is None: 

960 filename = str(key)+"_treemap.html" 

961 keepcharacters = (".","_") 

962 filename = "".join(c for c in filename if c.isalnum() 

963 or c in keepcharacters).rstrip() 

964 import plotly 

965 plotly.offline.plot(fig, filename=filename) 

966 

967 

968 def icicle(self, key, *, permissivity=2, returnfig=False, filename=None): 

969 tree, _ = self.get_tree(key, permissivity=permissivity) 

970 fig = icicle(*plotlyify(tree, self.sol)) 

971 if returnfig: 

972 return fig 

973 if filename is None: 

974 filename = str(key)+"_icicle.html" 

975 keepcharacters = (".","_") 

976 filename = "".join(c for c in filename if c.isalnum() 

977 or c in keepcharacters).rstrip() 

978 import plotly 

979 plotly.offline.plot(fig, filename=filename)