Coverage for gpkit/interactive/sankey.py: 89%
184 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 22:33 -0500
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 22:33 -0500
1"implements Sankey"
2# pylint: disable=import-error,consider-using-f-string
3import os
4import re
5from collections import defaultdict
6from collections.abc import Iterable
7import numpy as np
8from ipywidgets import Layout
9from ipysankeywidget import SankeyWidget
10from ..repr_conventions import lineagestr, unitstr
11from .. import Model, GPCOLORS
12from ..constraints.array import ArrayConstraint
15INSENSITIVE = 1e-2
16EPS = 1e-10
18def isnamedmodel(constraint):
19 "Checks if a constraint is a named model"
20 return (isinstance(constraint, Model)
21 and constraint.__class__.__name__ != "Model")
23def getcolor(value):
24 "color scheme for sensitivities"
25 if abs(value) < INSENSITIVE:
26 return "#cfcfcf"
27 return GPCOLORS[0 if value < 0 else 1]
29def cleanfilename(string):
30 "Parses string into valid filename"
31 return re.sub(r"\\/?|\"><:\*", "_", string) # Replace invalid with _
34# pylint: disable=too-many-instance-attributes
35class Sankey:
36 "Return Jupyter diagrams of sensitivity flow"
37 minsenss = 0
38 maxlinks = 20
39 showconstraints = True
40 last_top_node = None
42 def __init__(self, solution, constraintset, csetlabel=None):
43 self.solution = solution
44 self.csenss = solution["sensitivities"]["constraints"]
45 self.cset = constraintset
46 if csetlabel is None:
47 csetlabel = lineagestr(self.cset) or self.cset.__class__.__name__
48 self.csetlabel = csetlabel
49 self.links = defaultdict(float)
50 self.links_to_target = defaultdict(int)
51 self.nodes = {}
53 def add_node(self, target, title, tag=None):
54 "adds nodes of a given target, title, and tag to self.nodes"
55 self.links_to_target[target] += 1
56 node = "%s.%04i" % (target, self.links_to_target[target])
57 self.nodes[node] = {"id": node, "title": title, tag: True}
58 return node
60 def linkfixed(self, cset, target):
61 "adds fixedvariable links as if they were (array)constraints"
62 fixedvecs = {}
63 total_sens = 0
64 for vk in sorted(cset.unique_varkeys, key=str):
65 if vk not in self.solution["constants"]:
66 continue
67 if vk.veckey and vk.veckey not in fixedvecs:
68 vecval = self.solution["constants"][vk.veckey]
69 firstval = vecval.flatten()[0]
70 if vecval.shape and (firstval == vecval).all():
71 label = "%s = %.4g %s" % (vk.veckey.name, firstval,
72 unitstr(vk.veckey))
73 fixedvecs[vk.veckey] = self.add_node(target, label,
74 "constraint")
75 abs_var_sens = -abs(self.solution["sensitivities"] \
76 ["constants"].get(vk, EPS))
77 if np.isnan(abs_var_sens):
78 abs_var_sens = EPS
79 label = "%s = %.4g %s" % (vk.str_without(["lineage"]),
80 self.solution["variables"][vk],
81 unitstr(vk))
82 if vk.veckey in fixedvecs:
83 vectarget = fixedvecs[vk.veckey]
84 source = self.add_node(vectarget, label, "subarray")
85 self.links[source, vectarget] = abs_var_sens
86 self.links[vectarget, target] += abs_var_sens
87 else:
88 source = self.add_node(target, label, "constraint")
89 self.links[source, target] = abs_var_sens
90 total_sens += abs_var_sens
91 return total_sens
93 # pylint: disable=too-many-branches, too-many-arguments
94 def link(self, cset, target, vk, *, labeled=False, subarray=False):
95 "adds links of a given constraint set to self.links"
96 total_sens = 0
97 switchedtarget = False
98 if not labeled and isnamedmodel(cset):
99 if cset is not self.cset: # top-level, no need to switch targets
100 switchedtarget = target
101 target = self.add_node(target, cset.lineage[-1][0])
102 if vk is None:
103 total_sens += self.linkfixed(cset, target)
104 elif isinstance(cset, ArrayConstraint) and cset.constraints.size > 1:
105 switchedtarget = target
106 cstr = cset.str_without(["lineage", "units"]).replace("[:]", "")
107 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30]
108 target = self.add_node(target, label, "constraint")
109 subarray = True
110 if getattr(cset, "idxlookup", None):
111 cset = {k: cset[i] for k, i in cset.idxlookup.items()}
112 if isinstance(cset, dict):
113 for label, c in cset.items():
114 source = self.add_node(target, label)
115 subtotal_sens = self.link(c, source, vk, labeled=True)
116 self.links[source, target] += subtotal_sens
117 total_sens += subtotal_sens
118 elif isinstance(cset, Iterable):
119 for c in cset:
120 total_sens += self.link(c, target, vk, subarray=subarray)
121 else:
122 if vk is None and cset in self.csenss:
123 total_sens = -abs(self.csenss[cset]) or -EPS
124 elif vk is not None:
125 if cset.v_ss is None:
126 if vk in cset.varkeys:
127 total_sens = EPS
128 elif vk in cset.v_ss:
129 total_sens = cset.v_ss[vk] or EPS
130 if not labeled:
131 cstr = cset.str_without(["lineage", "units"])
132 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30]
133 tag = "subarray" if subarray else "constraint"
134 source = self.add_node(target, label, tag)
135 self.links[source, target] = total_sens
136 if switchedtarget:
137 self.links[target, switchedtarget] += total_sens
138 return total_sens
140 def filter(self, links, function, forced=False):
141 "If over maxlinks, removes links that do not match criteria."
142 if len(links) > self.maxlinks or forced:
143 for (s, t), v in list(links.items()):
144 if not function(s, t, v):
145 del links[(s, t)]
147 # pylint: disable=too-many-locals, too-many-arguments
148 def diagram(self, variable=None, varlabel=None, *, minsenss=0, maxlinks=20,
149 top=0, bottom=0, left=230, right=140, width=1000, height=400,
150 showconstraints=True):
151 "creates links and an ipython widget to show them"
152 margins = {"top": top, "bottom": bottom, "left": left, "right": right}
153 self.minsenss = minsenss
154 self.maxlinks = maxlinks
155 self.showconstraints = showconstraints
157 self.solution.set_necessarylineage()
159 if variable:
160 variable = variable.key
161 if not varlabel:
162 varlabel = str(variable)
163 if len(varlabel) > 20:
164 varlabel = variable.str_without(["lineage"])
165 self.nodes[varlabel] = {"id": varlabel, "title": varlabel}
166 csetnode = self.add_node(varlabel, self.csetlabel)
167 if variable in self.solution["sensitivities"]["cost"]:
168 costnode = self.add_node(varlabel, "[cost function]")
169 self.links[costnode, varlabel] = \
170 self.solution["sensitivities"]["cost"][variable]
171 else:
172 csetnode = self.csetlabel
173 self.nodes[self.csetlabel] = {"id": self.csetlabel,
174 "title": self.csetlabel}
175 total_sens = self.link(self.cset, csetnode, variable)
176 if variable:
177 self.links[csetnode, varlabel] = total_sens
179 links, nodes = self._links_and_nodes()
180 out = SankeyWidget(nodes=nodes, links=links, margins=margins,
181 layout=Layout(width=str(width), height=str(height)))
183 filename = self.csetlabel
184 if variable:
185 filename += "_%s" % variable
186 if not os.path.isdir("sankey_autosaves"):
187 os.makedirs("sankey_autosaves")
188 filename = "sankey_autosaves" + os.path.sep + cleanfilename(filename)
189 out.auto_save_png(filename + ".png")
190 out.auto_save_svg(filename + ".svg")
191 out.on_node_clicked(self.onclick)
192 out.on_link_clicked(self.onclick)
194 self.solution.set_necessarylineage(clear=True)
195 return out
197 def _links_and_nodes(self, top_node=None):
198 links = self.links.copy()
199 # filter if...not below the chosen top node
200 if top_node is not None:
201 self.filter(links, lambda s, t, v: top_node in s or top_node in t,
202 forced=True)
203 # ...below minimum sensitivity
204 self.filter(links, lambda s, t, v: abs(v) > self.minsenss, forced=True)
205 if not self.showconstraints:
206 # ...is a constraint or subarray and we're not showing those
207 self.filter(links, lambda s, t, v:
208 ("constraint" not in self.nodes[s]
209 and "subarray" not in self.nodes[s]), forced=True)
210 # ...is a subarray and we still have too many links
211 self.filter(links, lambda s, t, v: "subarray" not in self.nodes[s])
212 # ...is an insensitive constraint and we still have too many links
213 self.filter(links, lambda s, t, v: ("constraint" not in self.nodes[s]
214 or abs(v) > INSENSITIVE))
215 # ...is at culldepth, repeating up to a relative depth of 1 or 2
216 culldepth = max(node.count(".") for node in self.nodes) - 1
217 mindepth = 1 if not top_node else top_node.count(".") + 1
218 while len(links) > self.maxlinks and culldepth > mindepth:
219 self.filter(links, lambda s, t, v: culldepth > s.count("."))
220 culldepth -= 1
221 # ...is a constraint and we still have too many links
222 self.filter(links, lambda s, t, v: "constraint" not in self.nodes[s])
224 linkslist, nodes, nodeset = [], [], set()
225 for (source, target), value in links.items():
226 if source == top_node:
227 nodes.append({"id": self.nodes[target]["id"],
228 "title": "⟶ %s" % self.nodes[target]["title"]})
229 nodeset.add(target)
230 for node in [source, target]:
231 if node not in nodeset:
232 nodes.append({"id": self.nodes[node]["id"],
233 "title": self.nodes[node]["title"]})
234 nodeset.add(node)
235 linkslist.append({"source": source, "target": target,
236 "value": abs(value), "color": getcolor(value),
237 "title": "%+.2g" % value})
238 return linkslist, nodes
240 def onclick(self, sankey, node_or_link):
241 "Callback function for when a node or link is clicked on."
242 if node_or_link is not None:
243 if "id" in node_or_link: # it's a node
244 top_node = node_or_link["id"]
245 else: # it's a link
246 top_node = node_or_link["source"]
247 if self.last_top_node != top_node:
248 sankey.links, sankey.nodes = self._links_and_nodes(top_node)
249 sankey.send_state()
250 self.last_top_node = top_node