Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from .nomials import NomialArray
11from .small_classes import DictOfLists, Strings
12from .small_scripts import mag, try_str_without
13from .repr_conventions import unitstr, lineagestr
16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
18VALSTR_REPLACES = [
19 ("+nan", " nan"),
20 ("-nan", " nan"),
21 ("nan%", "nan "),
22 ("nan", " - "),
23]
26class SolSavingEnvironment:
27 """Temporarily removes construction/solve attributes from constraints.
29 This approximately halves the size of the pickled solution.
30 """
32 def __init__(self, solarray, saveconstraints):
33 self.solarray = solarray
34 self.attrstore = {}
35 self.saveconstraints = saveconstraints
36 self.constraintstore = None
39 def __enter__(self):
40 if self.saveconstraints:
41 for constraint_attr in ["bounded", "meq_bounded", "vks",
42 "v_ss", "unsubbed", "varkeys"]:
43 store = {}
44 for constraint in self.solarray["sensitivities"]["constraints"]:
45 if getattr(constraint, constraint_attr, None):
46 store[constraint] = getattr(constraint, constraint_attr)
47 delattr(constraint, constraint_attr)
48 self.attrstore[constraint_attr] = store
49 else:
50 self.constraintstore = \
51 self.solarray["sensitivities"].pop("constraints")
53 def __exit__(self, type_, val, traceback):
54 if self.saveconstraints:
55 for constraint_attr, store in self.attrstore.items():
56 for constraint, value in store.items():
57 setattr(constraint, constraint_attr, value)
58 else:
59 self.solarray["sensitivities"]["constraints"] = self.constraintstore
61def msenss_table(data, _, **kwargs):
62 "Returns model sensitivity table lines"
63 if "models" not in data.get("sensitivities", {}):
64 return ""
65 data = sorted(data["sensitivities"]["models"].items(),
66 key=lambda i: -np.mean(i[1]))
67 lines = ["Model Sensitivities", "-------------------"]
68 if kwargs["sortmodelsbysenss"]:
69 lines[0] += " (sorts models in sections below)"
70 for model, msenss in data:
71 if not model: # for now let's only do named models
72 continue
73 if (msenss < 0.1).all():
74 msenss = np.max(msenss)
75 if msenss:
76 msenssstr = "%6s" % ("<1E%i" % np.log10(msenss))
77 else:
78 msenssstr = " =0 "
79 elif not msenss.shape:
80 msenssstr = "%+6.1f" % msenss
81 else:
82 meansenss = np.mean(msenss)
83 deltas = msenss - meansenss
84 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
85 for d in deltas]
86 msenssstr = "%+6.1f + [ %s ]" % (meansenss, " ".join(deltastrs))
88 lines.append("%s : %s" % (msenssstr, model))
89 return lines + [""] if len(lines) > 3 else []
92def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
93 "Returns sensitivity table lines"
94 if "variables" in data.get("sensitivities", {}):
95 data = data["sensitivities"]["variables"]
96 if showvars:
97 data = {k: data[k] for k in showvars if k in data}
98 return var_table(data, title, sortbyvals=True, skipifempty=True,
99 valfmt="%+-.2g ", vecfmt="%+-8.2g",
100 printunits=False, minval=1e-3, **kwargs)
103def topsenss_table(data, showvars, nvars=5, **kwargs):
104 "Returns top sensitivity table lines"
105 data, filtered = topsenss_filter(data, showvars, nvars)
106 title = "Most Sensitive Variables"
107 if filtered:
108 title = "Next Most Sensitive Variables"
109 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
112def topsenss_filter(data, showvars, nvars=5):
113 "Filters sensitivities down to top N vars"
114 if "variables" in data.get("sensitivities", {}):
115 data = data["sensitivities"]["variables"]
116 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
117 if not np.isnan(s).any()}
118 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
119 filter_already_shown = showvars.intersection(topk)
120 for k in filter_already_shown:
121 topk.remove(k)
122 if nvars > 3: # always show at least 3
123 nvars -= 1
124 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
127def insenss_table(data, _, maxval=0.1, **kwargs):
128 "Returns insensitivity table lines"
129 if "constants" in data.get("sensitivities", {}):
130 data = data["sensitivities"]["variables"]
131 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
132 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
135def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
136 "Return constraint tightness lines"
137 title = "Most Sensitive Constraints"
138 if len(self) > 1:
139 title += " (in last sweep)"
140 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
141 "%+6.2g" % s[-1], id(c), c)
142 for c, s in self["sensitivities"]["constraints"].items()
143 if s[-1] >= tight_senss)[:ntightconstrs]
144 else:
145 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
146 for c, s in self["sensitivities"]["constraints"].items()
147 if s >= tight_senss)[:ntightconstrs]
148 return constraint_table(data, title, **kwargs)
150def loose_table(self, _, min_senss=1e-5, **kwargs):
151 "Return constraint tightness lines"
152 title = "Insensitive Constraints |below %+g|" % min_senss
153 if len(self) > 1:
154 title += " (in last sweep)"
155 data = [(0, "", id(c), c)
156 for c, s in self["sensitivities"]["constraints"].items()
157 if s[-1] <= min_senss]
158 else:
159 data = [(0, "", id(c), c)
160 for c, s in self["sensitivities"]["constraints"].items()
161 if s <= min_senss]
162 return constraint_table(data, title, **kwargs)
165# pylint: disable=too-many-branches,too-many-locals,too-many-statements
166def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
167 "Creates lines for tables where the right side is a constraint."
168 # TODO: this should support 1D array inputs from sweeps
169 excluded = ("units", "unnecessary lineage")
170 if not showmodels:
171 excluded = ("units", "lineage") # hide all of it
172 models, decorated = {}, []
173 for sortby, openingstr, _, constraint in sorted(data):
174 model = lineagestr(constraint) if sortbymodel else ""
175 if model not in models:
176 models[model] = len(models)
177 constrstr = try_str_without(constraint, excluded)
178 if " at 0x" in constrstr: # don't print memory addresses
179 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
180 decorated.append((models[model], model, sortby, constrstr, openingstr))
181 decorated.sort()
182 previous_model, lines = None, []
183 for varlist in decorated:
184 _, model, _, constrstr, openingstr = varlist
185 if model != previous_model:
186 if lines:
187 lines.append(["", ""])
188 if model or lines:
189 lines.append([("newmodelline",), model])
190 previous_model = model
191 constrstr = constrstr.replace(model, "")
192 minlen, maxlen = 25, 80
193 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
194 constraintlines = []
195 line = ""
196 next_idx = 0
197 while next_idx < len(segments):
198 segment = segments[next_idx]
199 next_idx += 1
200 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
201 segments[next_idx] = segment[1:] + segments[next_idx]
202 segment = segment[0]
203 elif len(line) + len(segment) > maxlen and len(line) > minlen:
204 constraintlines.append(line)
205 line = " " # start a new line
206 line += segment
207 while len(line) > maxlen:
208 constraintlines.append(line[:maxlen])
209 line = " " + line[maxlen:]
210 constraintlines.append(line)
211 lines += [(openingstr + " : ", constraintlines[0])]
212 lines += [("", l) for l in constraintlines[1:]]
213 if not lines:
214 lines = [("", "(none)")]
215 maxlens = np.max([list(map(len, line)) for line in lines
216 if line[0] != ("newmodelline",)], axis=0)
217 dirs = [">", "<"] # we'll check lengths before using zip
218 assert len(list(dirs)) == len(list(maxlens))
219 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
220 for i, line in enumerate(lines):
221 if line[0] == ("newmodelline",):
222 linelist = [fmts[0].format(" | "), line[1]]
223 else:
224 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
225 lines[i] = "".join(linelist).rstrip()
226 return [title] + ["-"*len(title)] + lines + [""]
229def warnings_table(self, _, **kwargs):
230 "Makes a table for all warnings in the solution."
231 title = "WARNINGS"
232 lines = ["~"*len(title), title, "~"*len(title)]
233 if "warnings" not in self or not self["warnings"]:
234 return []
235 for wtype in sorted(self["warnings"]):
236 data_vec = self["warnings"][wtype]
237 if not hasattr(data_vec, "shape"):
238 data_vec = [data_vec]
239 for i, data in enumerate(data_vec):
240 data = sorted(data, key=lambda l: l[0]) # sort by msg
241 title = wtype
242 if len(data_vec) > 1:
243 title += " in sweep %i" % i
244 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
245 data = [(-int(1e5*c.relax_sensitivity),
246 "%+6.2g" % c.relax_sensitivity, id(c), c)
247 for _, c in data]
248 lines += constraint_table(data, title, **kwargs)
249 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
250 data = [(-int(1e5*c.rel_diff),
251 "%.4g %s %.4g" % c.tightvalues, id(c), c)
252 for _, c in data]
253 lines += constraint_table(data, title, **kwargs)
254 else:
255 lines += [title] + ["-"*len(wtype)]
256 lines += [msg for msg, _ in data] + [""]
257 lines[-1] = "~~~~~~~~"
258 return lines + [""]
261TABLEFNS = {"sensitivities": senss_table,
262 "top sensitivities": topsenss_table,
263 "insensitivities": insenss_table,
264 "model sensitivities": msenss_table,
265 "tightest constraints": tight_table,
266 "loose constraints": loose_table,
267 "warnings": warnings_table,
268 }
270def unrolled_absmax(values):
271 "From an iterable of numbers and arrays, returns the largest magnitude"
272 finalval, absmaxest = None, 0
273 for val in values:
274 absmaxval = np.abs(val).max()
275 if absmaxval >= absmaxest:
276 absmaxest, finalval = absmaxval, val
277 if getattr(finalval, "shape", None):
278 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
279 finalval.shape)]
280 return finalval
283def cast(function, val1, val2):
284 "Relative difference between val1 and val2 (positive if val2 is larger)"
285 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
286 pywarnings.simplefilter("ignore")
287 if hasattr(val1, "shape") and hasattr(val2, "shape"):
288 if val1.ndim == val2.ndim:
289 return function(val1, val2)
290 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
291 dimdelta = dimmest.ndim - lessdim.ndim
292 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
293 if dimmest is val1:
294 return function(dimmest, lessdim[add_axes])
295 if dimmest is val2:
296 return function(lessdim[add_axes], dimmest)
297 return function(val1, val2)
300class SolutionArray(DictOfLists):
301 """A dictionary (of dictionaries) of lists, with convenience methods.
303 Items
304 -----
305 cost : array
306 variables: dict of arrays
307 sensitivities: dict containing:
308 monomials : array
309 posynomials : array
310 variables: dict of arrays
311 localmodels : NomialArray
312 Local power-law fits (small sensitivities are cut off)
314 Example
315 -------
316 >>> import gpkit
317 >>> import numpy as np
318 >>> x = gpkit.Variable("x")
319 >>> x_min = gpkit.Variable("x_{min}", 2)
320 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
321 >>>
322 >>> # VALUES
323 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
324 >>> assert all(np.array(values) == 2)
325 >>>
326 >>> # SENSITIVITIES
327 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
328 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
329 >>> assert all(np.array(senss) == 1)
330 """
331 modelstr = ""
332 _name_collision_varkeys = None
333 table_titles = {"sweepvariables": "Swept Variables",
334 "freevariables": "Free Variables",
335 "constants": "Fixed Variables", # TODO: change everywhere
336 "variables": "Variables"}
338 def name_collision_varkeys(self):
339 "Returns the set of contained varkeys whose names are not unique"
340 if self._name_collision_varkeys is None:
341 self["variables"].update_keymap()
342 keymap = self["variables"].keymap
343 self._name_collision_varkeys = set()
344 for key in list(keymap):
345 if hasattr(key, "key"):
346 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
347 self._name_collision_varkeys.add(key)
348 return self._name_collision_varkeys
350 def __len__(self):
351 try:
352 return len(self["cost"])
353 except TypeError:
354 return 1
355 except KeyError:
356 return 0
358 def __call__(self, posy):
359 posy_subbed = self.subinto(posy)
360 return getattr(posy_subbed, "c", posy_subbed)
362 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
363 "Checks for almost-equality between two solutions"
364 svars, ovars = self["variables"], other["variables"]
365 svks, ovks = set(svars), set(ovars)
366 if svks != ovks:
367 return False
368 for key in svks:
369 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
370 return False
371 if abs(self["sensitivities"]["variables"][key]
372 - other["sensitivities"]["variables"][key]) >= sens_abstol:
373 return False
374 return True
376 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
377 def diff(self, other, showvars=None, *,
378 constraintsdiff=True, senssdiff=False, sensstol=0.1,
379 absdiff=False, abstol=0, reldiff=True, reltol=1.0,
380 sortmodelsbysenss= True, **tableargs):
381 """Outputs differences between this solution and another
383 Arguments
384 ---------
385 other : solution or string
386 strings will be treated as paths to pickled solutions
387 senssdiff : boolean
388 if True, show sensitivity differences
389 sensstol : float
390 the smallest sensitivity difference worth showing
391 abssdiff : boolean
392 if True, show absolute differences
393 absstol : float
394 the smallest absolute difference worth showing
395 reldiff : boolean
396 if True, show relative differences
397 reltol : float
398 the smallest relative difference worth showing
400 Returns
401 -------
402 str
403 """
404 if sortmodelsbysenss:
405 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
406 else:
407 tableargs["sortmodelsbysenss"] = False
408 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
409 "skipifempty": False})
410 if isinstance(other, Strings):
411 if other[-4:] == ".pgz":
412 other = SolutionArray.decompress_file(other)
413 else:
414 other = pickle.load(open(other, "rb"))
415 svars, ovars = self["variables"], other["variables"]
416 lines = ["Solution Diff",
417 "=============",
418 "(argument is the baseline solution)", ""]
419 svks, ovks = set(svars), set(ovars)
420 if showvars:
421 lines[0] += " (for selected variables)"
422 lines[1] += "========================="
423 showvars = self._parse_showvars(showvars)
424 svks = {k for k in showvars if k in svars}
425 ovks = {k for k in showvars if k in ovars}
426 if constraintsdiff and other.modelstr and self.modelstr:
427 if self.modelstr == other.modelstr:
428 lines += ["** no constraint differences **", ""]
429 else:
430 cdiff = ["Constraint Differences",
431 "**********************"]
432 cdiff.extend(list(difflib.unified_diff(
433 other.modelstr.split("\n"), self.modelstr.split("\n"),
434 lineterm="", n=3))[2:])
435 cdiff += ["", "**********************", ""]
436 lines += cdiff
437 if svks - ovks:
438 lines.append("Variable(s) of this solution"
439 " which are not in the argument:")
440 lines.append("\n".join(" %s" % key for key in svks - ovks))
441 lines.append("")
442 if ovks - svks:
443 lines.append("Variable(s) of the argument"
444 " which are not in this solution:")
445 lines.append("\n".join(" %s" % key for key in ovks - svks))
446 lines.append("")
447 sharedvks = svks.intersection(ovks)
448 if reldiff:
449 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
450 for vk in sharedvks}
451 lines += var_table(rel_diff,
452 "Relative Differences |above %g%%|" % reltol,
453 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
454 minval=reltol, printunits=False, **tableargs)
455 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
456 lines.insert(-1, ("The largest is %+g%%."
457 % unrolled_absmax(rel_diff.values())))
458 if absdiff:
459 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
460 lines += var_table(abs_diff,
461 "Absolute Differences |above %g|" % abstol,
462 valfmt="%+.2g", vecfmt="%+8.2g",
463 minval=abstol, **tableargs)
464 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
465 lines.insert(-1, ("The largest is %+g."
466 % unrolled_absmax(abs_diff.values())))
467 if senssdiff:
468 ssenss = self["sensitivities"]["variables"]
469 osenss = other["sensitivities"]["variables"]
470 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
471 for vk in svks.intersection(ovks)}
472 lines += var_table(senss_delta,
473 "Sensitivity Differences |above %g|" % sensstol,
474 valfmt="%+-.2f ", vecfmt="%+-6.2f",
475 minval=sensstol, printunits=False, **tableargs)
476 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
477 lines.insert(-1, ("The largest is %+g."
478 % unrolled_absmax(senss_delta.values())))
479 return "\n".join(lines)
481 def save(self, filename="solution.pkl",
482 *, saveconstraints=True, **pickleargs):
483 """Pickles the solution and saves it to a file.
485 Solution can then be loaded with e.g.:
486 >>> import pickle
487 >>> pickle.load(open("solution.pkl"))
488 """
489 with SolSavingEnvironment(self, saveconstraints):
490 pickle.dump(self, open(filename, "wb"), **pickleargs)
492 def save_compressed(self, filename="solution.pgz",
493 *, saveconstraints=True, **cpickleargs):
494 "Pickle a file and then compress it into a file with extension."
495 with gzip.open(filename, "wb") as f:
496 with SolSavingEnvironment(self, saveconstraints):
497 pickled = pickle.dumps(self, **cpickleargs)
498 f.write(pickletools.optimize(pickled))
500 @staticmethod
501 def decompress_file(file):
502 "Load a gzip-compressed pickle file"
503 with gzip.open(file, "rb") as f:
504 return pickle.Unpickler(f).load()
506 def varnames(self, showvars, exclude):
507 "Returns list of variables, optionally with minimal unique names"
508 if showvars:
509 showvars = self._parse_showvars(showvars)
510 for key in self.name_collision_varkeys():
511 key.descr["necessarylineage"] = True
512 names = {}
513 for key in showvars or self["variables"]:
514 for k in self["variables"].keymap[key]:
515 names[k.str_without(exclude)] = k
516 for key in self.name_collision_varkeys():
517 del key.descr["necessarylineage"]
518 return names
520 def savemat(self, filename="solution.mat", showvars=None,
521 excluded=("unnecessary lineage", "vec")):
522 "Saves primal solution as matlab file"
523 from scipy.io import savemat
524 savemat(filename,
525 {name.replace(".", "_"): np.array(self["variables"][key], "f")
526 for name, key in self.varnames(showvars, excluded).items()})
528 def todataframe(self, showvars=None,
529 excluded=("unnecessary lineage", "vec")):
530 "Returns primal solution as pandas dataframe"
531 import pandas as pd # pylint:disable=import-error
532 rows = []
533 cols = ["Name", "Index", "Value", "Units", "Label",
534 "Lineage", "Other"]
535 for _, key in sorted(self.varnames(showvars, excluded).items(),
536 key=lambda k: k[0]):
537 value = self["variables"][key]
538 if key.shape:
539 idxs = []
540 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
541 while not it.finished:
542 idx = it.multi_index
543 idxs.append(idx[0] if len(idx) == 1 else idx)
544 it.iternext()
545 else:
546 idxs = [None]
547 for idx in idxs:
548 row = [
549 key.name,
550 "" if idx is None else idx,
551 value if idx is None else value[idx]]
552 rows.append(row)
553 row.extend([
554 key.unitstr(),
555 key.label or "",
556 key.lineage or "",
557 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
558 if k not in ["name", "units", "unitrepr",
559 "idx", "shape", "veckey",
560 "value", "original_fn",
561 "lineage", "label"])])
562 return pd.DataFrame(rows, columns=cols)
564 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs):
565 "Saves solution table as a text file"
566 with open(filename, "w") as f:
567 if printmodel:
568 f.write(self.modelstr + "\n")
569 f.write(self.table(**kwargs))
571 def savecsv(self, showvars=None, filename="solution.csv", valcols=5):
572 "Saves primal solution as a CSV sorted by modelname, like the tables."
573 data = self["variables"]
574 if showvars:
575 showvars = self._parse_showvars(showvars)
576 data = {k: data[k] for k in showvars if k in data}
577 # if the columns don't capture any dimensions, skip them
578 minspan, maxspan = None, 1
579 for v in data.values():
580 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
581 minspan_ = min((di for di in v.shape if di != 1))
582 maxspan_ = max((di for di in v.shape if di != 1))
583 if minspan is None or minspan_ < minspan:
584 minspan = minspan_
585 if maxspan is None or maxspan_ > maxspan:
586 maxspan = maxspan_
587 if minspan is not None and minspan > valcols:
588 valcols = 1
589 if maxspan < valcols:
590 valcols = maxspan
591 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
592 tables=("cost", "sweepvariables", "freevariables",
593 "constants", "sensitivities"))
594 with open(filename, "w") as f:
595 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
596 + "Units,Description\n")
597 for line in lines:
598 if line[0] == ("newmodelline",):
599 f.write(line[1])
600 elif not line[1]: # spacer line
601 f.write("\n")
602 else:
603 f.write("," + line[0].replace(" : ", "") + ",")
604 vals = line[1].replace("[", "").replace("]", "").strip()
605 for el in vals.split():
606 f.write(el + ",")
607 f.write(","*(valcols - len(vals.split())))
608 f.write((line[2].replace("[", "").replace("]", "").strip()
609 + ","))
610 f.write(line[3].strip() + "\n")
612 def subinto(self, posy):
613 "Returns NomialArray of each solution substituted into posy."
614 if posy in self["variables"]:
615 return self["variables"](posy)
617 if not hasattr(posy, "sub"):
618 raise ValueError("no variable '%s' found in the solution" % posy)
620 if len(self) > 1:
621 return NomialArray([self.atindex(i).subinto(posy)
622 for i in range(len(self))])
624 return posy.sub(self["variables"])
626 def _parse_showvars(self, showvars):
627 showvars_out = set()
628 for k in showvars:
629 k, _ = self["variables"].parse_and_index(k)
630 keys = self["variables"].keymap[k]
631 showvars_out.update(keys)
632 return showvars_out
634 def summary(self, showvars=(), ntopsenss=5, **kwargs):
635 "Print summary table, showing top sensitivities and no constants"
636 showvars = self._parse_showvars(showvars)
637 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
638 "freevariables"], **kwargs)
639 constants_in_showvars = showvars.intersection(self["constants"])
640 senss_tables = []
641 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
642 senss_tables.append("sensitivities")
643 if len(self["constants"]) >= ntopsenss+2:
644 senss_tables.append("top sensitivities")
645 senss_tables.append("tightest constraints")
646 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
647 **kwargs)
648 if senss_str:
649 out += "\n" + senss_str
650 return out
652 def table(self, showvars=(),
653 tables=("cost", "warnings", "sweepvariables",
654 "model sensitivities", "freevariables",
655 "constants", "sensitivities", "tightest constraints"),
656 sortmodelsbysenss=True, **kwargs):
657 """A table representation of this SolutionArray
659 Arguments
660 ---------
661 tables: Iterable
662 Which to print of ("cost", "sweepvariables", "freevariables",
663 "constants", "sensitivities")
664 fixedcols: If true, print vectors in fixed-width format
665 latex: int
666 If > 0, return latex format (options 1-3); otherwise plain text
667 included_models: Iterable of strings
668 If specified, the models (by name) to include
669 excluded_models: Iterable of strings
670 If specified, model names to exclude
672 Returns
673 -------
674 str
675 """
676 if sortmodelsbysenss:
677 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
678 else:
679 kwargs["sortmodelsbysenss"] = False
680 varlist = list(self["variables"])
681 has_only_one_model = True
682 for var in varlist[1:]:
683 if var.lineage != varlist[0].lineage:
684 has_only_one_model = False
685 break
686 if has_only_one_model:
687 kwargs["sortbymodel"] = False
688 for key in self.name_collision_varkeys():
689 key.descr["necessarylineage"] = True
690 showvars = self._parse_showvars(showvars)
691 strs = []
692 for table in tables:
693 if table == "cost":
694 cost = self["cost"] # pylint: disable=unsubscriptable-object
695 if kwargs.get("latex", None): # cost is not printed for latex
696 continue
697 strs += ["\n%s\n------------" % "Optimal Cost"]
698 if len(self) > 1:
699 costs = ["%-8.3g" % c for c in mag(cost[:4])]
700 strs += [" [ %s %s ]" % (" ".join(costs),
701 "..." if len(self) > 4 else "")]
702 else:
703 strs += [" %-.4g" % mag(cost)]
704 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
705 strs += [""]
706 elif table in TABLEFNS:
707 strs += TABLEFNS[table](self, showvars, **kwargs)
708 elif table in self:
709 data = self[table]
710 if showvars:
711 showvars = self._parse_showvars(showvars)
712 data = {k: data[k] for k in showvars if k in data}
713 strs += var_table(data, self.table_titles[table], **kwargs)
714 if kwargs.get("latex", None):
715 preamble = "\n".join(("% \\documentclass[12pt]{article}",
716 "% \\usepackage{booktabs}",
717 "% \\usepackage{longtable}",
718 "% \\usepackage{amsmath}",
719 "% \\begin{document}\n"))
720 strs = [preamble] + strs + ["% \\end{document}"]
721 for key in self.name_collision_varkeys():
722 del key.descr["necessarylineage"]
723 return "\n".join(strs)
725 def plot(self, posys=None, axes=None):
726 "Plots a sweep for each posy"
727 if len(self["sweepvariables"]) != 1:
728 print("SolutionArray.plot only supports 1-dimensional sweeps")
729 if not hasattr(posys, "__len__"):
730 posys = [posys]
731 import matplotlib.pyplot as plt
732 from .interactive.plot_sweep import assign_axes
733 from . import GPBLU
734 (swept, x), = self["sweepvariables"].items()
735 posys, axes = assign_axes(swept, posys, axes)
736 for posy, ax in zip(posys, axes):
737 y = self(posy) if posy not in [None, "cost"] else self["cost"]
738 ax.plot(x, y, color=GPBLU)
739 if len(axes) == 1:
740 axes, = axes
741 return plt.gcf(), axes
744# pylint: disable=too-many-branches,too-many-locals,too-many-statements
745def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
746 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
747 minval=0, sortbyvals=False, hidebelowminval=False,
748 included_models=None, excluded_models=None, sortbymodel=True,
749 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
750 """
751 Pretty string representation of a dict of VarKeys
752 Iterable values are handled specially (partial printing)
754 Arguments
755 ---------
756 data : dict whose keys are VarKey's
757 data to represent in table
758 title : string
759 printunits : bool
760 latex : int
761 If > 0, return latex format (options 1-3); otherwise plain text
762 varfmt : string
763 format for variable names
764 valfmt : string
765 format for scalar values
766 vecfmt : string
767 format for vector values
768 minval : float
769 skip values with all(abs(value)) < minval
770 sortbyvals : boolean
771 If true, rows are sorted by their average value instead of by name.
772 included_models : Iterable of strings
773 If specified, the models (by name) to include
774 excluded_models : Iterable of strings
775 If specified, model names to exclude
776 """
777 if not data:
778 return []
779 decorated, models = [], set()
780 for i, (k, v) in enumerate(data.items()):
781 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
782 continue # no values below minval
783 if minval and hidebelowminval and getattr(v, "shape", None):
784 v[np.abs(v) <= minval] = np.nan
785 model = lineagestr(k.lineage) if sortbymodel else ""
786 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0
787 if hasattr(msenss, "shape"):
788 msenss = np.mean(msenss)
789 models.add(model)
790 b = bool(getattr(v, "shape", None))
791 s = k.str_without(("lineage", "vec"))
792 if not sortbyvals:
793 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
794 else: # for consistent sorting, add small offset to negative vals
795 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
796 sort = (float("%.4g" % -val), k.name)
797 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
798 if not decorated and skipifempty:
799 return []
800 if included_models:
801 included_models = set(included_models)
802 included_models.add("")
803 models = models.intersection(included_models)
804 if excluded_models:
805 models = models.difference(excluded_models)
806 decorated.sort()
807 previous_model, lines = None, []
808 for varlist in decorated:
809 if sortbyvals:
810 model, _, msenss, isvector, varstr, _, var, val = varlist
811 else:
812 msenss, model, isvector, varstr, _, var, val = varlist
813 if model not in models:
814 continue
815 if model != previous_model:
816 if lines:
817 lines.append(["", "", "", ""])
818 if model:
819 if not latex:
820 lines.append([("newmodelline",), model, "", ""])
821 else:
822 lines.append(
823 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
824 previous_model = model
825 label = var.descr.get("label", "")
826 units = var.unitstr(" [%s] ") if printunits else ""
827 if not isvector:
828 valstr = valfmt % val
829 else:
830 last_dim_index = len(val.shape)-1
831 horiz_dim, ncols = last_dim_index, 1 # starting values
832 for dim_idx, dim_size in enumerate(val.shape):
833 if ncols <= dim_size <= maxcolumns:
834 horiz_dim, ncols = dim_idx, dim_size
835 # align the array with horiz_dim by making it the last one
836 dim_order = list(range(last_dim_index))
837 dim_order.insert(horiz_dim, last_dim_index)
838 flatval = val.transpose(dim_order).flatten()
839 vals = [vecfmt % v for v in flatval[:ncols]]
840 bracket = " ] " if len(flatval) <= ncols else ""
841 valstr = "[ %s%s" % (" ".join(vals), bracket)
842 for before, after in VALSTR_REPLACES:
843 valstr = valstr.replace(before, after)
844 if not latex:
845 lines.append([varstr, valstr, units, label])
846 if isvector and len(flatval) > ncols:
847 values_remaining = len(flatval) - ncols
848 while values_remaining > 0:
849 idx = len(flatval)-values_remaining
850 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
851 values_remaining -= ncols
852 valstr = " " + " ".join(vals)
853 for before, after in VALSTR_REPLACES:
854 valstr = valstr.replace(before, after)
855 if values_remaining <= 0:
856 spaces = (-values_remaining
857 * len(valstr)//(values_remaining + ncols))
858 valstr = valstr + " ]" + " "*spaces
859 lines.append(["", valstr, "", ""])
860 else:
861 varstr = "$%s$" % varstr.replace(" : ", "")
862 if latex == 1: # normal results table
863 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
864 label])
865 coltitles = [title, "Value", "Units", "Description"]
866 elif latex == 2: # no values
867 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
868 coltitles = [title, "Units", "Description"]
869 elif latex == 3: # no description
870 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
871 coltitles = [title, "Value", "Units"]
872 else:
873 raise ValueError("Unexpected latex option, %s." % latex)
874 if rawlines:
875 return lines
876 if not latex:
877 if lines:
878 maxlens = np.max([list(map(len, line)) for line in lines
879 if line[0] != ("newmodelline",)], axis=0)
880 dirs = [">", "<", "<", "<"]
881 # check lengths before using zip
882 assert len(list(dirs)) == len(list(maxlens))
883 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
884 for i, line in enumerate(lines):
885 if line[0] == ("newmodelline",):
886 line = [fmts[0].format(" | "), line[1]]
887 else:
888 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
889 lines[i] = "".join(line).rstrip()
890 lines = [title] + ["-"*len(title)] + lines + [""]
891 else:
892 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
893 lines = (["\n".join(["{\\footnotesize",
894 "\\begin{longtable}{%s}" % colfmt[latex],
895 "\\toprule",
896 " & ".join(coltitles) + " \\\\ \\midrule"])] +
897 [" & ".join(l) + " \\\\" for l in lines] +
898 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
899 return lines