Coverage for gpkit/solution_ensemble.py: 0%
185 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-28 12:36 -0400
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-28 12:36 -0400
1"implements SolutionEnsemble class"
2import pickle
3import numpy as np
4from gpkit.keydict import KeyDict
5from gpkit.varkey import VarKey
7def varsort(diff):
8 "Sort function for variables"
9 var, *_ = diff
10 return var.str_without({"hiddenlineage"})
13def vardescr(var):
14 "Returns a string fully describing a variable"
15 return "%s (%s)" % (var.label, var)
17class OpenedSolutionEnsemble:
18 "Helper class for use with `with` to handle opening/closing an ensemble"
19 def __init__(self, filename="solensemble.pkl"):
20 self.filename = filename
21 try:
22 self.solensemble = SolutionEnsemble.load(filename)
23 except (EOFError, FileNotFoundError):
24 self.solensemble = SolutionEnsemble()
26 def __enter__(self):
27 return self.solensemble
29 def __exit__(self, type_, val, traceback):
30 self.solensemble.save(self.filename)
32class SolutionEnsemble:
33 """An ensemble of solutions.
35 Attributes:
36 "solutions" : all solutions, keyed by modified variables
37 "labels" : solution labels, keyed by modified variables
39 SolutionEnsemble[varstr] : will return the relevant varkey
41 """
43 def __str__(self):
44 out = ("Solution ensemble with a baseline and %s modified solutions:"
45 % (len(self.solutions) - 1))
46 for differences in self.solutions:
47 if differences:
48 out += "\n " + self.labels[differences]
49 return out
51 def __init__(self):
52 self.baseline = None
53 self.solutions = {}
54 self.labels = {}
56 def save(self, filename="solensemble.pkl", **pickleargs):
57 "Pickle a file and then compress it into a file with extension."
58 pickle.dump(self, open(filename, "wb"), **pickleargs)
60 @staticmethod
61 def load(filename):
62 "Loads a SolutionEnsemble"
63 return pickle.load(open(filename, "rb"))
65 def __getitem__(self, var):
66 nameref = self.baseline["variables"]
67 k, _ = nameref.parse_and_index(var)
68 if isinstance(k, str):
69 kstr = k
70 else:
71 kstr = k.str_without({"lineage", "idx"})
72 if k.lineage:
73 kstr = k.lineagestr() + "." + kstr
74 keys = nameref.keymap[kstr]
75 if len(keys) != 1:
76 raise KeyError(var)
77 basevar, = keys
78 return basevar
80 def filter(self, *requirements):
81 "Filters by requirements, returning another solution ensemble"
82 candidates = set(self.solutions)
83 for requirement in requirements:
84 if (isinstance(requirement, str)
85 or not hasattr(requirement, "__len__")):
86 requirement = [requirement]
87 subreqs = []
88 for subreq in requirement:
89 try:
90 subreqs.append(self[subreq])
91 except (AttributeError, KeyError):
92 subreqs.append(subreq)
93 for candidate in set(candidates):
94 found_requirement = False
95 for difference in candidate:
96 if all(subreq in difference for subreq in subreqs):
97 found_requirement = True
98 break
99 if not found_requirement:
100 candidates.remove(candidate)
101 se = SolutionEnsemble()
102 se.append(self.baseline)
103 for candidate in candidates:
104 se.append(self.solutions[candidate], verbosity=0)
105 return se
107 def get_solutions(self, *requirements):
108 "Filters by requirements, returning a list of solutions."
109 return [sol
110 for diff, sol in self.filter(*requirements).solutions.items()
111 if diff]
113 def append(self, solution, verbosity=1): # pylint: disable=too-many-locals, too-many-branches, too-many-statements
114 "Appends solution to the Ensemble"
115 solution.set_necessarylineage()
116 for var in solution["variables"]:
117 var.descr.pop("vecfn", None)
118 var.descr.pop("evalfn", None)
119 if self.baseline is None:
120 if "sweepvariables" in solution:
121 raise ValueError("baseline solution cannot be a sweep")
122 self.baseline = self.solutions[()] = solution
123 self.labels[()] = "Baseline Solution"
124 return
126 solconstraintstr, baseconstraintstr = (
127 sol.modelstr[sol.modelstr.find("Constraints"):]
128 for sol in [solution, self.baseline])
129 if solconstraintstr != baseconstraintstr:
130 raise ValueError("the new model's constraints are not identical"
131 " to the base model's constraints."
132 " (Use .baseline.diff(sol) to compare.)")
134 solution.pop("warnings", None)
135 solution.pop("freevariables", None)
136 solution["sensitivities"].pop("constants", None)
137 for subd, value in solution.items():
138 if isinstance(value, KeyDict):
139 solution[subd] = KeyDict()
140 for oldkey, val in value.items():
141 solution[subd][self[oldkey]] = val
142 for subd, value in solution["sensitivities"].items():
143 if subd == "constraints":
144 solution["sensitivities"][subd] = {}
145 cstrs = {str(c): c
146 for c in self.baseline["sensitivities"][subd]}
147 for oldkey, val in value.items():
148 if np.abs(val).max() < 1e-2:
149 if hasattr(val, "shape"):
150 val = np.zeros(val.shape, dtype=np.bool_)
151 else:
152 val = 0
153 elif hasattr(val, "shape"):
154 val = np.array(val, dtype=np.float16)
155 solution["sensitivities"][subd][cstrs[str(oldkey)]] = val
156 elif isinstance(value, KeyDict):
157 solution["sensitivities"][subd] = KeyDict()
158 for oldkey, val in value.items():
159 if np.abs(val).max() < 1e-2:
160 if hasattr(val, "shape"):
161 val = np.zeros(val.shape, dtype=np.bool_)
162 else:
163 val = 0
164 elif hasattr(val, "shape"):
165 val = np.array(val, dtype=np.float16)
166 solution["sensitivities"][subd][self[oldkey]] = val
168 differences = []
169 labels = []
170 solcostfun = solution["cost function"]
171 if len(solution) > 1:
172 solcostfun = solcostfun[0]
173 solcoststr = solcostfun.str_without({"units"})
174 basecoststr = self.baseline["cost function"].str_without({"units"})
175 if basecoststr != solcoststr:
176 differences.append(("cost", solcoststr))
177 labels.append("Cost function set to %s" % solcoststr)
179 freedvars = set()
180 setvars = set()
181 def check_var(var,):
182 fixed_in_baseline = var in self.baseline["constants"]
183 fixed_in_solution = var in solution["constants"]
184 bval = self.baseline["variables"][var]
185 if fixed_in_solution:
186 sval = solution["constants"][var]
187 else:
188 sval = solution["variables"][var]
189 if fixed_in_solution and getattr(sval, "shape", None):
190 pass # calculated constant that depends on a sweep variable
191 elif fixed_in_solution and sval != bval:
192 setvars.add((var, sval)) # whether free or fixed before
193 elif not fixed_in_solution and fixed_in_baseline:
194 if var not in solution["sweepvariables"]:
195 freedvars.add((var,))
197 for var in self.baseline["variables"]:
198 if var not in solution["variables"]:
199 print("Variable", var, "removed (relative to baseline)")
200 continue
201 if not var.shape:
202 check_var(var)
203 else:
204 it = np.nditer(np.empty(var.shape), flags=["multi_index"])
205 while not it.finished:
206 check_var(VarKey(idx=it.multi_index, **var.descr))
207 it.iternext()
209 for freedvar, in sorted(freedvars, key=varsort):
210 differences.append((freedvar, "freed"))
211 labels.append(vardescr(freedvar) + " freed")
212 for setvar, setval in sorted(setvars, key=varsort):
213 differences.append((setvar, setval))
214 labels.append(vardescr(setvar) + " set to %.5g%s"
215 % (setval, setvar.unitstr(into=' %s')))
216 if "sweepvariables" in solution:
217 for var, vals in sorted(solution["sweepvariables"].items(),
218 key=varsort):
219 var = self[var]
220 if var.shape:
221 it = np.nditer(np.empty(var.shape), flags=["multi_index"])
222 while not it.finished:
223 valsi = vals[(...,)+it.multi_index]
224 if not np.isnan(valsi).any():
225 idxvar = VarKey(idx=it.multi_index, **var.descr)
226 differences.append((idxvar, "sweep",
227 (min(valsi), max(valsi))))
228 labels.append(vardescr(idxvar) + " swept from"
229 + " %.5g to" % min(valsi)
230 + " %.5g" % max(valsi)
231 + idxvar.unitstr(into=' %s'))
232 it.iternext()
233 else:
234 differences.append((var, "sweep", (min(vals), max(vals))))
235 labels.append(vardescr(var) + " swept from"
236 + " %.5g to" % min(vals)
237 + " %.5g" % max(vals)
238 + var.unitstr(into=' %s'))
239 difference = tuple(differences)
240 label = ", ".join(labels)
241 if verbosity > 0:
242 if difference in self.solutions:
243 if not difference:
244 print("The baseline in this ensemble cannot be replaced.")
245 else:
246 print(label + " will be replaced in the ensemble.")
247 else:
248 print(label + " added to the ensemble.")
250 self.solutions[difference] = solution
251 self.labels[difference] = label