Coverage for gpkit/solution_ensemble.py: 0%

188 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 22:15 -0500

1"implements SolutionEnsemble class" 

2import pickle 

3import numpy as np 

4from gpkit.keydict import KeyDict 

5from gpkit.varkey import VarKey 

6 

7def varsort(diff): 

8 "Sort function for variables" 

9 var, *_ = diff 

10 return var.str_without({"hiddenlineage"}) 

11 

12 

13def vardescr(var): 

14 "Returns a string fully describing a variable" 

15 return f"{var.label} ({var})" 

16 

17class OpenedSolutionEnsemble: 

18 "Helper class for use with `with` to handle opening/closing an ensemble" 

19 def __init__(self, filename="solensemble.pkl"): 

20 self.filename = filename 

21 try: 

22 self.solensemble = SolutionEnsemble.load(filename) 

23 except (EOFError, FileNotFoundError): 

24 self.solensemble = SolutionEnsemble() 

25 

26 def __enter__(self): 

27 return self.solensemble 

28 

29 def __exit__(self, type_, val, traceback): 

30 self.solensemble.save(self.filename) 

31 

32class SolutionEnsemble: 

33 """An ensemble of solutions. 

34 

35 Attributes: 

36 "solutions" : all solutions, keyed by modified variables 

37 "labels" : solution labels, keyed by modified variables 

38 

39 SolutionEnsemble[varstr] : will return the relevant varkey 

40 

41 """ 

42 

43 def __str__(self): 

44 nmods = len(self.solutions) - 1 

45 out = ("Solution ensemble with a baseline and" 

46 f"{nmods} modified solutions:") 

47 for differences in self.solutions: 

48 if differences: 

49 out += "\n " + self.labels[differences] 

50 return out 

51 

52 def __init__(self): 

53 self.baseline = None 

54 self.solutions = {} 

55 self.labels = {} 

56 

57 def save(self, filename="solensemble.pkl", **pickleargs): 

58 "Pickle a file and then compress it into a file with extension." 

59 with open(filename, "wb") as f: 

60 pickle.dump(self, f, **pickleargs) 

61 

62 @staticmethod 

63 def load(filename): 

64 "Loads a SolutionEnsemble" 

65 return pickle.load(open(filename, "rb")) 

66 

67 def __getitem__(self, var): 

68 nameref = self.baseline["variables"] 

69 k, _ = nameref.parse_and_index(var) 

70 if isinstance(k, str): 

71 kstr = k 

72 else: 

73 kstr = k.str_without({"lineage", "idx"}) 

74 if k.lineage: 

75 kstr = k.lineagestr() + "." + kstr 

76 keys = nameref.keymap[kstr] 

77 if len(keys) != 1: 

78 raise KeyError(var) 

79 basevar, = keys 

80 return basevar 

81 

82 def filter(self, *requirements): 

83 "Filters by requirements, returning another solution ensemble" 

84 candidates = set(self.solutions) 

85 for requirement in requirements: 

86 if (isinstance(requirement, str) 

87 or not hasattr(requirement, "__len__")): 

88 requirement = [requirement] 

89 subreqs = [] 

90 for subreq in requirement: 

91 try: 

92 subreqs.append(self[subreq]) 

93 except (AttributeError, KeyError): 

94 subreqs.append(subreq) 

95 for candidate in set(candidates): 

96 found_requirement = False 

97 for difference in candidate: 

98 if all(subreq in difference for subreq in subreqs): 

99 found_requirement = True 

100 break 

101 if not found_requirement: 

102 candidates.remove(candidate) 

103 se = SolutionEnsemble() 

104 se.append(self.baseline) 

105 for candidate in candidates: 

106 se.append(self.solutions[candidate], verbosity=0) 

107 return se 

108 

109 def get_solutions(self, *requirements): 

110 "Filters by requirements, returning a list of solutions." 

111 return [sol 

112 for diff, sol in self.filter(*requirements).solutions.items() 

113 if diff] 

114 

115 def append(self, solution, verbosity=1): # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

116 "Appends solution to the Ensemble" 

117 solution.set_necessarylineage() 

118 for var in solution["variables"]: 

119 var.descr.pop("vecfn", None) 

120 var.descr.pop("evalfn", None) 

121 if self.baseline is None: 

122 if "sweepvariables" in solution: 

123 raise ValueError("baseline solution cannot be a sweep") 

124 self.baseline = self.solutions[()] = solution 

125 self.labels[()] = "Baseline Solution" 

126 return 

127 

128 solconstraintstr, baseconstraintstr = ( 

129 sol.modelstr[sol.modelstr.find("Constraints"):] 

130 for sol in [solution, self.baseline]) 

131 if solconstraintstr != baseconstraintstr: 

132 raise ValueError("the new model's constraints are not identical" 

133 " to the base model's constraints." 

134 " (Use .baseline.diff(sol) to compare.)") 

135 

136 solution.pop("warnings", None) 

137 solution.pop("freevariables", None) 

138 solution["sensitivities"].pop("constants", None) 

139 for subd, value in solution.items(): 

