Coverage for gpkit/solution_ensemble.py: 0%

185 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 16:49 -0500

1"implements SolutionEnsemble class" 

2import pickle 

3import numpy as np 

4from gpkit.keydict import KeyDict 

5from gpkit.varkey import VarKey 

6 

7def varsort(diff): 

8 "Sort function for variables" 

9 var, *_ = diff 

10 return var.str_without({"hiddenlineage"}) 

11 

12 

13def vardescr(var): 

14 "Returns a string fully describing a variable" 

15 return "%s (%s)" % (var.label, var) 

16 

17class OpenedSolutionEnsemble: 

18 "Helper class for use with `with` to handle opening/closing an ensemble" 

19 def __init__(self, filename="solensemble.pkl"): 

20 self.filename = filename 

21 try: 

22 self.solensemble = SolutionEnsemble.load(filename) 

23 except (EOFError, FileNotFoundError): 

24 self.solensemble = SolutionEnsemble() 

25 

26 def __enter__(self): 

27 return self.solensemble 

28 

29 def __exit__(self, type_, val, traceback): 

30 self.solensemble.save(self.filename) 

31 

32class SolutionEnsemble: 

33 """An ensemble of solutions. 

34 

35 Attributes: 

36 "solutions" : all solutions, keyed by modified variables 

37 "labels" : solution labels, keyed by modified variables 

38 

39 SolutionEnsemble[varstr] : will return the relevant varkey 

40 

41 """ 

42 

43 def __str__(self): 

44 out = ("Solution ensemble with a baseline and %s modified solutions:" 

45 % (len(self.solutions) - 1)) 

46 for differences in self.solutions: 

47 if differences: 

48 out += "\n " + self.labels[differences] 

49 return out 

50 

51 def __init__(self): 

52 self.baseline = None 

53 self.solutions = {} 

54 self.labels = {} 

55 

56 def save(self, filename="solensemble.pkl", **pickleargs): 

57 "Pickle a file and then compress it into a file with extension." 

58 pickle.dump(self, open(filename, "wb"), **pickleargs) 

59 

60 @staticmethod 

61 def load(filename): 

62 "Loads a SolutionEnsemble" 

63 return pickle.load(open(filename, "rb")) 

64 

65 def __getitem__(self, var): 

66 nameref = self.baseline["variables"] 

67 k, _ = nameref.parse_and_index(var) 

68 if isinstance(k, str): 

69 kstr = k 

70 else: 

71 kstr = k.str_without({"lineage", "idx"}) 

72 if k.lineage: 

73 kstr = k.lineagestr() + "." + kstr 

74 keys = nameref.keymap[kstr] 

75 if len(keys) != 1: 

76 raise KeyError(var) 

77 basevar, = keys 

78 return basevar 

79 

80 def filter(self, *requirements): 

81 "Filters by requirements, returning another solution ensemble" 

82 candidates = set(self.solutions) 

83 for requirement in requirements: 

84 if (isinstance(requirement, str) 

85 or not hasattr(requirement, "__len__")): 

86 requirement = [requirement] 

87 subreqs = [] 

88 for subreq in requirement: 

89 try: 

90 subreqs.append(self[subreq]) 

91 except (AttributeError, KeyError): 

92 subreqs.append(subreq) 

93 for candidate in set(candidates): 

94 found_requirement = False 

95 for difference in candidate: 

96 if all(subreq in difference for subreq in subreqs): 

97 found_requirement = True 

98 break 

99 if not found_requirement: 

100 candidates.remove(candidate) 

101 se = SolutionEnsemble() 

102 se.append(self.baseline) 

103 for candidate in candidates: 

104 se.append(self.solutions[candidate], verbosity=0) 

105 return se 

106 

107 def get_solutions(self, *requirements): 

108 "Filters by requirements, returning a list of solutions." 

109 return [sol 

110 for diff, sol in self.filter(*requirements).solutions.items() 

111 if diff] 

112 

113 def append(self, solution, verbosity=1): # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

114 "Appends solution to the Ensemble" 

115 solution.set_necessarylineage() 

116 for var in solution["variables"]: 

117 var.descr.pop("vecfn", None) 

118 var.descr.pop("evalfn", None) 

119 if self.baseline is None: 

120 if "sweepvariables" in solution: 

121 raise ValueError("baseline solution cannot be a sweep") 

122 self.baseline = self.solutions[()] = solution 

123 self.labels[()] = "Baseline Solution" 

124 return 

125 

126 solconstraintstr, baseconstraintstr = ( 

127 sol.modelstr[sol.modelstr.find("Constraints"):] 

128 for sol in [solution, self.baseline]) 

129 if solconstraintstr != baseconstraintstr: 

130 raise ValueError("the new model's constraints are not identical" 

131 " to the base model's constraints." 

132 " (Use .baseline.diff(sol) to compare.)") 

133 

134 solution.pop("warnings", None) 

135 solution.pop("freevariables", None) 

136 solution["sensitivities"].pop("constants", None) 

137 for subd, value in solution.items(): 

