Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"implements Sankey" 

2import os 

3import re 

4from collections import defaultdict 

5from collections.abc import Iterable 

6import numpy as np 

7from ipywidgets import Layout 

8from ipysankeywidget import SankeyWidget 

9from ..repr_conventions import lineagestr, unitstr 

10from .. import Model, GPCOLORS 

11from ..constraints.array import ArrayConstraint 

12 

13 

14INSENSITIVE = 1e-2 

15EPS = 1e-10 

16 

17def isnamedmodel(constraint): 

18 "Checks if a constraint is a named model" 

19 return (isinstance(constraint, Model) 

20 and constraint.__class__.__name__ != "Model") 

21 

22def getcolor(value): 

23 "color scheme for sensitivities" 

24 if abs(value) < INSENSITIVE: 

25 return "#cfcfcf" 

26 return GPCOLORS[0 if value < 0 else 1] 

27 

28def cleanfilename(string): 

29 "Parses string into valid filename" 

30 return re.sub(r"\\/?|\"><:\*", "_", string) # Replace invalid with _ 

31 

32 

33# pylint: disable=too-many-instance-attributes 

34class Sankey: 

35 "Return Jupyter diagrams of sensitivity flow" 

36 minsenss = 0 

37 maxlinks = 20 

38 showconstraints = True 

39 last_top_node = None 

40 

41 def __init__(self, solution, constraintset, csetlabel=None): 

42 self.solution = solution 

43 self.csenss = solution["sensitivities"]["constraints"] 

44 self.cset = constraintset 

45 if csetlabel is None: 

46 csetlabel = lineagestr(self.cset) or self.cset.__class__.__name__ 

47 self.csetlabel = csetlabel 

48 self.links = defaultdict(float) 

49 self.links_to_target = defaultdict(int) 

50 self.nodes = {} 

51 

52 def add_node(self, target, title, tag=None): 

53 "adds nodes of a given target, title, and tag to self.nodes" 

54 self.links_to_target[target] += 1 

55 node = "%s.%04i" % (target, self.links_to_target[target]) 

56 self.nodes[node] = {"id": node, "title": title, tag: True} 

57 return node 

58 

59 def linkfixed(self, cset, target): 

60 "adds fixedvariable links as if they were (array)constraints" 

61 fixedvecs = {} 

62 total_sens = 0 

63 for vk in sorted(cset.unique_varkeys, key=str): 

64 if vk not in self.solution["constants"]: 

65 continue 

66 if vk.veckey and vk.veckey not in fixedvecs: 

67 vecval = self.solution["constants"][vk.veckey] 

68 firstval = vecval.flatten()[0] 

69 if vecval.shape and (firstval == vecval).all(): 

70 label = "%s = %.4g %s" % (vk.veckey.name, firstval, 

71 unitstr(vk.veckey)) 

72 fixedvecs[vk.veckey] = self.add_node(target, label, 

73 "constraint") 

74 abs_var_sens = -abs(self.solution["sensitivities"] \ 

75 ["constants"].get(vk, EPS)) 

76 if np.isnan(abs_var_sens): 

77 abs_var_sens = EPS 

78 label = "%s = %.4g %s" % (vk.str_without(["lineage"]), 

79 self.solution["variables"][vk], 

80 unitstr(vk)) 

81 if vk.veckey in fixedvecs: 

82 vectarget = fixedvecs[vk.veckey] 

83 source = self.add_node(vectarget, label, "subarray") 

84 self.links[source, vectarget] = abs_var_sens 

85 self.links[vectarget, target] += abs_var_sens 

86 else: 

87 source = self.add_node(target, label, "constraint") 

88 self.links[source, target] = abs_var_sens 

89 total_sens += abs_var_sens 

90 return total_sens 

91 

92 # pylint: disable=too-many-branches 

93 def link(self, cset, target, var, *, labeled=False, subarray=False): 

94 "adds links of a given constraint set to self.links" 

95 total_sens = 0 

96 switchedtarget = False 

97 if not labeled and isnamedmodel(cset): 

98 if cset is not self.cset: # top-level, no need to switch targets 

99 switchedtarget = target 

100 target = self.add_node(target, cset.lineage[-1][0]) 

101 if var is None: 

102 total_sens += self.linkfixed(cset, target) 

103 elif isinstance(cset, ArrayConstraint) and cset.constraints.size > 1: 

104 switchedtarget = target 

105 cstr = cset.str_without(["lineage", "units"]).replace("[:]", "") 

106 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30] 

107 target = self.add_node(target, label, "constraint") 

108 subarray = True 

109 if getattr(cset, "idxlookup", None): 

110 cset = {k: cset[i] for k, i in cset.idxlookup.items()} 

111 if isinstance(cset, dict): 

112 for label, c in cset.items(): 

113 source = self.add_node(target, label) 

114 subtotal_sens = self.link(c, source, var, labeled=True) 

115 self.links[source, target] += subtotal_sens 

116 total_sens += subtotal_sens 

117 elif isinstance(cset, Iterable): 

118 for c in cset: 

119 total_sens += self.link(c, target, var, subarray=subarray) 

120 else: 

121 if var is None and cset in self.csenss: 

122 total_sens = -abs(self.csenss[cset]) or -EPS 

123 elif var is not None: 

124 if cset.v_ss is None: 

125 if var.key in cset.varkeys: 

126 total_sens = EPS 

127 elif var.key in cset.v_ss: 

128 total_sens = cset.v_ss[var.key] or EPS 

129 if not labeled: 

130 cstr = cset.str_without(["lineage", "units"]) 

131 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30] 

132 tag = "subarray" if subarray else "constraint" 

