Coverage for gpkit/breakdowns.py: 93%
754 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-28 12:37 -0400
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-28 12:37 -0400
1#TODO: cleanup weird conditionals
2# add conversions to plotly/sankey
4# pylint: skip-file
5import string
6from collections import defaultdict, namedtuple, Counter
7from gpkit.nomials import Monomial, Posynomial, Variable
8from gpkit.nomials.map import NomialMap
9from gpkit.small_scripts import mag, try_str_without
10from gpkit.small_classes import FixedScalar, HashVector
11from gpkit.exceptions import DimensionalityError
12from gpkit.repr_conventions import unitstr as get_unitstr
13from gpkit.repr_conventions import lineagestr
14from gpkit.varkey import VarKey
15import numpy as np
17Tree = namedtuple("Tree", ["key", "value", "branches"])
18Transform = namedtuple("Transform", ["factor", "power", "origkey"])
19def is_factor(key):
20 return (isinstance(key, Transform) and key.power == 1)
21def is_power(key):
22 return (isinstance(key, Transform) and key.power != 1)
24def get_free_vks(posy, solution):
25 "Returns all free vks of a given posynomial for a given solution"
26 return set(vk for vk in posy.vks if vk not in solution["constants"])
28def get_model_breakdown(solution):
29 breakdowns = {"|sensitivity|": 0}
30 for constraint, senss in solution["sensitivities"]["constraints"].items():
31 senss = abs(senss) # for those monomial
32 if senss <= 1e-5:
33 continue
34 subbd = breakdowns
35 subbd["|sensitivity|"] += senss
36 for parent in lineagestr(constraint).split("."):
37 if parent == "":
38 continue
39 if parent not in subbd:
40 subbd[parent] = {}
41 subbd = subbd[parent]
42 if "|sensitivity|" not in subbd:
43 subbd["|sensitivity|"] = 0
44 subbd["|sensitivity|"] += senss
45 # treat vectors as namespace
46 constrstr = try_str_without(constraint, {"units", ":MAGIC:"+lineagestr(constraint)})
47 if " at 0x" in constrstr: # don't print memory addresses
48 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
49 subbd[constrstr] = {"|sensitivity|": senss}
50 for vk in solution["sensitivities"]["variables"].keymap: # could this be done away with for backwards compatibility?
51 if not isinstance(vk, VarKey) or (vk.shape and not vk.index):
52 continue
53 senss = abs(solution["sensitivities"]["variables"][vk])
54 if hasattr(senss, "shape"):
55 senss = np.nansum(senss)
56 if senss <= 1e-5:
57 continue
58 subbd = breakdowns
59 subbd["|sensitivity|"] += senss
60 for parent in vk.lineagestr().split("."):
61 if parent == "":
62 continue
63 if parent not in subbd:
64 subbd[parent] = {}
65 subbd = subbd[parent]
66 if "|sensitivity|" not in subbd:
67 subbd["|sensitivity|"] = 0
68 subbd["|sensitivity|"] += senss
69 # treat vectors as namespace (indexing vectors above)
70 vk = vk.str_without({"lineage"}) + get_valstr(vk, solution, " = %s").replace(", fixed", "")
71 subbd[vk] = {"|sensitivity|": senss}
72 # TODO: track down in a live-solve environment why this isn't the same
73 # print(breakdowns["HyperloopSystem"]["|sensitivity|"])
74 return breakdowns
76def crawl_modelbd(bd, lookup, name="Model"):
77 tree = Tree(name, bd.pop("|sensitivity|"), [])
78 if bd:
79 lookup[name] = tree
80 for subname, subtree in sorted(bd.items(),
81 key=lambda kv: (-float("%.2g" % kv[1]["|sensitivity|"]), kv[0])):
82 tree.branches.append(crawl_modelbd(subtree, lookup, subname))
83 return tree
85def divide_out_vk(vk, pow, lt, gt):
86 hmap = NomialMap({HashVector({vk: 1}): 1.0})
87 hmap.units = vk.units
88 var = Monomial(hmap)**pow
89 lt, gt = lt/var, gt/var
90 lt.ast = gt.ast = None
91 return lt, gt
93# @profile
94def get_breakdowns(basically_fixed_variables, solution):
95 """Returns {key: (lt, gt, constraint)} for breakdown constrain in solution.
97 A breakdown constraint is any whose "gt" contains a single free variable.
99 (At present, monomial constraints check both sides as "gt")
100 """
101 breakdowns = defaultdict(list)
102 beatout = defaultdict(set)
103 for constraint, senss in sorted(solution["sensitivities"]["constraints"].items(), key=lambda kv: (-abs(float("%.2g" % kv[1])), str(kv[0]))):
104 while getattr(constraint, "child", None):
105 constraint = constraint.child
106 while getattr(constraint, "generated", None):
107 constraint = constraint.generated
108 if abs(senss) <= 1e-5: # only tight-ish ones
109 continue
110 if constraint.oper == ">=":
111 gt, lt = (constraint.left, constraint.right)
112 elif constraint.oper == "<=":
113 lt, gt = (constraint.left, constraint.right)
114 elif constraint.oper == "=":
115 if senss > 0: # l_over_r is more sensitive - see nomials/math.py
116 lt, gt = (constraint.left, constraint.right)
117 else: # r_over_l is more sensitive - see nomials/math.py
118 gt, lt = (constraint.left, constraint.right)
119 for gtvk in gt.vks: # remove RelaxPCCP.C
120 if (gtvk.name == "C" and gtvk.lineage[0][0] == "RelaxPCCP"
121 and gtvk not in solution["variables"]):
122 lt, gt = lt.sub({gtvk: 1}), gt.sub({gtvk: 1})
123 if len(gt.hmap) > 1:
124 continue
125 pos_gtvks = {vk for vk, pow in gt.exp.items() if pow > 0}
126 if len(pos_gtvks) > 1:
127 pos_gtvks &= get_free_vks(gt, solution) # remove constants
128 if len(pos_gtvks) == 1:
129 chosenvk, = pos_gtvks
130 while getattr(constraint, "parent", None):
131 constraint = constraint.parent
132 while getattr(constraint, "generated_by", None):
133 constraint = constraint.generated_by
134 breakdowns[chosenvk].append((lt, gt, constraint))
135 for constraint, senss in sorted(solution["sensitivities"]["constraints"].items(), key=lambda kv: (-abs(float("%.2g" % kv[1])), str(kv[0]))):
136 if abs(senss) <= 1e-5: # only tight-ish ones
137 continue
138 while getattr(constraint, "child", None):
139 constraint = constraint.child
140 while getattr(constraint, "generated", None):
141 constraint = constraint.generated
142 if constraint.oper == ">=":
143 gt, lt = (constraint.left, constraint.right)
144 elif constraint.oper == "<=":
145 lt, gt = (constraint.left, constraint.right)
146 elif constraint.oper == "=":
147 if senss > 0: # l_over_r is more sensitive - see nomials/math.py
148 lt, gt = (constraint.left, constraint.right)
149 else: # r_over_l is more sensitive - see nomials/math.py
150 gt, lt = (constraint.left, constraint.right)
151 for gtvk in gt.vks:
152 if (gtvk.name == "C" and gtvk.lineage[0][0] == "RelaxPCCP"
153 and gtvk not in solution["variables"]):
154 lt, gt = lt.sub({gtvk: 1}), gt.sub({gtvk: 1})
155 if len(gt.hmap) > 1:
156 continue
157 pos_gtvks = {vk for vk, pow in gt.exp.items() if pow > 0}
158 if len(pos_gtvks) > 1:
159 pos_gtvks &= get_free_vks(gt, solution) # remove constants
160 if len(pos_gtvks) != 1: # we'll choose our favorite vk
161 for vk, pow in gt.exp.items():
162 if pow < 0: # remove all non-positive
163 lt, gt = divide_out_vk(vk, pow, lt, gt)
164 # bring over common factors from lt
165 lt_pows = defaultdict(set)
166 for exp in lt.hmap:
167 for vk, pow in exp.items():
168 lt_pows[vk].add(pow)
169 for vk, pows in lt_pows.items():
170 if len(pows) == 1:
171 pow, = pows
172 if pow < 0: # ...but only if they're positive
173 lt, gt = divide_out_vk(vk, pow, lt, gt)
174 # don't choose something that's already been broken down
175 candidatevks = {vk for vk in gt.vks if vk not in breakdowns}
176 if candidatevks:
177 vrisk = solution["sensitivities"]["variablerisk"]
178 chosenvk, *_ = sorted(
179 candidatevks,
180 key=lambda vk: (-float("%.2g" % (gt.exp[vk]*vrisk.get(vk, 0))), str(vk))
181 )
182 for vk, pow in gt.exp.items():
183 if vk is not chosenvk:
184 lt, gt = divide_out_vk(vk, pow, lt, gt)
185 while getattr(constraint, "parent", None):
186 constraint = constraint.parent
187 while getattr(constraint, "generated_by", None):
188 constraint = constraint.generated_by
189 breakdowns[chosenvk].append((lt, gt, constraint))
190 breakdowns = dict(breakdowns) # remove the defaultdict-ness
192 prevlen = None
193 while len(basically_fixed_variables) != prevlen:
194 prevlen = len(basically_fixed_variables)
195 for key in breakdowns:
196 if key not in basically_fixed_variables:
197 get_fixity(basically_fixed_variables, key, breakdowns, solution)
198 return breakdowns
201def get_fixity(basically_fixed, key, bd, solution, visited=set()):
202 lt, gt, _ = bd[key][0]
203 free_vks = get_free_vks(lt, solution).union(get_free_vks(gt, solution))
204 for vk in free_vks:
205 if vk is key or vk in basically_fixed:
206 continue # currently checking or already checked
207 if vk not in bd:
208 return # a very free variable, can't even be broken down
209 if vk in visited:
210 return # tried it before, implicitly it didn't work out
211 # maybe it's basically fixed?
212 visited.add(key)
213 get_fixity(basically_fixed, vk, bd, solution, visited)
214 if vk not in basically_fixed:
215 return # ...well, we tried
216 basically_fixed.add(key)
218# @profile # ~84% of total last check # TODO: remove
219def crawl(basically_fixed_variables, key, bd, solution, basescale=1, permissivity=2, verbosity=0,
220 visited_bdkeys=None, gone_negative=False, all_visited_bdkeys=None):
221 "Returns the tree of breakdowns of key in bd, sorting by solution's values"
222 if key != solution["cost function"] and hasattr(key, "key"):
223 key = key.key # clear up Variables
224 if key in bd:
225 # TODO: do multiple if sensitivities are quite close?
226 composition, keymon, constraint = bd[key][0]
227 elif isinstance(key, Posynomial):
228 composition = key
229 keymon = None
230 else:
231 raise TypeError("the `key` argument must be a VarKey or Posynomial.")
233 if visited_bdkeys is None:
234 visited_bdkeys = set()
235 all_visited_bdkeys = set()
236 if verbosity == 1:
237 already_set = False #not solution._lineageset TODO
238 if not already_set:
239 solution.set_necessarylineage()
240 if verbosity:
241 indent = verbosity-1 # HACK: a bit of overloading, here
242 kvstr = "%s (%s)" % (key, get_valstr(key, solution))
243 if key in all_visited_bdkeys:
244 print(" "*indent + kvstr + " [as broken down above]")
245 verbosity = 0
246 else:
247 print(" "*indent + kvstr)
248 indent += 1
249 orig_subtree = subtree = []
250 tree = Tree(key, basescale, subtree)
251 visited_bdkeys.add(key)
252 all_visited_bdkeys.add(key)
253 if keymon is None:
254 scale = solution(key)/basescale
255 else:
256 if verbosity:
257 print(" "*indent + "which in: "
258 + constraint.str_without(["units", "lineage"])
259 + " (sensitivity %+.2g)" % solution["sensitivities"]["constraints"][constraint])
260 interesting_vks = {key}
261 subkey, = interesting_vks
262 power = keymon.exp[subkey]
263 boring_vks = set(keymon.vks) - interesting_vks
264 scale = solution(key)**power/basescale
265 # TODO: make method that can handle both kinds of transforms
266 if (power != 1 or boring_vks or mag(keymon.c) != 1
267 or keymon.units != key.units): # some kind of transform here
268 units = 1
269 exp = HashVector()
270 for vk in interesting_vks:
271 exp[vk] = keymon.exp[vk]
272 if vk.units:
273 units *= vk.units**keymon.exp[vk]
274 subhmap = NomialMap({exp: 1})
275 try:
276 subhmap.units = None if units == 1 else units
277 except DimensionalityError:
278 # pints was unable to divide a unit by itself bc
279 # it has terrible floating-point errors.
280 # so let's assume it isn't dimensionless
281 # even though it probably is
282 subhmap.units = units
283 freemon = Monomial(subhmap)
284 factor = Monomial(keymon/freemon)
285 scale = scale * solution(factor)
286 if factor != 1:
287 factor = factor**(-1/power) # invert the transform
288 factor.ast = None
289 if verbosity:
290 print(" "*indent + "{ through a factor of %s (%s) }" %
291 (factor.str_without(["units"]),
292 get_valstr(factor, solution)))
293 subsubtree = []
294 transform = Transform(factor, 1, keymon)
295 orig_subtree.append(Tree(transform, basescale, subsubtree))
296 orig_subtree = subsubtree
297 if power != 1:
298 if verbosity:
299 print(" "*indent + "{ through a power of %.2g }" % power)
300 subsubtree = []
301 transform = Transform(1, 1/power, keymon) # inverted bc it's on the gt side
302 orig_subtree.append(Tree(transform, basescale, subsubtree))
303 orig_subtree = subsubtree
305 # TODO: use ast_parsing instead of chop?
306 mons = composition.chop()
307 monsols = [solution(mon) for mon in mons] # ~20% of total last check # TODO: remove
308 parsed_monsols = [getattr(mon, "value", mon) for mon in monsols]
309 monvals = [float(mon/scale) for mon in parsed_monsols] # ~10% of total last check # TODO: remove
310 # sort by value, preserving order in case of value tie
311 sortedmonvals = sorted(zip([-float("%.2g" % mv) for mv in monvals], range(len(mons)), monvals, mons))
312 # print([m.str_without({"units", "lineage"}) for m in mons])
313 if verbosity:
314 if len(monsols) == 1:
315 print(" "*indent + "breaks down into:")
316 else:
317 print(" "*indent + "breaks down into %i monomials:" % len(monsols))
318 indent += 1
319 indent += 1
320 for i, (_, _, scaledmonval, mon) in enumerate(sortedmonvals):
321 if not scaledmonval:
322 continue
323 subtree = orig_subtree # return to the original subtree
324 # time for some filtering
325 interesting_vks = mon.vks
326 potential_filters = [
327 {vk for vk in interesting_vks if vk not in bd},
328 mon.vks - get_free_vks(mon, solution),
329 {vk for vk in interesting_vks if vk in basically_fixed_variables}
330 ]
331 if scaledmonval < 1 - permissivity: # skip breakdown filter
332 potential_filters = potential_filters[1:]
333 potential_filters.insert(0, visited_bdkeys)
334 for filter in potential_filters:
335 if interesting_vks - filter: # don't remove the last one
336 interesting_vks = interesting_vks - filter
337 # if filters weren't enough and permissivity is high enough, sort!
338 if len(interesting_vks) > 1 and permissivity > 1:
339 csenss = solution["sensitivities"]["constraints"]
340 best_vks = sorted((vk for vk in interesting_vks if vk in bd),
341 key=lambda vk: (-abs(float("%.2g" % (mon.exp[vk]*csenss[bd[vk][0][2]]))),
342 -float("%.2g" % solution["variables"][vk]),
343 str(bd[vk][0][0]))) # ~5% of total last check # TODO: remove
344 # TODO: changing to str(vk) above does some odd stuff, why?
345 if best_vks:
346 interesting_vks = set([best_vks[0]])
347 boring_vks = mon.vks - interesting_vks
349 subkey = None
350 if len(interesting_vks) == 1:
351 subkey, = interesting_vks
352 if subkey in visited_bdkeys and len(sortedmonvals) == 1:
353 continue # don't even go there
354 if subkey not in bd:
355 power = 1 # no need for a transform
356 else:
357 power = mon.exp[subkey]
358 if power < 0 and gone_negative:
359 subkey = None # don't breakdown another negative
361 if len(monsols) > 1 and verbosity:
362 indent -= 1
363 print(" "*indent + "%s) forming %i%% of the RHS and %i%% of the total:" % (i+1, scaledmonval/basescale*100, scaledmonval*100))
364 indent += 1
366 if subkey is None:
367 power = 1
368 if scaledmonval > 1 - permissivity and not boring_vks:
369 boring_vks = interesting_vks
370 interesting_vks = set()
371 if not interesting_vks:
372 # prioritize showing some boring_vks as if they were "free"
373 if len(boring_vks) == 1:
374 interesting_vks = boring_vks
375 boring_vks = set()
376 else:
377 for vk in list(boring_vks):
378 if vk.units and not vk.units.dimensionless:
379 interesting_vks.add(vk)
380 boring_vks.remove(vk)
382 if interesting_vks and (boring_vks or mag(mon.c) != 1):
383 units = 1
384 exp = HashVector()
385 for vk in interesting_vks:
386 exp[vk] = mon.exp[vk]
387 if vk.units:
388 units *= vk.units**mon.exp[vk]
389 subhmap = NomialMap({exp: 1})
390 subhmap.units = None if units is 1 else units
391 freemon = Monomial(subhmap)
392 factor = mon/freemon # autoconvert...
393 if (factor.units is None and isinstance(factor, FixedScalar)
394 and abs(factor.value - 1) <= 1e-4):
395 factor = 1 # minor fudge to clear numerical inaccuracies
396 if factor != 1 :
397 factor.ast = None
398 if verbosity:
399 keyvalstr = "%s (%s)" % (factor.str_without(["units"]),
400 get_valstr(factor, solution))
401 print(" "*indent + "{ through a factor of %s }" % keyvalstr)
402 subsubtree = []
403 transform = Transform(factor, 1, mon)
404 subtree.append(Tree(transform, scaledmonval, subsubtree))
405 subtree = subsubtree
406 mon = freemon # simplifies units
407 if power != 1:
408 if verbosity:
409 print(" "*indent + "{ through a power of %.2g }" % power)
410 subsubtree = []
411 transform = Transform(1, power, mon)
412 subtree.append(Tree(transform, scaledmonval, subsubtree))
413 subtree = subsubtree
414 mon = mon**(1/power)
415 mon.ast = None
416 # TODO: make minscale an argument - currently an arbitrary 0.01
417 if (subkey is not None and subkey not in visited_bdkeys
418 and subkey in bd and scaledmonval > 0.05):
419 subverbosity = indent + 1 if verbosity else 0 # slight hack
420 subsubtree = crawl(basically_fixed_variables, subkey, bd, solution, scaledmonval,
421 permissivity, subverbosity, set(visited_bdkeys),
422 gone_negative, all_visited_bdkeys)
423 subtree.append(subsubtree)
424 else:
425 if verbosity:
426 keyvalstr = "%s (%s)" % (mon.str_without(["units"]),
427 get_valstr(mon, solution))
428 print(" "*indent + keyvalstr)
429 subtree.append(Tree(mon, scaledmonval, []))
430 if verbosity == 1:
431 if not already_set:
432 solution.set_necessarylineage(clear=True)
433 return tree
435SYMBOLS = string.ascii_uppercase + string.ascii_lowercase
436for ambiguous_symbol in "lILT":
437 SYMBOLS = SYMBOLS.replace(ambiguous_symbol, "")
439def get_spanstr(legend, length, label, leftwards, solution):
440 "Returns span visualization, collapsing labels to symbols"
441 if label is None:
442 return " "*length
443 spacer, lend, rend = "│", "┯", "┷"
444 if isinstance(label, Transform):
445 spacer, lend, rend = "╎", "╤", "╧"
446 if label.power != 1:
447 spacer = " "
448 lend = rend = "^" if label.power > 0 else "/"
449 # remove origkeys so they collide in the legends dictionary
450 label = Transform(label.factor, label.power, None)
451 if label.power == 1 and len(str(label.factor)) == 1:
452 legend[label] = str(label.factor)
454 if label not in legend:
455 legend[label] = SYMBOLS[len(legend)]
456 shortname = legend[label]
458 if length <= 1:
459 return shortname
460 shortside = int(max(0, length - 2)/2)
461 longside = int(max(0, length - 3)/2)
462 if leftwards:
463 if length == 2:
464 return lend + shortname
465 return lend + spacer*shortside + shortname + spacer*longside + rend
466 else:
467 if length == 2:
468 return shortname + rend
469 # HACK: no corners on long rightwards - only used for depth 0
470 return "┃"*(longside+1) + shortname + "┃"*(shortside+1)
472def discretize(tree, extent, solution, collapse, depth=0, justsplit=False):
473 # TODO: add vertical simplification?
474 key, val, branches = tree
475 if collapse: # collapse Transforms with power 1
476 while any(isinstance(branch.key, Transform) and branch.key.power > 0 for branch in branches):
477 newbranches = []
478 for branch in branches:
479 # isinstance(branch.key, Transform) and branch.key.power > 0
480 if isinstance(branch.key, Transform) and branch.key.power > 0:
481 newbranches.extend(branch.branches)
482 else:
483 newbranches.append(branch)
484 branches = newbranches
486 scale = extent/val
487 values = [b.value for b in branches]
488 bkey_indexs = {}
489 for i, b in enumerate(branches):
490 k = get_keystr(b.key, solution)
491 if isinstance(b.key, Transform):
492 if len(b.branches) == 1:
493 k = get_keystr(b.branches[0].key, solution)
494 if k in bkey_indexs:
495 values[bkey_indexs[k]] += values[i]
496 values[i] = None
497 else:
498 bkey_indexs[k] = i
499 if any(v is None for v in values):
500 bvs = zip(*sorted(((-float("%.2g" % v), i, b, v) for i, (b, v) in enumerate(zip(branches, values)) if v is not None)))
501 _, _, branches, values = bvs
502 branches = list(branches)
503 values = list(values)
504 extents = [int(round(scale*v)) for v in values]
505 surplus = extent - sum(extents)
506 for i, b in enumerate(branches):
507 if isinstance(b.key, Transform):
508 subscale = extents[i]/b.value
509 if not any(round(subscale*subv) for _, subv, _ in b.branches):
510 extents[i] = 0 # transform with no worthy heirs: misc it
511 if not any(extents):
512 return Tree(key, extent, [])
513 if not all(extents): # create a catch-all
514 branches = branches.copy()
515 miscvkeys, miscval = [], 0
516 for subextent in reversed(extents):
517 if not subextent or (branches[-1].value < miscval and surplus < 0):
518 extents.pop()
519 k, v, _ = branches.pop()
520 if isinstance(k, Transform):
521 k = k.origkey # TODO: this is the only use of origkey - remove it
522 if isinstance(k, tuple):
523 vkeys = [(-kv[1], str(kv[0]), kv[0]) for kv in k]
524 if not isinstance(k, tuple):
525 vkeys = [(-float("%.2g" % v), str(k), k)]
526 miscvkeys += vkeys
527 surplus -= (round(scale*(miscval + v))
528 - round(scale*miscval) - subextent)
529 miscval += v
530 misckeys = tuple(k for _, _, k in sorted(miscvkeys))
531 branches.append(Tree(misckeys, miscval, []))
532 extents.append(int(round(scale*miscval)))
533 if surplus:
534 sign = int(np.sign(surplus))
535 bump_priority = sorted((ext, sign*float("%.2g" % b.value), i) for i, (b, ext)
536 in enumerate(zip(branches, extents)))
537 # print(key, surplus, bump_priority)
538 while surplus:
539 try:
540 extents[bump_priority.pop()[-1]] += sign
541 surplus -= sign
542 except IndexError:
543 raise ValueError(val, [b.value for b in branches])
545 tree = Tree(key, extent, [])
546 # simplify based on how we're branching
547 branchfactor = len([ext for ext in extents if ext]) - 1
548 if depth and not isinstance(key, Transform):
549 if extent == 1 or branchfactor >= max(extent-2, 1):
550 # if we'd branch too much, stop
551 return tree
552 if collapse and not branchfactor and not justsplit:
553 # if we didn't just split and aren't about to, skip through
554 return discretize(branches[0], extent, solution, collapse,
555 depth=depth+1, justsplit=False)
556 if branchfactor:
557 justsplit = True
558 elif not isinstance(key, Transform): # justsplit passes through transforms
559 justsplit = False
561 for branch, subextent in zip(branches, extents):
562 if subextent:
563 branch = discretize(branch, subextent, solution, collapse,
564 depth=depth+1, justsplit=justsplit)
565 if (collapse and is_power(branch.key)
566 and all(is_power(b.key) for b in branch.branches)):
567 # combine stacked powers
568 power = branch.key.power
569 for b in branch.branches:
570 key = Transform(1, power*b.key.power, None)
571 if key.power == 1: # powers canceled, collapse both
572 tree.branches.extend(b.branches)
573 else: # collapse this level
574 tree.branches.append(Tree(key, b.value, b.branches))
575 else:
576 tree.branches.append(branch)
577 return tree
579def layer(map, tree, maxdepth, depth=0):
580 "Turns the tree into a 2D-array"
581 key, extent, branches = tree
582 if depth <= maxdepth:
583 if len(map) <= depth:
584 map.append([])
585 map[depth].append((key, extent))
586 if not branches:
587 branches = [Tree(None, extent, [])] # pad it out
588 for branch in branches:
589 layer(map, branch, maxdepth, depth+1)
590 return map
592def plumb(tree, depth=0):
593 "Finds maximum depth of a tree"
594 maxdepth = depth
595 for branch in tree.branches:
596 maxdepth = max(maxdepth, plumb(branch, depth+1))
597 return maxdepth
599def prune(tree, solution, maxlength, length=-1, prefix=""):
600 "Prune branches that are longer than a certain number of characters"
601 key, extent, branches = tree
602 keylength = max(len(get_valstr(key, solution, into="(%s)")),
603 len(get_keystr(key, solution, prefix)))
604 if length == -1 and isinstance(key, VarKey) and key.necessarylineage:
605 prefix = key.lineagestr()
606 length += keylength + 3
607 for branch in branches:
608 keylength = max(len(get_valstr(branch.key, solution, into="(%s)")),
609 len(get_keystr(branch.key, solution, prefix)))
610 branchlength = length + keylength + 3
611 if branchlength > maxlength:
612 return Tree(key, extent, [])
613 return Tree(key, extent, [prune(b, solution, maxlength, length, prefix)
614 for b in branches])
616def simplify(tree, solution, extent, maxdepth, maxlength, collapse):
617 "Discretize, prune, and layer a tree to prepare for printing"
618 subtree = discretize(tree, extent, solution, collapse)
619 if collapse and maxlength:
620 subtree = prune(subtree, solution, maxlength)
621 return layer([], subtree, maxdepth)
623# @profile # ~16% of total last check # TODO: remove
624def graph(tree, breakdowns, solution, basically_fixed_variables, *,
625 height=None, maxdepth=None, maxwidth=81, showlegend=False):
626 "Prints breakdown"
627 already_set = solution._lineageset
628 if not already_set:
629 solution.set_necessarylineage()
630 collapse = (not showlegend) # TODO: set to True while showlegend is True for first approx of receipts; autoinclude with trace?
631 if maxdepth is None:
632 maxdepth = plumb(tree)
633 if height is not None:
634 mt = simplify(tree, solution, height, maxdepth, maxwidth, collapse)
635 else: # zoom in from a default height of 20 to a height of 4 per branch
636 prev_height = None
637 height = 20
638 while prev_height != height:
639 mt = simplify(tree, solution, height, maxdepth, maxwidth, collapse)
640 prev_height = height
641 height = min(height, max(*(4*len(at_depth) for at_depth in mt)))
643 legend = {}
644 chararray = np.full((len(mt), height), "", "object")
645 for depth, elements_at_depth in enumerate(mt):
646 row = ""
647 for i, (element, length) in enumerate(elements_at_depth):
648 leftwards = depth > 0 and length > 2
649 row += get_spanstr(legend, length, element, leftwards, solution)
650 chararray[depth, :] = list(row)
652 # Format depth=0
653 A_key, = [key for key, value in legend.items() if value == "A"]
654 prefix = ""
655 if A_key is solution["cost function"]:
656 A_str = "Cost"
657 else:
658 A_str = get_keystr(A_key, solution)
659 if isinstance(A_key, VarKey) and A_key.necessarylineage:
660 prefix = A_key.lineagestr()
661 A_valstr = get_valstr(A_key, solution, into="(%s)")
662 fmt = "{0:>%s}" % (max(len(A_str), len(A_valstr)) + 3)
663 for j, entry in enumerate(chararray[0,:]):
664 if entry == "A":
665 chararray[0,j] = fmt.format(A_str + "╺┫")
666 chararray[0,j+1] = fmt.format(A_valstr + " ┃")
667 else:
668 chararray[0,j] = fmt.format(entry)
669 # Format depths 1+
670 labeled = set()
671 reverse_legend = {v: k for k, v in legend.items()}
672 legend = {}
673 for pos in range(height):
674 for depth in reversed(range(1,len(mt))):
675 char = chararray[depth, pos]
676 if char not in reverse_legend:
677 continue
678 key = reverse_legend[char]
679 if key not in legend and (isinstance(key, tuple) or (depth != len(mt) - 1 and chararray[depth+1, pos] != " ")):
680 legend[key] = SYMBOLS[len(legend)]
681 if collapse and is_power(key):
682 chararray[depth, pos] = "^" if key.power > 0 else "/"
683 del legend[key]
684 continue
685 if key in legend:
686 chararray[depth, pos] = legend[key]
687 if isinstance(key, tuple) and not isinstance(key, Transform):
688 chararray[depth, pos] = "*" + chararray[depth, pos]
689 del legend[key]
690 if showlegend:
691 continue
693 keystr = get_keystr(key, solution, prefix)
694 if keystr in labeled:
695 valuestr = ""
696 else:
697 valuestr = get_valstr(key, solution, into=" (%s)")
698 if collapse:
699 fmt = "{0:<%s}" % max(len(keystr) + 3, len(valuestr) + 2)
700 else:
701 fmt = "{0:<1}"
702 span = 0
703 tryup, trydn = True, True
704 while tryup or trydn:
705 span += 1
706 if tryup:
707 if pos - span < 0:
708 tryup = False
709 else:
710 upchar = chararray[depth, pos-span]
711 if upchar == "│":
712 chararray[depth, pos-span] = fmt.format("┃")
713 elif upchar == "┯":
714 chararray[depth, pos-span] = fmt.format("┓")
715 else:
716 tryup = False
717 if trydn:
718 if pos + span >= height:
719 trydn = False
720 else:
721 dnchar = chararray[depth, pos+span]
722 if dnchar == "│":
723 chararray[depth, pos+span] = fmt.format("┃")
724 elif dnchar == "┷":
725 chararray[depth, pos+span] = fmt.format("┛")
726 else:
727 trydn = False
728 linkstr = "┣╸"
729 if not isinstance(key, FixedScalar):
730 labeled.add(keystr)
731 if span > 1 and (collapse or pos + 2 >= height
732 or chararray[depth, pos+1] == "┃"):
733 vallabel = chararray[depth, pos+1].rstrip() + valuestr
734 chararray[depth, pos+1] = fmt.format(vallabel)
735 elif showlegend:
736 keystr += valuestr
737 if (key in breakdowns and not chararray[depth+1, pos].strip()
738 and (depth >= len(mt)-2
739 or not chararray[depth+2, pos].strip())):
740 keystr = keystr + "╶⎨"
741 chararray[depth, pos] = fmt.format(linkstr + keystr)
742 # Rotate and print
743 rowstrs = ["".join(row).rstrip() for row in chararray.T.tolist()]
744 print("\n" + "\n".join(rowstrs) + "\n")
746 if showlegend: # create and print legend
747 legend_lines = []
748 for key, shortname in sorted(legend.items(), key=lambda kv: kv[1]):
749 legend_lines.append(legend_entry(key, shortname, solution, prefix,
750 basically_fixed_variables))
751 maxlens = [max(len(el) for el in col) for col in zip(*legend_lines)]
752 fmts = ["{0:<%s}" % L for L in maxlens]
753 for line in legend_lines:
754 line = "".join(fmt.format(cell)
755 for fmt, cell in zip(fmts, line) if cell).rstrip()
756 print(" " + line)
758 if not already_set:
759 solution.set_necessarylineage(clear=True)
761def legend_entry(key, shortname, solution, prefix, basically_fixed_variables):
762 "Returns list of legend elements"
763 operator = note = ""
764 keystr = valuestr = " "
765 operator = "= " if shortname else " + "
766 if is_factor(key):
767 operator = " ×"
768 key = key.factor
769 free, quasifixed = False, False
770 if any(vk not in basically_fixed_variables
771 for vk in get_free_vks(key, solution)):
772 note = " [free factor]"
773 if is_power(key):
774 valuestr = " ^%.3g" % key.power
775 else:
776 valuestr = get_valstr(key, solution, into=" "+operator+"%s")
777 if not isinstance(key, FixedScalar):
778 keystr = get_keystr(key, solution, prefix)
779 return ["%-4s" % shortname, keystr, valuestr, note]
781def get_keystr(key, solution, prefix=""):
782 "Returns formatted string of the key in solution."
783 if hasattr(key, "str_without"):
784 out = key.str_without({"units", ":MAGIC:"+prefix})
785 elif isinstance(key, tuple):
786 out = "[%i terms]" % len(key)
787 else:
788 out = str(key)
789 return out if len(out) <= 67 else out[:66]+"…"
791def get_valstr(key, solution, into="%s"):
792 "Returns formatted string of the value of key in solution."
793 # get valuestr
794 try:
795 value = solution(key)
796 except (ValueError, TypeError):
797 try:
798 value = sum(solution(subkey) for subkey in key)
799 except (ValueError, TypeError):
800 return " "
801 if isinstance(value, FixedScalar):
802 value = value.value
803 if 1e3 <= mag(value) < 1e6:
804 valuestr = "{:,.0f}".format(mag(value))
805 else:
806 valuestr = "%-.3g" % mag(value)
807 # get unitstr
808 if hasattr(key, "unitstr"):
809 unitstr = key.unitstr()
810 else:
811 try:
812 if hasattr(value, "units"):
813 value.ito_reduced_units()
814 except DimensionalityError:
815 pass
816 unitstr = get_unitstr(value)
817 if unitstr[:2] == "1/":
818 unitstr = "/" + unitstr[2:]
819 if key in solution["constants"] or (
820 hasattr(key, "vks") and key.vks
821 and all(vk in solution["constants"] for vk in key.vks)):
822 unitstr += ", fixed"
823 return into % (valuestr + unitstr)
826import plotly.graph_objects as go
827def plotlyify(tree, solution, minval=None):
828 """Plots model structure as Plotly TreeMap
830 Arguments
831 ---------
832 model: Model
833 GPkit model object
835 itemize (optional): string, either "variables" or "constraints"
836 Specify whether to iterate over the model varkeys or constraints
838 sizebycount (optional): bool
839 Whether to size blocks by number of variables/constraints or use
840 default sizing
842 Returns
843 -------
844 plotly.graph_objects.Figure
845 Plot of model hierarchy
847 """
848 ids = []
849 labels = []
850 parents = []
851 values = []
853 key, value, branches = tree
854 if isinstance(key, VarKey) and key.necessarylineage:
855 prefix = key.lineagestr()
856 else:
857 prefix = ""
859 if minval is None:
860 minval = value/1000
862 parent_budgets = {}
864 def crawl(tree, parent_id=None):
865 key, value, branches = tree
866 if value > minval:
867 if isinstance(key, Transform):
868 id = parent_id
869 else:
870 id = len(ids)+1
871 ids.append(id)
872 labels.append(get_keystr(key, solution, prefix))
873 if not isinstance(key, str):
874 labels[-1] = labels[-1] + "<br>" + get_valstr(key, solution)
875 parents.append(parent_id)
876 if parent_id is not None: # make sure there's no overflow
877 if parent_budgets[parent_id] < value:
878 value = parent_budgets[parent_id] # take remainder
879 parent_budgets[parent_id] -= value
880 values.append(value)
881 parent_budgets[id] = value
882 for branch in branches:
883 crawl(branch, id)
885 crawl(tree)
886 return ids, labels, parents, values
888def treemap(ids, labels, parents, values):
889 return go.Figure(go.Treemap(
890 ids=ids,
891 labels=labels,
892 parents=parents,
893 values=values,
894 branchvalues="total"
895 ))
897def icicle(ids, labels, parents, values):
898 return go.Figure(go.Icicle(
899 ids=ids,
900 labels=labels,
901 parents=parents,
902 values=values,
903 branchvalues="total"
904 ))
907import functools
909class Breakdowns(object):
910 def __init__(self, sol):
911 self.sol = sol
912 self.mlookup = {}
913 self.mtree = crawl_modelbd(get_model_breakdown(sol), self.mlookup)
914 self.basically_fixed_variables = set()
915 self.bd = get_breakdowns(self.basically_fixed_variables, self.sol)
917 def trace(self, key, *, permissivity=2):
918 print("") # a little padding to start
919 self.get_tree(key, permissivity=permissivity, verbosity=1)
921 def get_tree(self, key, *, permissivity=2, verbosity=0):
922 tree = None
923 kind = "variable"
924 if isinstance(key, str):
925 if key == "model sensitivities":
926 tree = self.mtree
927 kind = "constraint"
928 elif key == "cost":
929 key = self.sol["cost function"]
930 elif key in self.mlookup:
931 tree = self.mlookup[key]
932 kind = "constraint"
933 else:
934 # TODO: support submodels
935 keys = [vk for vk in self.bd
936 if key in str(vk) and key[-1] == str(vk)[-1]]
937 if not keys:
938 raise KeyError(key)
939 elif len(keys) > 1:
940 raise KeyError("There are %i keys containing '%s'." % (len(keys), key))
941 key, = keys
942 if tree is None:
943 tree = crawl(self.basically_fixed_variables, key, self.bd, self.sol,
944 permissivity=permissivity, verbosity=verbosity)
945 return tree, kind
947 def plot(self, key, *, height=None, permissivity=2, showlegend=False,
948 maxwidth=85):
949 tree, kind = self.get_tree(key, permissivity=permissivity)
950 lookup = self.bd if kind == "variable" else self.mlookup
951 graph(tree, lookup, self.sol, self.basically_fixed_variables,
952 height=height, showlegend=showlegend, maxwidth=maxwidth)
954 def treemap(self, key, *, permissivity=2, returnfig=False, filename=None):
955 tree, _ = self.get_tree(key)
956 fig = treemap(*plotlyify(tree, self.sol))
957 if returnfig:
958 return fig
959 if filename is None:
960 filename = str(key)+"_treemap.html"
961 keepcharacters = (".","_")
962 filename = "".join(c for c in filename if c.isalnum()
963 or c in keepcharacters).rstrip()
964 import plotly
965 plotly.offline.plot(fig, filename=filename)
968 def icicle(self, key, *, permissivity=2, returnfig=False, filename=None):
969 tree, _ = self.get_tree(key, permissivity=permissivity)
970 fig = icicle(*plotlyify(tree, self.sol))
971 if returnfig:
972 return fig
973 if filename is None:
974 filename = str(key)+"_icicle.html"
975 keepcharacters = (".","_")
976 filename = "".join(c for c in filename if c.isalnum()
977 or c in keepcharacters).rstrip()
978 import plotly
979 plotly.offline.plot(fig, filename=filename)