Coverage for gpkit\interactive\sankey.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

184 statements  

1"implements Sankey" 

2import os 

3import re 

4from collections import defaultdict 

5from collections.abc import Iterable 

6import numpy as np 

7from ipywidgets import Layout 

8from ipysankeywidget import SankeyWidget 

9from ..repr_conventions import lineagestr, unitstr 

10from .. import Model, GPCOLORS 

11from ..constraints.array import ArrayConstraint 

12 

13 

14INSENSITIVE = 1e-2 

15EPS = 1e-10 

16 

17def isnamedmodel(constraint): 

18 "Checks if a constraint is a named model" 

19 return (isinstance(constraint, Model) 

20 and constraint.__class__.__name__ != "Model") 

21 

22def getcolor(value): 

23 "color scheme for sensitivities" 

24 if abs(value) < INSENSITIVE: 

25 return "#cfcfcf" 

26 return GPCOLORS[0 if value < 0 else 1] 

27 

28def cleanfilename(string): 

29 "Parses string into valid filename" 

30 return re.sub(r"\\/?|\"><:\*", "_", string) # Replace invalid with _ 

31 

32 

33# pylint: disable=too-many-instance-attributes 

34class Sankey: 

35 "Return Jupyter diagrams of sensitivity flow" 

36 minsenss = 0 

37 maxlinks = 20 

38 showconstraints = True 

39 last_top_node = None 

40 

41 def __init__(self, solution, constraintset, csetlabel=None): 

42 self.solution = solution 

43 self.csenss = solution["sensitivities"]["constraints"] 

44 self.cset = constraintset 

45 if csetlabel is None: 

46 csetlabel = lineagestr(self.cset) or self.cset.__class__.__name__ 

47 self.csetlabel = csetlabel 

48 self.links = defaultdict(float) 

49 self.links_to_target = defaultdict(int) 

50 self.nodes = {} 

51 

52 def add_node(self, target, title, tag=None): 

53 "adds nodes of a given target, title, and tag to self.nodes" 

54 self.links_to_target[target] += 1 

55 node = "%s.%04i" % (target, self.links_to_target[target]) 

56 self.nodes[node] = {"id": node, "title": title, tag: True} 

57 return node 

58 

59 def linkfixed(self, cset, target): 

60 "adds fixedvariable links as if they were (array)constraints" 

61 fixedvecs = {} 

62 total_sens = 0 

63 for vk in sorted(cset.unique_varkeys, key=str): 

64 if vk not in self.solution["constants"]: 

65 continue 

66 if vk.veckey and vk.veckey not in fixedvecs: 

67 vecval = self.solution["constants"][vk.veckey] 

68 firstval = vecval.flatten()[0] 

69 if vecval.shape and (firstval == vecval).all(): 

70 label = "%s = %.4g %s" % (vk.veckey.name, firstval, 

71 unitstr(vk.veckey)) 

72 fixedvecs[vk.veckey] = self.add_node(target, label, 

73 "constraint") 

74 abs_var_sens = -abs(self.solution["sensitivities"] \ 

75 ["constants"].get(vk, EPS)) 

76 if np.isnan(abs_var_sens): 

77 abs_var_sens = EPS 

78 label = "%s = %.4g %s" % (vk.str_without(["lineage"]), 

79 self.solution["variables"][vk], 

80 unitstr(vk)) 

81 if vk.veckey in fixedvecs: 

82 vectarget = fixedvecs[vk.veckey] 

83 source = self.add_node(vectarget, label, "subarray") 

84 self.links[source, vectarget] = abs_var_sens 

85 self.links[vectarget, target] += abs_var_sens 

86 else: 

87 source = self.add_node(target, label, "constraint") 

88 self.links[source, target] = abs_var_sens 

89 total_sens += abs_var_sens 

90 return total_sens 

91 

92 # pylint: disable=too-many-branches 

93 def link(self, cset, target, vk, *, labeled=False, subarray=False): 

94 "adds links of a given constraint set to self.links" 

95 total_sens = 0 

96 switchedtarget = False 

97 if not labeled and isnamedmodel(cset): 

98 if cset is not self.cset: # top-level, no need to switch targets 

99 switchedtarget = target 

100 target = self.add_node(target, cset.lineage[-1][0]) 

101 if vk is None: 

102 total_sens += self.linkfixed(cset, target) 

103 elif isinstance(cset, ArrayConstraint) and cset.constraints.size > 1: 

104 switchedtarget = target 

105 cstr = cset.str_without(["lineage", "units"]).replace("[:]", "") 

106 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30] 

107 target = self.add_node(target, label, "constraint") 

108 subarray = True 

109 if getattr(cset, "idxlookup", None): 

110 cset = {k: cset[i] for k, i in cset.idxlookup.items()} 

111 if isinstance(cset, dict): 

112 for label, c in cset.items(): 

113 source = self.add_node(target, label) 

114 subtotal_sens = self.link(c, source, vk, labeled=True) 

115 self.links[source, target] += subtotal_sens 

116 total_sens += subtotal_sens 

117 elif isinstance(cset, Iterable): 

118 for c in cset: 

119 total_sens += self.link(c, target, vk, subarray=subarray) 

120 else: 

121 if vk is None and cset in self.csenss: 

122 total_sens = -abs(self.csenss[cset]) or -EPS 

123 elif vk is not None: 

124 if cset.v_ss is None: 

125 if vk in cset.varkeys: 

126 total_sens = EPS 

127 elif vk in cset.v_ss: 

128 total_sens = cset.v_ss[vk] or EPS 

129 if not labeled: 

