Coverage for gpkit/interactive/sankey.py: 89%

184 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 16:49 -0500

1"implements Sankey" 

2import os 

3import re 

4from collections import defaultdict 

5from collections.abc import Iterable 

6import numpy as np 

7from ipywidgets import Layout 

8from ipysankeywidget import SankeyWidget 

9from ..repr_conventions import lineagestr, unitstr 

10from .. import Model, GPCOLORS 

11from ..constraints.array import ArrayConstraint 

12 

13 

14INSENSITIVE = 1e-2 

15EPS = 1e-10 

16 

17def isnamedmodel(constraint): 

18 "Checks if a constraint is a named model" 

19 return (isinstance(constraint, Model) 

20 and constraint.__class__.__name__ != "Model") 

21 

22def getcolor(value): 

23 "color scheme for sensitivities" 

24 if abs(value) < INSENSITIVE: 

25 return "#cfcfcf" 

26 return GPCOLORS[0 if value < 0 else 1] 

27 

28def cleanfilename(string): 

29 "Parses string into valid filename" 

30 return re.sub(r"\\/?|\"><:\*", "_", string) # Replace invalid with _ 

31 

32 

33# pylint: disable=too-many-instance-attributes 

34class Sankey: 

35 "Return Jupyter diagrams of sensitivity flow" 

36 minsenss = 0 

37 maxlinks = 20 

38 showconstraints = True 

39 last_top_node = None 

40 

41 def __init__(self, solution, constraintset, csetlabel=None): 

42 self.solution = solution 

43 self.csenss = solution["sensitivities"]["constraints"] 

44 self.cset = constraintset 

45 if csetlabel is None: 

46 csetlabel = lineagestr(self.cset) or self.cset.__class__.__name__ 

47 self.csetlabel = csetlabel 

48 self.links = defaultdict(float) 

49 self.links_to_target = defaultdict(int) 

50 self.nodes = {} 

51 

52 def add_node(self, target, title, tag=None): 

53 "adds nodes of a given target, title, and tag to self.nodes" 

54 self.links_to_target[target] += 1 

55 node = "%s.%04i" % (target, self.links_to_target[target]) 

56 self.nodes[node] = {"id": node, "title": title, tag: True} 

57 return node 

58 

59 def linkfixed(self, cset, target): 

60 "adds fixedvariable links as if they were (array)constraints" 

61 fixedvecs = {} 

62 total_sens = 0 

63 for vk in sorted(cset.unique_varkeys, key=str): 

64 if vk not in self.solution["constants"]: 

65 continue 

66 if vk.veckey and vk.veckey not in fixedvecs: 

67 vecval = self.solution["constants"][vk.veckey] 

68 firstval = vecval.flatten()[0] 

69 if vecval.shape and (firstval == vecval).all(): 

70 label = "%s = %.4g %s" % (vk.veckey.name, firstval, 

71 unitstr(vk.veckey)) 

72 fixedvecs[vk.veckey] = self.add_node(target, label, 

73 "constraint") 

74 abs_var_sens = -abs(self.solution["sensitivities"] \ 

75 ["constants"].get(vk, EPS)) 

76 if np.isnan(abs_var_sens): 

77 abs_var_sens = EPS 

78 label = "%s = %.4g %s" % (vk.str_without(["lineage"]), 

79 self.solution["variables"][vk], 

80 unitstr(vk)) 

81 if vk.veckey in fixedvecs: 

82 vectarget = fixedvecs[vk.veckey] 

83 source = self.add_node(vectarget, label, "subarray") 

84 self.links[source, vectarget] = abs_var_sens 

85 self.links[vectarget, target] += abs_var_sens 

86 else: 

87 source = self.add_node(target, label, "constraint") 

88 self.links[source, target] = abs_var_sens 

89 total_sens += abs_var_sens 

90 return total_sens 

91 

92 # pylint: disable=too-many-branches 

93 def link(self, cset, target, vk, *, labeled=False, subarray=False): 

94 "adds links of a given constraint set to self.links" 

95 total_sens = 0 

96 switchedtarget = False 

97 if not labeled and isnamedmodel(cset): 

98 if cset is not self.cset: # top-level, no need to switch targets 

99 switchedtarget = target 

100 target = self.add_node(target, cset.lineage[-1][0]) 

101 if vk is None: 

102 total_sens += self.linkfixed(cset, target) 

103 elif isinstance(cset, ArrayConstraint) and cset.constraints.size > 1: 

104 switchedtarget = target 

105 cstr = cset.str_without(["lineage", "units"]).replace("[:]", "") 

106 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30] 

107 target = self.add_node(target, label, "constraint") 

108 subarray = True 

109 if getattr(cset, "idxlookup", None): 

110 cset = {k: cset[i] for k, i in cset.idxlookup.items()} 

111 if isinstance(cset, dict): 

112 for label, c in cset.items(): 

113 source = self.add_node(target, label) 

114 subtotal_sens = self.link(c, source, vk, labeled=True) 

115 self.links[source, target] += subtotal_sens 

116 total_sens += subtotal_sens 

117 elif isinstance(cset, Iterable): 

118 for c in cset: 

119 total_sens += self.link(c, target, vk, subarray=subarray) 

120 else: 

121 if vk is None and cset in self.csenss: 

122 total_sens = -abs(self.csenss[cset]) or -EPS 

123 elif vk is not None: 

124 if cset.v_ss is None: 

125 if vk in cset.varkeys: 

126 total_sens = EPS 

127 elif vk in cset.v_ss: 

128 total_sens = cset.v_ss[vk] or EPS 

129 if not labeled: 

