Coverage for gpkit/solution_array.py: 80%
640 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-07 22:13 -0500
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-07 22:13 -0500
1"""Defines SolutionArray class"""
2import sys
3import re
4import json
5import difflib
6from operator import sub
7import warnings as pywarnings
8import pickle
9import gzip
10import pickletools
11from collections import defaultdict
12import numpy as np
13from .nomials import NomialArray
14from .small_classes import DictOfLists, Strings, SolverLog
15from .small_scripts import mag, try_str_without
16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS
17from .breakdowns import Breakdowns
20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
22VALSTR_REPLACES = [
23 ("+nan", " nan"),
24 ("-nan", " nan"),
25 ("nan%", "nan "),
26 ("nan", " - "),
27]
28# pylint: disable=consider-using-f-string # some would be less readable
30class SolSavingEnvironment:
31 """Temporarily removes construction/solve attributes from constraints.
33 This approximately halves the size of the pickled solution.
34 """
36 def __init__(self, solarray, saveconstraints):
37 self.solarray = solarray
38 self.attrstore = {}
39 self.saveconstraints = saveconstraints
40 self.constraintstore = None
43 def __enter__(self):
44 if "sensitivities" not in self.solarray:
45 pass
46 elif self.saveconstraints:
47 for constraint_attr in ["bounded", "meq_bounded", "vks",
48 "v_ss", "unsubbed", "varkeys"]:
49 store = {}
50 for constraint in self.solarray["sensitivities"]["constraints"]:
51 if getattr(constraint, constraint_attr, None):
52 store[constraint] = getattr(constraint, constraint_attr)
53 delattr(constraint, constraint_attr)
54 self.attrstore[constraint_attr] = store
55 else:
56 self.constraintstore = \
57 self.solarray["sensitivities"].pop("constraints")
59 def __exit__(self, type_, val, traceback):
60 if self.saveconstraints:
61 for constraint_attr, store in self.attrstore.items():
62 for constraint, value in store.items():
63 setattr(constraint, constraint_attr, value)
64 elif self.constraintstore:
65 self.solarray["sensitivities"]["constraints"] = self.constraintstore
67def msenss_table(data, _, **kwargs):
68 "Returns model sensitivity table lines"
69 if "models" not in data.get("sensitivities", {}):
70 return ""
71 data = sorted(data["sensitivities"]["models"].items(),
72 key=lambda i: ((i[1] < 0.1).all(),
73 -np.max(i[1]) if (i[1] < 0.1).all()
74 else -round(np.mean(i[1]), 1), i[0]))
75 lines = ["Model Sensitivities", "-------------------"]
76 if kwargs["sortmodelsbysenss"]:
77 lines[0] += " (sorts models in sections below)"
78 previousmsenssstr = ""
79 for model, msenss in data:
80 if not model: # for now let's only do named models
81 continue
82 if (msenss < 0.1).all():
83 msenss = np.max(msenss)
84 if msenss:
85 #pylint: disable=consider-using-f-string
86 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss)))
87 else:
88 msenssstr = " =0 "
89 else:
90 meansenss = round(np.mean(msenss), 1)
91 msenssstr = f"{meansenss:+6.1f}"
92 deltas = msenss - meansenss
93 if np.max(np.abs(deltas)) > 0.1:
94 deltastrs = [f"{d:+4.1f}" if abs(d) >= 0.1 else " - "
95 for d in deltas]
96 msenssstr += f" + [ {' '.join(deltastrs)} ]"
97 if msenssstr == previousmsenssstr:
98 msenssstr = " "*len(msenssstr)
99 else:
100 previousmsenssstr = msenssstr
101 lines.append(f"{msenssstr} : {model}")
102 return lines + [""] if len(lines) > 3 else []
105def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
106 "Returns sensitivity table lines"
107 if "variables" in data.get("sensitivities", {}):
108 data = data["sensitivities"]["variables"]
109 if showvars:
110 data = {k: data[k] for k in showvars if k in data}
111 return var_table(data, title, sortbyvals=True, skipifempty=True,
112 valfmt="%+-.2g ", vecfmt="%+-8.2g",
113 printunits=False, minval=1e-3, **kwargs)
116def topsenss_table(data, showvars, nvars=5, **kwargs):
117 "Returns top sensitivity table lines"
118 data, filtered = topsenss_filter(data, showvars, nvars)
119 title = "Most Sensitive Variables"
120 if filtered:
121 title = "Next Most Sensitive Variables"
122 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
125def topsenss_filter(data, showvars, nvars=5):
126 "Filters sensitivities down to top N vars"
127 if "variables" in data.get("sensitivities", {}):
128 data = data["sensitivities"]["variables"]
129 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
130 if not np.isnan(s).any()}
131 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
132 filter_already_shown = showvars.intersection(topk)
133 for k in filter_already_shown:
134 topk.remove(k)
135 if nvars > 3: # always show at least 3
136 nvars -= 1
137 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
140def insenss_table(data, _, maxval=0.1, **kwargs):
141 "Returns insensitivity table lines"
142 if "constants" in data.get("sensitivities", {}):
143 data = data["sensitivities"]["variables"]
144 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
145 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
148def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
149 "Return constraint tightness lines"
150 title = "Most Sensitive Constraints"
151 if len(self) > 1:
152 title += " (in last sweep)"
153 data = sorted(((-float(f"{abs(s[-1]):+6.2g}"), str(c)),
154 f"{abs(s[-1]):+6.2g}", id(c), c)
155 for c, s in self["sensitivities"]["constraints"].items()
156 if s[-1] >= tight_senss)[:ntightconstrs]
157 else:
158 data = sorted(((-float(f"{abs(s):+6.2g}"), str(c)),
159 f"{abs(s):+6.2g}", id(c), c)
160 for c, s in self["sensitivities"]["constraints"].items()
161 if s >= tight_senss)[:ntightconstrs]
162 return constraint_table(data, title, **kwargs)
164def loose_table(self, _, min_senss=1e-5, **kwargs):
165 "Return constraint tightness lines"
166 title = f"Insensitive Constraints |below {min_senss:+g}|"
167 if len(self) > 1:
168 title += " (in last sweep)"
169 data = [(0, "", id(c), c)
170 for c, s in self["sensitivities"]["constraints"].items()
171 if s[-1] <= min_senss]
172 else:
173 data = [(0, "", id(c), c)
174 for c, s in self["sensitivities"]["constraints"].items()
175 if s <= min_senss]
176 return constraint_table(data, title, **kwargs)
179# pylint: disable=too-many-branches,too-many-locals,too-many-statements
180def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
181 "Creates lines for tables where the right side is a constraint."
182 # TODO: this should support 1D array inputs from sweeps
183 excluded = {"units"} if showmodels else {"units", "lineage"}
184 models, decorated = {}, []
185 for sortby, openingstr, _, constraint in sorted(data):
186 model = lineagestr(constraint) if sortbymodel else ""
187 if model not in models:
188 models[model] = len(models)
189 constrstr = try_str_without(
190 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded))
191 if " at 0x" in constrstr: # don't print memory addresses
192 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
193 decorated.append((models[model], model, sortby, constrstr, openingstr))
194 decorated.sort()
195 previous_model, lines = None, []
196 for varlist in decorated:
197 _, model, _, constrstr, openingstr = varlist
198 if model != previous_model:
199 if lines:
200 lines.append(["", ""])
201 if model or lines:
202 lines.append([("newmodelline",), model])
203 previous_model = model
204 minlen, maxlen = 25, 80
205 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
206 constraintlines = []
207 line = ""
208 next_idx = 0
209 while next_idx < len(segments):
210 segment = segments[next_idx]
211 next_idx += 1
212 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
213 segments[next_idx] = segment[1:] + segments[next_idx]
214 segment = segment[0]
215 elif len(line) + len(segment) > maxlen and len(line) > minlen:
216 constraintlines.append(line)
217 line = " " # start a new line
218 line += segment
219 while len(line) > maxlen:
220 constraintlines.append(line[:maxlen])
221 line = " " + line[maxlen:]
222 constraintlines.append(line)
223 lines += [(openingstr + " : ", constraintlines[0])]
224 lines += [("", l) for l in constraintlines[1:]]
225 if not lines:
226 lines = [("", "(none)")]
227 maxlens = np.max([list(map(len, line)) for line in lines
228 if line[0] != ("newmodelline",)], axis=0)
229 dirs = [">", "<"] # we'll check lengths before using zip
230 assert len(list(dirs)) == len(list(maxlens))
231 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
232 for i, line in enumerate(lines):
233 if line[0] == ("newmodelline",):
234 linelist = [fmts[0].format(" | "), line[1]]
235 else:
236 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
237 lines[i] = "".join(linelist).rstrip()
238 return [title] + ["-"*len(title)] + lines + [""]
241def warnings_table(self, _, **kwargs):
242 "Makes a table for all warnings in the solution."
243 title = "WARNINGS"
244 lines = ["~"*len(title), title, "~"*len(title)]
245 if "warnings" not in self or not self["warnings"]:
246 return []
247 for wtype in sorted(self["warnings"]):
248 data_vec = self["warnings"][wtype]
249 if len(data_vec) == 0:
250 continue
251 if not hasattr(data_vec, "shape"):
252 data_vec = [data_vec] # not a sweep
253 else:
254 all_equal = True
255 for data in data_vec[1:]:
256 eq_i = data == data_vec[0]
257 if hasattr(eq_i, "all"):
258 eq_i = eq_i.all()
259 if not eq_i:
260 all_equal = False
261 break
262 if all_equal:
263 data_vec = [data_vec[0]] # warnings identical across sweeps
264 for i, data in enumerate(data_vec):
265 if len(data) == 0:
266 continue
267 data = sorted(data, key=lambda l: l[0]) # sort by msg
268 title = wtype
269 if len(data_vec) > 1:
270 title += f" in sweep {i}"
271 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
272 data = [(-int(1e5*relax_sensitivity),
273 f"{relax_sensitivity:+6.2g}", id(c), c)
274 for _, (relax_sensitivity, c) in data]
275 lines += constraint_table(data, title, **kwargs)
276 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
277 data = [(-int(1e5*rel_diff),
278 "%.4g %s %.4g" % tightvalues, id(c), c)
279 for _, (rel_diff, tightvalues, c) in data]
280 lines += constraint_table(data, title, **kwargs)
281 else:
282 lines += [title] + ["-"*len(wtype)]
283 lines += [msg for msg, _ in data] + [""]
284 if len(lines) == 3: # just the header
285 return []
286 lines[-1] = "~~~~~~~~"
287 return lines + [""]
289def bdtable_gen(key):
290 "Generator for breakdown tablefns"
292 def bdtable(self, _showvars, **_):
293 "Cost breakdown plot"
294 bds = Breakdowns(self)
295 original_stdout = sys.stdout
296 try:
297 sys.stdout = SolverLog(original_stdout, verbosity=0)
298 bds.plot(key)
299 finally:
300 lines = sys.stdout.lines()
301 sys.stdout = original_stdout
302 return lines
304 return bdtable
307TABLEFNS = {"sensitivities": senss_table,
308 "top sensitivities": topsenss_table,
309 "insensitivities": insenss_table,
310 "model sensitivities": msenss_table,
311 "tightest constraints": tight_table,
312 "loose constraints": loose_table,
313 "warnings": warnings_table,
314 "model sensitivities breakdown": bdtable_gen("model sensitivities"),
315 "cost breakdown": bdtable_gen("cost")
316 }
318def unrolled_absmax(values):
319 "From an iterable of numbers and arrays, returns the largest magnitude"
320 finalval, absmaxest = None, 0
321 for val in values:
322 absmaxval = np.abs(val).max()
323 if absmaxval >= absmaxest:
324 absmaxest, finalval = absmaxval, val
325 if getattr(finalval, "shape", None):
326 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
327 finalval.shape)]
328 return finalval
331def cast(function, val1, val2):
332 "Relative difference between val1 and val2 (positive if val2 is larger)"
333 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
334 pywarnings.simplefilter("ignore")
335 if hasattr(val1, "shape") and hasattr(val2, "shape"):
336 if val1.ndim == val2.ndim:
337 return function(val1, val2)
338 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
339 dimdelta = dimmest.ndim - lessdim.ndim
340 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
341 if dimmest is val1:
342 return function(dimmest, lessdim[add_axes])
343 if dimmest is val2:
344 return function(lessdim[add_axes], dimmest)
345 return function(val1, val2)
348class SolutionArray(DictOfLists):
349 """A dictionary (of dictionaries) of lists, with convenience methods.
351 Items
352 -----
353 cost : array
354 variables: dict of arrays
355 sensitivities: dict containing:
356 monomials : array
357 posynomials : array
358 variables: dict of arrays
359 localmodels : NomialArray
360 Local power-law fits (small sensitivities are cut off)
362 Example
363 -------
364 >>> import gpkit
365 >>> import numpy as np
366 >>> x = gpkit.Variable("x")
367 >>> x_min = gpkit.Variable("x_{min}", 2)
368 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
369 >>>
370 >>> # VALUES
371 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
372 >>> assert all(np.array(values) == 2)
373 >>>
374 >>> # SENSITIVITIES
375 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
376 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
377 >>> assert all(np.array(senss) == 1)
378 """
379 modelstr = ""
380 _name_collision_varkeys = None
381 _lineageset = False
382 table_titles = {"choicevariables": "Choice Variables",
383 "sweepvariables": "Swept Variables",
384 "freevariables": "Free Variables",
385 "constants": "Fixed Variables", # TODO: change everywhere
386 "variables": "Variables"}
388 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches
389 "Returns the set of contained varkeys whose names are not unique"
390 if self._name_collision_varkeys is None:
391 self._name_collision_varkeys = {}
392 self["variables"].update_keymap()
393 keymap = self["variables"].keymap
394 name_collisions = defaultdict(set)
395 for key in keymap:
396 if hasattr(key, "key"):
397 if len(keymap[key.name]) == 1: # very unique
398 self._name_collision_varkeys[key] = 0
399 else:
400 shortname = key.str_without(["lineage", "vec"])
401 if len(keymap[shortname]) > 1:
402 name_collisions[shortname].add(key)
403 for varkeys in name_collisions.values():
404 min_namespaced = defaultdict(set)
405 for vk in varkeys:
406 *_, mineage = vk.lineagestr().split(".")
407 min_namespaced[(mineage, 1)].add(vk)
408 while any(len(vks) > 1 for vks in min_namespaced.values()):
409 for key, vks in list(min_namespaced.items()):
410 if len(vks) <= 1:
411 continue
412 del min_namespaced[key]
413 mineage, idx = key
414 idx += 1
415 for vk in vks:
416 lineages = vk.lineagestr().split(".")
417 submineage = lineages[-idx] + "." + mineage
418 min_namespaced[(submineage, idx)].add(vk)
419 for (_, idx), vks in min_namespaced.items():
420 vk, = vks
421 self._name_collision_varkeys[vk] = idx
422 if clear:
423 self._lineageset = False
424 for vk in self._name_collision_varkeys:
425 del vk.descr["necessarylineage"]
426 else:
427 self._lineageset = True
428 for vk, idx in self._name_collision_varkeys.items():
429 vk.descr["necessarylineage"] = idx
431 def __len__(self):
432 try:
433 return len(self["cost"])
434 except TypeError:
435 return 1
436 except KeyError:
437 return 0
439 def __call__(self, posy):
440 posy_subbed = self.subinto(posy)
441 return getattr(posy_subbed, "c", posy_subbed)
443 def almost_equal(self, other, reltol=1e-3):
444 "Checks for almost-equality between two solutions"
445 svars, ovars = self["variables"], other["variables"]
446 svks, ovks = set(svars), set(ovars)
447 if svks != ovks:
448 return False
449 for key in svks:
450 reldiff = np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1))
451 if reldiff >= reltol:
452 return False
453 return True
455 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
456 # pylint: disable=too-many-arguments
457 def diff(self, other, showvars=None, *,
458 constraintsdiff=True, senssdiff=False, sensstol=0.1,
459 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
460 sortmodelsbysenss=True, **tableargs):
461 """Outputs differences between this solution and another
463 Arguments
464 ---------
465 other : solution or string
466 strings will be treated as paths to pickled solutions
467 senssdiff : boolean
468 if True, show sensitivity differences
469 sensstol : float
470 the smallest sensitivity difference worth showing
471 absdiff : boolean
472 if True, show absolute differences
473 abstol : float
474 the smallest absolute difference worth showing
475 reldiff : boolean
476 if True, show relative differences
477 reltol : float
478 the smallest relative difference worth showing
480 Returns
481 -------
482 str
483 """
484 if sortmodelsbysenss:
485 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
486 else:
487 tableargs["sortmodelsbysenss"] = False
488 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
489 "skipifempty": False})
490 if isinstance(other, Strings):
491 if other[-4:] == ".pgz":
492 other = SolutionArray.decompress_file(other)
493 else:
494 with open(other, "rb") as f:
495 other = pickle.load(f)
496 svars, ovars = self["variables"], other["variables"]
497 lines = ["Solution Diff",
498 "=============",
499 "(argument is the baseline solution)", ""]
500 svks, ovks = set(svars), set(ovars)
501 if showvars:
502 lines[0] += " (for selected variables)"
503 lines[1] += "========================="
504 showvars = self._parse_showvars(showvars)
505 svks = {k for k in showvars if k in svars}
506 ovks = {k for k in showvars if k in ovars}
507 if constraintsdiff and other.modelstr and self.modelstr:
508 if self.modelstr == other.modelstr:
509 lines += ["** no constraint differences **", ""]
510 else:
511 cdiff = ["Constraint Differences",
512 "**********************"]
513 cdiff.extend(list(difflib.unified_diff(
514 other.modelstr.split("\n"), self.modelstr.split("\n"),
515 lineterm="", n=3))[2:])
516 cdiff += ["", "**********************", ""]
517 lines += cdiff
518 if svks - ovks:
519 lines.append("Variable(s) of this solution"
520 " which are not in the argument:")
521 lines.append("\n".join(f" {key}" for key in svks - ovks))
522 lines.append("")
523 if ovks - svks:
524 lines.append("Variable(s) of the argument"
525 " which are not in this solution:")
526 lines.append("\n".join(f" {key}" for key in ovks - svks))
527 lines.append("")
528 sharedvks = svks.intersection(ovks)
529 if reldiff:
530 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
531 for vk in sharedvks}
532 lines += var_table(rel_diff,
533 f"Relative Differences |above {reltol:g}%|",
534 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
535 minval=reltol, printunits=False, **tableargs)
536 if lines[-2][:10] == "-"*10: # nothing larger than reltol
537 lines.insert(-1, ("The largest is %+g%%."
538 % unrolled_absmax(rel_diff.values())))
539 if absdiff:
540 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
541 lines += var_table(abs_diff,
542 f"Absolute Differences |above {abstol:g}|",
543 valfmt="%+.2g", vecfmt="%+8.2g",
544 minval=abstol, **tableargs)
545 if lines[-2][:10] == "-"*10: # nothing larger than abstol
546 lines.insert(-1, ("The largest is %+g."
547 % unrolled_absmax(abs_diff.values())))
548 if senssdiff:
549 ssenss = self["sensitivities"]["variables"]
550 osenss = other["sensitivities"]["variables"]
551 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
552 for vk in svks.intersection(ovks)}
553 lines += var_table(senss_delta,
554 f"Sensitivity Differences |above {sensstol:g}|",
555 valfmt="%+-.2f ", vecfmt="%+-6.2f",
556 minval=sensstol, printunits=False, **tableargs)
557 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
558 lines.insert(-1, ("The largest is %+g."
559 % unrolled_absmax(senss_delta.values())))
560 return "\n".join(lines)
562 def save(self, filename="solution.pkl",
563 *, saveconstraints=True, **pickleargs):
564 """Pickles the solution and saves it to a file.
566 Solution can then be loaded with e.g.:
567 >>> import pickle
568 >>> pickle.load(open("solution.pkl"))
569 """
570 with open(filename, "wb") as f:
571 with SolSavingEnvironment(self, saveconstraints):
572 pickle.dump(self, f, **pickleargs)
574 def save_compressed(self, filename="solution.pgz",
575 *, saveconstraints=True, **cpickleargs):
576 "Pickle a file and then compress it into a file with extension."
577 with gzip.open(filename, "wb") as f:
578 with SolSavingEnvironment(self, saveconstraints):
579 pickled = pickle.dumps(self, **cpickleargs)
580 f.write(pickletools.optimize(pickled))
582 @staticmethod
583 def decompress_file(file):
584 "Load a gzip-compressed pickle file"
585 with gzip.open(file, "rb") as f:
586 return pickle.Unpickler(f).load()
588 def varnames(self, showvars, exclude):
589 "Returns list of variables, optionally with minimal unique names"
590 if showvars:
591 showvars = self._parse_showvars(showvars)
592 self.set_necessarylineage()
593 names = {}
594 for key in showvars or self["variables"]:
595 for k in self["variables"].keymap[key]:
596 names[k.str_without(exclude)] = k
597 self.set_necessarylineage(clear=True)
598 return names
600 def savemat(self, filename="solution.mat", *, showvars=None,
601 excluded="vec"):
602 "Saves primal solution as matlab file"
603 from scipy.io import savemat # pylint: disable=import-outside-toplevel
604 savemat(filename,
605 {name.replace(".", "_"): np.array(self["variables"][key], "f")
606 for name, key in self.varnames(showvars, excluded).items()})
608 def todataframe(self, showvars=None, excluded="vec"):
609 "Returns primal solution as pandas dataframe"
610 # pylint: disable=import-outside-toplevel
611 import pandas as pd # pylint:disable=import-error
612 rows = []
613 cols = ["Name", "Index", "Value", "Units", "Label",
614 "Lineage", "Other"]
615 for _, key in sorted(self.varnames(showvars, excluded).items(),
616 key=lambda k: k[0]):
617 value = self["variables"][key]
618 if key.shape:
619 idxs = []
620 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
621 while not it.finished:
622 idx = it.multi_index
623 idxs.append(idx[0] if len(idx) == 1 else idx)
624 it.iternext()
625 else:
626 idxs = [None]
627 for idx in idxs:
628 row = [
629 key.name,
630 "" if idx is None else idx,
631 value if idx is None else value[idx]]
632 rows.append(row)
633 row.extend([
634 key.unitstr(),
635 key.label or "",
636 key.lineage or "",
637 ", ".join(f"{k}={v}" for (k, v) in key.descr.items()
638 if k not in ["name", "units", "unitrepr",
639 "idx", "shape", "veckey",
640 "value", "vecfn",
641 "lineage", "label"])])
642 return pd.DataFrame(rows, columns=cols)
644 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
645 "Saves solution table as a text file"
646 with open(filename, "w", encoding="UTF-8") as f:
647 if printmodel:
648 f.write(self.modelstr + "\n")
649 f.write(self.table(**kwargs))
651 def savejson(self, filename="solution.json", showvars=None):
652 "Saves solution table as a json file"
653 sol_dict = {}
654 if self._lineageset:
655 self.set_necessarylineage(clear=True)
656 data = self["variables"]
657 if showvars:
658 showvars = self._parse_showvars(showvars)
659 data = {k: data[k] for k in showvars if k in data}
660 # add appropriate data for each variable to the dictionary
661 for k, v in data.items():
662 key = str(k)
663 if isinstance(v, np.ndarray):
664 val = {"v": v.tolist(), "u": k.unitstr()}
665 else:
666 val = {"v": v, "u": k.unitstr()}
667 sol_dict[key] = val
668 with open(filename, "w", encoding="UTF-8") as f:
669 json.dump(sol_dict, f)
671 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
672 "Saves primal solution as a CSV sorted by modelname, like the tables."
673 data = self["variables"]
674 if showvars:
675 showvars = self._parse_showvars(showvars)
676 data = {k: data[k] for k in showvars if k in data}
677 # if the columns don't capture any dimensions, skip them
678 minspan, maxspan = None, 1
679 for v in data.values():
680 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
681 minspan_ = min((di for di in v.shape if di != 1))
682 maxspan_ = max((di for di in v.shape if di != 1))
683 if minspan is None or minspan_ < minspan:
684 minspan = minspan_
685 if maxspan is None or maxspan_ > maxspan:
686 maxspan = maxspan_
687 if minspan is not None and minspan > valcols:
688 valcols = 1
689 if maxspan < valcols:
690 valcols = maxspan
691 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
692 tables=("cost", "sweepvariables", "freevariables",
693 "constants", "sensitivities"))
694 with open(filename, "w", encoding="UTF-8") as f:
695 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
696 + "Units,Description\n")
697 for line in lines:
698 if line[0] == ("newmodelline",):
699 f.write(line[1])
700 elif not line[1]: # spacer line
701 f.write("\n")
702 else:
703 f.write("," + line[0].replace(" : ", "") + ",")
704 vals = line[1].replace("[", "").replace("]", "").strip()
705 for el in vals.split():
706 f.write(el + ",")
707 f.write(","*(valcols - len(vals.split())))
708 f.write((line[2].replace("[", "").replace("]", "").strip()
709 + ","))
710 f.write(line[3].strip() + "\n")
712 def subinto(self, posy):
713 "Returns NomialArray of each solution substituted into posy."
714 if posy in self["variables"]:
715 return self["variables"](posy)
717 if not hasattr(posy, "sub"):
718 raise ValueError(f"no variable '{posy}' found in the solution")
720 if len(self) > 1:
721 return NomialArray([self.atindex(i).subinto(posy)
722 for i in range(len(self))])
724 return posy.sub(self["variables"], require_positive=False)
726 def _parse_showvars(self, showvars):
727 showvars_out = set()
728 for k in showvars:
729 k, _ = self["variables"].parse_and_index(k)
730 keys = self["variables"].keymap[k]
731 showvars_out.update(keys)
732 return showvars_out
734 def summary(self, showvars=(), **kwargs):
735 "Print summary table, showing no sensitivities or constants"
736 return self.table(showvars,
737 ["cost breakdown", "model sensitivities breakdown",
738 "warnings", "sweepvariables", "freevariables"],
739 **kwargs)
741 def table(self, showvars=(),
742 tables=("cost breakdown", "model sensitivities breakdown",
743 "warnings", "sweepvariables", "freevariables",
744 "constants", "sensitivities", "tightest constraints"),
745 sortmodelsbysenss=False, **kwargs):
746 """A table representation of this SolutionArray
748 Arguments
749 ---------
750 tables: Iterable
751 Which to print of ("cost", "sweepvariables", "freevariables",
752 "constants", "sensitivities")
753 fixedcols: If true, print vectors in fixed-width format
754 latex: int
755 If > 0, return latex format (options 1-3); otherwise plain text
756 included_models: Iterable of strings
757 If specified, the models (by name) to include
758 excluded_models: Iterable of strings
759 If specified, model names to exclude
761 Returns
762 -------
763 str
764 """
765 if sortmodelsbysenss and "sensitivities" in self:
766 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
767 else:
768 kwargs["sortmodelsbysenss"] = False
769 varlist = list(self["variables"])
770 has_only_one_model = True
771 for var in varlist[1:]:
772 if var.lineage != varlist[0].lineage:
773 has_only_one_model = False
774 break
775 if has_only_one_model:
776 kwargs["sortbymodel"] = False
777 self.set_necessarylineage()
778 showvars = self._parse_showvars(showvars)
779 strs = []
780 for table in tables:
781 if "breakdown" in table:
782 if (len(self) > 1 or not UNICODE_EXPONENTS
783 or "sensitivities" not in self):
784 # no breakdowns for sweeps or no-unicode environments
785 table = table.replace(" breakdown", "")
786 if "sensitivities" not in self and ("sensitivities" in table or
787 "constraints" in table):
788 continue
789 if table == "cost":
790 cost = self["cost"] # pylint: disable=unsubscriptable-object
791 if kwargs.get("latex", None): # cost is not printed for latex
792 continue
793 strs += ["\nOptimal Cost\n------------"]
794 if len(self) > 1:
795 costs = [f"{c:-8.3g}" for c in mag(cost[:4])]
796 strs += [" [ %s %s ]" % (" ".join(costs),
797 "..." if len(self) > 4 else "")]
798 else:
799 strs += [f" {mag(cost):-.4g}"]
800 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
801 strs += [""]
802 elif table in TABLEFNS:
803 strs += TABLEFNS[table](self, showvars, **kwargs)
804 elif table in self:
805 data = self[table]
806 if showvars:
807 showvars = self._parse_showvars(showvars)
808 data = {k: data[k] for k in showvars if k in data}
809 strs += var_table(data, self.table_titles[table], **kwargs)
810 if kwargs.get("latex", None):
811 preamble = "\n".join(("% \\documentclass[12pt]{article}",
812 "% \\usepackage{booktabs}",
813 "% \\usepackage{longtable}",
814 "% \\usepackage{amsmath}",
815 "% \\begin{document}\n"))
816 strs = [preamble] + strs + ["% \\end{document}"]
817 self.set_necessarylineage(clear=True)
818 return "\n".join(strs)
820 def plot(self, posys=None, axes=None):
821 "Plots a sweep for each posy"
822 if len(self["sweepvariables"]) != 1:
823 print("SolutionArray.plot only supports 1-dimensional sweeps")
824 if not hasattr(posys, "__len__"):
825 posys = [posys]
826 # pylint: disable=import-outside-toplevel
827 import matplotlib.pyplot as plt
828 from .interactive.plot_sweep import assign_axes
829 from . import GPBLU
830 (swept, x), = self["sweepvariables"].items()
831 posys, axes = assign_axes(swept, posys, axes)
832 for posy, ax in zip(posys, axes):
833 y = self(posy) if posy not in [None, "cost"] else self["cost"]
834 ax.plot(x, y, color=GPBLU)
835 if len(axes) == 1:
836 axes, = axes
837 return plt.gcf(), axes
840# pylint: disable=too-many-branches,too-many-locals,too-many-statements
841# pylint: disable=too-many-arguments
842def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
843 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
844 minval=0, sortbyvals=False, hidebelowminval=False,
845 included_models=None, excluded_models=None, sortbymodel=True,
846 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
847 """
848 Pretty string representation of a dict of VarKeys
849 Iterable values are handled specially (partial printing)
851 Arguments
852 ---------
853 data : dict whose keys are VarKey's
854 data to represent in table
855 title : string
856 printunits : bool
857 latex : int
858 If > 0, return latex format (options 1-3); otherwise plain text
859 varfmt : string
860 format for variable names
861 valfmt : string
862 format for scalar values
863 vecfmt : string
864 format for vector values
865 minval : float
866 skip values with all(abs(value)) < minval
867 sortbyvals : boolean
868 If true, rows are sorted by their average value instead of by name.
869 included_models : Iterable of strings
870 If specified, the models (by name) to include
871 excluded_models : Iterable of strings
872 If specified, model names to exclude
873 """
874 if not data:
875 return []
876 decorated, models = [], set()
877 for i, (k, v) in enumerate(data.items()):
878 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
879 continue # no values below minval
880 if minval and hidebelowminval and getattr(v, "shape", None):
881 v[np.abs(v) <= minval] = np.nan
882 model = lineagestr(k.lineage) if sortbymodel else ""
883 if not sortmodelsbysenss:
884 msenss = 0
885 else: # sort should match that in msenss_table above
886 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4)
887 models.add(model)
888 b = bool(getattr(v, "shape", None))
889 s = k.str_without(("lineage", "vec"))
890 if not sortbyvals:
891 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
892 else: # for consistent sorting, add small offset to negative vals
893 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
894 sort = (float(f"{-val:.4g}"), k.name)
895 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
896 if not decorated and skipifempty:
897 return []
898 if included_models:
899 included_models = set(included_models)
900 included_models.add("")
901 models = models.intersection(included_models)
902 if excluded_models:
903 models = models.difference(excluded_models)
904 decorated.sort()
905 previous_model, lines = None, []
906 for varlist in decorated:
907 if sortbyvals:
908 model, _, msenss, isvector, varstr, _, var, val = varlist
909 else:
910 msenss, model, isvector, varstr, _, var, val = varlist
911 if model not in models:
912 continue
913 if model != previous_model:
914 if lines:
915 lines.append(["", "", "", ""])
916 if model:
917 if not latex:
918 lines.append([("newmodelline",), model, "", ""])
919 else:
920 lines.append(
921 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
922 previous_model = model
923 label = var.descr.get("label", "")
924 units = var.unitstr(" [%s] ") if printunits else ""
925 if not isvector:
926 valstr = valfmt % val
927 else:
928 last_dim_index = len(val.shape)-1
929 horiz_dim, ncols = last_dim_index, 1 # starting values
930 for dim_idx, dim_size in enumerate(val.shape):
931 if ncols <= dim_size <= maxcolumns:
932 horiz_dim, ncols = dim_idx, dim_size
933 # align the array with horiz_dim by making it the last one
934 dim_order = list(range(last_dim_index))
935 dim_order.insert(horiz_dim, last_dim_index)
936 flatval = val.transpose(dim_order).flatten()
937 vals = [vecfmt % v for v in flatval[:ncols]]
938 bracket = " ] " if len(flatval) <= ncols else ""
939 valstr = "[ %s%s" % (" ".join(vals), bracket)
940 for before, after in VALSTR_REPLACES:
941 valstr = valstr.replace(before, after)
942 if not latex:
943 lines.append([varstr, valstr, units, label])
944 if isvector and len(flatval) > ncols:
945 values_remaining = len(flatval) - ncols
946 while values_remaining > 0:
947 idx = len(flatval)-values_remaining
948 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
949 values_remaining -= ncols
950 valstr = " " + " ".join(vals)
951 for before, after in VALSTR_REPLACES:
952 valstr = valstr.replace(before, after)
953 if values_remaining <= 0:
954 spaces = (-values_remaining
955 * len(valstr)//(values_remaining + ncols))
956 valstr = valstr + " ]" + " "*spaces
957 lines.append(["", valstr, "", ""])
958 else:
959 varstr = "$%s$" % varstr.replace(" : ", "")
960 if latex == 1: # normal results table
961 lines.append([varstr, valstr, f"${var.latex_unitstr()}$",
962 label])
963 coltitles = [title, "Value", "Units", "Description"]
964 elif latex == 2: # no values
965 lines.append([varstr, f"${var.latex_unitstr()}$", label])
966 coltitles = [title, "Units", "Description"]
967 elif latex == 3: # no description
968 lines.append([varstr, valstr, f"${var.latex_unitstr()}$"])
969 coltitles = [title, "Value", "Units"]
970 else:
971 raise ValueError(f"Unexpected latex option, {latex}.")
972 if rawlines:
973 return lines
974 if not latex:
975 if lines:
976 maxlens = np.max([list(map(len, line)) for line in lines
977 if line[0] != ("newmodelline",)], axis=0)
978 dirs = [">", "<", "<", "<"]
979 # check lengths before using zip
980 assert len(list(dirs)) == len(list(maxlens))
981 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
982 for i, line in enumerate(lines):
983 if line[0] == ("newmodelline",):
984 line = [fmts[0].format(" | "), line[1]]
985 else:
986 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
987 lines[i] = "".join(line).rstrip()
988 lines = [title] + ["-"*len(title)] + lines + [""]
989 else:
990 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
991 lines = (["\n".join(["{\\footnotesize",
992 "\\begin{longtable}{%s}" % colfmt[latex],
993 "\\toprule",
994 " & ".join(coltitles) + " \\\\ \\midrule"])] +
995 [" & ".join(l) + " \\\\" for l in lines] +
996 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
997 return lines