Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from .nomials import NomialArray
11from .small_classes import DictOfLists, Strings
12from .small_scripts import mag, try_str_without
13from .repr_conventions import unitstr, lineagestr
16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
18VALSTR_REPLACES = [
19 ("+nan", " nan"),
20 ("-nan", " nan"),
21 ("nan%", "nan "),
22 ("nan", " - "),
23]
26class SolSavingEnvironment:
27 """Temporarily removes construction/solve attributes from constraints.
29 This approximately halves the size of the pickled solution.
30 """
32 def __init__(self, solarray):
33 self.solarray = solarray
34 self.attrstore = {}
36 def __enter__(self):
37 for constraint_attr in ["mfm", "pmap", "bounded", "meq_bounded",
38 "v_ss", "unsubbed", "varkeys"]:
39 store = {}
40 for constraint in self.solarray["sensitivities"]["constraints"]:
41 if getattr(constraint, constraint_attr, None):
42 store[constraint] = getattr(constraint, constraint_attr)
43 delattr(constraint, constraint_attr)
44 self.attrstore[constraint_attr] = store
46 def __exit__(self, type_, val, traceback):
47 for constraint_attr, store in self.attrstore.items():
48 for constraint, value in store.items():
49 setattr(constraint, constraint_attr, value)
52def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
53 "Returns sensitivity table lines"
54 if "variables" in data.get("sensitivities", {}):
55 data = data["sensitivities"]["variables"]
56 if showvars:
57 data = {k: data[k] for k in showvars if k in data}
58 return var_table(data, title, sortbyvals=True, skipifempty=True,
59 valfmt="%+-.2g ", vecfmt="%+-8.2g",
60 printunits=False, minval=1e-3, **kwargs)
63def topsenss_table(data, showvars, nvars=5, **kwargs):
64 "Returns top sensitivity table lines"
65 data, filtered = topsenss_filter(data, showvars, nvars)
66 title = "Most Sensitive Variables"
67 if filtered:
68 title = "Next Most Sensitive Variables"
69 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
72def topsenss_filter(data, showvars, nvars=5):
73 "Filters sensitivities down to top N vars"
74 if "variables" in data.get("sensitivities", {}):
75 data = data["sensitivities"]["variables"]
76 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
77 if not np.isnan(s).any()}
78 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
79 filter_already_shown = showvars.intersection(topk)
80 for k in filter_already_shown:
81 topk.remove(k)
82 if nvars > 3: # always show at least 3
83 nvars -= 1
84 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
87def insenss_table(data, _, maxval=0.1, **kwargs):
88 "Returns insensitivity table lines"
89 if "constants" in data.get("sensitivities", {}):
90 data = data["sensitivities"]["variables"]
91 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
92 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
95def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
96 "Return constraint tightness lines"
97 title = "Most Sensitive Constraints"
98 if len(self) > 1:
99 title += " (in last sweep)"
100 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
101 "%+6.2g" % s[-1], id(c), c)
102 for c, s in self["sensitivities"]["constraints"].items()
103 if s[-1] >= tight_senss)[:ntightconstrs]
104 else:
105 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
106 for c, s in self["sensitivities"]["constraints"].items()
107 if s >= tight_senss)[:ntightconstrs]
108 return constraint_table(data, title, **kwargs)
110def loose_table(self, _, min_senss=1e-5, **kwargs):
111 "Return constraint tightness lines"
112 title = "Insensitive Constraints |below %+g|" % min_senss
113 if len(self) > 1:
114 title += " (in last sweep)"
115 data = [(0, "", id(c), c)
116 for c, s in self["sensitivities"]["constraints"].items()
117 if s[-1] <= min_senss]
118 else:
119 data = [(0, "", id(c), c)
120 for c, s in self["sensitivities"]["constraints"].items()
121 if s <= min_senss]
122 return constraint_table(data, title, **kwargs)
125# pylint: disable=too-many-branches,too-many-locals,too-many-statements
126def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
127 "Creates lines for tables where the right side is a constraint."
128 # TODO: this should support 1D array inputs from sweeps
129 excluded = ("units", "unnecessary lineage")
130 if not showmodels:
131 excluded = ("units", "lineage") # hide all of it
132 models, decorated = {}, []
133 for sortby, openingstr, _, constraint in sorted(data):
134 model = lineagestr(constraint) if sortbymodel else ""
135 if model not in models:
136 models[model] = len(models)
137 constrstr = try_str_without(constraint, excluded)
138 if " at 0x" in constrstr: # don't print memory addresses
139 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
140 decorated.append((models[model], model, sortby, constrstr, openingstr))
141 decorated.sort()
142 previous_model, lines = None, []
143 for varlist in decorated:
144 _, model, _, constrstr, openingstr = varlist
145 if model != previous_model:
146 if lines:
147 lines.append(["", ""])
148 if model or lines:
149 lines.append([("modelname",), model])
150 previous_model = model
151 constrstr = constrstr.replace(model, "")
152 minlen, maxlen = 25, 80
153 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
154 constraintlines = []
155 line = ""
156 next_idx = 0
157 while next_idx < len(segments):
158 segment = segments[next_idx]
159 next_idx += 1
160 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
161 segments[next_idx] = segment[1:] + segments[next_idx]
162 segment = segment[0]
163 elif len(line) + len(segment) > maxlen and len(line) > minlen:
164 constraintlines.append(line)
165 line = " " # start a new line
166 line += segment
167 while len(line) > maxlen:
168 constraintlines.append(line[:maxlen])
169 line = " " + line[maxlen:]
170 constraintlines.append(line)
171 lines += [(openingstr + " : ", constraintlines[0])]
172 lines += [("", l) for l in constraintlines[1:]]
173 if not lines:
174 lines = [("", "(none)")]
175 maxlens = np.max([list(map(len, line)) for line in lines
176 if line[0] != ("modelname",)], axis=0)
177 dirs = [">", "<"] # we'll check lengths before using zip
178 assert len(list(dirs)) == len(list(maxlens))
179 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
180 for i, line in enumerate(lines):
181 if line[0] == ("modelname",):
182 linelist = [fmts[0].format(" | "), line[1]]
183 else:
184 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
185 lines[i] = "".join(linelist).rstrip()
186 return [title] + ["-"*len(title)] + lines + [""]
189def warnings_table(self, _, **kwargs):
190 "Makes a table for all warnings in the solution."
191 title = "WARNINGS"
192 lines = ["~"*len(title), title, "~"*len(title)]
193 if "warnings" not in self or not self["warnings"]:
194 return []
195 for wtype in sorted(self["warnings"]):
196 data_vec = self["warnings"][wtype]
197 if not hasattr(data_vec, "shape"):
198 data_vec = [data_vec]
199 for i, data in enumerate(data_vec):
200 data = sorted(data, key=lambda l: l[0]) # sort by msg
201 title = wtype
202 if len(data_vec) > 1:
203 title += " in sweep %i" % i
204 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
205 data = [(-int(1e5*c.relax_sensitivity),
206 "%+6.2g" % c.relax_sensitivity, id(c), c)
207 for _, c in data]
208 lines += constraint_table(data, title, **kwargs)
209 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
210 data = [(-int(1e5*c.rel_diff),
211 "%.4g %s %.4g" % c.tightvalues, id(c), c)
212 for _, c in data]
213 lines += constraint_table(data, title, **kwargs)
214 else:
215 lines += [title] + ["-"*len(wtype)]
216 lines += [msg for msg, _ in data] + [""]
217 lines[-1] = "~~~~~~~~"
218 return lines + [""]
221TABLEFNS = {"sensitivities": senss_table,
222 "top sensitivities": topsenss_table,
223 "insensitivities": insenss_table,
224 "tightest constraints": tight_table,
225 "loose constraints": loose_table,
226 "warnings": warnings_table,
227 }
229def unrolled_absmax(values):
230 "From an iterable of numbers and arrays, returns the largest magnitude"
231 finalval, absmaxest = None, 0
232 for val in values:
233 absmaxval = np.abs(val).max()
234 if absmaxval >= absmaxest:
235 absmaxest, finalval = absmaxval, val
236 if getattr(finalval, "shape", None):
237 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
238 finalval.shape)]
239 return finalval
242def cast(function, val1, val2):
243 "Relative difference between val1 and val2 (positive if val2 is larger)"
244 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
245 pywarnings.simplefilter("ignore")
246 if hasattr(val1, "shape") and hasattr(val2, "shape"):
247 if val1.ndim == val2.ndim:
248 return function(val1, val2)
249 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
250 dimdelta = dimmest.ndim - lessdim.ndim
251 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
252 if dimmest is val1:
253 return function(dimmest, lessdim[add_axes])
254 if dimmest is val2:
255 return function(lessdim[add_axes], dimmest)
256 return function(val1, val2)
259class SolutionArray(DictOfLists):
260 """A dictionary (of dictionaries) of lists, with convenience methods.
262 Items
263 -----
264 cost : array
265 variables: dict of arrays
266 sensitivities: dict containing:
267 monomials : array
268 posynomials : array
269 variables: dict of arrays
270 localmodels : NomialArray
271 Local power-law fits (small sensitivities are cut off)
273 Example
274 -------
275 >>> import gpkit
276 >>> import numpy as np
277 >>> x = gpkit.Variable("x")
278 >>> x_min = gpkit.Variable("x_{min}", 2)
279 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
280 >>>
281 >>> # VALUES
282 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
283 >>> assert all(np.array(values) == 2)
284 >>>
285 >>> # SENSITIVITIES
286 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
287 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
288 >>> assert all(np.array(senss) == 1)
289 """
290 modelstr = ""
291 _name_collision_varkeys = None
292 table_titles = {"sweepvariables": "Swept Variables",
293 "freevariables": "Free Variables",
294 "constants": "Fixed Variables", # TODO: change everywhere
295 "variables": "Variables"}
297 def name_collision_varkeys(self):
298 "Returns the set of contained varkeys whose names are not unique"
299 if self._name_collision_varkeys is None:
300 self["variables"].update_keymap()
301 keymap = self["variables"].keymap
302 self._name_collision_varkeys = set()
303 for key in list(keymap):
304 if hasattr(key, "key"):
305 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
306 self._name_collision_varkeys.add(key)
307 return self._name_collision_varkeys
309 def __len__(self):
310 try:
311 return len(self["cost"])
312 except TypeError:
313 return 1
314 except KeyError:
315 return 0
317 def __call__(self, posy):
318 posy_subbed = self.subinto(posy)
319 return getattr(posy_subbed, "c", posy_subbed)
321 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
322 "Checks for almost-equality between two solutions"
323 svars, ovars = self["variables"], other["variables"]
324 svks, ovks = set(svars), set(ovars)
325 if svks != ovks:
326 return False
327 for key in svks:
328 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
329 return False
330 if abs(self["sensitivities"]["variables"][key]
331 - other["sensitivities"]["variables"][key]) >= sens_abstol:
332 return False
333 return True
335 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
336 def diff(self, other, showvars=None, *,
337 constraintsdiff=True, senssdiff=False, sensstol=0.1,
338 absdiff=False, abstol=0, reldiff=True, reltol=1.0, **tableargs):
339 """Outputs differences between this solution and another
341 Arguments
342 ---------
343 other : solution or string
344 strings will be treated as paths to pickled solutions
345 senssdiff : boolean
346 if True, show sensitivity differences
347 sensstol : float
348 the smallest sensitivity difference worth showing
349 abssdiff : boolean
350 if True, show absolute differences
351 absstol : float
352 the smallest absolute difference worth showing
353 reldiff : boolean
354 if True, show relative differences
355 reltol : float
356 the smallest relative difference worth showing
358 Returns
359 -------
360 str
361 """
362 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
363 "skipifempty": False})
364 if isinstance(other, Strings):
365 if other[-4:] == ".pgz":
366 other = SolutionArray.decompress_file(other)
367 else:
368 other = pickle.load(open(other, "rb"))
369 svars, ovars = self["variables"], other["variables"]
370 lines = ["Solution Diff",
371 "=============",
372 "(argument is the baseline solution)", ""]
373 svks, ovks = set(svars), set(ovars)
374 if showvars:
375 lines[0] += " (for selected variables)"
376 lines[1] += "========================="
377 showvars = self._parse_showvars(showvars)
378 svks = {k for k in showvars if k in svars}
379 ovks = {k for k in showvars if k in ovars}
380 if constraintsdiff and other.modelstr and self.modelstr:
381 if self.modelstr == other.modelstr:
382 lines += ["** no constraint differences **", ""]
383 else:
384 cdiff = ["Constraint Differences",
385 "**********************"]
386 cdiff.extend(list(difflib.unified_diff(
387 other.modelstr.split("\n"), self.modelstr.split("\n"),
388 lineterm="", n=3))[2:])
389 cdiff += ["", "**********************", ""]
390 lines += cdiff
391 if svks - ovks:
392 lines.append("Variable(s) of this solution"
393 " which are not in the argument:")
394 lines.append("\n".join(" %s" % key for key in svks - ovks))
395 lines.append("")
396 if ovks - svks:
397 lines.append("Variable(s) of the argument"
398 " which are not in this solution:")
399 lines.append("\n".join(" %s" % key for key in ovks - svks))
400 lines.append("")
401 sharedvks = svks.intersection(ovks)
402 if reldiff:
403 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
404 for vk in sharedvks}
405 lines += var_table(rel_diff,
406 "Relative Differences |above %g%%|" % reltol,
407 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
408 minval=reltol, printunits=False, **tableargs)
409 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
410 lines.insert(-1, ("The largest is %+g%%."
411 % unrolled_absmax(rel_diff.values())))
412 if absdiff:
413 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
414 lines += var_table(abs_diff,
415 "Absolute Differences |above %g|" % abstol,
416 valfmt="%+.2g", vecfmt="%+8.2g",
417 minval=abstol, **tableargs)
418 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
419 lines.insert(-1, ("The largest is %+g."
420 % unrolled_absmax(abs_diff.values())))
421 if senssdiff:
422 ssenss = self["sensitivities"]["variables"]
423 osenss = other["sensitivities"]["variables"]
424 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
425 for vk in svks.intersection(ovks)}
426 lines += var_table(senss_delta,
427 "Sensitivity Differences |above %g|" % sensstol,
428 valfmt="%+-.2f ", vecfmt="%+-6.2f",
429 minval=sensstol, printunits=False, **tableargs)
430 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
431 lines.insert(-1, ("The largest is %+g."
432 % unrolled_absmax(senss_delta.values())))
433 return "\n".join(lines)
435 def save(self, filename="solution.pkl", **pickleargs):
436 """Pickles the solution and saves it to a file.
438 Solution can then be loaded with e.g.:
439 >>> import pickle
440 >>> pickle.load(open("solution.pkl"))
441 """
442 with SolSavingEnvironment(self):
443 pickle.dump(self, open(filename, "wb"), **pickleargs)
445 def save_compressed(self, filename="solution.pgz", **cpickleargs):
446 "Pickle a file and then compress it into a file with extension."
447 with gzip.open(filename, "wb") as f:
448 with SolSavingEnvironment(self):
449 pickled = pickle.dumps(self, **cpickleargs)
450 f.write(pickletools.optimize(pickled))
452 @staticmethod
453 def decompress_file(file):
454 "Load a gzip-compressed pickle file"
455 with gzip.open(file, "rb") as f:
456 return pickle.Unpickler(f).load()
458 def varnames(self, showvars, exclude):
459 "Returns list of variables, optionally with minimal unique names"
460 if showvars:
461 showvars = self._parse_showvars(showvars)
462 for key in self.name_collision_varkeys():
463 key.descr["necessarylineage"] = True
464 names = {}
465 for key in showvars or self["variables"]:
466 for k in self["variables"].keymap[key]:
467 names[k.str_without(exclude)] = k
468 for key in self.name_collision_varkeys():
469 del key.descr["necessarylineage"]
470 return names
472 def savemat(self, filename="solution.mat", showvars=None,
473 excluded=("unnecessary lineage", "vec")):
474 "Saves primal solution as matlab file"
475 from scipy.io import savemat
476 savemat(filename,
477 {name.replace(".", "_"): np.array(self["variables"][key], "f")
478 for name, key in self.varnames(showvars, excluded).items()})
480 def todataframe(self, showvars=None,
481 excluded=("unnecessary lineage", "vec")):
482 "Returns primal solution as pandas dataframe"
483 import pandas as pd # pylint:disable=import-error
484 rows = []
485 cols = ["Name", "Index", "Value", "Units", "Label",
486 "Lineage", "Other"]
487 for _, key in sorted(self.varnames(showvars, excluded).items(),
488 key=lambda k: k[0]):
489 value = self["variables"][key]
490 if key.shape:
491 idxs = []
492 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
493 while not it.finished:
494 idx = it.multi_index
495 idxs.append(idx[0] if len(idx) == 1 else idx)
496 it.iternext()
497 else:
498 idxs = [None]
499 for idx in idxs:
500 row = [
501 key.name,
502 "" if idx is None else idx,
503 value if idx is None else value[idx]]
504 rows.append(row)
505 row.extend([
506 key.unitstr(),
507 key.label or "",
508 key.lineage or "",
509 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
510 if k not in ["name", "units", "unitrepr",
511 "idx", "shape", "veckey",
512 "value", "original_fn",
513 "lineage", "label"])])
514 return pd.DataFrame(rows, columns=cols)
516 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs):
517 "Saves solution table as a text file"
518 with open(filename, "w") as f:
519 if printmodel:
520 f.write(self.modelstr)
521 f.write(self.table(**kwargs))
523 def savecsv(self, showvars=None, filename="solution.csv", valcols=5):
524 "Saves primal solution as a CSV sorted by modelname, like the tables."
525 data = self["variables"]
526 if showvars:
527 showvars = self._parse_showvars(showvars)
528 data = {k: data[k] for k in showvars if k in data}
529 # if the columns don't capture any dimensions, skip them
530 minspan, maxspan = None, 1
531 for v in data.values():
532 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
533 minspan_ = min((di for di in v.shape if di != 1))
534 maxspan_ = max((di for di in v.shape if di != 1))
535 if minspan is None or minspan_ < minspan:
536 minspan = minspan_
537 if maxspan is None or maxspan_ > maxspan:
538 maxspan = maxspan_
539 if minspan is not None and minspan > valcols:
540 valcols = 1
541 if maxspan < valcols:
542 valcols = maxspan
543 lines = var_table(data, "", rawlines=True, maxcolumns=valcols)
544 with open(filename, "w") as f:
545 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
546 + "Units,Description\n")
547 for line in lines:
548 if line[0] == ("modelname",):
549 f.write(line[1])
550 elif not line[1]: # spacer line
551 f.write("\n")
552 else:
553 f.write("," + line[0].replace(" : ", "") + ",")
554 vals = line[1].replace("[", "").replace("]", "").strip()
555 for el in vals.split():
556 f.write(el + ",")
557 f.write(","*(valcols - len(vals.split())))
558 f.write((line[2].replace("[", "").replace("]", "").strip()
559 + ","))
560 f.write(line[3].strip() + "\n")
562 def subinto(self, posy):
563 "Returns NomialArray of each solution substituted into posy."
564 if posy in self["variables"]:
565 return self["variables"](posy)
567 if not hasattr(posy, "sub"):
568 raise ValueError("no variable '%s' found in the solution" % posy)
570 if len(self) > 1:
571 return NomialArray([self.atindex(i).subinto(posy)
572 for i in range(len(self))])
574 return posy.sub(self["variables"])
576 def _parse_showvars(self, showvars):
577 showvars_out = set()
578 for k in showvars:
579 k, _ = self["variables"].parse_and_index(k)
580 keys = self["variables"].keymap[k]
581 showvars_out.update(keys)
582 return showvars_out
584 def summary(self, showvars=(), ntopsenss=5, **kwargs):
585 "Print summary table, showing top sensitivities and no constants"
586 showvars = self._parse_showvars(showvars)
587 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
588 "freevariables"], **kwargs)
589 constants_in_showvars = showvars.intersection(self["constants"])
590 senss_tables = []
591 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
592 senss_tables.append("sensitivities")
593 if len(self["constants"]) >= ntopsenss+2:
594 senss_tables.append("top sensitivities")
595 senss_tables.append("tightest constraints")
596 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
597 **kwargs)
598 if senss_str:
599 out += "\n" + senss_str
600 return out
602 def table(self, showvars=(),
603 tables=("cost", "warnings", "sweepvariables", "freevariables",
604 "constants", "sensitivities", "tightest constraints"),
605 **kwargs):
606 """A table representation of this SolutionArray
608 Arguments
609 ---------
610 tables: Iterable
611 Which to print of ("cost", "sweepvariables", "freevariables",
612 "constants", "sensitivities")
613 fixedcols: If true, print vectors in fixed-width format
614 latex: int
615 If > 0, return latex format (options 1-3); otherwise plain text
616 included_models: Iterable of strings
617 If specified, the models (by name) to include
618 excluded_models: Iterable of strings
619 If specified, model names to exclude
621 Returns
622 -------
623 str
624 """
625 varlist = list(self["variables"])
626 has_only_one_model = True
627 for var in varlist[1:]:
628 if var.lineage != varlist[0].lineage:
629 has_only_one_model = False
630 break
631 if has_only_one_model:
632 kwargs["sortbymodel"] = False
633 for key in self.name_collision_varkeys():
634 key.descr["necessarylineage"] = True
635 showvars = self._parse_showvars(showvars)
636 strs = []
637 for table in tables:
638 if table == "cost":
639 cost = self["cost"] # pylint: disable=unsubscriptable-object
640 if kwargs.get("latex", None): # cost is not printed for latex
641 continue
642 strs += ["\n%s\n------------" % "Optimal Cost"]
643 if len(self) > 1:
644 costs = ["%-8.3g" % c for c in mag(cost[:4])]
645 strs += [" [ %s %s ]" % (" ".join(costs),
646 "..." if len(self) > 4 else "")]
647 else:
648 strs += [" %-.4g" % mag(cost)]
649 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
650 strs += [""]
651 elif table in TABLEFNS:
652 strs += TABLEFNS[table](self, showvars, **kwargs)
653 elif table in self:
654 data = self[table]
655 if showvars:
656 showvars = self._parse_showvars(showvars)
657 data = {k: data[k] for k in showvars if k in data}
658 strs += var_table(data, self.table_titles[table], **kwargs)
659 if kwargs.get("latex", None):
660 preamble = "\n".join(("% \\documentclass[12pt]{article}",
661 "% \\usepackage{booktabs}",
662 "% \\usepackage{longtable}",
663 "% \\usepackage{amsmath}",
664 "% \\begin{document}\n"))
665 strs = [preamble] + strs + ["% \\end{document}"]
666 for key in self.name_collision_varkeys():
667 del key.descr["necessarylineage"]
668 return "\n".join(strs)
670 def plot(self, posys=None, axes=None):
671 "Plots a sweep for each posy"
672 if len(self["sweepvariables"]) != 1:
673 print("SolutionArray.plot only supports 1-dimensional sweeps")
674 if not hasattr(posys, "__len__"):
675 posys = [posys]
676 import matplotlib.pyplot as plt
677 from .interactive.plot_sweep import assign_axes
678 from . import GPBLU
679 (swept, x), = self["sweepvariables"].items()
680 posys, axes = assign_axes(swept, posys, axes)
681 for posy, ax in zip(posys, axes):
682 y = self(posy) if posy not in [None, "cost"] else self["cost"]
683 ax.plot(x, y, color=GPBLU)
684 if len(axes) == 1:
685 axes, = axes
686 return plt.gcf(), axes
689# pylint: disable=too-many-branches,too-many-locals,too-many-statements
690def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
691 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
692 minval=0, sortbyvals=False, hidebelowminval=False,
693 included_models=None, excluded_models=None, sortbymodel=True,
694 maxcolumns=5, skipifempty=True, **_):
695 """
696 Pretty string representation of a dict of VarKeys
697 Iterable values are handled specially (partial printing)
699 Arguments
700 ---------
701 data : dict whose keys are VarKey's
702 data to represent in table
703 title : string
704 printunits : bool
705 latex : int
706 If > 0, return latex format (options 1-3); otherwise plain text
707 varfmt : string
708 format for variable names
709 valfmt : string
710 format for scalar values
711 vecfmt : string
712 format for vector values
713 minval : float
714 skip values with all(abs(value)) < minval
715 sortbyvals : boolean
716 If true, rows are sorted by their average value instead of by name.
717 included_models : Iterable of strings
718 If specified, the models (by name) to include
719 excluded_models : Iterable of strings
720 If specified, model names to exclude
721 """
722 if not data:
723 return []
724 decorated, models = [], set()
725 for i, (k, v) in enumerate(data.items()):
726 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
727 continue # no values below minval
728 if minval and hidebelowminval and getattr(v, "shape", None):
729 v[np.abs(v) <= minval] = np.nan
730 model = lineagestr(k.lineage) if sortbymodel else ""
731 models.add(model)
732 b = bool(getattr(v, "shape", None))
733 s = k.str_without(("lineage", "vec"))
734 if not sortbyvals:
735 decorated.append((model, b, (varfmt % s), i, k, v))
736 else: # for consistent sorting, add small offset to negative vals
737 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
738 sort = (float("%.4g" % -val), k.name)
739 decorated.append((model, sort, b, (varfmt % s), i, k, v))
740 if not decorated and skipifempty:
741 return []
742 if included_models:
743 included_models = set(included_models)
744 included_models.add("")
745 models = models.intersection(included_models)
746 if excluded_models:
747 models = models.difference(excluded_models)
748 decorated.sort()
749 previous_model, lines = None, []
750 for varlist in decorated:
751 if not sortbyvals:
752 model, isvector, varstr, _, var, val = varlist
753 else:
754 model, _, isvector, varstr, _, var, val = varlist
755 if model not in models:
756 continue
757 if model != previous_model:
758 if lines:
759 lines.append(["", "", "", ""])
760 if model:
761 if not latex:
762 lines.append([("modelname",), model, "", ""])
763 else:
764 lines.append(
765 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
766 previous_model = model
767 label = var.descr.get("label", "")
768 units = var.unitstr(" [%s] ") if printunits else ""
769 if not isvector:
770 valstr = valfmt % val
771 else:
772 last_dim_index = len(val.shape)-1
773 horiz_dim, ncols = last_dim_index, 1 # starting values
774 for dim_idx, dim_size in enumerate(val.shape):
775 if ncols <= dim_size <= maxcolumns:
776 horiz_dim, ncols = dim_idx, dim_size
777 # align the array with horiz_dim by making it the last one
778 dim_order = list(range(last_dim_index))
779 dim_order.insert(horiz_dim, last_dim_index)
780 flatval = val.transpose(dim_order).flatten()
781 vals = [vecfmt % v for v in flatval[:ncols]]
782 bracket = " ] " if len(flatval) <= ncols else ""
783 valstr = "[ %s%s" % (" ".join(vals), bracket)
784 for before, after in VALSTR_REPLACES:
785 valstr = valstr.replace(before, after)
786 if not latex:
787 lines.append([varstr, valstr, units, label])
788 if isvector and len(flatval) > ncols:
789 values_remaining = len(flatval) - ncols
790 while values_remaining > 0:
791 idx = len(flatval)-values_remaining
792 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
793 values_remaining -= ncols
794 valstr = " " + " ".join(vals)
795 for before, after in VALSTR_REPLACES:
796 valstr = valstr.replace(before, after)
797 if values_remaining <= 0:
798 spaces = (-values_remaining
799 * len(valstr)//(values_remaining + ncols))
800 valstr = valstr + " ]" + " "*spaces
801 lines.append(["", valstr, "", ""])
802 else:
803 varstr = "$%s$" % varstr.replace(" : ", "")
804 if latex == 1: # normal results table
805 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
806 label])
807 coltitles = [title, "Value", "Units", "Description"]
808 elif latex == 2: # no values
809 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
810 coltitles = [title, "Units", "Description"]
811 elif latex == 3: # no description
812 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
813 coltitles = [title, "Value", "Units"]
814 else:
815 raise ValueError("Unexpected latex option, %s." % latex)
816 if rawlines:
817 return lines
818 if not latex:
819 if lines:
820 maxlens = np.max([list(map(len, line)) for line in lines
821 if line[0] != ("modelname",)], axis=0)
822 dirs = [">", "<", "<", "<"]
823 # check lengths before using zip
824 assert len(list(dirs)) == len(list(maxlens))
825 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
826 for i, line in enumerate(lines):
827 if line[0] == ("modelname",):
828 line = [fmts[0].format(" | "), line[1]]
829 else:
830 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
831 lines[i] = "".join(line).rstrip()
832 lines = [title] + ["-"*len(title)] + lines + [""]
833 else:
834 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
835 lines = (["\n".join(["{\\footnotesize",
836 "\\begin{longtable}{%s}" % colfmt[latex],
837 "\\toprule",
838 " & ".join(coltitles) + " \\\\ \\midrule"])] +
839 [" & ".join(l) + " \\\\" for l in lines] +
840 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
841 return lines