138 if isinstance(value, KeyDict): 

139 solution[subd] = KeyDict() 

140 for oldkey, val in value.items(): 

141 solution[subd][self[oldkey]] = val 

142 for subd, value in solution["sensitivities"].items(): 

143 if subd == "constraints": 

144 solution["sensitivities"][subd] = {} 

145 cstrs = {str(c): c 

146 for c in self.baseline["sensitivities"][subd]} 

147 for oldkey, val in value.items(): 

148 if np.abs(val).max() < 1e-2: 

149 if hasattr(val, "shape"): 

150 val = np.zeros(val.shape, dtype=np.bool_) 

151 else: 

152 val = 0 

153 elif hasattr(val, "shape"): 

154 val = np.array(val, dtype=np.float16) 

155 solution["sensitivities"][subd][cstrs[str(oldkey)]] = val 

156 elif isinstance(value, KeyDict): 

157 solution["sensitivities"][subd] = KeyDict() 

158 for oldkey, val in value.items(): 

159 if np.abs(val).max() < 1e-2: 

160 if hasattr(val, "shape"): 

161 val = np.zeros(val.shape, dtype=np.bool_) 

162 else: 

163 val = 0 

164 elif hasattr(val, "shape"): 

165 val = np.array(val, dtype=np.float16) 

166 solution["sensitivities"][subd][self[oldkey]] = val 

167 

168 differences = [] 

169 labels = [] 

170 solcostfun = solution["cost function"] 

171 if len(solution) > 1: 

172 solcostfun = solcostfun[0] 

173 solcoststr = solcostfun.str_without({"units"}) 

174 basecoststr = self.baseline["cost function"].str_without({"units"}) 

175 if basecoststr != solcoststr: 

176 differences.append(("cost", solcoststr)) 

177 labels.append("Cost function set to %s" % solcoststr) 

178 

179 freedvars = set() 

180 setvars = set() 

181 def check_var(var,): 

182 fixed_in_baseline = var in self.baseline["constants"] 

183 fixed_in_solution = var in solution["constants"] 

184 bval = self.baseline["variables"][var] 

185 if fixed_in_solution: 

186 sval = solution["constants"][var] 

187 else: 

188 sval = solution["variables"][var] 

189 if fixed_in_solution and getattr(sval, "shape", None): 

190 pass # calculated constant that depends on a sweep variable 

191 elif fixed_in_solution and sval != bval: 

192 setvars.add((var, sval)) # whether free or fixed before 

193 elif not fixed_in_solution and fixed_in_baseline: 

194 if var not in solution["sweepvariables"]: 

195 freedvars.add((var,)) 

196 

197 for var in self.baseline["variables"]: 

198 if var not in solution["variables"]: 

199 print("Variable", var, "removed (relative to baseline)") 

200 continue 

201 if not var.shape: 

202 check_var(var) 

203 else: 

204 it = np.nditer(np.empty(var.shape), flags=["multi_index"]) 

205 while not it.finished: 

206 check_var(VarKey(idx=it.multi_index, **var.descr)) 

207 it.iternext() 

208 

209 for freedvar, in sorted(freedvars, key=varsort): 

210 differences.append((freedvar, "freed")) 

211 labels.append(vardescr(freedvar) + " freed") 

212 for setvar, setval in sorted(setvars, key=varsort): 

213 differences.append((setvar, setval)) 

214 labels.append(vardescr(setvar) + " set to %.5g%s" 

215 % (setval, setvar.unitstr(into=' %s'))) 

216 if "sweepvariables" in solution: 

217 for var, vals in sorted(solution["sweepvariables"].items(), 

218 key=varsort): 

219 var = self[var] 

220 if var.shape: 

221 it = np.nditer(np.empty(var.shape), flags=["multi_index"]) 

222 while not it.finished: 

223 valsi = vals[(...,)+it.multi_index] 

224 if not np.isnan(valsi).any(): 

225 idxvar = VarKey(idx=it.multi_index, **var.descr) 

226 differences.append((idxvar, "sweep", 

227 (min(valsi), max(valsi)))) 

228 labels.append(vardescr(idxvar) + " swept from" 

229 + " %.5g to" % min(valsi) 

230 + " %.5g" % max(valsi) 

231 + idxvar.unitstr(into=' %s')) 

232 it.iternext() 

233 else: 

234 differences.append((var, "sweep", (min(vals), max(vals)))) 

235 labels.append(vardescr(var) + " swept from" 

236 + " %.5g to" % min(vals) 

237 + " %.5g" % max(vals) 

238 + var.unitstr(into=' %s')) 

239 difference = tuple(differences) 

240 label = ", ".join(labels) 

241 if verbosity > 0: 

242 if difference in self.solutions: 

243 if not difference: 

244 print("The baseline in this ensemble cannot be replaced.") 

245 else: 

246 print(label + " will be replaced in the ensemble.") 

247 else: 

248 print(label + " added to the ensemble.") 

249 

250 self.solutions[difference] = solution 

251 self.labels[difference] = label