Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from .nomials import NomialArray
11from .small_classes import DictOfLists, Strings
12from .small_scripts import mag, try_str_without
13from .repr_conventions import unitstr, lineagestr
16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
18VALSTR_REPLACES = [
19 ("+nan", " nan"),
20 ("-nan", " nan"),
21 ("nan%", "nan "),
22 ("nan", " - "),
23]
26class SolSavingEnvironment:
27 """Temporarily removes construction/solve attributes from constraints.
29 This approximately halves the size of the pickled solution.
30 """
32 def __init__(self, solarray, saveconstraints):
33 self.solarray = solarray
34 self.attrstore = {}
35 self.saveconstraints = saveconstraints
36 self.constraintstore = None
39 def __enter__(self):
40 if self.saveconstraints:
41 for constraint_attr in ["bounded", "meq_bounded", "vks",
42 "v_ss", "unsubbed", "varkeys"]:
43 store = {}
44 for constraint in self.solarray["sensitivities"]["constraints"]:
45 if getattr(constraint, constraint_attr, None):
46 store[constraint] = getattr(constraint, constraint_attr)
47 delattr(constraint, constraint_attr)
48 self.attrstore[constraint_attr] = store
49 else:
50 self.constraintstore = \
51 self.solarray["sensitivities"].pop("constraints")
53 def __exit__(self, type_, val, traceback):
54 if self.saveconstraints:
55 for constraint_attr, store in self.attrstore.items():
56 for constraint, value in store.items():
57 setattr(constraint, constraint_attr, value)
58 else:
59 self.solarray["sensitivities"]["constraints"] = self.constraintstore
61def msenss_table(data, _, **kwargs):
62 "Returns model sensitivity table lines"
63 if "models" not in data.get("sensitivities", {}):
64 return ""
65 data = sorted(data["sensitivities"]["models"].items(),
66 key=lambda i: -np.mean(i[1]))
67 lines = ["Model Sensitivities", "-------------------"]
68 if kwargs["sortmodelsbysenss"]:
69 lines[0] += " (sorts models in sections below)"
70 for model, msenss in data:
71 if not model: # for now let's only do named models
72 continue
73 if not msenss.shape:
74 msenssstr = "%+5.2f" % msenss
75 else:
76 meansenss = np.mean(msenss)
77 deltas = msenss - meansenss
78 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
79 for d in deltas]
80 msenssstr = "%+5.2f + [ %s ]" % (meansenss, " ".join(deltastrs))
82 lines.append(" %s : %s" % (msenssstr, model))
83 return lines + [""] if len(lines) > 3 else []
86def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
87 "Returns sensitivity table lines"
88 if "variables" in data.get("sensitivities", {}):
89 data = data["sensitivities"]["variables"]
90 if showvars:
91 data = {k: data[k] for k in showvars if k in data}
92 return var_table(data, title, sortbyvals=True, skipifempty=True,
93 valfmt="%+-.2g ", vecfmt="%+-8.2g",
94 printunits=False, minval=1e-3, **kwargs)
97def topsenss_table(data, showvars, nvars=5, **kwargs):
98 "Returns top sensitivity table lines"
99 data, filtered = topsenss_filter(data, showvars, nvars)
100 title = "Most Sensitive Variables"
101 if filtered:
102 title = "Next Most Sensitive Variables"
103 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
106def topsenss_filter(data, showvars, nvars=5):
107 "Filters sensitivities down to top N vars"
108 if "variables" in data.get("sensitivities", {}):
109 data = data["sensitivities"]["variables"]
110 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
111 if not np.isnan(s).any()}
112 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
113 filter_already_shown = showvars.intersection(topk)
114 for k in filter_already_shown:
115 topk.remove(k)
116 if nvars > 3: # always show at least 3
117 nvars -= 1
118 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
121def insenss_table(data, _, maxval=0.1, **kwargs):
122 "Returns insensitivity table lines"
123 if "constants" in data.get("sensitivities", {}):
124 data = data["sensitivities"]["variables"]
125 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
126 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
129def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
130 "Return constraint tightness lines"
131 title = "Most Sensitive Constraints"
132 if len(self) > 1:
133 title += " (in last sweep)"
134 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
135 "%+6.2g" % s[-1], id(c), c)
136 for c, s in self["sensitivities"]["constraints"].items()
137 if s[-1] >= tight_senss)[:ntightconstrs]
138 else:
139 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
140 for c, s in self["sensitivities"]["constraints"].items()
141 if s >= tight_senss)[:ntightconstrs]
142 return constraint_table(data, title, **kwargs)
144def loose_table(self, _, min_senss=1e-5, **kwargs):
145 "Return constraint tightness lines"
146 title = "Insensitive Constraints |below %+g|" % min_senss
147 if len(self) > 1:
148 title += " (in last sweep)"
149 data = [(0, "", id(c), c)
150 for c, s in self["sensitivities"]["constraints"].items()
151 if s[-1] <= min_senss]
152 else:
153 data = [(0, "", id(c), c)
154 for c, s in self["sensitivities"]["constraints"].items()
155 if s <= min_senss]
156 return constraint_table(data, title, **kwargs)
159# pylint: disable=too-many-branches,too-many-locals,too-many-statements
160def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
161 "Creates lines for tables where the right side is a constraint."
162 # TODO: this should support 1D array inputs from sweeps
163 excluded = ("units", "unnecessary lineage")
164 if not showmodels:
165 excluded = ("units", "lineage") # hide all of it
166 models, decorated = {}, []
167 for sortby, openingstr, _, constraint in sorted(data):
168 model = lineagestr(constraint) if sortbymodel else ""
169 if model not in models:
170 models[model] = len(models)
171 constrstr = try_str_without(constraint, excluded)
172 if " at 0x" in constrstr: # don't print memory addresses
173 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
174 decorated.append((models[model], model, sortby, constrstr, openingstr))
175 decorated.sort()
176 previous_model, lines = None, []
177 for varlist in decorated:
178 _, model, _, constrstr, openingstr = varlist
179 if model != previous_model:
180 if lines:
181 lines.append(["", ""])
182 if model or lines:
183 lines.append([("newmodelline",), model])
184 previous_model = model
185 constrstr = constrstr.replace(model, "")
186 minlen, maxlen = 25, 80
187 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
188 constraintlines = []
189 line = ""
190 next_idx = 0
191 while next_idx < len(segments):
192 segment = segments[next_idx]
193 next_idx += 1
194 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
195 segments[next_idx] = segment[1:] + segments[next_idx]
196 segment = segment[0]
197 elif len(line) + len(segment) > maxlen and len(line) > minlen:
198 constraintlines.append(line)
199 line = " " # start a new line
200 line += segment
201 while len(line) > maxlen:
202 constraintlines.append(line[:maxlen])
203 line = " " + line[maxlen:]
204 constraintlines.append(line)
205 lines += [(openingstr + " : ", constraintlines[0])]
206 lines += [("", l) for l in constraintlines[1:]]
207 if not lines:
208 lines = [("", "(none)")]
209 maxlens = np.max([list(map(len, line)) for line in lines
210 if line[0] != ("newmodelline",)], axis=0)
211 dirs = [">", "<"] # we'll check lengths before using zip
212 assert len(list(dirs)) == len(list(maxlens))
213 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
214 for i, line in enumerate(lines):
215 if line[0] == ("newmodelline",):
216 linelist = [fmts[0].format(" | "), line[1]]
217 else:
218 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
219 lines[i] = "".join(linelist).rstrip()
220 return [title] + ["-"*len(title)] + lines + [""]
223def warnings_table(self, _, **kwargs):
224 "Makes a table for all warnings in the solution."
225 title = "WARNINGS"
226 lines = ["~"*len(title), title, "~"*len(title)]
227 if "warnings" not in self or not self["warnings"]:
228 return []
229 for wtype in sorted(self["warnings"]):
230 data_vec = self["warnings"][wtype]
231 if not hasattr(data_vec, "shape"):
232 data_vec = [data_vec]
233 for i, data in enumerate(data_vec):
234 data = sorted(data, key=lambda l: l[0]) # sort by msg
235 title = wtype
236 if len(data_vec) > 1:
237 title += " in sweep %i" % i
238 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
239 data = [(-int(1e5*c.relax_sensitivity),
240 "%+6.2g" % c.relax_sensitivity, id(c), c)
241 for _, c in data]
242 lines += constraint_table(data, title, **kwargs)
243 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
244 data = [(-int(1e5*c.rel_diff),
245 "%.4g %s %.4g" % c.tightvalues, id(c), c)
246 for _, c in data]
247 lines += constraint_table(data, title, **kwargs)
248 else:
249 lines += [title] + ["-"*len(wtype)]
250 lines += [msg for msg, _ in data] + [""]
251 lines[-1] = "~~~~~~~~"
252 return lines + [""]
255TABLEFNS = {"sensitivities": senss_table,
256 "top sensitivities": topsenss_table,
257 "insensitivities": insenss_table,
258 "model sensitivities": msenss_table,
259 "tightest constraints": tight_table,
260 "loose constraints": loose_table,
261 "warnings": warnings_table,
262 }
264def unrolled_absmax(values):
265 "From an iterable of numbers and arrays, returns the largest magnitude"
266 finalval, absmaxest = None, 0
267 for val in values:
268 absmaxval = np.abs(val).max()
269 if absmaxval >= absmaxest:
270 absmaxest, finalval = absmaxval, val
271 if getattr(finalval, "shape", None):
272 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
273 finalval.shape)]
274 return finalval
277def cast(function, val1, val2):
278 "Relative difference between val1 and val2 (positive if val2 is larger)"
279 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
280 pywarnings.simplefilter("ignore")
281 if hasattr(val1, "shape") and hasattr(val2, "shape"):
282 if val1.ndim == val2.ndim:
283 return function(val1, val2)
284 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
285 dimdelta = dimmest.ndim - lessdim.ndim
286 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
287 if dimmest is val1:
288 return function(dimmest, lessdim[add_axes])
289 if dimmest is val2:
290 return function(lessdim[add_axes], dimmest)
291 return function(val1, val2)
294class SolutionArray(DictOfLists):
295 """A dictionary (of dictionaries) of lists, with convenience methods.
297 Items
298 -----
299 cost : array
300 variables: dict of arrays
301 sensitivities: dict containing:
302 monomials : array
303 posynomials : array
304 variables: dict of arrays
305 localmodels : NomialArray
306 Local power-law fits (small sensitivities are cut off)
308 Example
309 -------
310 >>> import gpkit
311 >>> import numpy as np
312 >>> x = gpkit.Variable("x")
313 >>> x_min = gpkit.Variable("x_{min}", 2)
314 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
315 >>>
316 >>> # VALUES
317 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
318 >>> assert all(np.array(values) == 2)
319 >>>
320 >>> # SENSITIVITIES
321 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
322 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
323 >>> assert all(np.array(senss) == 1)
324 """
325 modelstr = ""
326 _name_collision_varkeys = None
327 table_titles = {"sweepvariables": "Swept Variables",
328 "freevariables": "Free Variables",
329 "constants": "Fixed Variables", # TODO: change everywhere
330 "variables": "Variables"}
332 def name_collision_varkeys(self):
333 "Returns the set of contained varkeys whose names are not unique"
334 if self._name_collision_varkeys is None:
335 self["variables"].update_keymap()
336 keymap = self["variables"].keymap
337 self._name_collision_varkeys = set()
338 for key in list(keymap):
339 if hasattr(key, "key"):
340 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
341 self._name_collision_varkeys.add(key)
342 return self._name_collision_varkeys
344 def __len__(self):
345 try:
346 return len(self["cost"])
347 except TypeError:
348 return 1
349 except KeyError:
350 return 0
352 def __call__(self, posy):
353 posy_subbed = self.subinto(posy)
354 return getattr(posy_subbed, "c", posy_subbed)
356 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
357 "Checks for almost-equality between two solutions"
358 svars, ovars = self["variables"], other["variables"]
359 svks, ovks = set(svars), set(ovars)
360 if svks != ovks:
361 return False
362 for key in svks:
363 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
364 return False
365 if abs(self["sensitivities"]["variables"][key]
366 - other["sensitivities"]["variables"][key]) >= sens_abstol:
367 return False
368 return True
370 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
371 def diff(self, other, showvars=None, *,
372 constraintsdiff=True, senssdiff=False, sensstol=0.1,
373 absdiff=False, abstol=0, reldiff=True, reltol=1.0,
374 sortmodelsbysenss= True, **tableargs):
375 """Outputs differences between this solution and another
377 Arguments
378 ---------
379 other : solution or string
380 strings will be treated as paths to pickled solutions
381 senssdiff : boolean
382 if True, show sensitivity differences
383 sensstol : float
384 the smallest sensitivity difference worth showing
385 abssdiff : boolean
386 if True, show absolute differences
387 absstol : float
388 the smallest absolute difference worth showing
389 reldiff : boolean
390 if True, show relative differences
391 reltol : float
392 the smallest relative difference worth showing
394 Returns
395 -------
396 str
397 """
398 if sortmodelsbysenss:
399 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
400 else:
401 tableargs["sortmodelsbysenss"] = False
402 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
403 "skipifempty": False})
404 if isinstance(other, Strings):
405 if other[-4:] == ".pgz":
406 other = SolutionArray.decompress_file(other)
407 else:
408 other = pickle.load(open(other, "rb"))
409 svars, ovars = self["variables"], other["variables"]
410 lines = ["Solution Diff",
411 "=============",
412 "(argument is the baseline solution)", ""]
413 svks, ovks = set(svars), set(ovars)
414 if showvars:
415 lines[0] += " (for selected variables)"
416 lines[1] += "========================="
417 showvars = self._parse_showvars(showvars)
418 svks = {k for k in showvars if k in svars}
419 ovks = {k for k in showvars if k in ovars}
420 if constraintsdiff and other.modelstr and self.modelstr:
421 if self.modelstr == other.modelstr:
422 lines += ["** no constraint differences **", ""]
423 else:
424 cdiff = ["Constraint Differences",
425 "**********************"]
426 cdiff.extend(list(difflib.unified_diff(
427 other.modelstr.split("\n"), self.modelstr.split("\n"),
428 lineterm="", n=3))[2:])
429 cdiff += ["", "**********************", ""]
430 lines += cdiff
431 if svks - ovks:
432 lines.append("Variable(s) of this solution"
433 " which are not in the argument:")
434 lines.append("\n".join(" %s" % key for key in svks - ovks))
435 lines.append("")
436 if ovks - svks:
437 lines.append("Variable(s) of the argument"
438 " which are not in this solution:")
439 lines.append("\n".join(" %s" % key for key in ovks - svks))
440 lines.append("")
441 sharedvks = svks.intersection(ovks)
442 if reldiff:
443 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
444 for vk in sharedvks}
445 lines += var_table(rel_diff,
446 "Relative Differences |above %g%%|" % reltol,
447 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
448 minval=reltol, printunits=False, **tableargs)
449 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
450 lines.insert(-1, ("The largest is %+g%%."
451 % unrolled_absmax(rel_diff.values())))
452 if absdiff:
453 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
454 lines += var_table(abs_diff,
455 "Absolute Differences |above %g|" % abstol,
456 valfmt="%+.2g", vecfmt="%+8.2g",
457 minval=abstol, **tableargs)
458 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
459 lines.insert(-1, ("The largest is %+g."
460 % unrolled_absmax(abs_diff.values())))
461 if senssdiff:
462 ssenss = self["sensitivities"]["variables"]
463 osenss = other["sensitivities"]["variables"]
464 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
465 for vk in svks.intersection(ovks)}
466 lines += var_table(senss_delta,
467 "Sensitivity Differences |above %g|" % sensstol,
468 valfmt="%+-.2f ", vecfmt="%+-6.2f",
469 minval=sensstol, printunits=False, **tableargs)
470 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
471 lines.insert(-1, ("The largest is %+g."
472 % unrolled_absmax(senss_delta.values())))
473 return "\n".join(lines)
475 def save(self, filename="solution.pkl",
476 *, saveconstraints=True, **pickleargs):
477 """Pickles the solution and saves it to a file.
479 Solution can then be loaded with e.g.:
480 >>> import pickle
481 >>> pickle.load(open("solution.pkl"))
482 """
483 with SolSavingEnvironment(self, saveconstraints):
484 pickle.dump(self, open(filename, "wb"), **pickleargs)
486 def save_compressed(self, filename="solution.pgz",
487 *, saveconstraints=True, **cpickleargs):
488 "Pickle a file and then compress it into a file with extension."
489 with gzip.open(filename, "wb") as f:
490 with SolSavingEnvironment(self, saveconstraints):
491 pickled = pickle.dumps(self, **cpickleargs)
492 f.write(pickletools.optimize(pickled))
494 @staticmethod
495 def decompress_file(file):
496 "Load a gzip-compressed pickle file"
497 with gzip.open(file, "rb") as f:
498 return pickle.Unpickler(f).load()
500 def varnames(self, showvars, exclude):
501 "Returns list of variables, optionally with minimal unique names"
502 if showvars:
503 showvars = self._parse_showvars(showvars)
504 for key in self.name_collision_varkeys():
505 key.descr["necessarylineage"] = True
506 names = {}
507 for key in showvars or self["variables"]:
508 for k in self["variables"].keymap[key]:
509 names[k.str_without(exclude)] = k
510 for key in self.name_collision_varkeys():
511 del key.descr["necessarylineage"]
512 return names
514 def savemat(self, filename="solution.mat", showvars=None,
515 excluded=("unnecessary lineage", "vec")):
516 "Saves primal solution as matlab file"
517 from scipy.io import savemat
518 savemat(filename,
519 {name.replace(".", "_"): np.array(self["variables"][key], "f")
520 for name, key in self.varnames(showvars, excluded).items()})
522 def todataframe(self, showvars=None,
523 excluded=("unnecessary lineage", "vec")):
524 "Returns primal solution as pandas dataframe"
525 import pandas as pd # pylint:disable=import-error
526 rows = []
527 cols = ["Name", "Index", "Value", "Units", "Label",
528 "Lineage", "Other"]
529 for _, key in sorted(self.varnames(showvars, excluded).items(),
530 key=lambda k: k[0]):
531 value = self["variables"][key]
532 if key.shape:
533 idxs = []
534 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
535 while not it.finished:
536 idx = it.multi_index
537 idxs.append(idx[0] if len(idx) == 1 else idx)
538 it.iternext()
539 else:
540 idxs = [None]
541 for idx in idxs:
542 row = [
543 key.name,
544 "" if idx is None else idx,
545 value if idx is None else value[idx]]
546 rows.append(row)
547 row.extend([
548 key.unitstr(),
549 key.label or "",
550 key.lineage or "",
551 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
552 if k not in ["name", "units", "unitrepr",
553 "idx", "shape", "veckey",
554 "value", "original_fn",
555 "lineage", "label"])])
556 return pd.DataFrame(rows, columns=cols)
558 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs):
559 "Saves solution table as a text file"
560 with open(filename, "w") as f:
561 if printmodel:
562 f.write(self.modelstr + "\n")
563 f.write(self.table(**kwargs))
565 def savecsv(self, showvars=None, filename="solution.csv", valcols=5):
566 "Saves primal solution as a CSV sorted by modelname, like the tables."
567 data = self["variables"]
568 if showvars:
569 showvars = self._parse_showvars(showvars)
570 data = {k: data[k] for k in showvars if k in data}
571 # if the columns don't capture any dimensions, skip them
572 minspan, maxspan = None, 1
573 for v in data.values():
574 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
575 minspan_ = min((di for di in v.shape if di != 1))
576 maxspan_ = max((di for di in v.shape if di != 1))
577 if minspan is None or minspan_ < minspan:
578 minspan = minspan_
579 if maxspan is None or maxspan_ > maxspan:
580 maxspan = maxspan_
581 if minspan is not None and minspan > valcols:
582 valcols = 1
583 if maxspan < valcols:
584 valcols = maxspan
585 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
586 tables=("cost", "sweepvariables", "freevariables",
587 "constants", "sensitivities"))
588 with open(filename, "w") as f:
589 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
590 + "Units,Description\n")
591 for line in lines:
592 if line[0] == ("newmodelline",):
593 f.write(line[1])
594 elif not line[1]: # spacer line
595 f.write("\n")
596 else:
597 f.write("," + line[0].replace(" : ", "") + ",")
598 vals = line[1].replace("[", "").replace("]", "").strip()
599 for el in vals.split():
600 f.write(el + ",")
601 f.write(","*(valcols - len(vals.split())))
602 f.write((line[2].replace("[", "").replace("]", "").strip()
603 + ","))
604 f.write(line[3].strip() + "\n")
606 def subinto(self, posy):
607 "Returns NomialArray of each solution substituted into posy."
608 if posy in self["variables"]:
609 return self["variables"](posy)
611 if not hasattr(posy, "sub"):
612 raise ValueError("no variable '%s' found in the solution" % posy)
614 if len(self) > 1:
615 return NomialArray([self.atindex(i).subinto(posy)
616 for i in range(len(self))])
618 return posy.sub(self["variables"])
620 def _parse_showvars(self, showvars):
621 showvars_out = set()
622 for k in showvars:
623 k, _ = self["variables"].parse_and_index(k)
624 keys = self["variables"].keymap[k]
625 showvars_out.update(keys)
626 return showvars_out
628 def summary(self, showvars=(), ntopsenss=5, **kwargs):
629 "Print summary table, showing top sensitivities and no constants"
630 showvars = self._parse_showvars(showvars)
631 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
632 "freevariables"], **kwargs)
633 constants_in_showvars = showvars.intersection(self["constants"])
634 senss_tables = []
635 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
636 senss_tables.append("sensitivities")
637 if len(self["constants"]) >= ntopsenss+2:
638 senss_tables.append("top sensitivities")
639 senss_tables.append("tightest constraints")
640 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
641 **kwargs)
642 if senss_str:
643 out += "\n" + senss_str
644 return out
646 def table(self, showvars=(),
647 tables=("cost", "warnings", "sweepvariables",
648 "model sensitivities", "freevariables",
649 "constants", "sensitivities", "tightest constraints"),
650 sortmodelsbysenss=True, **kwargs):
651 """A table representation of this SolutionArray
653 Arguments
654 ---------
655 tables: Iterable
656 Which to print of ("cost", "sweepvariables", "freevariables",
657 "constants", "sensitivities")
658 fixedcols: If true, print vectors in fixed-width format
659 latex: int
660 If > 0, return latex format (options 1-3); otherwise plain text
661 included_models: Iterable of strings
662 If specified, the models (by name) to include
663 excluded_models: Iterable of strings
664 If specified, model names to exclude
666 Returns
667 -------
668 str
669 """
670 if sortmodelsbysenss:
671 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
672 else:
673 kwargs["sortmodelsbysenss"] = False
674 varlist = list(self["variables"])
675 has_only_one_model = True
676 for var in varlist[1:]:
677 if var.lineage != varlist[0].lineage:
678 has_only_one_model = False
679 break
680 if has_only_one_model:
681 kwargs["sortbymodel"] = False
682 for key in self.name_collision_varkeys():
683 key.descr["necessarylineage"] = True
684 showvars = self._parse_showvars(showvars)
685 strs = []
686 for table in tables:
687 if table == "cost":
688 cost = self["cost"] # pylint: disable=unsubscriptable-object
689 if kwargs.get("latex", None): # cost is not printed for latex
690 continue
691 strs += ["\n%s\n------------" % "Optimal Cost"]
692 if len(self) > 1:
693 costs = ["%-8.3g" % c for c in mag(cost[:4])]
694 strs += [" [ %s %s ]" % (" ".join(costs),
695 "..." if len(self) > 4 else "")]
696 else:
697 strs += [" %-.4g" % mag(cost)]
698 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
699 strs += [""]
700 elif table in TABLEFNS:
701 strs += TABLEFNS[table](self, showvars, **kwargs)
702 elif table in self:
703 data = self[table]
704 if showvars:
705 showvars = self._parse_showvars(showvars)
706 data = {k: data[k] for k in showvars if k in data}
707 strs += var_table(data, self.table_titles[table], **kwargs)
708 if kwargs.get("latex", None):
709 preamble = "\n".join(("% \\documentclass[12pt]{article}",
710 "% \\usepackage{booktabs}",
711 "% \\usepackage{longtable}",
712 "% \\usepackage{amsmath}",
713 "% \\begin{document}\n"))
714 strs = [preamble] + strs + ["% \\end{document}"]
715 for key in self.name_collision_varkeys():
716 del key.descr["necessarylineage"]
717 return "\n".join(strs)
719 def plot(self, posys=None, axes=None):
720 "Plots a sweep for each posy"
721 if len(self["sweepvariables"]) != 1:
722 print("SolutionArray.plot only supports 1-dimensional sweeps")
723 if not hasattr(posys, "__len__"):
724 posys = [posys]
725 import matplotlib.pyplot as plt
726 from .interactive.plot_sweep import assign_axes
727 from . import GPBLU
728 (swept, x), = self["sweepvariables"].items()
729 posys, axes = assign_axes(swept, posys, axes)
730 for posy, ax in zip(posys, axes):
731 y = self(posy) if posy not in [None, "cost"] else self["cost"]
732 ax.plot(x, y, color=GPBLU)
733 if len(axes) == 1:
734 axes, = axes
735 return plt.gcf(), axes
738# pylint: disable=too-many-branches,too-many-locals,too-many-statements
739def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
740 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
741 minval=0, sortbyvals=False, hidebelowminval=False,
742 included_models=None, excluded_models=None, sortbymodel=True,
743 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
744 """
745 Pretty string representation of a dict of VarKeys
746 Iterable values are handled specially (partial printing)
748 Arguments
749 ---------
750 data : dict whose keys are VarKey's
751 data to represent in table
752 title : string
753 printunits : bool
754 latex : int
755 If > 0, return latex format (options 1-3); otherwise plain text
756 varfmt : string
757 format for variable names
758 valfmt : string
759 format for scalar values
760 vecfmt : string
761 format for vector values
762 minval : float
763 skip values with all(abs(value)) < minval
764 sortbyvals : boolean
765 If true, rows are sorted by their average value instead of by name.
766 included_models : Iterable of strings
767 If specified, the models (by name) to include
768 excluded_models : Iterable of strings
769 If specified, model names to exclude
770 """
771 if not data:
772 return []
773 decorated, models = [], set()
774 for i, (k, v) in enumerate(data.items()):
775 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
776 continue # no values below minval
777 if minval and hidebelowminval and getattr(v, "shape", None):
778 v[np.abs(v) <= minval] = np.nan
779 model = lineagestr(k.lineage) if sortbymodel else ""
780 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0
781 if hasattr(msenss, "shape"):
782 msenss = np.mean(msenss)
783 models.add(model)
784 b = bool(getattr(v, "shape", None))
785 s = k.str_without(("lineage", "vec"))
786 if not sortbyvals:
787 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
788 else: # for consistent sorting, add small offset to negative vals
789 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
790 sort = (float("%.4g" % -val), k.name)
791 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
792 if not decorated and skipifempty:
793 return []
794 if included_models:
795 included_models = set(included_models)
796 included_models.add("")
797 models = models.intersection(included_models)
798 if excluded_models:
799 models = models.difference(excluded_models)
800 decorated.sort()
801 previous_model, lines = None, []
802 for varlist in decorated:
803 if sortbyvals:
804 model, _, msenss, isvector, varstr, _, var, val = varlist
805 else:
806 msenss, model, isvector, varstr, _, var, val = varlist
807 if model not in models:
808 continue
809 if model != previous_model:
810 if lines:
811 lines.append(["", "", "", ""])
812 if model:
813 if not latex:
814 lines.append([("newmodelline",), model, "", ""])
815 else:
816 lines.append(
817 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
818 previous_model = model
819 label = var.descr.get("label", "")
820 units = var.unitstr(" [%s] ") if printunits else ""
821 if not isvector:
822 valstr = valfmt % val
823 else:
824 last_dim_index = len(val.shape)-1
825 horiz_dim, ncols = last_dim_index, 1 # starting values
826 for dim_idx, dim_size in enumerate(val.shape):
827 if ncols <= dim_size <= maxcolumns:
828 horiz_dim, ncols = dim_idx, dim_size
829 # align the array with horiz_dim by making it the last one
830 dim_order = list(range(last_dim_index))
831 dim_order.insert(horiz_dim, last_dim_index)
832 flatval = val.transpose(dim_order).flatten()
833 vals = [vecfmt % v for v in flatval[:ncols]]
834 bracket = " ] " if len(flatval) <= ncols else ""
835 valstr = "[ %s%s" % (" ".join(vals), bracket)
836 for before, after in VALSTR_REPLACES:
837 valstr = valstr.replace(before, after)
838 if not latex:
839 lines.append([varstr, valstr, units, label])
840 if isvector and len(flatval) > ncols:
841 values_remaining = len(flatval) - ncols
842 while values_remaining > 0:
843 idx = len(flatval)-values_remaining
844 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
845 values_remaining -= ncols
846 valstr = " " + " ".join(vals)
847 for before, after in VALSTR_REPLACES:
848 valstr = valstr.replace(before, after)
849 if values_remaining <= 0:
850 spaces = (-values_remaining
851 * len(valstr)//(values_remaining + ncols))
852 valstr = valstr + " ]" + " "*spaces
853 lines.append(["", valstr, "", ""])
854 else:
855 varstr = "$%s$" % varstr.replace(" : ", "")
856 if latex == 1: # normal results table
857 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
858 label])
859 coltitles = [title, "Value", "Units", "Description"]
860 elif latex == 2: # no values
861 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
862 coltitles = [title, "Units", "Description"]
863 elif latex == 3: # no description
864 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
865 coltitles = [title, "Value", "Units"]
866 else:
867 raise ValueError("Unexpected latex option, %s." % latex)
868 if rawlines:
869 return lines
870 if not latex:
871 if lines:
872 maxlens = np.max([list(map(len, line)) for line in lines
873 if line[0] != ("newmodelline",)], axis=0)
874 dirs = [">", "<", "<", "<"]
875 # check lengths before using zip
876 assert len(list(dirs)) == len(list(maxlens))
877 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
878 for i, line in enumerate(lines):
879 if line[0] == ("newmodelline",):
880 line = [fmts[0].format(" | "), line[1]]
881 else:
882 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
883 lines[i] = "".join(line).rstrip()
884 lines = [title] + ["-"*len(title)] + lines + [""]
885 else:
886 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
887 lines = (["\n".join(["{\\footnotesize",
888 "\\begin{longtable}{%s}" % colfmt[latex],
889 "\\toprule",
890 " & ".join(coltitles) + " \\\\ \\midrule"])] +
891 [" & ".join(l) + " \\\\" for l in lines] +
892 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
893 return lines