Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"implements Sankey"
2import os
3import re
4from collections import defaultdict
5from collections.abc import Iterable
6import numpy as np
7from ipywidgets import Layout
8from ipysankeywidget import SankeyWidget
9from ..repr_conventions import lineagestr, unitstr
10from .. import Model, GPCOLORS
11from ..constraints.array import ArrayConstraint
14INSENSITIVE = 1e-2
15EPS = 1e-10
17def isnamedmodel(constraint):
18 "Checks if a constraint is a named model"
19 return (isinstance(constraint, Model)
20 and constraint.__class__.__name__ != "Model")
22def getcolor(value):
23 "color scheme for sensitivities"
24 if abs(value) < INSENSITIVE:
25 return "#cfcfcf"
26 return GPCOLORS[0 if value < 0 else 1]
28def cleanfilename(string):
29 "Parses string into valid filename"
30 return re.sub(r"\\/?|\"><:\*", "_", string) # Replace invalid with _
33# pylint: disable=too-many-instance-attributes
34class Sankey:
35 "Return Jupyter diagrams of sensitivity flow"
36 minsenss = 0
37 maxlinks = 20
38 showconstraints = True
39 last_top_node = None
41 def __init__(self, solution, constraintset, csetlabel=None):
42 self.solution = solution
43 self.csenss = solution["sensitivities"]["constraints"]
44 self.cset = constraintset
45 if csetlabel is None:
46 csetlabel = lineagestr(self.cset) or self.cset.__class__.__name__
47 self.csetlabel = csetlabel
48 self.links = defaultdict(float)
49 self.links_to_target = defaultdict(int)
50 self.nodes = {}
52 def add_node(self, target, title, tag=None):
53 "adds nodes of a given target, title, and tag to self.nodes"
54 self.links_to_target[target] += 1
55 node = "%s.%04i" % (target, self.links_to_target[target])
56 self.nodes[node] = {"id": node, "title": title, tag: True}
57 return node
59 def linkfixed(self, cset, target):
60 "adds fixedvariable links as if they were (array)constraints"
61 fixedvecs = {}
62 total_sens = 0
63 for vk in sorted(cset.unique_varkeys, key=str):
64 if vk not in self.solution["constants"]:
65 continue
66 if vk.veckey and vk.veckey not in fixedvecs:
67 vecval = self.solution["constants"][vk.veckey]
68 firstval = vecval.flatten()[0]
69 if vecval.shape and (firstval == vecval).all():
70 label = "%s = %.4g %s" % (vk.veckey.name, firstval,
71 unitstr(vk.veckey))
72 fixedvecs[vk.veckey] = self.add_node(target, label,
73 "constraint")
74 abs_var_sens = -abs(self.solution["sensitivities"] \
75 ["constants"].get(vk, EPS))
76 if np.isnan(abs_var_sens):
77 abs_var_sens = EPS
78 label = "%s = %.4g %s" % (vk.str_without(["lineage"]),
79 self.solution["variables"][vk],
80 unitstr(vk))
81 if vk.veckey in fixedvecs:
82 vectarget = fixedvecs[vk.veckey]
83 source = self.add_node(vectarget, label, "subarray")
84 self.links[source, vectarget] = abs_var_sens
85 self.links[vectarget, target] += abs_var_sens
86 else:
87 source = self.add_node(target, label, "constraint")
88 self.links[source, target] = abs_var_sens
89 total_sens += abs_var_sens
90 return total_sens
92 # pylint: disable=too-many-branches
93 def link(self, cset, target, var, *, labeled=False, subarray=False):
94 "adds links of a given constraint set to self.links"
95 total_sens = 0
96 switchedtarget = False
97 if not labeled and isnamedmodel(cset):
98 if cset is not self.cset: # top-level, no need to switch targets
99 switchedtarget = target
100 target = self.add_node(target, cset.lineage[-1][0])
101 if var is None:
102 total_sens += self.linkfixed(cset, target)
103 elif isinstance(cset, ArrayConstraint) and cset.constraints.size > 1:
104 switchedtarget = target
105 cstr = cset.str_without(["lineage", "units"]).replace("[:]", "")
106 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30]
107 target = self.add_node(target, label, "constraint")
108 subarray = True
109 if getattr(cset, "idxlookup", None):
110 cset = {k: cset[i] for k, i in cset.idxlookup.items()}
111 if isinstance(cset, dict):
112 for label, c in cset.items():
113 source = self.add_node(target, label)
114 subtotal_sens = self.link(c, source, var, labeled=True)
115 self.links[source, target] += subtotal_sens
116 total_sens += subtotal_sens
117 elif isinstance(cset, Iterable):
118 for c in cset:
119 total_sens += self.link(c, target, var, subarray=subarray)
120 else:
121 if var is None and cset in self.csenss:
122 total_sens = -abs(self.csenss[cset]) or -EPS
123 elif var is not None:
124 if cset.v_ss is None:
125 if var.key in cset.varkeys:
126 total_sens = EPS
127 elif var.key in cset.v_ss:
128 total_sens = cset.v_ss[var.key] or EPS
129 if not labeled:
130 cstr = cset.str_without(["lineage", "units"])
131 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30]
132 tag = "subarray" if subarray else "constraint"
133 source = self.add_node(target, label, tag)
134 self.links[source, target] = total_sens
135 if switchedtarget:
136 self.links[target, switchedtarget] += total_sens
137 return total_sens
139 def filter(self, links, function, forced=False):
140 "If over maxlinks, removes links that do not match criteria."
141 if len(links) > self.maxlinks or forced:
142 for (s, t), v in list(links.items()):
143 if not function(s, t, v):
144 del links[(s, t)]
146 # pylint: disable=too-many-locals
147 def diagram(self, variable=None, varlabel=None, *, minsenss=0, maxlinks=20,
148 top=0, bottom=0, left=230, right=140, width=1000, height=400,
149 showconstraints=True):
150 "creates links and an ipython widget to show them"
151 margins = dict(top=top, bottom=bottom, left=left, right=right)
152 self.minsenss = minsenss
153 self.maxlinks = maxlinks
154 self.showconstraints = showconstraints
156 for key in self.solution.name_collision_varkeys():
157 key.descr["necessarylineage"] = True
159 if variable:
160 if not varlabel:
161 varlabel = variable.str_without(["unnecessary lineage"])
162 if len(varlabel) > 20:
163 varlabel = variable.str_without(["lineage"])
164 self.nodes[varlabel] = {"id": varlabel, "title": varlabel}
165 csetnode = self.add_node(varlabel, self.csetlabel)
166 if variable.key in self.solution["sensitivities"]["cost"]:
167 costnode = self.add_node(varlabel, "[cost function]")
168 self.links[costnode, varlabel] = \
169 self.solution["sensitivities"]["cost"][variable.key]
170 else:
171 csetnode = self.csetlabel
172 self.nodes[self.csetlabel] = {"id": self.csetlabel,
173 "title": self.csetlabel}
174 total_sens = self.link(self.cset, csetnode, variable)
175 if variable:
176 self.links[csetnode, varlabel] = total_sens
178 links, nodes = self._links_and_nodes()
179 out = SankeyWidget(nodes=nodes, links=links, margins=margins,
180 layout=Layout(width=str(width), height=str(height)))
182 filename = self.csetlabel
183 if variable:
184 filename += "_" + variable.str_without(["unnecessary lineage",
185 "units"])
186 if not os.path.isdir("sankey_autosaves"):
187 os.makedirs("sankey_autosaves")
188 filename = "sankey_autosaves" + os.path.sep + cleanfilename(filename)
189 out.auto_save_png(filename + ".png")
190 out.auto_save_svg(filename + ".svg")
191 out.on_node_clicked(self.onclick)
192 out.on_link_clicked(self.onclick)
194 for key in self.solution.name_collision_varkeys():
195 del key.descr["necessarylineage"]
196 return out
198 def _links_and_nodes(self, top_node=None):
199 links = self.links.copy()
200 # filter if...not below the chosen top node
201 if top_node is not None:
202 self.filter(links, lambda s, t, v: top_node in s or top_node in t,
203 forced=True)
204 # ...below minimum sensitivity
205 self.filter(links, lambda s, t, v: abs(v) > self.minsenss, forced=True)
206 if not self.showconstraints:
207 # ...is a constraint or subarray and we're not showing those
208 self.filter(links, lambda s, t, v:
209 ("constraint" not in self.nodes[s]
210 and "subarray" not in self.nodes[s]), forced=True)
211 # ...is a subarray and we still have too many links
212 self.filter(links, lambda s, t, v: "subarray" not in self.nodes[s])
213 # ...is an insensitive constraint and we still have too many links
214 self.filter(links, lambda s, t, v: ("constraint" not in self.nodes[s]
215 or abs(v) > INSENSITIVE))
216 # ...is at culldepth, repeating up to a relative depth of 1 or 2
217 culldepth = max(node.count(".") for node in self.nodes) - 1
218 mindepth = 1 if not top_node else top_node.count(".") + 1
219 while len(links) > self.maxlinks and culldepth > mindepth:
220 self.filter(links, lambda s, t, v: culldepth > s.count("."))
221 culldepth -= 1
222 # ...is a constraint and we still have too many links
223 self.filter(links, lambda s, t, v: "constraint" not in self.nodes[s])
225 linkslist, nodes, nodeset = [], [], set()
226 for (source, target), value in links.items():
227 if source == top_node:
228 nodes.append({"id": self.nodes[target]["id"],
229 "title": "⟶ %s" % self.nodes[target]["title"]})
230 nodeset.add(target)
231 for node in [source, target]:
232 if node not in nodeset:
233 nodes.append({"id": self.nodes[node]["id"],
234 "title": self.nodes[node]["title"]})
235 nodeset.add(node)
236 linkslist.append({"source": source, "target": target,
237 "value": abs(value), "color": getcolor(value),
238 "title": "%+.2g" % value})
239 return linkslist, nodes
241 def onclick(self, sankey, node_or_link):
242 "Callback function for when a node or link is clicked on."
243 if node_or_link is not None:
244 if "id" in node_or_link: # it's a node
245 top_node = node_or_link["id"]
246 else: # it's a link
247 top_node = node_or_link["source"]
248 if self.last_top_node != top_node:
249 sankey.links, sankey.nodes = self._links_and_nodes(top_node)
250 sankey.send_state()
251 self.last_top_node = top_node