133 source = self.add_node(target, label, tag) 

134 self.links[source, target] = total_sens 

135 if switchedtarget: 

136 self.links[target, switchedtarget] += total_sens 

137 return total_sens 

138 

139 def filter(self, links, function, forced=False): 

140 "If over maxlinks, removes links that do not match criteria." 

141 if len(links) > self.maxlinks or forced: 

142 for (s, t), v in list(links.items()): 

143 if not function(s, t, v): 

144 del links[(s, t)] 

145 

146 # pylint: disable=too-many-locals 

147 def diagram(self, variable=None, varlabel=None, *, minsenss=0, maxlinks=20, 

148 top=0, bottom=0, left=230, right=140, width=1000, height=400, 

149 showconstraints=True): 

150 "creates links and an ipython widget to show them" 

151 margins = dict(top=top, bottom=bottom, left=left, right=right) 

152 self.minsenss = minsenss 

153 self.maxlinks = maxlinks 

154 self.showconstraints = showconstraints 

155 

156 for key in self.solution.name_collision_varkeys(): 

157 key.descr["necessarylineage"] = True 

158 

159 if variable: 

160 if not varlabel: 

161 varlabel = variable.str_without(["unnecessary lineage"]) 

162 if len(varlabel) > 20: 

163 varlabel = variable.str_without(["lineage"]) 

164 self.nodes[varlabel] = {"id": varlabel, "title": varlabel} 

165 csetnode = self.add_node(varlabel, self.csetlabel) 

166 if variable.key in self.solution["sensitivities"]["cost"]: 

167 costnode = self.add_node(varlabel, "[cost function]") 

168 self.links[costnode, varlabel] = \ 

169 self.solution["sensitivities"]["cost"][variable.key] 

170 else: 

171 csetnode = self.csetlabel 

172 self.nodes[self.csetlabel] = {"id": self.csetlabel, 

173 "title": self.csetlabel} 

174 total_sens = self.link(self.cset, csetnode, variable) 

175 if variable: 

176 self.links[csetnode, varlabel] = total_sens 

177 

178 links, nodes = self._links_and_nodes() 

179 out = SankeyWidget(nodes=nodes, links=links, margins=margins, 

180 layout=Layout(width=str(width), height=str(height))) 

181 

182 filename = self.csetlabel 

183 if variable: 

184 filename += "_" + variable.str_without(["unnecessary lineage", 

185 "units"]) 

186 if not os.path.isdir("sankey_autosaves"): 

187 os.makedirs("sankey_autosaves") 

188 filename = "sankey_autosaves" + os.path.sep + cleanfilename(filename) 

189 out.auto_save_png(filename + ".png") 

190 out.auto_save_svg(filename + ".svg") 

191 out.on_node_clicked(self.onclick) 

192 out.on_link_clicked(self.onclick) 

193 

194 for key in self.solution.name_collision_varkeys(): 

195 del key.descr["necessarylineage"] 

196 return out 

197 

198 def _links_and_nodes(self, top_node=None): 

199 links = self.links.copy() 

200 # filter if...not below the chosen top node 

201 if top_node is not None: 

202 self.filter(links, lambda s, t, v: top_node in s or top_node in t, 

203 forced=True) 

204 # ...below minimum sensitivity 

205 self.filter(links, lambda s, t, v: abs(v) > self.minsenss, forced=True) 

206 if not self.showconstraints: 

207 # ...is a constraint or subarray and we're not showing those 

208 self.filter(links, lambda s, t, v: 

209 ("constraint" not in self.nodes[s] 

210 and "subarray" not in self.nodes[s]), forced=True) 

211 # ...is a subarray and we still have too many links 

212 self.filter(links, lambda s, t, v: "subarray" not in self.nodes[s]) 

213 # ...is an insensitive constraint and we still have too many links 

214 self.filter(links, lambda s, t, v: ("constraint" not in self.nodes[s] 

215 or abs(v) > INSENSITIVE)) 

216 # ...is at culldepth, repeating up to a relative depth of 1 or 2 

217 culldepth = max(node.count(".") for node in self.nodes) - 1 

218 mindepth = 1 if not top_node else top_node.count(".") + 1 

219 while len(links) > self.maxlinks and culldepth > mindepth: 

220 self.filter(links, lambda s, t, v: culldepth > s.count(".")) 

221 culldepth -= 1 

222 # ...is a constraint and we still have too many links 

223 self.filter(links, lambda s, t, v: "constraint" not in self.nodes[s]) 

224 

225 linkslist, nodes, nodeset = [], [], set() 

226 for (source, target), value in links.items(): 

227 if source == top_node: 

228 nodes.append({"id": self.nodes[target]["id"], 

229 "title": "⟶ %s" % self.nodes[target]["title"]}) 

230 nodeset.add(target) 

231 for node in [source, target]: 

232 if node not in nodeset: 

233 nodes.append({"id": self.nodes[node]["id"], 

234 "title": self.nodes[node]["title"]}) 

235 nodeset.add(node) 

236 linkslist.append({"source": source, "target": target, 

237 "value": abs(value), "color": getcolor(value), 

238 "title": "%+.2g" % value}) 

239 return linkslist, nodes 

240 

241 def onclick(self, sankey, node_or_link): 

242 "Callback function for when a node or link is clicked on." 

243 if node_or_link is not None: 

244 if "id" in node_or_link: # it's a node 

245 top_node = node_or_link["id"] 

246 else: # it's a link 

247 top_node = node_or_link["source"] 

248 if self.last_top_node != top_node: 

249 sankey.links, sankey.nodes = self._links_and_nodes(top_node) 

250 sankey.send_state() 

251 self.last_top_node = top_node