Coverage for gpkit/solution_array.py : 83%
![Show keyboard shortcuts](keybd_closed.png)
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from .nomials import NomialArray
11from .small_classes import DictOfLists, Strings
12from .small_scripts import mag, try_str_without
13from .repr_conventions import unitstr, lineagestr
16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
18VALSTR_REPLACES = [
19 ("+nan", " nan"),
20 ("-nan", " nan"),
21 ("nan%", "nan "),
22 ("nan", " - "),
23]
26class SolSavingEnvironment:
27 """Temporarily removes construction/solve attributes from constraints.
29 This approximately halves the size of the pickled solution.
30 """
32 def __init__(self, solarray, saveconstraints):
33 self.solarray = solarray
34 self.attrstore = {}
35 self.saveconstraints = saveconstraints
36 self.constraintstore = None
39 def __enter__(self):
40 if self.saveconstraints:
41 for constraint_attr in ["bounded", "meq_bounded", "vks",
42 "v_ss", "unsubbed", "varkeys"]:
43 store = {}
44 for constraint in self.solarray["sensitivities"]["constraints"]:
45 if getattr(constraint, constraint_attr, None):
46 store[constraint] = getattr(constraint, constraint_attr)
47 delattr(constraint, constraint_attr)
48 self.attrstore[constraint_attr] = store
49 else:
50 self.constraintstore = \
51 self.solarray["sensitivities"].pop("constraints")
53 def __exit__(self, type_, val, traceback):
54 if self.saveconstraints:
55 for constraint_attr, store in self.attrstore.items():
56 for constraint, value in store.items():
57 setattr(constraint, constraint_attr, value)
58 else:
59 self.solarray["sensitivities"]["constraints"] = self.constraintstore
62def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
63 "Returns sensitivity table lines"
64 if "variables" in data.get("sensitivities", {}):
65 data = data["sensitivities"]["variables"]
66 if showvars:
67 data = {k: data[k] for k in showvars if k in data}
68 return var_table(data, title, sortbyvals=True, skipifempty=True,
69 valfmt="%+-.2g ", vecfmt="%+-8.2g",
70 printunits=False, minval=1e-3, **kwargs)
73def topsenss_table(data, showvars, nvars=5, **kwargs):
74 "Returns top sensitivity table lines"
75 data, filtered = topsenss_filter(data, showvars, nvars)
76 title = "Most Sensitive Variables"
77 if filtered:
78 title = "Next Most Sensitive Variables"
79 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
82def topsenss_filter(data, showvars, nvars=5):
83 "Filters sensitivities down to top N vars"
84 if "variables" in data.get("sensitivities", {}):
85 data = data["sensitivities"]["variables"]
86 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
87 if not np.isnan(s).any()}
88 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
89 filter_already_shown = showvars.intersection(topk)
90 for k in filter_already_shown:
91 topk.remove(k)
92 if nvars > 3: # always show at least 3
93 nvars -= 1
94 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
97def insenss_table(data, _, maxval=0.1, **kwargs):
98 "Returns insensitivity table lines"
99 if "constants" in data.get("sensitivities", {}):
100 data = data["sensitivities"]["variables"]
101 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
102 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
105def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
106 "Return constraint tightness lines"
107 title = "Most Sensitive Constraints"
108 if len(self) > 1:
109 title += " (in last sweep)"
110 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
111 "%+6.2g" % s[-1], id(c), c)
112 for c, s in self["sensitivities"]["constraints"].items()
113 if s[-1] >= tight_senss)[:ntightconstrs]
114 else:
115 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
116 for c, s in self["sensitivities"]["constraints"].items()
117 if s >= tight_senss)[:ntightconstrs]
118 return constraint_table(data, title, **kwargs)
120def loose_table(self, _, min_senss=1e-5, **kwargs):
121 "Return constraint tightness lines"
122 title = "Insensitive Constraints |below %+g|" % min_senss
123 if len(self) > 1:
124 title += " (in last sweep)"
125 data = [(0, "", id(c), c)
126 for c, s in self["sensitivities"]["constraints"].items()
127 if s[-1] <= min_senss]
128 else:
129 data = [(0, "", id(c), c)
130 for c, s in self["sensitivities"]["constraints"].items()
131 if s <= min_senss]
132 return constraint_table(data, title, **kwargs)
135# pylint: disable=too-many-branches,too-many-locals,too-many-statements
136def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
137 "Creates lines for tables where the right side is a constraint."
138 # TODO: this should support 1D array inputs from sweeps
139 excluded = ("units", "unnecessary lineage")
140 if not showmodels:
141 excluded = ("units", "lineage") # hide all of it
142 models, decorated = {}, []
143 for sortby, openingstr, _, constraint in sorted(data):
144 model = lineagestr(constraint) if sortbymodel else ""
145 if model not in models:
146 models[model] = len(models)
147 constrstr = try_str_without(constraint, excluded)
148 if " at 0x" in constrstr: # don't print memory addresses
149 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
150 decorated.append((models[model], model, sortby, constrstr, openingstr))
151 decorated.sort()
152 previous_model, lines = None, []
153 for varlist in decorated:
154 _, model, _, constrstr, openingstr = varlist
155 if model != previous_model:
156 if lines:
157 lines.append(["", ""])
158 if model or lines:
159 lines.append([("modelname",), model])
160 previous_model = model
161 constrstr = constrstr.replace(model, "")
162 minlen, maxlen = 25, 80
163 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
164 constraintlines = []
165 line = ""
166 next_idx = 0
167 while next_idx < len(segments):
168 segment = segments[next_idx]
169 next_idx += 1
170 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
171 segments[next_idx] = segment[1:] + segments[next_idx]
172 segment = segment[0]
173 elif len(line) + len(segment) > maxlen and len(line) > minlen:
174 constraintlines.append(line)
175 line = " " # start a new line
176 line += segment
177 while len(line) > maxlen:
178 constraintlines.append(line[:maxlen])
179 line = " " + line[maxlen:]
180 constraintlines.append(line)
181 lines += [(openingstr + " : ", constraintlines[0])]
182 lines += [("", l) for l in constraintlines[1:]]
183 if not lines:
184 lines = [("", "(none)")]
185 maxlens = np.max([list(map(len, line)) for line in lines
186 if line[0] != ("modelname",)], axis=0)
187 dirs = [">", "<"] # we'll check lengths before using zip
188 assert len(list(dirs)) == len(list(maxlens))
189 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
190 for i, line in enumerate(lines):
191 if line[0] == ("modelname",):
192 linelist = [fmts[0].format(" | "), line[1]]
193 else:
194 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
195 lines[i] = "".join(linelist).rstrip()
196 return [title] + ["-"*len(title)] + lines + [""]
199def warnings_table(self, _, **kwargs):
200 "Makes a table for all warnings in the solution."
201 title = "WARNINGS"
202 lines = ["~"*len(title), title, "~"*len(title)]
203 if "warnings" not in self or not self["warnings"]:
204 return []
205 for wtype in sorted(self["warnings"]):
206 data_vec = self["warnings"][wtype]
207 if not hasattr(data_vec, "shape"):
208 data_vec = [data_vec]
209 for i, data in enumerate(data_vec):
210 data = sorted(data, key=lambda l: l[0]) # sort by msg
211 title = wtype
212 if len(data_vec) > 1:
213 title += " in sweep %i" % i
214 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
215 data = [(-int(1e5*c.relax_sensitivity),
216 "%+6.2g" % c.relax_sensitivity, id(c), c)
217 for _, c in data]
218 lines += constraint_table(data, title, **kwargs)
219 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
220 data = [(-int(1e5*c.rel_diff),
221 "%.4g %s %.4g" % c.tightvalues, id(c), c)
222 for _, c in data]
223 lines += constraint_table(data, title, **kwargs)
224 else:
225 lines += [title] + ["-"*len(wtype)]
226 lines += [msg for msg, _ in data] + [""]
227 lines[-1] = "~~~~~~~~"
228 return lines + [""]
231TABLEFNS = {"sensitivities": senss_table,
232 "top sensitivities": topsenss_table,
233 "insensitivities": insenss_table,
234 "tightest constraints": tight_table,
235 "loose constraints": loose_table,
236 "warnings": warnings_table,
237 }
239def unrolled_absmax(values):
240 "From an iterable of numbers and arrays, returns the largest magnitude"
241 finalval, absmaxest = None, 0
242 for val in values:
243 absmaxval = np.abs(val).max()
244 if absmaxval >= absmaxest:
245 absmaxest, finalval = absmaxval, val
246 if getattr(finalval, "shape", None):
247 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
248 finalval.shape)]
249 return finalval
252def cast(function, val1, val2):
253 "Relative difference between val1 and val2 (positive if val2 is larger)"
254 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
255 pywarnings.simplefilter("ignore")
256 if hasattr(val1, "shape") and hasattr(val2, "shape"):
257 if val1.ndim == val2.ndim:
258 return function(val1, val2)
259 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
260 dimdelta = dimmest.ndim - lessdim.ndim
261 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
262 if dimmest is val1:
263 return function(dimmest, lessdim[add_axes])
264 if dimmest is val2:
265 return function(lessdim[add_axes], dimmest)
266 return function(val1, val2)
269class SolutionArray(DictOfLists):
270 """A dictionary (of dictionaries) of lists, with convenience methods.
272 Items
273 -----
274 cost : array
275 variables: dict of arrays
276 sensitivities: dict containing:
277 monomials : array
278 posynomials : array
279 variables: dict of arrays
280 localmodels : NomialArray
281 Local power-law fits (small sensitivities are cut off)
283 Example
284 -------
285 >>> import gpkit
286 >>> import numpy as np
287 >>> x = gpkit.Variable("x")
288 >>> x_min = gpkit.Variable("x_{min}", 2)
289 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
290 >>>
291 >>> # VALUES
292 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
293 >>> assert all(np.array(values) == 2)
294 >>>
295 >>> # SENSITIVITIES
296 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
297 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
298 >>> assert all(np.array(senss) == 1)
299 """
300 modelstr = ""
301 _name_collision_varkeys = None
302 table_titles = {"sweepvariables": "Swept Variables",
303 "freevariables": "Free Variables",
304 "constants": "Fixed Variables", # TODO: change everywhere
305 "variables": "Variables"}
307 def name_collision_varkeys(self):
308 "Returns the set of contained varkeys whose names are not unique"
309 if self._name_collision_varkeys is None:
310 self["variables"].update_keymap()
311 keymap = self["variables"].keymap
312 self._name_collision_varkeys = set()
313 for key in list(keymap):
314 if hasattr(key, "key"):
315 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
316 self._name_collision_varkeys.add(key)
317 return self._name_collision_varkeys
319 def __len__(self):
320 try:
321 return len(self["cost"])
322 except TypeError:
323 return 1
324 except KeyError:
325 return 0
327 def __call__(self, posy):
328 posy_subbed = self.subinto(posy)
329 return getattr(posy_subbed, "c", posy_subbed)
331 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
332 "Checks for almost-equality between two solutions"
333 svars, ovars = self["variables"], other["variables"]
334 svks, ovks = set(svars), set(ovars)
335 if svks != ovks:
336 return False
337 for key in svks:
338 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
339 return False
340 if abs(self["sensitivities"]["variables"][key]
341 - other["sensitivities"]["variables"][key]) >= sens_abstol:
342 return False
343 return True
345 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
346 def diff(self, other, showvars=None, *,
347 constraintsdiff=True, senssdiff=False, sensstol=0.1,
348 absdiff=False, abstol=0, reldiff=True, reltol=1.0, **tableargs):
349 """Outputs differences between this solution and another
351 Arguments
352 ---------
353 other : solution or string
354 strings will be treated as paths to pickled solutions
355 senssdiff : boolean
356 if True, show sensitivity differences
357 sensstol : float
358 the smallest sensitivity difference worth showing
359 abssdiff : boolean
360 if True, show absolute differences
361 absstol : float
362 the smallest absolute difference worth showing
363 reldiff : boolean
364 if True, show relative differences
365 reltol : float
366 the smallest relative difference worth showing
368 Returns
369 -------
370 str
371 """
372 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
373 "skipifempty": False})
374 if isinstance(other, Strings):
375 if other[-4:] == ".pgz":
376 other = SolutionArray.decompress_file(other)
377 else:
378 other = pickle.load(open(other, "rb"))
379 svars, ovars = self["variables"], other["variables"]
380 lines = ["Solution Diff",
381 "=============",
382 "(argument is the baseline solution)", ""]
383 svks, ovks = set(svars), set(ovars)
384 if showvars:
385 lines[0] += " (for selected variables)"
386 lines[1] += "========================="
387 showvars = self._parse_showvars(showvars)
388 svks = {k for k in showvars if k in svars}
389 ovks = {k for k in showvars if k in ovars}
390 if constraintsdiff and other.modelstr and self.modelstr:
391 if self.modelstr == other.modelstr:
392 lines += ["** no constraint differences **", ""]
393 else:
394 cdiff = ["Constraint Differences",
395 "**********************"]
396 cdiff.extend(list(difflib.unified_diff(
397 other.modelstr.split("\n"), self.modelstr.split("\n"),
398 lineterm="", n=3))[2:])
399 cdiff += ["", "**********************", ""]
400 lines += cdiff
401 if svks - ovks:
402 lines.append("Variable(s) of this solution"
403 " which are not in the argument:")
404 lines.append("\n".join(" %s" % key for key in svks - ovks))
405 lines.append("")
406 if ovks - svks:
407 lines.append("Variable(s) of the argument"
408 " which are not in this solution:")
409 lines.append("\n".join(" %s" % key for key in ovks - svks))
410 lines.append("")
411 sharedvks = svks.intersection(ovks)
412 if reldiff:
413 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
414 for vk in sharedvks}
415 lines += var_table(rel_diff,
416 "Relative Differences |above %g%%|" % reltol,
417 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
418 minval=reltol, printunits=False, **tableargs)
419 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
420 lines.insert(-1, ("The largest is %+g%%."
421 % unrolled_absmax(rel_diff.values())))
422 if absdiff:
423 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
424 lines += var_table(abs_diff,
425 "Absolute Differences |above %g|" % abstol,
426 valfmt="%+.2g", vecfmt="%+8.2g",
427 minval=abstol, **tableargs)
428 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
429 lines.insert(-1, ("The largest is %+g."
430 % unrolled_absmax(abs_diff.values())))
431 if senssdiff:
432 ssenss = self["sensitivities"]["variables"]
433 osenss = other["sensitivities"]["variables"]
434 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
435 for vk in svks.intersection(ovks)}
436 lines += var_table(senss_delta,
437 "Sensitivity Differences |above %g|" % sensstol,
438 valfmt="%+-.2f ", vecfmt="%+-6.2f",
439 minval=sensstol, printunits=False, **tableargs)
440 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
441 lines.insert(-1, ("The largest is %+g."
442 % unrolled_absmax(senss_delta.values())))
443 return "\n".join(lines)
445 def save(self, filename="solution.pkl",
446 *, saveconstraints=True, **pickleargs):
447 """Pickles the solution and saves it to a file.
449 Solution can then be loaded with e.g.:
450 >>> import pickle
451 >>> pickle.load(open("solution.pkl"))
452 """
453 with SolSavingEnvironment(self, saveconstraints):
454 pickle.dump(self, open(filename, "wb"), **pickleargs)
456 def save_compressed(self, filename="solution.pgz",
457 *, saveconstraints=True, **cpickleargs):
458 "Pickle a file and then compress it into a file with extension."
459 with gzip.open(filename, "wb") as f:
460 with SolSavingEnvironment(self, saveconstraints):
461 pickled = pickle.dumps(self, **cpickleargs)
462 f.write(pickletools.optimize(pickled))
464 @staticmethod
465 def decompress_file(file):
466 "Load a gzip-compressed pickle file"
467 with gzip.open(file, "rb") as f:
468 return pickle.Unpickler(f).load()
470 def varnames(self, showvars, exclude):
471 "Returns list of variables, optionally with minimal unique names"
472 if showvars:
473 showvars = self._parse_showvars(showvars)
474 for key in self.name_collision_varkeys():
475 key.descr["necessarylineage"] = True
476 names = {}
477 for key in showvars or self["variables"]:
478 for k in self["variables"].keymap[key]:
479 names[k.str_without(exclude)] = k
480 for key in self.name_collision_varkeys():
481 del key.descr["necessarylineage"]
482 return names
484 def savemat(self, filename="solution.mat", showvars=None,
485 excluded=("unnecessary lineage", "vec")):
486 "Saves primal solution as matlab file"
487 from scipy.io import savemat
488 savemat(filename,
489 {name.replace(".", "_"): np.array(self["variables"][key], "f")
490 for name, key in self.varnames(showvars, excluded).items()})
492 def todataframe(self, showvars=None,
493 excluded=("unnecessary lineage", "vec")):
494 "Returns primal solution as pandas dataframe"
495 import pandas as pd # pylint:disable=import-error
496 rows = []
497 cols = ["Name", "Index", "Value", "Units", "Label",
498 "Lineage", "Other"]
499 for _, key in sorted(self.varnames(showvars, excluded).items(),
500 key=lambda k: k[0]):
501 value = self["variables"][key]
502 if key.shape:
503 idxs = []
504 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
505 while not it.finished:
506 idx = it.multi_index
507 idxs.append(idx[0] if len(idx) == 1 else idx)
508 it.iternext()
509 else:
510 idxs = [None]
511 for idx in idxs:
512 row = [
513 key.name,
514 "" if idx is None else idx,
515 value if idx is None else value[idx]]
516 rows.append(row)
517 row.extend([
518 key.unitstr(),
519 key.label or "",
520 key.lineage or "",
521 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
522 if k not in ["name", "units", "unitrepr",
523 "idx", "shape", "veckey",
524 "value", "original_fn",
525 "lineage", "label"])])
526 return pd.DataFrame(rows, columns=cols)
528 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs):
529 "Saves solution table as a text file"
530 with open(filename, "w") as f:
531 if printmodel:
532 f.write(self.modelstr + "\n")
533 f.write(self.table(**kwargs))
535 def savecsv(self, showvars=None, filename="solution.csv", valcols=5):
536 "Saves primal solution as a CSV sorted by modelname, like the tables."
537 data = self["variables"]
538 if showvars:
539 showvars = self._parse_showvars(showvars)
540 data = {k: data[k] for k in showvars if k in data}
541 # if the columns don't capture any dimensions, skip them
542 minspan, maxspan = None, 1
543 for v in data.values():
544 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
545 minspan_ = min((di for di in v.shape if di != 1))
546 maxspan_ = max((di for di in v.shape if di != 1))
547 if minspan is None or minspan_ < minspan:
548 minspan = minspan_
549 if maxspan is None or maxspan_ > maxspan:
550 maxspan = maxspan_
551 if minspan is not None and minspan > valcols:
552 valcols = 1
553 if maxspan < valcols:
554 valcols = maxspan
555 lines = var_table(data, "", rawlines=True, maxcolumns=valcols)
556 with open(filename, "w") as f:
557 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
558 + "Units,Description\n")
559 for line in lines:
560 if line[0] == ("modelname",):
561 f.write(line[1])
562 elif not line[1]: # spacer line
563 f.write("\n")
564 else:
565 f.write("," + line[0].replace(" : ", "") + ",")
566 vals = line[1].replace("[", "").replace("]", "").strip()
567 for el in vals.split():
568 f.write(el + ",")
569 f.write(","*(valcols - len(vals.split())))
570 f.write((line[2].replace("[", "").replace("]", "").strip()
571 + ","))
572 f.write(line[3].strip() + "\n")
574 def subinto(self, posy):
575 "Returns NomialArray of each solution substituted into posy."
576 if posy in self["variables"]:
577 return self["variables"](posy)
579 if not hasattr(posy, "sub"):
580 raise ValueError("no variable '%s' found in the solution" % posy)
582 if len(self) > 1:
583 return NomialArray([self.atindex(i).subinto(posy)
584 for i in range(len(self))])
586 return posy.sub(self["variables"])
588 def _parse_showvars(self, showvars):
589 showvars_out = set()
590 for k in showvars:
591 k, _ = self["variables"].parse_and_index(k)
592 keys = self["variables"].keymap[k]
593 showvars_out.update(keys)
594 return showvars_out
596 def summary(self, showvars=(), ntopsenss=5, **kwargs):
597 "Print summary table, showing top sensitivities and no constants"
598 showvars = self._parse_showvars(showvars)
599 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
600 "freevariables"], **kwargs)
601 constants_in_showvars = showvars.intersection(self["constants"])
602 senss_tables = []
603 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
604 senss_tables.append("sensitivities")
605 if len(self["constants"]) >= ntopsenss+2:
606 senss_tables.append("top sensitivities")
607 senss_tables.append("tightest constraints")
608 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
609 **kwargs)
610 if senss_str:
611 out += "\n" + senss_str
612 return out
614 def table(self, showvars=(),
615 tables=("cost", "warnings", "sweepvariables", "freevariables",
616 "constants", "sensitivities", "tightest constraints"),
617 **kwargs):
618 """A table representation of this SolutionArray
620 Arguments
621 ---------
622 tables: Iterable
623 Which to print of ("cost", "sweepvariables", "freevariables",
624 "constants", "sensitivities")
625 fixedcols: If true, print vectors in fixed-width format
626 latex: int
627 If > 0, return latex format (options 1-3); otherwise plain text
628 included_models: Iterable of strings
629 If specified, the models (by name) to include
630 excluded_models: Iterable of strings
631 If specified, model names to exclude
633 Returns
634 -------
635 str
636 """
637 varlist = list(self["variables"])
638 has_only_one_model = True
639 for var in varlist[1:]:
640 if var.lineage != varlist[0].lineage:
641 has_only_one_model = False
642 break
643 if has_only_one_model:
644 kwargs["sortbymodel"] = False
645 for key in self.name_collision_varkeys():
646 key.descr["necessarylineage"] = True
647 showvars = self._parse_showvars(showvars)
648 strs = []
649 for table in tables:
650 if table == "cost":
651 cost = self["cost"] # pylint: disable=unsubscriptable-object
652 if kwargs.get("latex", None): # cost is not printed for latex
653 continue
654 strs += ["\n%s\n------------" % "Optimal Cost"]
655 if len(self) > 1:
656 costs = ["%-8.3g" % c for c in mag(cost[:4])]
657 strs += [" [ %s %s ]" % (" ".join(costs),
658 "..." if len(self) > 4 else "")]
659 else:
660 strs += [" %-.4g" % mag(cost)]
661 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
662 strs += [""]
663 elif table in TABLEFNS:
664 strs += TABLEFNS[table](self, showvars, **kwargs)
665 elif table in self:
666 data = self[table]
667 if showvars:
668 showvars = self._parse_showvars(showvars)
669 data = {k: data[k] for k in showvars if k in data}
670 strs += var_table(data, self.table_titles[table], **kwargs)
671 if kwargs.get("latex", None):
672 preamble = "\n".join(("% \\documentclass[12pt]{article}",
673 "% \\usepackage{booktabs}",
674 "% \\usepackage{longtable}",
675 "% \\usepackage{amsmath}",
676 "% \\begin{document}\n"))
677 strs = [preamble] + strs + ["% \\end{document}"]
678 for key in self.name_collision_varkeys():
679 del key.descr["necessarylineage"]
680 return "\n".join(strs)
682 def plot(self, posys=None, axes=None):
683 "Plots a sweep for each posy"
684 if len(self["sweepvariables"]) != 1:
685 print("SolutionArray.plot only supports 1-dimensional sweeps")
686 if not hasattr(posys, "__len__"):
687 posys = [posys]
688 import matplotlib.pyplot as plt
689 from .interactive.plot_sweep import assign_axes
690 from . import GPBLU
691 (swept, x), = self["sweepvariables"].items()
692 posys, axes = assign_axes(swept, posys, axes)
693 for posy, ax in zip(posys, axes):
694 y = self(posy) if posy not in [None, "cost"] else self["cost"]
695 ax.plot(x, y, color=GPBLU)
696 if len(axes) == 1:
697 axes, = axes
698 return plt.gcf(), axes
701# pylint: disable=too-many-branches,too-many-locals,too-many-statements
702def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
703 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
704 minval=0, sortbyvals=False, hidebelowminval=False,
705 included_models=None, excluded_models=None, sortbymodel=True,
706 maxcolumns=5, skipifempty=True, **_):
707 """
708 Pretty string representation of a dict of VarKeys
709 Iterable values are handled specially (partial printing)
711 Arguments
712 ---------
713 data : dict whose keys are VarKey's
714 data to represent in table
715 title : string
716 printunits : bool
717 latex : int
718 If > 0, return latex format (options 1-3); otherwise plain text
719 varfmt : string
720 format for variable names
721 valfmt : string
722 format for scalar values
723 vecfmt : string
724 format for vector values
725 minval : float
726 skip values with all(abs(value)) < minval
727 sortbyvals : boolean
728 If true, rows are sorted by their average value instead of by name.
729 included_models : Iterable of strings
730 If specified, the models (by name) to include
731 excluded_models : Iterable of strings
732 If specified, model names to exclude
733 """
734 if not data:
735 return []
736 decorated, models = [], set()
737 for i, (k, v) in enumerate(data.items()):
738 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
739 continue # no values below minval
740 if minval and hidebelowminval and getattr(v, "shape", None):
741 v[np.abs(v) <= minval] = np.nan
742 model = lineagestr(k.lineage) if sortbymodel else ""
743 models.add(model)
744 b = bool(getattr(v, "shape", None))
745 s = k.str_without(("lineage", "vec"))
746 if not sortbyvals:
747 decorated.append((model, b, (varfmt % s), i, k, v))
748 else: # for consistent sorting, add small offset to negative vals
749 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
750 sort = (float("%.4g" % -val), k.name)
751 decorated.append((model, sort, b, (varfmt % s), i, k, v))
752 if not decorated and skipifempty:
753 return []
754 if included_models:
755 included_models = set(included_models)
756 included_models.add("")
757 models = models.intersection(included_models)
758 if excluded_models:
759 models = models.difference(excluded_models)
760 decorated.sort()
761 previous_model, lines = None, []
762 for varlist in decorated:
763 if not sortbyvals:
764 model, isvector, varstr, _, var, val = varlist
765 else:
766 model, _, isvector, varstr, _, var, val = varlist
767 if model not in models:
768 continue
769 if model != previous_model:
770 if lines:
771 lines.append(["", "", "", ""])
772 if model:
773 if not latex:
774 lines.append([("modelname",), model, "", ""])
775 else:
776 lines.append(
777 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
778 previous_model = model
779 label = var.descr.get("label", "")
780 units = var.unitstr(" [%s] ") if printunits else ""
781 if not isvector:
782 valstr = valfmt % val
783 else:
784 last_dim_index = len(val.shape)-1
785 horiz_dim, ncols = last_dim_index, 1 # starting values
786 for dim_idx, dim_size in enumerate(val.shape):
787 if ncols <= dim_size <= maxcolumns:
788 horiz_dim, ncols = dim_idx, dim_size
789 # align the array with horiz_dim by making it the last one
790 dim_order = list(range(last_dim_index))
791 dim_order.insert(horiz_dim, last_dim_index)
792 flatval = val.transpose(dim_order).flatten()
793 vals = [vecfmt % v for v in flatval[:ncols]]
794 bracket = " ] " if len(flatval) <= ncols else ""
795 valstr = "[ %s%s" % (" ".join(vals), bracket)
796 for before, after in VALSTR_REPLACES:
797 valstr = valstr.replace(before, after)
798 if not latex:
799 lines.append([varstr, valstr, units, label])
800 if isvector and len(flatval) > ncols:
801 values_remaining = len(flatval) - ncols
802 while values_remaining > 0:
803 idx = len(flatval)-values_remaining
804 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
805 values_remaining -= ncols
806 valstr = " " + " ".join(vals)
807 for before, after in VALSTR_REPLACES:
808 valstr = valstr.replace(before, after)
809 if values_remaining <= 0:
810 spaces = (-values_remaining
811 * len(valstr)//(values_remaining + ncols))
812 valstr = valstr + " ]" + " "*spaces
813 lines.append(["", valstr, "", ""])
814 else:
815 varstr = "$%s$" % varstr.replace(" : ", "")
816 if latex == 1: # normal results table
817 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
818 label])
819 coltitles = [title, "Value", "Units", "Description"]
820 elif latex == 2: # no values
821 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
822 coltitles = [title, "Units", "Description"]
823 elif latex == 3: # no description
824 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
825 coltitles = [title, "Value", "Units"]
826 else:
827 raise ValueError("Unexpected latex option, %s." % latex)
828 if rawlines:
829 return lines
830 if not latex:
831 if lines:
832 maxlens = np.max([list(map(len, line)) for line in lines
833 if line[0] != ("modelname",)], axis=0)
834 dirs = [">", "<", "<", "<"]
835 # check lengths before using zip
836 assert len(list(dirs)) == len(list(maxlens))
837 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
838 for i, line in enumerate(lines):
839 if line[0] == ("modelname",):
840 line = [fmts[0].format(" | "), line[1]]
841 else:
842 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
843 lines[i] = "".join(line).rstrip()
844 lines = [title] + ["-"*len(title)] + lines + [""]
845 else:
846 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
847 lines = (["\n".join(["{\\footnotesize",
848 "\\begin{longtable}{%s}" % colfmt[latex],
849 "\\toprule",
850 " & ".join(coltitles) + " \\\\ \\midrule"])] +
851 [" & ".join(l) + " \\\\" for l in lines] +
852 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
853 return lines