Coverage for gpkit/interactive/sankey.py: 89%
184 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-28 11:50 -0400
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-28 11:50 -0400
1"implements Sankey"
2import os
3import re
4from collections import defaultdict
5from collections.abc import Iterable
6import numpy as np
7from ipywidgets import Layout
8from ipysankeywidget import SankeyWidget
9from ..repr_conventions import lineagestr, unitstr
10from .. import Model, GPCOLORS
11from ..constraints.array import ArrayConstraint
14INSENSITIVE = 1e-2
15EPS = 1e-10
17def isnamedmodel(constraint):
18 "Checks if a constraint is a named model"
19 return (isinstance(constraint, Model)
20 and constraint.__class__.__name__ != "Model")
22def getcolor(value):
23 "color scheme for sensitivities"
24 if abs(value) < INSENSITIVE:
25 return "#cfcfcf"
26 return GPCOLORS[0 if value < 0 else 1]
28def cleanfilename(string):
29 "Parses string into valid filename"
30 return re.sub(r"\\/?|\"><:\*", "_", string) # Replace invalid with _
33# pylint: disable=too-many-instance-attributes
34class Sankey:
35 "Return Jupyter diagrams of sensitivity flow"
36 minsenss = 0
37 maxlinks = 20
38 showconstraints = True
39 last_top_node = None
41 def __init__(self, solution, constraintset, csetlabel=None):
42 self.solution = solution
43 self.csenss = solution["sensitivities"]["constraints"]
44 self.cset = constraintset
45 if csetlabel is None:
46 csetlabel = lineagestr(self.cset) or self.cset.__class__.__name__
47 self.csetlabel = csetlabel
48 self.links = defaultdict(float)
49 self.links_to_target = defaultdict(int)
50 self.nodes = {}
52 def add_node(self, target, title, tag=None):
53 "adds nodes of a given target, title, and tag to self.nodes"
54 self.links_to_target[target] += 1
55 node = "%s.%04i" % (target, self.links_to_target[target])
56 self.nodes[node] = {"id": node, "title": title, tag: True}
57 return node
59 def linkfixed(self, cset, target):
60 "adds fixedvariable links as if they were (array)constraints"
61 fixedvecs = {}
62 total_sens = 0
63 for vk in sorted(cset.unique_varkeys, key=str):
64 if vk not in self.solution["constants"]:
65 continue
66 if vk.veckey and vk.veckey not in fixedvecs:
67 vecval = self.solution["constants"][vk.veckey]
68 firstval = vecval.flatten()[0]
69 if vecval.shape and (firstval == vecval).all():
70 label = "%s = %.4g %s" % (vk.veckey.name, firstval,
71 unitstr(vk.veckey))
72 fixedvecs[vk.veckey] = self.add_node(target, label,
73 "constraint")
74 abs_var_sens = -abs(self.solution["sensitivities"] \
75 ["constants"].get(vk, EPS))
76 if np.isnan(abs_var_sens):
77 abs_var_sens = EPS
78 label = "%s = %.4g %s" % (vk.str_without(["lineage"]),
79 self.solution["variables"][vk],
80 unitstr(vk))
81 if vk.veckey in fixedvecs:
82 vectarget = fixedvecs[vk.veckey]
83 source = self.add_node(vectarget, label, "subarray")
84 self.links[source, vectarget] = abs_var_sens
85 self.links[vectarget, target] += abs_var_sens
86 else:
87 source = self.add_node(target, label, "constraint")
88 self.links[source, target] = abs_var_sens
89 total_sens += abs_var_sens
90 return total_sens
92 # pylint: disable=too-many-branches
93 def link(self, cset, target, vk, *, labeled=False, subarray=False):
94 "adds links of a given constraint set to self.links"
95 total_sens = 0
96 switchedtarget = False
97 if not labeled and isnamedmodel(cset):
98 if cset is not self.cset: # top-level, no need to switch targets
99 switchedtarget = target
100 target = self.add_node(target, cset.lineage[-1][0])
101 if vk is None:
102 total_sens += self.linkfixed(cset, target)
103 elif isinstance(cset, ArrayConstraint) and cset.constraints.size > 1:
104 switchedtarget = target
105 cstr = cset.str_without(["lineage", "units"]).replace("[:]", "")
106 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30]
107 target = self.add_node(target, label, "constraint")
108 subarray = True
109 if getattr(cset, "idxlookup", None):
110 cset = {k: cset[i] for k, i in cset.idxlookup.items()}
111 if isinstance(cset, dict):
112 for label, c in cset.items():
113 source = self.add_node(target, label)
114 subtotal_sens = self.link(c, source, vk, labeled=True)
115 self.links[source, target] += subtotal_sens
116 total_sens += subtotal_sens
117 elif isinstance(cset, Iterable):
118 for c in cset:
119 total_sens += self.link(c, target, vk, subarray=subarray)
120 else:
121 if vk is None and cset in self.csenss:
122 total_sens = -abs(self.csenss[cset]) or -EPS
123 elif vk is not None:
124 if cset.v_ss is None:
125 if vk in cset.varkeys:
126 total_sens = EPS
127 elif vk in cset.v_ss:
128 total_sens = cset.v_ss[vk] or EPS
129 if not labeled:
130 cstr = cset.str_without(["lineage", "units"])
131 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30]
132 tag = "subarray" if subarray else "constraint"
133 source = self.add_node(target, label, tag)
134 self.links[source, target] = total_sens
135 if switchedtarget:
136 self.links[target, switchedtarget] += total_sens
137 return total_sens
139 def filter(self, links, function, forced=False):
140 "If over maxlinks, removes links that do not match criteria."
141 if len(links) > self.maxlinks or forced:
142 for (s, t), v in list(links.items()):
143 if not function(s, t, v):
144 del links[(s, t)]
146 # pylint: disable=too-many-locals
147 def diagram(self, variable=None, varlabel=None, *, minsenss=0, maxlinks=20,
148 top=0, bottom=0, left=230, right=140, width=1000, height=400,
149 showconstraints=True):
150 "creates links and an ipython widget to show them"
151 margins = dict(top=top, bottom=bottom, left=left, right=right)
152 self.minsenss = minsenss
153 self.maxlinks = maxlinks
154 self.showconstraints = showconstraints
156 self.solution.set_necessarylineage()
158 if variable:
159 variable = variable.key
160 if not varlabel:
161 varlabel = str(variable)
162 if len(varlabel) > 20:
163 varlabel = variable.str_without(["lineage"])
164 self.nodes[varlabel] = {"id": varlabel, "title": varlabel}
165 csetnode = self.add_node(varlabel, self.csetlabel)
166 if variable in self.solution["sensitivities"]["cost"]:
167 costnode = self.add_node(varlabel, "[cost function]")
168 self.links[costnode, varlabel] = \
169 self.solution["sensitivities"]["cost"][variable]
170 else:
171 csetnode = self.csetlabel
172 self.nodes[self.csetlabel] = {"id": self.csetlabel,
173 "title": self.csetlabel}
174 total_sens = self.link(self.cset, csetnode, variable)
175 if variable:
176 self.links[csetnode, varlabel] = total_sens
178 links, nodes = self._links_and_nodes()
179 out = SankeyWidget(nodes=nodes, links=links, margins=margins,
180 layout=Layout(width=str(width), height=str(height)))
182 filename = self.csetlabel
183 if variable:
184 filename += "_%s" % variable
185 if not os.path.isdir("sankey_autosaves"):
186 os.makedirs("sankey_autosaves")
187 filename = "sankey_autosaves" + os.path.sep + cleanfilename(filename)
188 out.auto_save_png(filename + ".png")
189 out.auto_save_svg(filename + ".svg")
190 out.on_node_clicked(self.onclick)
191 out.on_link_clicked(self.onclick)
193 self.solution.set_necessarylineage(clear=True)
194 return out
196 def _links_and_nodes(self, top_node=None):
197 links = self.links.copy()
198 # filter if...not below the chosen top node
199 if top_node is not None:
200 self.filter(links, lambda s, t, v: top_node in s or top_node in t,
201 forced=True)
202 # ...below minimum sensitivity
203 self.filter(links, lambda s, t, v: abs(v) > self.minsenss, forced=True)
204 if not self.showconstraints:
205 # ...is a constraint or subarray and we're not showing those
206 self.filter(links, lambda s, t, v:
207 ("constraint" not in self.nodes[s]
208 and "subarray" not in self.nodes[s]), forced=True)
209 # ...is a subarray and we still have too many links
210 self.filter(links, lambda s, t, v: "subarray" not in self.nodes[s])
211 # ...is an insensitive constraint and we still have too many links
212 self.filter(links, lambda s, t, v: ("constraint" not in self.nodes[s]
213 or abs(v) > INSENSITIVE))
214 # ...is at culldepth, repeating up to a relative depth of 1 or 2
215 culldepth = max(node.count(".") for node in self.nodes) - 1
216 mindepth = 1 if not top_node else top_node.count(".") + 1
217 while len(links) > self.maxlinks and culldepth > mindepth:
218 self.filter(links, lambda s, t, v: culldepth > s.count("."))
219 culldepth -= 1
220 # ...is a constraint and we still have too many links
221 self.filter(links, lambda s, t, v: "constraint" not in self.nodes[s])
223 linkslist, nodes, nodeset = [], [], set()
224 for (source, target), value in links.items():
225 if source == top_node:
226 nodes.append({"id": self.nodes[target]["id"],
227 "title": "⟶ %s" % self.nodes[target]["title"]})
228 nodeset.add(target)
229 for node in [source, target]:
230 if node not in nodeset:
231 nodes.append({"id": self.nodes[node]["id"],
232 "title": self.nodes[node]["title"]})
233 nodeset.add(node)
234 linkslist.append({"source": source, "target": target,
235 "value": abs(value), "color": getcolor(value),
236 "title": "%+.2g" % value})
237 return linkslist, nodes
239 def onclick(self, sankey, node_or_link):
240 "Callback function for when a node or link is clicked on."
241 if node_or_link is not None:
242 if "id" in node_or_link: # it's a node
243 top_node = node_or_link["id"]
244 else: # it's a link
245 top_node = node_or_link["source"]
246 if self.last_top_node != top_node:
247 sankey.links, sankey.nodes = self._links_and_nodes(top_node)
248 sankey.send_state()
249 self.last_top_node = top_node