Coverage for gpkit/breakdowns.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#TODO: cleanup weird conditionals
2# add conversions to plotly/sankey
4# pylint: skip-file
5import string
6from collections import defaultdict, namedtuple, Counter
7from gpkit.nomials import Monomial, Posynomial, Variable
8from gpkit.nomials.map import NomialMap
9from gpkit.small_scripts import mag, try_str_without
10from gpkit.small_classes import FixedScalar, HashVector
11from gpkit.exceptions import DimensionalityError
12from gpkit.repr_conventions import unitstr as get_unitstr
13from gpkit.repr_conventions import lineagestr
14from gpkit.varkey import VarKey
15import numpy as np
17Tree = namedtuple("Tree", ["key", "value", "branches"])
18Transform = namedtuple("Transform", ["factor", "power", "origkey"])
19def is_factor(key):
20 return (isinstance(key, Transform) and key.power == 1)
21def is_power(key):
22 return (isinstance(key, Transform) and key.power != 1)
24def get_free_vks(posy, solution):
25 "Returns all free vks of a given posynomial for a given solution"
26 return set(vk for vk in posy.vks if vk not in solution["constants"])
28def get_model_breakdown(solution):
29 breakdowns = {"|sensitivity|": 0}
30 for modelname, senss in solution["sensitivities"]["models"].items():
31 senss = abs(senss) # for those monomial equalities
32 *namespace, name = modelname.split(".")
33 subbd = breakdowns
34 subbd["|sensitivity|"] += senss
35 for parent in namespace:
36 if parent not in subbd:
37 subbd[parent] = {parent: {}}
38 subbd = subbd[parent]
39 if "|sensitivity|" not in subbd:
40 subbd["|sensitivity|"] = 0
41 subbd["|sensitivity|"] += senss
42 subbd[name] = {"|sensitivity|": senss}
43 # print(breakdowns["HyperloopSystem"]["|sensitivity|"])
44 breakdowns = {"|sensitivity|": 0}
45 for constraint, senss in solution["sensitivities"]["constraints"].items():
46 senss = abs(senss) # for those monomial
47 if senss <= 1e-5:
48 continue
49 subbd = breakdowns
50 subbd["|sensitivity|"] += senss
51 for parent in lineagestr(constraint).split("."):
52 if parent == "":
53 continue
54 if parent not in subbd:
55 subbd[parent] = {}
56 subbd = subbd[parent]
57 if "|sensitivity|" not in subbd:
58 subbd["|sensitivity|"] = 0
59 subbd["|sensitivity|"] += senss
60 # treat vectors as namespace
61 constrstr = try_str_without(constraint, {"units", ":MAGIC:"+lineagestr(constraint)})
62 if " at 0x" in constrstr: # don't print memory addresses
63 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
64 subbd[constrstr] = {"|sensitivity|": senss}
65 for vk in solution["sensitivities"]["variables"].keymap: # could this be done away with for backwards compatibility?
66 if not isinstance(vk, VarKey) or (vk.shape and not vk.index):
67 continue
68 senss = abs(solution["sensitivities"]["variables"][vk])
69 if hasattr(senss, "shape"):
70 senss = np.nansum(senss)
71 if senss <= 1e-5:
72 continue
73 subbd = breakdowns
74 subbd["|sensitivity|"] += senss
75 for parent in vk.lineagestr().split("."):
76 if parent == "":
77 continue
78 if parent not in subbd:
79 subbd[parent] = {}
80 subbd = subbd[parent]
81 if "|sensitivity|" not in subbd:
82 subbd["|sensitivity|"] = 0
83 subbd["|sensitivity|"] += senss
84 # treat vectors as namespace (indexing vectors above)
85 vk = vk.str_without({"lineage"}) + get_valstr(vk, solution, " = %s").replace(", fixed", "")
86 subbd[vk] = {"|sensitivity|": senss}
87 # TODO: track down in a live-solve environment why this isn't the same
88 # print(breakdowns["HyperloopSystem"]["|sensitivity|"])
89 return breakdowns
91def crawl_modelbd(bd, lookup, name="Model"):
92 tree = Tree(name, bd.pop("|sensitivity|"), [])
93 if bd:
94 lookup[name] = tree
95 for subname, subtree in sorted(bd.items(),
96 key=lambda kv: (-float("%.2g" % kv[1]["|sensitivity|"]), kv[0])):
97 tree.branches.append(crawl_modelbd(subtree, lookup, subname))
98 return tree
100def divide_out_vk(vk, pow, lt, gt):
101 hmap = NomialMap({HashVector({vk: 1}): 1.0})
102 hmap.units = vk.units
103 var = Monomial(hmap)**pow
104 lt, gt = lt/var, gt/var
105 lt.ast = gt.ast = None
106 return lt, gt
108# @profile
109def get_breakdowns(basically_fixed_variables, solution):
110 """Returns {key: (lt, gt, constraint)} for breakdown constrain in solution.
112 A breakdown constraint is any whose "gt" contains a single free variable.
114 (At present, monomial constraints check both sides as "gt")
115 """
116 breakdowns = defaultdict(list)
117 beatout = defaultdict(set)
118 for constraint, senss in sorted(solution["sensitivities"]["constraints"].items(), key=lambda kv: (-abs(float("%.2g" % kv[1])), str(kv[0]))):
119 while getattr(constraint, "child", None):
120 constraint = constraint.child
121 while getattr(constraint, "generated", None):
122 constraint = constraint.generated
123 if abs(senss) <= 1e-5: # only tight-ish ones
124 continue
125 if constraint.oper == ">=":
126 gt, lt = (constraint.left, constraint.right)
127 elif constraint.oper == "<=":
128 lt, gt = (constraint.left, constraint.right)
129 elif constraint.oper == "=":
130 if senss > 0: # l_over_r is more sensitive - see nomials/math.py
131 lt, gt = (constraint.left, constraint.right)
132 else: # r_over_l is more sensitive - see nomials/math.py
133 gt, lt = (constraint.left, constraint.right)
134 for gtvk in gt.vks: # remove RelaxPCCP.C
135 if (gtvk.name == "C" and gtvk.lineage[0][0] == "RelaxPCCP"
136 and gtvk not in solution["freevariables"]):
137 lt, gt = lt.sub({gtvk: 1}), gt.sub({gtvk: 1})
138 if len(gt.hmap) > 1:
139 continue
140 pos_gtvks = {vk for vk, pow in gt.exp.items() if pow > 0}
141 if len(pos_gtvks) > 1:
142 pos_gtvks &= get_free_vks(gt, solution) # remove constants
143 if len(pos_gtvks) == 1:
144 chosenvk, = pos_gtvks
145 while getattr(constraint, "parent", None):
146 constraint = constraint.parent
147 while getattr(constraint, "generated_by", None):
148 constraint = constraint.generated_by
149 breakdowns[chosenvk].append((lt, gt, constraint))
150 for constraint, senss in sorted(solution["sensitivities"]["constraints"].items(), key=lambda kv: (-abs(float("%.2g" % kv[1])), str(kv[0]))):
151 if abs(senss) <= 1e-5: # only tight-ish ones
152 continue
153 while getattr(constraint, "child", None):
154 constraint = constraint.child
155 while getattr(constraint, "generated", None):
156 constraint = constraint.generated
157 if constraint.oper == ">=":
158 gt, lt = (constraint.left, constraint.right)
159 elif constraint.oper == "<=":
160 lt, gt = (constraint.left, constraint.right)
161 elif constraint.oper == "=":
162 if senss > 0: # l_over_r is more sensitive - see nomials/math.py
163 lt, gt = (constraint.left, constraint.right)
164 else: # r_over_l is more sensitive - see nomials/math.py
165 gt, lt = (constraint.left, constraint.right)
166 for gtvk in gt.vks:
167 if (gtvk.name == "C" and gtvk.lineage[0][0] == "RelaxPCCP"
168 and gtvk not in solution["freevariables"]):
169 lt, gt = lt.sub({gtvk: 1}), gt.sub({gtvk: 1})
170 if len(gt.hmap) > 1:
171 continue
172 pos_gtvks = {vk for vk, pow in gt.exp.items() if pow > 0}
173 if len(pos_gtvks) > 1:
174 pos_gtvks &= get_free_vks(gt, solution) # remove constants
175 if len(pos_gtvks) != 1: # we'll choose our favorite vk
176 for vk, pow in gt.exp.items():
177 if pow < 0: # remove all non-positive
178 lt, gt = divide_out_vk(vk, pow, lt, gt)
179 # bring over common factors from lt
180 lt_pows = defaultdict(set)
181 for exp in lt.hmap:
182 for vk, pow in exp.items():
183 lt_pows[vk].add(pow)
184 for vk, pows in lt_pows.items():
185 if len(pows) == 1:
186 pow, = pows
187 if pow < 0: # ...but only if they're positive
188 lt, gt = divide_out_vk(vk, pow, lt, gt)
189 # don't choose something that's already been broken down
190 candidatevks = {vk for vk in gt.vks if vk not in breakdowns}
191 if candidatevks:
192 vrisk = solution["sensitivities"]["variablerisk"]
193 chosenvk, *_ = sorted(
194 candidatevks,
195 key=lambda vk: (-float("%.2g" % (gt.exp[vk]*vrisk.get(vk, 0))), str(vk))
196 )
197 for vk, pow in gt.exp.items():
198 if vk is not chosenvk:
199 lt, gt = divide_out_vk(vk, pow, lt, gt)
200 while getattr(constraint, "parent", None):
201 constraint = constraint.parent
202 while getattr(constraint, "generated_by", None):
203 constraint = constraint.generated_by
204 breakdowns[chosenvk].append((lt, gt, constraint))
205 breakdowns = dict(breakdowns) # remove the defaultdict-ness
207 prevlen = None
208 while len(basically_fixed_variables) != prevlen:
209 prevlen = len(basically_fixed_variables)
210 for key in breakdowns:
211 if key not in basically_fixed_variables:
212 get_fixity(basically_fixed_variables, key, breakdowns, solution)
213 return breakdowns
216def get_fixity(basically_fixed, key, bd, solution, visited=set()):
217 lt, gt, _ = bd[key][0]
218 free_vks = get_free_vks(lt, solution).union(get_free_vks(gt, solution))
219 for vk in free_vks:
220 if vk is key or vk in basically_fixed:
221 continue # currently checking or already checked
222 if vk not in bd:
223 return # a very free variable, can't even be broken down
224 if vk in visited:
225 return # tried it before, implicitly it didn't work out
226 # maybe it's basically fixed?
227 visited.add(key)
228 get_fixity(basically_fixed, vk, bd, solution, visited)
229 if vk not in basically_fixed:
230 return # ...well, we tried
231 basically_fixed.add(key)
233# @profile # ~84% of total last check # TODO: remove
234def crawl(basically_fixed_variables, key, bd, solution, basescale=1, permissivity=2, verbosity=0,
235 visited_bdkeys=None, gone_negative=False, all_visited_bdkeys=None):
236 "Returns the tree of breakdowns of key in bd, sorting by solution's values"
237 if key != solution["cost function"] and hasattr(key, "key"):
238 key = key.key # clear up Variables
239 if key in bd:
240 # TODO: do multiple if sensitivities are quite close?
241 composition, keymon, constraint = bd[key][0]
242 elif isinstance(key, Posynomial):
243 composition = key
244 keymon = None
245 else:
246 raise TypeError("the `key` argument must be a VarKey or Posynomial.")
248 if visited_bdkeys is None:
249 visited_bdkeys = set()
250 all_visited_bdkeys = set()
251 if verbosity == 1:
252 already_set = False #not solution._lineageset TODO
253 if not already_set:
254 solution.set_necessarylineage()
255 if verbosity:
256 indent = verbosity-1 # HACK: a bit of overloading, here
257 kvstr = "%s (%s)" % (key, get_valstr(key, solution))
258 if key in all_visited_bdkeys:
259 print(" "*indent + kvstr + " [as broken down above]")
260 verbosity = 0
261 else:
262 print(" "*indent + kvstr)
263 indent += 1
264 orig_subtree = subtree = []
265 tree = Tree(key, basescale, subtree)
266 visited_bdkeys.add(key)
267 all_visited_bdkeys.add(key)
268 if keymon is None:
269 scale = solution(key)/basescale
270 else:
271 if verbosity:
272 print(" "*indent + "which in: "
273 + constraint.str_without(["units", "lineage"])
274 + " (sensitivity %+.2g)" % solution["sensitivities"]["constraints"][constraint])
275 interesting_vks = {key}
276 subkey, = interesting_vks
277 power = keymon.exp[subkey]
278 boring_vks = set(keymon.vks) - interesting_vks
279 scale = solution(key)**power/basescale
280 # TODO: make method that can handle both kinds of transforms
281 if (power != 1 or boring_vks or mag(keymon.c) != 1
282 or keymon.units != key.units): # some kind of transform here
283 units = 1
284 exp = HashVector()
285 for vk in interesting_vks:
286 exp[vk] = keymon.exp[vk]
287 if vk.units:
288 units *= vk.units**keymon.exp[vk]
289 subhmap = NomialMap({exp: 1})
290 try:
291 subhmap.units = None if units == 1 else units
292 except DimensionalityError:
293 # pints was unable to divide a unit by itself bc
294 # it has terrible floating-point errors.
295 # so let's assume it isn't dimensionless
296 # even though it probably is
297 subhmap.units = units
298 freemon = Monomial(subhmap)
299 factor = Monomial(keymon/freemon)
300 scale = scale * solution(factor)
301 if factor != 1:
302 factor = factor**(-1/power) # invert the transform
303 factor.ast = None
304 if verbosity:
305 print(" "*indent + "{ through a factor of %s (%s) }" %
306 (factor.str_without(["units"]),
307 get_valstr(factor, solution)))
308 subsubtree = []
309 transform = Transform(factor, 1, keymon)
310 orig_subtree.append(Tree(transform, basescale, subsubtree))
311 orig_subtree = subsubtree
312 if power != 1:
313 if verbosity:
314 print(" "*indent + "{ through a power of %.2g }" % power)
315 subsubtree = []
316 transform = Transform(1, 1/power, keymon) # inverted bc it's on the gt side
317 orig_subtree.append(Tree(transform, basescale, subsubtree))
318 orig_subtree = subsubtree
320 # TODO: use ast_parsing instead of chop?
321 mons = composition.chop()
322 monsols = [solution(mon) for mon in mons] # ~20% of total last check # TODO: remove
323 parsed_monsols = [getattr(mon, "value", mon) for mon in monsols]
324 monvals = [float(mon/scale) for mon in parsed_monsols] # ~10% of total last check # TODO: remove
325 # sort by value, preserving order in case of value tie
326 sortedmonvals = sorted(zip([-float("%.2g" % mv) for mv in monvals], range(len(mons)), monvals, mons))
327 # print([m.str_without({"units", "lineage"}) for m in mons])
328 if verbosity:
329 if len(monsols) == 1:
330 print(" "*indent + "breaks down into:")
331 else:
332 print(" "*indent + "breaks down into %i monomials:" % len(monsols))
333 indent += 1
334 indent += 1
335 for i, (_, _, scaledmonval, mon) in enumerate(sortedmonvals):
336 if not scaledmonval:
337 continue
338 subtree = orig_subtree # return to the original subtree
339 # time for some filtering
340 interesting_vks = mon.vks
341 potential_filters = [
342 {vk for vk in interesting_vks if vk not in bd},
343 mon.vks - get_free_vks(mon, solution),
344 {vk for vk in interesting_vks if vk in basically_fixed_variables}
345 ]
346 if scaledmonval < 1 - permissivity: # skip breakdown filter
347 potential_filters = potential_filters[1:]
348 potential_filters.insert(0, visited_bdkeys)
349 for filter in potential_filters:
350 if interesting_vks - filter: # don't remove the last one
351 interesting_vks = interesting_vks - filter
352 # if filters weren't enough and permissivity is high enough, sort!
353 if len(interesting_vks) > 1 and permissivity > 1:
354 csenss = solution["sensitivities"]["constraints"]
355 best_vks = sorted((vk for vk in interesting_vks if vk in bd),
356 key=lambda vk: (-abs(float("%.2g" % (mon.exp[vk]*csenss[bd[vk][0][2]]))),
357 -float("%.2g" % solution["variables"][vk]),
358 str(bd[vk][0][0]))) # ~5% of total last check # TODO: remove
359 # TODO: changing to str(vk) above does some odd stuff, why?
360 if best_vks:
361 interesting_vks = set([best_vks[0]])
362 boring_vks = mon.vks - interesting_vks
364 subkey = None
365 if len(interesting_vks) == 1:
366 subkey, = interesting_vks
367 if subkey in visited_bdkeys and len(sortedmonvals) == 1:
368 continue # don't even go there
369 if subkey not in bd:
370 power = 1 # no need for a transform
371 else:
372 power = mon.exp[subkey]
373 if power < 0 and gone_negative:
374 subkey = None # don't breakdown another negative
376 if len(monsols) > 1 and verbosity:
377 indent -= 1
378 print(" "*indent + "%s) forming %i%% of the RHS and %i%% of the total:" % (i+1, scaledmonval/basescale*100, scaledmonval*100))
379 indent += 1
381 if subkey is None:
382 power = 1
383 if scaledmonval > 1 - permissivity and not boring_vks:
384 boring_vks = interesting_vks
385 interesting_vks = set()
386 if not interesting_vks:
387 # prioritize showing some boring_vks as if they were "free"
388 if len(boring_vks) == 1:
389 interesting_vks = boring_vks
390 boring_vks = set()
391 else:
392 for vk in list(boring_vks):
393 if vk.units and not vk.units.dimensionless:
394 interesting_vks.add(vk)
395 boring_vks.remove(vk)
397 if interesting_vks and (boring_vks or mag(mon.c) != 1):
398 units = 1
399 exp = HashVector()
400 for vk in interesting_vks:
401 exp[vk] = mon.exp[vk]
402 if vk.units:
403 units *= vk.units**mon.exp[vk]
404 subhmap = NomialMap({exp: 1})
405 subhmap.units = None if units is 1 else units
406 freemon = Monomial(subhmap)
407 factor = mon/freemon # autoconvert...
408 if (factor.units is None and isinstance(factor, FixedScalar)
409 and abs(factor.value - 1) <= 1e-4):
410 factor = 1 # minor fudge to clear numerical inaccuracies
411 if factor != 1 :
412 factor.ast = None
413 if verbosity:
414 keyvalstr = "%s (%s)" % (factor.str_without(["units"]),
415 get_valstr(factor, solution))
416 print(" "*indent + "{ through a factor of %s }" % keyvalstr)
417 subsubtree = []
418 transform = Transform(factor, 1, mon)
419 subtree.append(Tree(transform, scaledmonval, subsubtree))
420 subtree = subsubtree
421 mon = freemon # simplifies units
422 if power != 1:
423 if verbosity:
424 print(" "*indent + "{ through a power of %.2g }" % power)
425 subsubtree = []
426 transform = Transform(1, power, mon)
427 subtree.append(Tree(transform, scaledmonval, subsubtree))
428 subtree = subsubtree
429 mon = mon**(1/power)
430 mon.ast = None
431 # TODO: make minscale an argument - currently an arbitrary 0.01
432 if (subkey is not None and subkey not in visited_bdkeys
433 and subkey in bd and scaledmonval > 0.05):
434 subverbosity = indent + 1 if verbosity else 0 # slight hack
435 subsubtree = crawl(basically_fixed_variables, subkey, bd, solution, scaledmonval,
436 permissivity, subverbosity, set(visited_bdkeys),
437 gone_negative, all_visited_bdkeys)
438 subtree.append(subsubtree)
439 else:
440 if verbosity:
441 keyvalstr = "%s (%s)" % (mon.str_without(["units"]),
442 get_valstr(mon, solution))
443 print(" "*indent + keyvalstr)
444 subtree.append(Tree(mon, scaledmonval, []))
445 if verbosity == 1:
446 if not already_set:
447 solution.set_necessarylineage(clear=True)
448 return tree
450SYMBOLS = string.ascii_uppercase + string.ascii_lowercase
451for ambiguous_symbol in "lILT":
452 SYMBOLS = SYMBOLS.replace(ambiguous_symbol, "")
454def get_spanstr(legend, length, label, leftwards, solution):
455 "Returns span visualization, collapsing labels to symbols"
456 if label is None:
457 return " "*length
458 spacer, lend, rend = "│", "┯", "┷"
459 if isinstance(label, Transform):
460 spacer, lend, rend = "╎", "╤", "╧"
461 if label.power != 1:
462 spacer = " "
463 lend = rend = "^" if label.power > 0 else "/"
464 # remove origkeys so they collide in the legends dictionary
465 label = Transform(label.factor, label.power, None)
466 if label.power == 1 and len(str(label.factor)) == 1:
467 legend[label] = str(label.factor)
469 if label not in legend:
470 legend[label] = SYMBOLS[len(legend)]
471 shortname = legend[label]
473 if length <= 1:
474 return shortname
475 shortside = int(max(0, length - 2)/2)
476 longside = int(max(0, length - 3)/2)
477 if leftwards:
478 if length == 2:
479 return lend + shortname
480 return lend + spacer*shortside + shortname + spacer*longside + rend
481 else:
482 if length == 2:
483 return shortname + rend
484 # HACK: no corners on long rightwards - only used for depth 0
485 return "┃"*(longside+1) + shortname + "┃"*(shortside+1)
487def discretize(tree, extent, solution, collapse, depth=0, justsplit=False):
488 # TODO: add vertical simplification?
489 key, val, branches = tree
490 if collapse: # collapse Transforms with power 1
491 while any(isinstance(branch.key, Transform) and branch.key.power > 0 for branch in branches):
492 newbranches = []
493 for branch in branches:
494 # isinstance(branch.key, Transform) and branch.key.power > 0
495 if isinstance(branch.key, Transform) and branch.key.power > 0:
496 newbranches.extend(branch.branches)
497 else:
498 newbranches.append(branch)
499 branches = newbranches
501 scale = extent/val
502 values = [b.value for b in branches]
503 bkey_indexs = {}
504 for i, b in enumerate(branches):
505 k = get_keystr(b.key, solution)
506 if isinstance(b.key, Transform):
507 if len(b.branches) == 1:
508 k = get_keystr(b.branches[0].key, solution)
509 if k in bkey_indexs:
510 values[bkey_indexs[k]] += values[i]
511 values[i] = None
512 else:
513 bkey_indexs[k] = i
514 if any(v is None for v in values):
515 bvs = zip(*sorted(((-float("%.2g" % v), i, b, v) for i, (b, v) in enumerate(zip(branches, values)) if v is not None)))
516 _, _, branches, values = bvs
517 branches = list(branches)
518 values = list(values)
519 extents = [int(round(scale*v)) for v in values]
520 surplus = extent - sum(extents)
521 for i, b in enumerate(branches):
522 if isinstance(b.key, Transform):
523 subscale = extents[i]/b.value
524 if not any(round(subscale*subv) for _, subv, _ in b.branches):
525 extents[i] = 0 # transform with no worthy heirs: misc it
526 if not any(extents):
527 return Tree(key, extent, [])
528 if not all(extents): # create a catch-all
529 branches = branches.copy()
530 miscvkeys, miscval = [], 0
531 for subextent in reversed(extents):
532 if not subextent or (branches[-1].value < miscval and surplus < 0):
533 extents.pop()
534 k, v, _ = branches.pop()
535 if isinstance(k, Transform):
536 k = k.origkey # TODO: this is the only use of origkey - remove it
537 if isinstance(k, tuple):
538 vkeys = [(-kv[1], str(kv[0]), kv[0]) for kv in k]
539 if not isinstance(k, tuple):
540 vkeys = [(-float("%.2g" % v), str(k), k)]
541 miscvkeys += vkeys
542 surplus -= (round(scale*(miscval + v))
543 - round(scale*miscval) - subextent)
544 miscval += v
545 misckeys = tuple(k for _, _, k in sorted(miscvkeys))
546 branches.append(Tree(misckeys, miscval, []))
547 extents.append(int(round(scale*miscval)))
548 if surplus:
549 sign = int(np.sign(surplus))
550 bump_priority = sorted((ext, sign*float("%.2g" % b.value), i) for i, (b, ext)
551 in enumerate(zip(branches, extents)))
552 # print(key, surplus, bump_priority)
553 while surplus:
554 try:
555 extents[bump_priority.pop()[-1]] += sign
556 surplus -= sign
557 except IndexError:
558 raise ValueError(val, [b.value for b in branches])
560 tree = Tree(key, extent, [])
561 # simplify based on how we're branching
562 branchfactor = len([ext for ext in extents if ext]) - 1
563 if depth and not isinstance(key, Transform):
564 if extent == 1 or branchfactor >= max(extent-2, 1):
565 # if we'd branch too much, stop
566 return tree
567 if collapse and not branchfactor and not justsplit:
568 # if we didn't just split and aren't about to, skip through
569 return discretize(branches[0], extent, solution, collapse,
570 depth=depth+1, justsplit=False)
571 if branchfactor:
572 justsplit = True
573 elif not isinstance(key, Transform): # justsplit passes through transforms
574 justsplit = False
576 for branch, subextent in zip(branches, extents):
577 if subextent:
578 branch = discretize(branch, subextent, solution, collapse,
579 depth=depth+1, justsplit=justsplit)
580 if (collapse and is_power(branch.key)
581 and all(is_power(b.key) for b in branch.branches)):
582 # combine stacked powers
583 power = branch.key.power
584 for b in branch.branches:
585 key = Transform(1, power*b.key.power, None)
586 if key.power == 1: # powers canceled, collapse both
587 tree.branches.extend(b.branches)
588 else: # collapse this level
589 tree.branches.append(Tree(key, b.value, b.branches))
590 else:
591 tree.branches.append(branch)
592 return tree
594def layer(map, tree, maxdepth, depth=0):
595 "Turns the tree into a 2D-array"
596 key, extent, branches = tree
597 if depth <= maxdepth:
598 if len(map) <= depth:
599 map.append([])
600 map[depth].append((key, extent))
601 if not branches:
602 branches = [Tree(None, extent, [])] # pad it out
603 for branch in branches:
604 layer(map, branch, maxdepth, depth+1)
605 return map
607def plumb(tree, depth=0):
608 "Finds maximum depth of a tree"
609 maxdepth = depth
610 for branch in tree.branches:
611 maxdepth = max(maxdepth, plumb(branch, depth+1))
612 return maxdepth
614def prune(tree, solution, maxlength, length=-1, prefix=""):
615 "Prune branches that are longer than a certain number of characters"
616 key, extent, branches = tree
617 keylength = max(len(get_valstr(key, solution, into="(%s)")),
618 len(get_keystr(key, solution, prefix)))
619 if length == -1 and isinstance(key, VarKey) and key.necessarylineage:
620 prefix = key.lineagestr()
621 length += keylength + 3
622 for branch in branches:
623 keylength = max(len(get_valstr(branch.key, solution, into="(%s)")),
624 len(get_keystr(branch.key, solution, prefix)))
625 branchlength = length + keylength + 3
626 if branchlength > maxlength:
627 return Tree(key, extent, [])
628 return Tree(key, extent, [prune(b, solution, maxlength, length, prefix)
629 for b in branches])
631def simplify(tree, solution, extent, maxdepth, maxlength, collapse):
632 "Discretize, prune, and layer a tree to prepare for printing"
633 subtree = discretize(tree, extent, solution, collapse)
634 if collapse and maxlength:
635 subtree = prune(subtree, solution, maxlength)
636 return layer([], subtree, maxdepth)
638# @profile # ~16% of total last check # TODO: remove
639def graph(tree, breakdowns, solution, basically_fixed_variables, *,
640 height=None, maxdepth=None, maxwidth=81, showlegend=False):
641 "Prints breakdown"
642 already_set = solution._lineageset
643 if not already_set:
644 solution.set_necessarylineage()
645 collapse = (not showlegend) # TODO: set to True while showlegend is True for first approx of receipts; autoinclude with trace?
646 if maxdepth is None:
647 maxdepth = plumb(tree)
648 if height is not None:
649 mt = simplify(tree, solution, height, maxdepth, maxwidth, collapse)
650 else: # zoom in from a default height of 20 to a height of 4 per branch
651 prev_height = None
652 height = 20
653 while prev_height != height:
654 mt = simplify(tree, solution, height, maxdepth, maxwidth, collapse)
655 prev_height = height
656 height = min(height, max(*(4*len(at_depth) for at_depth in mt)))
658 legend = {}
659 chararray = np.full((len(mt), height), "", "object")
660 for depth, elements_at_depth in enumerate(mt):
661 row = ""
662 for i, (element, length) in enumerate(elements_at_depth):
663 leftwards = depth > 0 and length > 2
664 row += get_spanstr(legend, length, element, leftwards, solution)
665 chararray[depth, :] = list(row)
667 # Format depth=0
668 A_key, = [key for key, value in legend.items() if value == "A"]
669 prefix = ""
670 if A_key is solution["cost function"]:
671 A_str = "Cost"
672 else:
673 A_str = get_keystr(A_key, solution)
674 if isinstance(A_key, VarKey) and A_key.necessarylineage:
675 prefix = A_key.lineagestr()
676 A_valstr = get_valstr(A_key, solution, into="(%s)")
677 fmt = "{0:>%s}" % (max(len(A_str), len(A_valstr)) + 3)
678 for j, entry in enumerate(chararray[0,:]):
679 if entry == "A":
680 chararray[0,j] = fmt.format(A_str + "╺┫")
681 chararray[0,j+1] = fmt.format(A_valstr + " ┃")
682 else:
683 chararray[0,j] = fmt.format(entry)
684 # Format depths 1+
685 labeled = set()
686 reverse_legend = {v: k for k, v in legend.items()}
687 legend = {}
688 for pos in range(height):
689 for depth in reversed(range(1,len(mt))):
690 char = chararray[depth, pos]
691 if char not in reverse_legend:
692 continue
693 key = reverse_legend[char]
694 if key not in legend and (isinstance(key, tuple) or (depth != len(mt) - 1 and chararray[depth+1, pos] != " ")):
695 legend[key] = SYMBOLS[len(legend)]
696 if collapse and is_power(key):
697 chararray[depth, pos] = "^" if key.power > 0 else "/"
698 del legend[key]
699 continue
700 if key in legend:
701 chararray[depth, pos] = legend[key]
702 if isinstance(key, tuple) and not isinstance(key, Transform):
703 chararray[depth, pos] = "*" + chararray[depth, pos]
704 del legend[key]
705 if showlegend:
706 continue
708 keystr = get_keystr(key, solution, prefix)
709 if keystr in labeled:
710 valuestr = ""
711 else:
712 valuestr = get_valstr(key, solution, into=" (%s)")
713 if collapse:
714 fmt = "{0:<%s}" % max(len(keystr) + 3, len(valuestr) + 2)
715 else:
716 fmt = "{0:<1}"
717 span = 0
718 tryup, trydn = True, True
719 while tryup or trydn:
720 span += 1
721 if tryup:
722 if pos - span < 0:
723 tryup = False
724 else:
725 upchar = chararray[depth, pos-span]
726 if upchar == "│":
727 chararray[depth, pos-span] = fmt.format("┃")
728 elif upchar == "┯":
729 chararray[depth, pos-span] = fmt.format("┓")
730 else:
731 tryup = False
732 if trydn:
733 if pos + span >= height:
734 trydn = False
735 else:
736 dnchar = chararray[depth, pos+span]
737 if dnchar == "│":
738 chararray[depth, pos+span] = fmt.format("┃")
739 elif dnchar == "┷":
740 chararray[depth, pos+span] = fmt.format("┛")
741 else:
742 trydn = False
743 linkstr = "┣╸"
744 if not isinstance(key, FixedScalar):
745 labeled.add(keystr)
746 if span > 1 and (collapse or pos + 2 >= height
747 or chararray[depth, pos+1] == "┃"):
748 vallabel = chararray[depth, pos+1].rstrip() + valuestr
749 chararray[depth, pos+1] = fmt.format(vallabel)
750 elif showlegend:
751 keystr += valuestr
752 if key in breakdowns and not chararray[depth+1, pos].strip():
753 keystr = keystr + "╶⎨"
754 chararray[depth, pos] = fmt.format(linkstr + keystr)
755 # Rotate and print
756 rowstrs = ["".join(row).rstrip() for row in chararray.T.tolist()]
757 print("\n" + "\n".join(rowstrs) + "\n")
759 if showlegend: # create and print legend
760 legend_lines = []
761 for key, shortname in sorted(legend.items(), key=lambda kv: kv[1]):
762 legend_lines.append(legend_entry(key, shortname, solution, prefix,
763 basically_fixed_variables))
764 maxlens = [max(len(el) for el in col) for col in zip(*legend_lines)]
765 fmts = ["{0:<%s}" % L for L in maxlens]
766 for line in legend_lines:
767 line = "".join(fmt.format(cell)
768 for fmt, cell in zip(fmts, line) if cell).rstrip()
769 print(" " + line)
771 if not already_set:
772 solution.set_necessarylineage(clear=True)
774def legend_entry(key, shortname, solution, prefix, basically_fixed_variables):
775 "Returns list of legend elements"
776 operator = note = ""
777 keystr = valuestr = " "
778 operator = "= " if shortname else " + "
779 if is_factor(key):
780 operator = " ×"
781 key = key.factor
782 free, quasifixed = False, False
783 if any(vk not in basically_fixed_variables
784 for vk in get_free_vks(key, solution)):
785 note = " [free factor]"
786 if is_power(key):
787 valuestr = " ^%.3g" % key.power
788 else:
789 valuestr = get_valstr(key, solution, into=" "+operator+"%s")
790 if not isinstance(key, FixedScalar):
791 keystr = get_keystr(key, solution, prefix)
792 return ["%-4s" % shortname, keystr, valuestr, note]
794def get_keystr(key, solution, prefix=""):
795 "Returns formatted string of the key in solution."
796 if hasattr(key, "str_without"):
797 out = key.str_without({"units", ":MAGIC:"+prefix})
798 elif isinstance(key, tuple):
799 out = "[%i terms]" % len(key)
800 else:
801 out = str(key)
802 return out if len(out) <= 67 else out[:66]+"…"
804def get_valstr(key, solution, into="%s"):
805 "Returns formatted string of the value of key in solution."
806 # get valuestr
807 try:
808 value = solution(key)
809 except (ValueError, TypeError):
810 try:
811 value = sum(solution(subkey) for subkey in key)
812 except (ValueError, TypeError):
813 return " "
814 if isinstance(value, FixedScalar):
815 value = value.value
816 if 1e3 <= mag(value) < 1e6:
817 valuestr = "{:,.0f}".format(mag(value))
818 else:
819 valuestr = "%-.3g" % mag(value)
820 # get unitstr
821 if hasattr(key, "unitstr"):
822 unitstr = key.unitstr()
823 else:
824 try:
825 if hasattr(value, "units"):
826 value.ito_reduced_units()
827 except DimensionalityError:
828 pass
829 unitstr = get_unitstr(value)
830 if unitstr[:2] == "1/":
831 unitstr = "/" + unitstr[2:]
832 if key in solution["constants"] or (
833 hasattr(key, "vks") and key.vks
834 and all(vk in solution["constants"] for vk in key.vks)):
835 unitstr += ", fixed"
836 return into % (valuestr + unitstr)
839import plotly.graph_objects as go
840def plotlyify(tree, solution, minval=None):
841 """Plots model structure as Plotly TreeMap
843 Arguments
844 ---------
845 model: Model
846 GPkit model object
848 itemize (optional): string, either "variables" or "constraints"
849 Specify whether to iterate over the model varkeys or constraints
851 sizebycount (optional): bool
852 Whether to size blocks by number of variables/constraints or use
853 default sizing
855 Returns
856 -------
857 plotly.graph_objects.Figure
858 Plot of model hierarchy
860 """
861 ids = []
862 labels = []
863 parents = []
864 values = []
866 key, value, branches = tree
867 if isinstance(key, VarKey) and key.necessarylineage:
868 prefix = key.lineagestr()
869 else:
870 prefix = ""
872 if minval is None:
873 minval = value/1000
875 parent_budgets = {}
877 def crawl(tree, parent_id=None):
878 key, value, branches = tree
879 if value > minval:
880 if isinstance(key, Transform):
881 id = parent_id
882 else:
883 id = len(ids)+1
884 ids.append(id)
885 labels.append(get_keystr(key, solution, prefix))
886 if not isinstance(key, str):
887 labels[-1] = labels[-1] + "<br>" + get_valstr(key, solution)
888 parents.append(parent_id)
889 parent_budgets[id] = value
890 if parent_id is not None: # make sure there's no overflow
891 if parent_budgets[parent_id] < value:
892 value = parent_budgets[parent_id] # take remained
893 parent_budgets[parent_id] -= value
894 values.append(value)
895 for branch in branches:
896 crawl(branch, id)
898 crawl(tree)
899 return ids, labels, parents, values
901def treemap(ids, labels, parents, values):
902 return go.Figure(go.Treemap(
903 ids=ids,
904 labels=labels,
905 parents=parents,
906 values=values,
907 branchvalues="total"
908 ))
910def icicle(ids, labels, parents, values):
911 return go.Figure(go.Icicle(
912 ids=ids,
913 labels=labels,
914 parents=parents,
915 values=values,
916 branchvalues="total"
917 ))
920import functools
922class Breakdowns(object):
923 def __init__(self, sol):
924 self.sol = sol
925 self.mlookup = {}
926 self.mtree = crawl_modelbd(get_model_breakdown(sol), self.mlookup)
927 self.basically_fixed_variables = set()
928 self.bd = get_breakdowns(self.basically_fixed_variables, self.sol)
930 def trace(self, key, *, permissivity=2):
931 print("") # a little padding to start
932 self.get_tree(key, permissivity=permissivity, verbosity=1)
934 def get_tree(self, key, *, permissivity=2, verbosity=0):
935 tree = None
936 kind = "variable"
937 if isinstance(key, str):
938 if key == "model sensitivities":
939 tree = self.mtree
940 kind = "constraint"
941 elif key == "cost":
942 key = self.sol["cost function"]
943 elif key in self.mlookup:
944 tree = self.mlookup[key]
945 kind = "constraint"
946 else:
947 # TODO: support submodels
948 keys = [vk for vk in self.bd if key in str(vk)]
949 if not keys:
950 raise KeyError(key)
951 elif len(keys) > 1:
952 raise KeyError("There are %i keys containing '%s'." % (len(keys), key))
953 key, = keys
954 if tree is None:
955 tree = crawl(self.basically_fixed_variables, key, self.bd, self.sol,
956 permissivity=permissivity, verbosity=verbosity)
957 return tree, kind
959 def plot(self, key, *, height=None, permissivity=2, showlegend=False,
960 maxwidth=85):
961 tree, kind = self.get_tree(key, permissivity=permissivity)
962 lookup = self.bd if kind == "variable" else self.mlookup
963 graph(tree, lookup, self.sol, self.basically_fixed_variables,
964 height=height, showlegend=showlegend, maxwidth=maxwidth)
966 def treemap(self, key, *, permissivity=2, returnfig=False, filename=None):
967 tree, _ = self.get_tree(key)
968 fig = treemap(*plotlyify(tree, self.sol))
969 if returnfig:
970 return fig
971 if filename is None:
972 filename = str(key)+"_treemap.html"
973 keepcharacters = (".","_")
974 filename = "".join(c for c in filename if c.isalnum()
975 or c in keepcharacters).rstrip()
976 import plotly
977 plotly.offline.plot(fig, filename=filename)
980 def icicle(self, key, *, permissivity=2, returnfig=False, filename=None):
981 tree, _ = self.get_tree(key, permissivity=permissivity)
982 fig = icicle(*plotlyify(tree, self.sol))
983 if returnfig:
984 return fig
985 if filename is None:
986 filename = str(key)+"_icicle.html"
987 keepcharacters = (".","_")
988 filename = "".join(c for c in filename if c.isalnum()
989 or c in keepcharacters).rstrip()
990 import plotly
991 plotly.offline.plot(fig, filename=filename)