140 if isinstance(value, KeyDict): 

141 solution[subd] = KeyDict() 

142 for oldkey, val in value.items(): 

143 solution[subd][self[oldkey]] = val 

144 for subd, value in solution["sensitivities"].items(): 

145 if subd == "constraints": 

146 solution["sensitivities"][subd] = {} 

147 cstrs = {str(c): c 

148 for c in self.baseline["sensitivities"][subd]} 

149 for oldkey, val in value.items(): 

150 if np.abs(val).max() < 1e-2: 

151 if hasattr(val, "shape"): 

152 val = np.zeros(val.shape, dtype=np.bool_) 

153 else: 

154 val = 0 

155 elif hasattr(val, "shape"): 

156 val = np.array(val, dtype=np.float16) 

157 solution["sensitivities"][subd][cstrs[str(oldkey)]] = val 

158 elif isinstance(value, KeyDict): 

159 solution["sensitivities"][subd] = KeyDict() 

160 for oldkey, val in value.items(): 

161 if np.abs(val).max() < 1e-2: 

162 if hasattr(val, "shape"): 

163 val = np.zeros(val.shape, dtype=np.bool_) 

164 else: 

165 val = 0 

166 elif hasattr(val, "shape"): 

167 val = np.array(val, dtype=np.float16) 

168 solution["sensitivities"][subd][self[oldkey]] = val 

169 

170 differences = [] 

171 labels = [] 

172 solcostfun = solution["cost function"] 

173 if len(solution) > 1: 

174 solcostfun = solcostfun[0] 

175 solcoststr = solcostfun.str_without({"units"}) 

176 basecoststr = self.baseline["cost function"].str_without({"units"}) 

177 if basecoststr != solcoststr: 

178 differences.append(("cost", solcoststr)) 

179 labels.append(f"Cost function set to {solcoststr}") 

180 

181 freedvars = set() 

182 setvars = set() 

183 def check_var(var,): 

184 fixed_in_baseline = var in self.baseline["constants"] 

185 fixed_in_solution = var in solution["constants"] 

186 bval = self.baseline["variables"][var] 

187 if fixed_in_solution: 

188 sval = solution["constants"][var] 

189 else: 

190 sval = solution["variables"][var] 

191 if fixed_in_solution and getattr(sval, "shape", None): 

192 pass # calculated constant that depends on a sweep variable 

193 elif fixed_in_solution and sval != bval: 

194 setvars.add((var, sval)) # whether free or fixed before 

195 elif not fixed_in_solution and fixed_in_baseline: 

196 if var not in solution["sweepvariables"]: 

197 freedvars.add((var,)) 

198 

199 for var in self.baseline["variables"]: 

200 if var not in solution["variables"]: 

201 print("Variable", var, "removed (relative to baseline)") 

202 continue 

203 if not var.shape: 

204 check_var(var) 

205 else: 

206 it = np.nditer(np.empty(var.shape), flags=["multi_index"]) 

207 while not it.finished: 

208 check_var(VarKey(idx=it.multi_index, **var.descr)) 

209 it.iternext() 

210 

211 for freedvar, in sorted(freedvars, key=varsort): 

212 differences.append((freedvar, "freed")) 

213 labels.append(vardescr(freedvar) + " freed") 

214 for setvar, setval in sorted(setvars, key=varsort): 

215 differences.append((setvar, setval)) 

216 ustr = setvar.unitstr(into=' %s') 

217 labels.append(vardescr(setvar) + f" set to {setval:.5g}" + ustr) 

218 if "sweepvariables" in solution: 

219 for var, vals in sorted(solution["sweepvariables"].items(), 

220 key=varsort): 

221 var = self[var] 

222 if var.shape: 

223 it = np.nditer(np.empty(var.shape), flags=["multi_index"]) 

224 while not it.finished: 

225 valsi = vals[(...,)+it.multi_index] 

226 if not np.isnan(valsi).any(): 

227 idxvar = VarKey(idx=it.multi_index, **var.descr) 

228 differences.append((idxvar, "sweep", 

229 (min(valsi), max(valsi)))) 

230 labels.append(vardescr(idxvar) + " swept from" 

231 + f" {min(valsi):.5g} to" 

232 + f" {max(valsi):.5g}" 

233 + idxvar.unitstr(into=' %s')) 

234 it.iternext() 

235 else: 

236 differences.append((var, "sweep", (min(vals), max(vals)))) 

237 labels.append(vardescr(var) + " swept from" 

238 + f" {min(vals):.5g} to" 

239 + f" {max(vals):.5g}" 

240 + var.unitstr(into=' %s')) 

241 difference = tuple(differences) 

242 label = ", ".join(labels) 

243 if verbosity > 0: 

244 if difference in self.solutions: 

245 if not difference: 

246 print("The baseline in this ensemble cannot be replaced.") 

247 else: 

248 print(label + " will be replaced in the ensemble.") 

249 else: 

250 print(label + " added to the ensemble.") 

251 

252 self.solutions[difference] = solution 

253 self.labels[difference] = label