Coverage for gpkit/interactive/sankey.py: 89%

184 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 22:15 -0500

1"implements Sankey" 

2# pylint: disable=import-error,consider-using-f-string 

3import os 

4import re 

5from collections import defaultdict 

6from collections.abc import Iterable 

7import numpy as np 

8from ipywidgets import Layout 

9from ipysankeywidget import SankeyWidget 

10from ..repr_conventions import lineagestr, unitstr 

11from .. import Model, GPCOLORS 

12from ..constraints.array import ArrayConstraint 

13 

14 

15INSENSITIVE = 1e-2 

16EPS = 1e-10 

17 

18def isnamedmodel(constraint): 

19 "Checks if a constraint is a named model" 

20 return (isinstance(constraint, Model) 

21 and constraint.__class__.__name__ != "Model") 

22 

23def getcolor(value): 

24 "color scheme for sensitivities" 

25 if abs(value) < INSENSITIVE: 

26 return "#cfcfcf" 

27 return GPCOLORS[0 if value < 0 else 1] 

28 

29def cleanfilename(string): 

30 "Parses string into valid filename" 

31 return re.sub(r"\\/?|\"><:\*", "_", string) # Replace invalid with _ 

32 

33 

34# pylint: disable=too-many-instance-attributes 

35class Sankey: 

36 "Return Jupyter diagrams of sensitivity flow" 

37 minsenss = 0 

38 maxlinks = 20 

39 showconstraints = True 

40 last_top_node = None 

41 

42 def __init__(self, solution, constraintset, csetlabel=None): 

43 self.solution = solution 

44 self.csenss = solution["sensitivities"]["constraints"] 

45 self.cset = constraintset 

46 if csetlabel is None: 

47 csetlabel = lineagestr(self.cset) or self.cset.__class__.__name__ 

48 self.csetlabel = csetlabel 

49 self.links = defaultdict(float) 

50 self.links_to_target = defaultdict(int) 

51 self.nodes = {} 

52 

53 def add_node(self, target, title, tag=None): 

54 "adds nodes of a given target, title, and tag to self.nodes" 

55 self.links_to_target[target] += 1 

56 node = "%s.%04i" % (target, self.links_to_target[target]) 

57 self.nodes[node] = {"id": node, "title": title, tag: True} 

58 return node 

59 

60 def linkfixed(self, cset, target): 

61 "adds fixedvariable links as if they were (array)constraints" 

62 fixedvecs = {} 

63 total_sens = 0 

64 for vk in sorted(cset.unique_varkeys, key=str): 

65 if vk not in self.solution["constants"]: 

66 continue 

67 if vk.veckey and vk.veckey not in fixedvecs: 

68 vecval = self.solution["constants"][vk.veckey] 

69 firstval = vecval.flatten()[0] 

70 if vecval.shape and (firstval == vecval).all(): 

71 label = "%s = %.4g %s" % (vk.veckey.name, firstval, 

72 unitstr(vk.veckey)) 

73 fixedvecs[vk.veckey] = self.add_node(target, label, 

74 "constraint") 

75 abs_var_sens = -abs(self.solution["sensitivities"] \ 

76 ["constants"].get(vk, EPS)) 

77 if np.isnan(abs_var_sens): 

78 abs_var_sens = EPS 

79 label = "%s = %.4g %s" % (vk.str_without(["lineage"]), 

80 self.solution["variables"][vk], 

81 unitstr(vk)) 

82 if vk.veckey in fixedvecs: 

83 vectarget = fixedvecs[vk.veckey] 

84 source = self.add_node(vectarget, label, "subarray") 

85 self.links[source, vectarget] = abs_var_sens 

86 self.links[vectarget, target] += abs_var_sens 

87 else: 

88 source = self.add_node(target, label, "constraint") 

89 self.links[source, target] = abs_var_sens 

90 total_sens += abs_var_sens 

91 return total_sens 

92 

93 # pylint: disable=too-many-branches, too-many-arguments 

94 def link(self, cset, target, vk, *, labeled=False, subarray=False): 

95 "adds links of a given constraint set to self.links" 

96 total_sens = 0 

97 switchedtarget = False 

98 if not labeled and isnamedmodel(cset): 

99 if cset is not self.cset: # top-level, no need to switch targets 

100 switchedtarget = target 

101 target = self.add_node(target, cset.lineage[-1][0]) 

102 if vk is None: 

103 total_sens += self.linkfixed(cset, target) 

104 elif isinstance(cset, ArrayConstraint) and cset.constraints.size > 1: 

105 switchedtarget = target 

106 cstr = cset.str_without(["lineage", "units"]).replace("[:]", "") 

107 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30] 

108 target = self.add_node(target, label, "constraint") 

109 subarray = True 

110 if getattr(cset, "idxlookup", None): 

111 cset = {k: cset[i] for k, i in cset.idxlookup.items()} 

112 if isinstance(cset, dict): 

113 for label, c in cset.items(): 

114 source = self.add_node(target, label) 

115 subtotal_sens = self.link(c, source, vk, labeled=True) 

116 self.links[source, target] += subtotal_sens 

117 total_sens += subtotal_sens 

118 elif isinstance(cset, Iterable): 

119 for c in cset: 

120 total_sens += self.link(c, target, vk, subarray=subarray) 

121 else: 

122 if vk is None and cset in self.csenss: 

123 total_sens = -abs(self.csenss[cset]) or -EPS 

124 elif vk is not None: 

125 if cset.v_ss is None: 

126 if vk in cset.varkeys: 

127 total_sens = EPS 

128 elif vk in cset.v_ss: 

129 total_sens = cset.v_ss[vk] or EPS 

130 if not labeled: 

