Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"implements Sankey"
2import os
3import re
4from collections import defaultdict
5from collections.abc import Iterable
6import numpy as np
7from ipywidgets import Layout
8from ipysankeywidget import SankeyWidget
9from ..repr_conventions import lineagestr, unitstr
10from .. import Model, GPCOLORS
11from ..constraints.array import ArrayConstraint
14INSENSITIVE = 1e-2
15EPS = 1e-10
17def isnamedmodel(constraint):
18 "Checks if a constraint is a named model"
19 return (isinstance(constraint, Model)
20 and constraint.__class__.__name__ != "Model")
22def getcolor(value):
23 "color scheme for sensitivities"
24 if abs(value) < INSENSITIVE:
25 return "#cfcfcf"
26 return GPCOLORS[0 if value < 0 else 1]
28def cleanfilename(string):
29 "Parses string into valid filename"
30 return re.sub(r"\\/?|\"><:\*", "_", string) # Replace invalid with _
33# pylint: disable=too-many-instance-attributes
34class Sankey:
35 "Return Jupyter diagrams of sensitivity flow"
36 minsenss = 0
37 maxlinks = 20
38 showconstraints = True
39 last_top_node = None
41 def __init__(self, solution, constraintset, csetlabel=None):
42 self.solution = solution
43 self.csenss = solution["sensitivities"]["constraints"]
44 self.cset = constraintset
45 if csetlabel is None:
46 csetlabel = lineagestr(self.cset) or self.cset.__class__.__name__
47 self.csetlabel = csetlabel
48 self.links = defaultdict(float)
49 self.links_to_target = defaultdict(int)
50 self.nodes = {}
52 def add_node(self, target, title, tag=None):
53 "adds nodes of a given target, title, and tag to self.nodes"
54 self.links_to_target[target] += 1
55 node = "%s.%04i" % (target, self.links_to_target[target])
56 self.nodes[node] = {"id": node, "title": title, tag: True}
57 return node
59 def linkfixed(self, cset, target):
60 "adds fixedvariable links as if they were (array)constraints"
61 fixedvecs = {}
62 total_sens = 0
63 for vk in sorted(cset.unique_varkeys, key=str):
64 if vk not in self.solution["constants"]:
65 continue
66 if vk.veckey and vk.veckey not in fixedvecs:
67 vecval = self.solution["constants"][vk.veckey]
68 firstval = vecval.flatten()[0]
69 if vecval.shape and (firstval == vecval).all():
70 label = "%s = %.4g %s" % (vk.veckey.name, firstval,
71 unitstr(vk.veckey))
72 fixedvecs[vk.veckey] = self.add_node(target, label,
73 "constraint")
74 abs_var_sens = -abs(self.solution["sensitivities"] \
75 ["constants"].get(vk, EPS))
76 if np.isnan(abs_var_sens):
77 abs_var_sens = EPS
78 label = "%s = %.4g %s" % (vk.str_without(["lineage"]),
79 self.solution["variables"][vk],
80 unitstr(vk))
81 if vk.veckey in fixedvecs:
82 vectarget = fixedvecs[vk.veckey]
83 source = self.add_node(vectarget, label, "subarray")
84 self.links[source, vectarget] = abs_var_sens
85 self.links[vectarget, target] += abs_var_sens
86 else:
87 source = self.add_node(target, label, "constraint")
88 self.links[source, target] = abs_var_sens
89 total_sens += abs_var_sens
90 return total_sens
92 # pylint: disable=too-many-branches
93 def link(self, cset, target, var, *, labeled=False, subarray=False):
94 "adds links of a given constraint set to self.links"
95 total_sens = 0
96 switchedtarget = False
97 if not labeled and isnamedmodel(cset):
98 if cset is not self.cset: # top-level, no need to switch targets
99 switchedtarget = target
100 target = self.add_node(target, cset.lineage[-1][0])
101 if var is None:
102 total_sens += self.linkfixed(cset, target)
103 elif isinstance(cset, ArrayConstraint) and cset.constraints.size > 1:
104 switchedtarget = target
105 cstr = cset.str_without(["lineage", "units"]).replace("[:]", "")
106 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30]
107 target = self.add_node(target, label, "constraint")
108 subarray = True
109 if getattr(cset, "idxlookup", None):
110 cset = {k: cset[i] for k, i in cset.idxlookup.items()}
111 if isinstance(cset, dict):
112 for label, c in cset.items():
113 source = self.add_node(target, label)
114 subtotal_sens = self.link(c, source, var, labeled=True)
115 self.links[source, target] += subtotal_sens
116 total_sens += subtotal_sens
117 elif isinstance(cset, Iterable):
118 for c in cset:
119 total_sens += self.link(c, target, var, subarray=subarray)
120 else:
121 if var is None and cset in self.csenss:
122 total_sens = -abs(self.csenss[cset]) or -EPS
123 elif var is not None:
124 if cset.v_ss is None:
125 if var.key in cset.varkeys:
126 total_sens = EPS
127 elif var.key in cset.v_ss:
128 total_sens = cset.v_ss[var.key] or EPS
129 if not labeled:
130 cstr = cset.str_without(["lineage", "units"])
131 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30]
132 tag = "subarray" if subarray else "constraint"
133 source = self.add_node(target, label, tag)
134 self.links[source, target] = total_sens
135 if switchedtarget:
136 self.links[target, switchedtarget] += total_sens
137 return total_sens
139 def filter(self, links, function, forced=False):
140 "If over maxlinks, removes links that do not match criteria."
141 if len(links) > self.maxlinks or forced:
142 for (s, t), v in list(links.items()):
143 if not function(s, t, v):
144 del links[(s, t)]
146 # pylint: disable=too-many-locals
147 def diagram(self, variable=None, varlabel=None, *, minsenss=0, maxlinks=20,
148 top=0, bottom=0, left=230, right=140, width=1000, height=400,
149 showconstraints=True):
150 "creates links and an ipython widget to show them"
151 margins = dict(top=top, bottom=bottom, left=left, right=right)
152 self.minsenss = minsenss
153 self.maxlinks = maxlinks
154 self.showconstraints = showconstraints
156 self.solution.set_necessarylineage()
158 if variable:
159 if not varlabel:
160 varlabel = variable.str_without(["unnecessary lineage"])
161 if len(varlabel) > 20:
162 varlabel = variable.str_without(["lineage"])
163 self.nodes[varlabel] = {"id": varlabel, "title": varlabel}
164 csetnode = self.add_node(varlabel, self.csetlabel)
165 if variable.key in self.solution["sensitivities"]["cost"]:
166 costnode = self.add_node(varlabel, "[cost function]")
167 self.links[costnode, varlabel] = \
168 self.solution["sensitivities"]["cost"][variable.key]
169 else:
170 csetnode = self.csetlabel
171 self.nodes[self.csetlabel] = {"id": self.csetlabel,
172 "title": self.csetlabel}
173 total_sens = self.link(self.cset, csetnode, variable)
174 if variable:
175 self.links[csetnode, varlabel] = total_sens
177 links, nodes = self._links_and_nodes()
178 out = SankeyWidget(nodes=nodes, links=links, margins=margins,
179 layout=Layout(width=str(width), height=str(height)))
181 filename = self.csetlabel
182 if variable:
183 filename += "_" + variable.str_without(["unnecessary lineage",
184 "units"])
185 if not os.path.isdir("sankey_autosaves"):
186 os.makedirs("sankey_autosaves")
187 filename = "sankey_autosaves" + os.path.sep + cleanfilename(filename)
188 out.auto_save_png(filename + ".png")
189 out.auto_save_svg(filename + ".svg")
190 out.on_node_clicked(self.onclick)
191 out.on_link_clicked(self.onclick)
193 self.solution.set_necessarylineage(clear=True)
194 return out
196 def _links_and_nodes(self, top_node=None):
197 links = self.links.copy()
198 # filter if...not below the chosen top node
199 if top_node is not None:
200 self.filter(links, lambda s, t, v: top_node in s or top_node in t,
201 forced=True)
202 # ...below minimum sensitivity
203 self.filter(links, lambda s, t, v: abs(v) > self.minsenss, forced=True)
204 if not self.showconstraints:
205 # ...is a constraint or subarray and we're not showing those
206 self.filter(links, lambda s, t, v:
207 ("constraint" not in self.nodes[s]
208 and "subarray" not in self.nodes[s]), forced=True)
209 # ...is a subarray and we still have too many links
210 self.filter(links, lambda s, t, v: "subarray" not in self.nodes[s])
211 # ...is an insensitive constraint and we still have too many links
212 self.filter(links, lambda s, t, v: ("constraint" not in self.nodes[s]
213 or abs(v) > INSENSITIVE))
214 # ...is at culldepth, repeating up to a relative depth of 1 or 2
215 culldepth = max(node.count(".") for node in self.nodes) - 1
216 mindepth = 1 if not top_node else top_node.count(".") + 1
217 while len(links) > self.maxlinks and culldepth > mindepth:
218 self.filter(links, lambda s, t, v: culldepth > s.count("."))
219 culldepth -= 1
220 # ...is a constraint and we still have too many links
221 self.filter(links, lambda s, t, v: "constraint" not in self.nodes[s])
223 linkslist, nodes, nodeset = [], [], set()
224 for (source, target), value in links.items():
225 if source == top_node:
226 nodes.append({"id": self.nodes[target]["id"],
227 "title": "⟶ %s" % self.nodes[target]["title"]})
228 nodeset.add(target)
229 for node in [source, target]:
230 if node not in nodeset:
231 nodes.append({"id": self.nodes[node]["id"],
232 "title": self.nodes[node]["title"]})
233 nodeset.add(node)
234 linkslist.append({"source": source, "target": target,
235 "value": abs(value), "color": getcolor(value),
236 "title": "%+.2g" % value})
237 return linkslist, nodes
239 def onclick(self, sankey, node_or_link):
240 "Callback function for when a node or link is clicked on."
241 if node_or_link is not None:
242 if "id" in node_or_link: # it's a node
243 top_node = node_or_link["id"]
244 else: # it's a link
245 top_node = node_or_link["source"]
246 if self.last_top_node != top_node:
247 sankey.links, sankey.nodes = self._links_and_nodes(top_node)
248 sankey.send_state()
249 self.last_top_node = top_node