130 cstr = cset.str_without(["lineage", "units"]) 

131 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30] 

132 tag = "subarray" if subarray else "constraint" 

133 source = self.add_node(target, label, tag) 

134 self.links[source, target] = total_sens 

135 if switchedtarget: 

136 self.links[target, switchedtarget] += total_sens 

137 return total_sens 

138 

139 def filter(self, links, function, forced=False): 

140 "If over maxlinks, removes links that do not match criteria." 

141 if len(links) > self.maxlinks or forced: 

142 for (s, t), v in list(links.items()): 

143 if not function(s, t, v): 

144 del links[(s, t)] 

145 

146 # pylint: disable=too-many-locals 

147 def diagram(self, variable=None, varlabel=None, *, minsenss=0, maxlinks=20, 

148 top=0, bottom=0, left=230, right=140, width=1000, height=400, 

149 showconstraints=True): 

150 "creates links and an ipython widget to show them" 

151 margins = dict(top=top, bottom=bottom, left=left, right=right) 

152 self.minsenss = minsenss 

153 self.maxlinks = maxlinks 

154 self.showconstraints = showconstraints 

155 

156 self.solution.set_necessarylineage() 

157 

158 if variable: 

159 variable = variable.key 

160 if not varlabel: 

161 varlabel = str(variable) 

162 if len(varlabel) > 20: 

163 varlabel = variable.str_without(["lineage"]) 

164 self.nodes[varlabel] = {"id": varlabel, "title": varlabel} 

165 csetnode = self.add_node(varlabel, self.csetlabel) 

166 if variable in self.solution["sensitivities"]["cost"]: 

167 costnode = self.add_node(varlabel, "[cost function]") 

168 self.links[costnode, varlabel] = \ 

169 self.solution["sensitivities"]["cost"][variable] 

170 else: 

171 csetnode = self.csetlabel 

172 self.nodes[self.csetlabel] = {"id": self.csetlabel, 

173 "title": self.csetlabel} 

174 total_sens = self.link(self.cset, csetnode, variable) 

175 if variable: 

176 self.links[csetnode, varlabel] = total_sens 

177 

178 links, nodes = self._links_and_nodes() 

179 out = SankeyWidget(nodes=nodes, links=links, margins=margins, 

180 layout=Layout(width=str(width), height=str(height))) 

181 

182 filename = self.csetlabel 

183 if variable: 

184 filename += "_%s" % variable 

185 if not os.path.isdir("sankey_autosaves"): 

186 os.makedirs("sankey_autosaves") 

187 filename = "sankey_autosaves" + os.path.sep + cleanfilename(filename) 

188 out.auto_save_png(filename + ".png") 

189 out.auto_save_svg(filename + ".svg") 

190 out.on_node_clicked(self.onclick) 

191 out.on_link_clicked(self.onclick) 

192 

193 self.solution.set_necessarylineage(clear=True) 

194 return out 

195 

196 def _links_and_nodes(self, top_node=None): 

197 links = self.links.copy() 

198 # filter if...not below the chosen top node 

199 if top_node is not None: 

200 self.filter(links, lambda s, t, v: top_node in s or top_node in t, 

201 forced=True) 

202 # ...below minimum sensitivity 

203 self.filter(links, lambda s, t, v: abs(v) > self.minsenss, forced=True) 

204 if not self.showconstraints: 

205 # ...is a constraint or subarray and we're not showing those 

206 self.filter(links, lambda s, t, v: 

207 ("constraint" not in self.nodes[s] 

208 and "subarray" not in self.nodes[s]), forced=True) 

209 # ...is a subarray and we still have too many links 

210 self.filter(links, lambda s, t, v: "subarray" not in self.nodes[s]) 

211 # ...is an insensitive constraint and we still have too many links 

212 self.filter(links, lambda s, t, v: ("constraint" not in self.nodes[s] 

213 or abs(v) > INSENSITIVE)) 

214 # ...is at culldepth, repeating up to a relative depth of 1 or 2 

215 culldepth = max(node.count(".") for node in self.nodes) - 1 

216 mindepth = 1 if not top_node else top_node.count(".") + 1 

217 while len(links) > self.maxlinks and culldepth > mindepth: 

218 self.filter(links, lambda s, t, v: culldepth > s.count(".")) 

219 culldepth -= 1 

220 # ...is a constraint and we still have too many links 

221 self.filter(links, lambda s, t, v: "constraint" not in self.nodes[s]) 

222 

223 linkslist, nodes, nodeset = [], [], set() 

224 for (source, target), value in links.items(): 

225 if source == top_node: 

226 nodes.append({"id": self.nodes[target]["id"], 

227 "title": "⟶ %s" % self.nodes[target]["title"]}) 

228 nodeset.add(target) 

229 for node in [source, target]: 

230 if node not in nodeset: 

231 nodes.append({"id": self.nodes[node]["id"], 

232 "title": self.nodes[node]["title"]}) 

233 nodeset.add(node) 

234 linkslist.append({"source": source, "target": target, 

235 "value": abs(value), "color": getcolor(value), 

236 "title": "%+.2g" % value}) 

237 return linkslist, nodes 

238 

239 def onclick(self, sankey, node_or_link): 

240 "Callback function for when a node or link is clicked on." 

241 if node_or_link is not None: 

242 if "id" in node_or_link: # it's a node 

243 top_node = node_or_link["id"] 

244 else: # it's a link 

245 top_node = node_or_link["source"] 

246 if self.last_top_node != top_node: 

247 sankey.links, sankey.nodes = self._links_and_nodes(top_node) 

248 sankey.send_state() 

249 self.last_top_node = top_node