131 cstr = cset.str_without(["lineage", "units"]) 

132 label = cstr if len(cstr) <= 30 else "%s ..." % cstr[:30] 

133 tag = "subarray" if subarray else "constraint" 

134 source = self.add_node(target, label, tag) 

135 self.links[source, target] = total_sens 

136 if switchedtarget: 

137 self.links[target, switchedtarget] += total_sens 

138 return total_sens 

139 

140 def filter(self, links, function, forced=False): 

141 "If over maxlinks, removes links that do not match criteria." 

142 if len(links) > self.maxlinks or forced: 

143 for (s, t), v in list(links.items()): 

144 if not function(s, t, v): 

145 del links[(s, t)] 

146 

147 # pylint: disable=too-many-locals, too-many-arguments 

148 def diagram(self, variable=None, varlabel=None, *, minsenss=0, maxlinks=20, 

149 top=0, bottom=0, left=230, right=140, width=1000, height=400, 

150 showconstraints=True): 

151 "creates links and an ipython widget to show them" 

152 margins = {"top": top, "bottom": bottom, "left": left, "right": right} 

153 self.minsenss = minsenss 

154 self.maxlinks = maxlinks 

155 self.showconstraints = showconstraints 

156 

157 self.solution.set_necessarylineage() 

158 

159 if variable: 

160 variable = variable.key 

161 if not varlabel: 

162 varlabel = str(variable) 

163 if len(varlabel) > 20: 

164 varlabel = variable.str_without(["lineage"]) 

165 self.nodes[varlabel] = {"id": varlabel, "title": varlabel} 

166 csetnode = self.add_node(varlabel, self.csetlabel) 

167 if variable in self.solution["sensitivities"]["cost"]: 

168 costnode = self.add_node(varlabel, "[cost function]") 

169 self.links[costnode, varlabel] = \ 

170 self.solution["sensitivities"]["cost"][variable] 

171 else: 

172 csetnode = self.csetlabel 

173 self.nodes[self.csetlabel] = {"id": self.csetlabel, 

174 "title": self.csetlabel} 

175 total_sens = self.link(self.cset, csetnode, variable) 

176 if variable: 

177 self.links[csetnode, varlabel] = total_sens 

178 

179 links, nodes = self._links_and_nodes() 

180 out = SankeyWidget(nodes=nodes, links=links, margins=margins, 

181 layout=Layout(width=str(width), height=str(height))) 

182 

183 filename = self.csetlabel 

184 if variable: 

185 filename += "_%s" % variable 

186 if not os.path.isdir("sankey_autosaves"): 

187 os.makedirs("sankey_autosaves") 

188 filename = "sankey_autosaves" + os.path.sep + cleanfilename(filename) 

189 out.auto_save_png(filename + ".png") 

190 out.auto_save_svg(filename + ".svg") 

191 out.on_node_clicked(self.onclick) 

192 out.on_link_clicked(self.onclick) 

193 

194 self.solution.set_necessarylineage(clear=True) 

195 return out 

196 

197 def _links_and_nodes(self, top_node=None): 

198 links = self.links.copy() 

199 # filter if...not below the chosen top node 

200 if top_node is not None: 

201 self.filter(links, lambda s, t, v: top_node in s or top_node in t, 

202 forced=True) 

203 # ...below minimum sensitivity 

204 self.filter(links, lambda s, t, v: abs(v) > self.minsenss, forced=True) 

205 if not self.showconstraints: 

206 # ...is a constraint or subarray and we're not showing those 

207 self.filter(links, lambda s, t, v: 

208 ("constraint" not in self.nodes[s] 

209 and "subarray" not in self.nodes[s]), forced=True) 

210 # ...is a subarray and we still have too many links 

211 self.filter(links, lambda s, t, v: "subarray" not in self.nodes[s]) 

212 # ...is an insensitive constraint and we still have too many links 

213 self.filter(links, lambda s, t, v: ("constraint" not in self.nodes[s] 

214 or abs(v) > INSENSITIVE)) 

215 # ...is at culldepth, repeating up to a relative depth of 1 or 2 

216 culldepth = max(node.count(".") for node in self.nodes) - 1 

217 mindepth = 1 if not top_node else top_node.count(".") + 1 

218 while len(links) > self.maxlinks and culldepth > mindepth: 

219 self.filter(links, lambda s, t, v: culldepth > s.count(".")) 

220 culldepth -= 1 

221 # ...is a constraint and we still have too many links 

222 self.filter(links, lambda s, t, v: "constraint" not in self.nodes[s]) 

223 

224 linkslist, nodes, nodeset = [], [], set() 

225 for (source, target), value in links.items(): 

226 if source == top_node: 

227 nodes.append({"id": self.nodes[target]["id"], 

228 "title": "⟶ %s" % self.nodes[target]["title"]}) 

229 nodeset.add(target) 

230 for node in [source, target]: 

231 if node not in nodeset: 

232 nodes.append({"id": self.nodes[node]["id"], 

233 "title": self.nodes[node]["title"]}) 

234 nodeset.add(node) 

235 linkslist.append({"source": source, "target": target, 

236 "value": abs(value), "color": getcolor(value), 

237 "title": "%+.2g" % value}) 

238 return linkslist, nodes 

239 

240 def onclick(self, sankey, node_or_link): 

241 "Callback function for when a node or link is clicked on." 

242 if node_or_link is not None: 

243 if "id" in node_or_link: # it's a node 

244 top_node = node_or_link["id"] 

245 else: # it's a link 

246 top_node = node_or_link["source"] 

247 if self.last_top_node != top_node: 

248 sankey.links, sankey.nodes = self._links_and_nodes(top_node) 

249 sankey.send_state() 

250 self.last_top_node = top_node