130 cstr = cset.str_without(["lineage", "units"]) 

131 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30] 

132 tag = "subarray" if subarray else "constraint" 

133 source = self.add_node(target, label, tag) 

134 self.links[source, target] = total_sens 

135 if switchedtarget: 

136 self.links[target, switchedtarget] += total_sens 

137 return total_sens 

138 

139 def filter(self, links, function, forced=False): 

140 "If over maxlinks, removes links that do not match criteria." 

141 if len(links) > self.maxlinks or forced: 

142 for (s, t), v in list(links.items()): 

143 if not function(s, t, v): 

144 del links[(s, t)] 

145 

146 # pylint: disable=too-many-locals 

147 def diagram(self, variable=None, varlabel=None, *, minsenss=0, maxlinks=20, 

148 top=0, bottom=0, left=230, right=140, width=1000, height=400, 

149 showconstraints=True): 

150 "creates links and an ipython widget to show them" 

151 margins = dict(top=top, bottom=bottom, left=left, right=right) 

152 self.minsenss = minsenss 

153 self.maxlinks = maxlinks 

154 self.showconstraints = showconstraints 

155 

156 self.solution.set_necessarylineage() 

157 

158 if variable: 

159 variable = variable.key 

160 if not varlabel: 

161 varlabel = str(variable) 

162 if len(varlabel) > 20: 

163 varlabel = variable.str_without(["lineage"]) 

164 self.nodes[varlabel] = {"id": varlabel, "title": varlabel} 

165 csetnode = self.add_node(varlabel, self.csetlabel) 

166 if variable in self.solution["sensitivities"]["cost"]: 

167 costnode = self.add_node(varlabel, "[cost function]") 

168 self.links[costnode, varlabel] = \ 

169 self.solution["sensitivities"]["cost"][variable] 

170 else: 

171 csetnode = self.csetlabel 

172 self.nodes[self.csetlabel] = {"id": self.csetlabel, 

173 "title": self.csetlabel} 

174 total_sens = self.link(self.cset, csetnode, variable) 

175 if variable: 

176 self.links[csetnode, varlabel] = total_sens 

177 

178 links, nodes = self._links_and_nodes() 

179 out = SankeyWidget(nodes=nodes, links=links, margins=margins, 

180 layout=Layout(width=str(width), height=str(height))) 

181 

182 filename = self.csetlabel 

183 if variable: 

184 filename += "_%s" % variable 

185 if not os.path.isdir("sankey_autosaves"): 

186 os.makedirs("sankey_autosaves") 

187 filename = "sankey_autosaves" + os.path.sep + cleanfilename(filename) 

188 out.auto_save_png(filename + ".png") 

189 out.auto_save_svg(filename + ".svg") 

190 out.on_node_clicked(self.onclick) 

191 out.on_link_clicked(self.onclick) 

192 

193 self.solution.set_necessarylineage(clear=True) 

194 return out 

195 

196 def _links_and_nodes(self, top_node=None): 

197 links = self.links.copy() 

198 # filter if...not below the chosen top node 

199 if top_node is not None: 

200 self.filter(links, lambda s, t, v: top_node in s or top_node in t, 

201 forced=True) 

202 # ...below minimum sensitivity 

203 self.filter(links, lambda s, t, v: abs(v) > self.minsenss, forced=True) 

204 if not self.showconstraints: 

205 # ...is a constraint or subarray and we're not showing those 

206 self.filter(links, lambda s, t, v: 

207 ("constraint" not in self.nodes[s] 

208 and "subarray" not in self.nodes[s]), forced=True) 

209 # ...is a subarray and we still have too many links 

210 self.filter(links, lambda s, t, v: "subarray" not in self.nodes[s]) 

211 # ...is an insensitive constraint and we still have too many links 

212 self.filter(links, lambda s, t, v: ("constraint" not in self.nodes[s] 

213 or abs(v) > INSENSITIVE)) 

214 # ...is at culldepth, repeating up to a relative depth of 1 or 2 

215 culldepth = max(node.count(".") for node in self.nodes) - 1 

216 mindepth = 1 if not top_node else top_node.count(".") + 1 

217 while len(links) > self.maxlinks and culldepth > mindepth: 

218 self.filter(links, lambda s, t, v: culldepth > s.count(".")) 

219 culldepth -= 1 

220 # ...is a constraint and we still have too many links 

221 self.filter(links, lambda s, t, v: "constraint" not in self.nodes[s]) 

222 

223 linkslist, nodes, nodeset = [], [], set() 

224 for (source, target), value in links.items(): 

225 if source == top_node: 

226 nodes.append({"id": self.nodes[target]["id"], 

227 "title": "⟶ %s" % self.nodes[target]["title"]}) 

228 nodeset.add(target) 

229 for node in [source, target]: 

230 if node not in nodeset: 

231 nodes.append({"id": self.nodes[node]["id"], 

232 "title": self.nodes[node]["title"]}) 

233 nodeset.add(node) 

234 linkslist.append({"source": source, "target": target, 

235 "value": abs(value), "color": getcolor(value), 

236 "title": "%+.2g" % value}) 

237 return linkslist, nodes 

238 

239 def onclick(self, sankey, node_or_link): 

240 "Callback function for when a node or link is clicked on." 

241 if node_or_link is not None: 

242 if "id" in node_or_link: # it's a node 

243 top_node = node_or_link["id"] 

244 else: # it's a link 

245 top_node = node_or_link["source"] 

246 if self.last_top_node != top_node: 

247 sankey.links, sankey.nodes = self._links_and_nodes(top_node) 

248 sankey.send_state() 

249 self.last_top_node = top_node