Coverage for gpkit/solution_ensemble.py: 0%
188 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-07 22:15 -0500
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-07 22:15 -0500
1"implements SolutionEnsemble class"
2import pickle
3import numpy as np
4from gpkit.keydict import KeyDict
5from gpkit.varkey import VarKey
7def varsort(diff):
8 "Sort function for variables"
9 var, *_ = diff
10 return var.str_without({"hiddenlineage"})
13def vardescr(var):
14 "Returns a string fully describing a variable"
15 return f"{var.label} ({var})"
17class OpenedSolutionEnsemble:
18 "Helper class for use with `with` to handle opening/closing an ensemble"
19 def __init__(self, filename="solensemble.pkl"):
20 self.filename = filename
21 try:
22 self.solensemble = SolutionEnsemble.load(filename)
23 except (EOFError, FileNotFoundError):
24 self.solensemble = SolutionEnsemble()
26 def __enter__(self):
27 return self.solensemble
29 def __exit__(self, type_, val, traceback):
30 self.solensemble.save(self.filename)
32class SolutionEnsemble:
33 """An ensemble of solutions.
35 Attributes:
36 "solutions" : all solutions, keyed by modified variables
37 "labels" : solution labels, keyed by modified variables
39 SolutionEnsemble[varstr] : will return the relevant varkey
41 """
43 def __str__(self):
44 nmods = len(self.solutions) - 1
45 out = ("Solution ensemble with a baseline and"
46 f"{nmods} modified solutions:")
47 for differences in self.solutions:
48 if differences:
49 out += "\n " + self.labels[differences]
50 return out
52 def __init__(self):
53 self.baseline = None
54 self.solutions = {}
55 self.labels = {}
57 def save(self, filename="solensemble.pkl", **pickleargs):
58 "Pickle a file and then compress it into a file with extension."
59 with open(filename, "wb") as f:
60 pickle.dump(self, f, **pickleargs)
62 @staticmethod
63 def load(filename):
64 "Loads a SolutionEnsemble"
65 return pickle.load(open(filename, "rb"))
67 def __getitem__(self, var):
68 nameref = self.baseline["variables"]
69 k, _ = nameref.parse_and_index(var)
70 if isinstance(k, str):
71 kstr = k
72 else:
73 kstr = k.str_without({"lineage", "idx"})
74 if k.lineage:
75 kstr = k.lineagestr() + "." + kstr
76 keys = nameref.keymap[kstr]
77 if len(keys) != 1:
78 raise KeyError(var)
79 basevar, = keys
80 return basevar
82 def filter(self, *requirements):
83 "Filters by requirements, returning another solution ensemble"
84 candidates = set(self.solutions)
85 for requirement in requirements:
86 if (isinstance(requirement, str)
87 or not hasattr(requirement, "__len__")):
88 requirement = [requirement]
89 subreqs = []
90 for subreq in requirement:
91 try:
92 subreqs.append(self[subreq])
93 except (AttributeError, KeyError):
94 subreqs.append(subreq)
95 for candidate in set(candidates):
96 found_requirement = False
97 for difference in candidate:
98 if all(subreq in difference for subreq in subreqs):
99 found_requirement = True
100 break
101 if not found_requirement:
102 candidates.remove(candidate)
103 se = SolutionEnsemble()
104 se.append(self.baseline)
105 for candidate in candidates:
106 se.append(self.solutions[candidate], verbosity=0)
107 return se
109 def get_solutions(self, *requirements):
110 "Filters by requirements, returning a list of solutions."
111 return [sol
112 for diff, sol in self.filter(*requirements).solutions.items()
113 if diff]
115 def append(self, solution, verbosity=1): # pylint: disable=too-many-locals, too-many-branches, too-many-statements
116 "Appends solution to the Ensemble"
117 solution.set_necessarylineage()
118 for var in solution["variables"]:
119 var.descr.pop("vecfn", None)
120 var.descr.pop("evalfn", None)
121 if self.baseline is None:
122 if "sweepvariables" in solution:
123 raise ValueError("baseline solution cannot be a sweep")
124 self.baseline = self.solutions[()] = solution
125 self.labels[()] = "Baseline Solution"
126 return
128 solconstraintstr, baseconstraintstr = (
129 sol.modelstr[sol.modelstr.find("Constraints"):]
130 for sol in [solution, self.baseline])
131 if solconstraintstr != baseconstraintstr:
132 raise ValueError("the new model's constraints are not identical"
133 " to the base model's constraints."
134 " (Use .baseline.diff(sol) to compare.)")
136 solution.pop("warnings", None)
137 solution.pop("freevariables", None)
138 solution["sensitivities"].pop("constants", None)
139 for subd, value in solution.items():
140 if isinstance(value, KeyDict):
141 solution[subd] = KeyDict()
142 for oldkey, val in value.items():
143 solution[subd][self[oldkey]] = val
144 for subd, value in solution["sensitivities"].items():
145 if subd == "constraints":
146 solution["sensitivities"][subd] = {}
147 cstrs = {str(c): c
148 for c in self.baseline["sensitivities"][subd]}
149 for oldkey, val in value.items():
150 if np.abs(val).max() < 1e-2:
151 if hasattr(val, "shape"):
152 val = np.zeros(val.shape, dtype=np.bool_)
153 else:
154 val = 0
155 elif hasattr(val, "shape"):
156 val = np.array(val, dtype=np.float16)
157 solution["sensitivities"][subd][cstrs[str(oldkey)]] = val
158 elif isinstance(value, KeyDict):
159 solution["sensitivities"][subd] = KeyDict()
160 for oldkey, val in value.items():
161 if np.abs(val).max() < 1e-2:
162 if hasattr(val, "shape"):
163 val = np.zeros(val.shape, dtype=np.bool_)
164 else:
165 val = 0
166 elif hasattr(val, "shape"):
167 val = np.array(val, dtype=np.float16)
168 solution["sensitivities"][subd][self[oldkey]] = val
170 differences = []
171 labels = []
172 solcostfun = solution["cost function"]
173 if len(solution) > 1:
174 solcostfun = solcostfun[0]
175 solcoststr = solcostfun.str_without({"units"})
176 basecoststr = self.baseline["cost function"].str_without({"units"})
177 if basecoststr != solcoststr:
178 differences.append(("cost", solcoststr))
179 labels.append(f"Cost function set to {solcoststr}")
181 freedvars = set()
182 setvars = set()
183 def check_var(var,):
184 fixed_in_baseline = var in self.baseline["constants"]
185 fixed_in_solution = var in solution["constants"]
186 bval = self.baseline["variables"][var]
187 if fixed_in_solution:
188 sval = solution["constants"][var]
189 else:
190 sval = solution["variables"][var]
191 if fixed_in_solution and getattr(sval, "shape", None):
192 pass # calculated constant that depends on a sweep variable
193 elif fixed_in_solution and sval != bval:
194 setvars.add((var, sval)) # whether free or fixed before
195 elif not fixed_in_solution and fixed_in_baseline:
196 if var not in solution["sweepvariables"]:
197 freedvars.add((var,))
199 for var in self.baseline["variables"]:
200 if var not in solution["variables"]:
201 print("Variable", var, "removed (relative to baseline)")
202 continue
203 if not var.shape:
204 check_var(var)
205 else:
206 it = np.nditer(np.empty(var.shape), flags=["multi_index"])
207 while not it.finished:
208 check_var(VarKey(idx=it.multi_index, **var.descr))
209 it.iternext()
211 for freedvar, in sorted(freedvars, key=varsort):
212 differences.append((freedvar, "freed"))
213 labels.append(vardescr(freedvar) + " freed")
214 for setvar, setval in sorted(setvars, key=varsort):
215 differences.append((setvar, setval))
216 ustr = setvar.unitstr(into=' %s')
217 labels.append(vardescr(setvar) + f" set to {setval:.5g}" + ustr)
218 if "sweepvariables" in solution:
219 for var, vals in sorted(solution["sweepvariables"].items(),
220 key=varsort):
221 var = self[var]
222 if var.shape:
223 it = np.nditer(np.empty(var.shape), flags=["multi_index"])
224 while not it.finished:
225 valsi = vals[(...,)+it.multi_index]
226 if not np.isnan(valsi).any():
227 idxvar = VarKey(idx=it.multi_index, **var.descr)
228 differences.append((idxvar, "sweep",
229 (min(valsi), max(valsi))))
230 labels.append(vardescr(idxvar) + " swept from"
231 + f" {min(valsi):.5g} to"
232 + f" {max(valsi):.5g}"
233 + idxvar.unitstr(into=' %s'))
234 it.iternext()
235 else:
236 differences.append((var, "sweep", (min(vals), max(vals))))
237 labels.append(vardescr(var) + " swept from"
238 + f" {min(vals):.5g} to"
239 + f" {max(vals):.5g}"
240 + var.unitstr(into=' %s'))
241 difference = tuple(differences)
242 label = ", ".join(labels)
243 if verbosity > 0:
244 if difference in self.solutions:
245 if not difference:
246 print("The baseline in this ensemble cannot be replaced.")
247 else:
248 print(label + " will be replaced in the ensemble.")
249 else:
250 print(label + " added to the ensemble.")
252 self.solutions[difference] = solution
253 self.labels[difference] = label