Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from .nomials import NomialArray
11from .small_classes import DictOfLists, Strings
12from .small_scripts import mag, try_str_without
13from .repr_conventions import unitstr, lineagestr
16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
18VALSTR_REPLACES = [
19 ("+nan", " nan"),
20 ("-nan", " nan"),
21 ("nan%", "nan "),
22 ("nan", " - "),
23]
26class SolSavingEnvironment:
27 """Temporarily removes construction/solve attributes from constraints.
29 This approximately halves the size of the pickled solution.
30 """
32 def __init__(self, solarray, saveconstraints):
33 self.solarray = solarray
34 self.attrstore = {}
35 self.saveconstraints = saveconstraints
36 self.constraintstore = None
39 def __enter__(self):
40 if self.saveconstraints:
41 for constraint_attr in ["bounded", "meq_bounded", "vks",
42 "v_ss", "unsubbed", "varkeys"]:
43 store = {}
44 for constraint in self.solarray["sensitivities"]["constraints"]:
45 if getattr(constraint, constraint_attr, None):
46 store[constraint] = getattr(constraint, constraint_attr)
47 delattr(constraint, constraint_attr)
48 self.attrstore[constraint_attr] = store
49 else:
50 self.constraintstore = \
51 self.solarray["sensitivities"].pop("constraints")
53 def __exit__(self, type_, val, traceback):
54 if self.saveconstraints:
55 for constraint_attr, store in self.attrstore.items():
56 for constraint, value in store.items():
57 setattr(constraint, constraint_attr, value)
58 else:
59 self.solarray["sensitivities"]["constraints"] = self.constraintstore
61def msenss_table(data, _, **kwargs):
62 "Returns model sensitivity table lines"
63 if "models" not in data.get("sensitivities", {}):
64 return ""
65 data = sorted(data["sensitivities"]["models"].items(),
66 key=lambda i: -np.mean(i[1]))
67 lines = ["Model Sensitivities", "-------------------"]
68 if kwargs["sortmodelsbysenss"]:
69 lines[0] += " (sorts models in sections below)"
70 previousmsenssstr = ""
71 for model, msenss in data:
72 if not model: # for now let's only do named models
73 continue
74 if (msenss < 0.1).all():
75 msenss = np.max(msenss)
76 if msenss:
77 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss))
78 else:
79 msenssstr = " =0 "
80 elif not msenss.shape:
81 msenssstr = "%+6.1f" % msenss
82 else:
83 meansenss = np.mean(msenss)
84 msenssstr = "%+6.1f" % meansenss
85 deltas = msenss - meansenss
86 if np.max(np.abs(deltas)) > 0.1:
87 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
88 for d in deltas]
89 msenssstr += " + [ %s ]" % " ".join(deltastrs)
90 if msenssstr == previousmsenssstr:
91 msenssstr = " "
92 else:
93 previousmsenssstr = msenssstr
94 lines.append("%s : %s" % (msenssstr, model))
95 return lines + [""] if len(lines) > 3 else []
98def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
99 "Returns sensitivity table lines"
100 if "variables" in data.get("sensitivities", {}):
101 data = data["sensitivities"]["variables"]
102 if showvars:
103 data = {k: data[k] for k in showvars if k in data}
104 return var_table(data, title, sortbyvals=True, skipifempty=True,
105 valfmt="%+-.2g ", vecfmt="%+-8.2g",
106 printunits=False, minval=1e-3, **kwargs)
109def topsenss_table(data, showvars, nvars=5, **kwargs):
110 "Returns top sensitivity table lines"
111 data, filtered = topsenss_filter(data, showvars, nvars)
112 title = "Most Sensitive Variables"
113 if filtered:
114 title = "Next Most Sensitive Variables"
115 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
118def topsenss_filter(data, showvars, nvars=5):
119 "Filters sensitivities down to top N vars"
120 if "variables" in data.get("sensitivities", {}):
121 data = data["sensitivities"]["variables"]
122 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
123 if not np.isnan(s).any()}
124 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
125 filter_already_shown = showvars.intersection(topk)
126 for k in filter_already_shown:
127 topk.remove(k)
128 if nvars > 3: # always show at least 3
129 nvars -= 1
130 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
133def insenss_table(data, _, maxval=0.1, **kwargs):
134 "Returns insensitivity table lines"
135 if "constants" in data.get("sensitivities", {}):
136 data = data["sensitivities"]["variables"]
137 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
138 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
141def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
142 "Return constraint tightness lines"
143 title = "Most Sensitive Constraints"
144 if len(self) > 1:
145 title += " (in last sweep)"
146 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
147 "%+6.2g" % s[-1], id(c), c)
148 for c, s in self["sensitivities"]["constraints"].items()
149 if s[-1] >= tight_senss)[:ntightconstrs]
150 else:
151 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
152 for c, s in self["sensitivities"]["constraints"].items()
153 if s >= tight_senss)[:ntightconstrs]
154 return constraint_table(data, title, **kwargs)
156def loose_table(self, _, min_senss=1e-5, **kwargs):
157 "Return constraint tightness lines"
158 title = "Insensitive Constraints |below %+g|" % min_senss
159 if len(self) > 1:
160 title += " (in last sweep)"
161 data = [(0, "", id(c), c)
162 for c, s in self["sensitivities"]["constraints"].items()
163 if s[-1] <= min_senss]
164 else:
165 data = [(0, "", id(c), c)
166 for c, s in self["sensitivities"]["constraints"].items()
167 if s <= min_senss]
168 return constraint_table(data, title, **kwargs)
171# pylint: disable=too-many-branches,too-many-locals,too-many-statements
172def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
173 "Creates lines for tables where the right side is a constraint."
174 # TODO: this should support 1D array inputs from sweeps
175 excluded = ("units", "unnecessary lineage")
176 if not showmodels:
177 excluded = ("units", "lineage") # hide all of it
178 models, decorated = {}, []
179 for sortby, openingstr, _, constraint in sorted(data):
180 model = lineagestr(constraint) if sortbymodel else ""
181 if model not in models:
182 models[model] = len(models)
183 constrstr = try_str_without(constraint, excluded)
184 if " at 0x" in constrstr: # don't print memory addresses
185 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
186 decorated.append((models[model], model, sortby, constrstr, openingstr))
187 decorated.sort()
188 previous_model, lines = None, []
189 for varlist in decorated:
190 _, model, _, constrstr, openingstr = varlist
191 if model != previous_model:
192 if lines:
193 lines.append(["", ""])
194 if model or lines:
195 lines.append([("newmodelline",), model])
196 previous_model = model
197 constrstr = constrstr.replace(model, "")
198 minlen, maxlen = 25, 80
199 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
200 constraintlines = []
201 line = ""
202 next_idx = 0
203 while next_idx < len(segments):
204 segment = segments[next_idx]
205 next_idx += 1
206 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
207 segments[next_idx] = segment[1:] + segments[next_idx]
208 segment = segment[0]
209 elif len(line) + len(segment) > maxlen and len(line) > minlen:
210 constraintlines.append(line)
211 line = " " # start a new line
212 line += segment
213 while len(line) > maxlen:
214 constraintlines.append(line[:maxlen])
215 line = " " + line[maxlen:]
216 constraintlines.append(line)
217 lines += [(openingstr + " : ", constraintlines[0])]
218 lines += [("", l) for l in constraintlines[1:]]
219 if not lines:
220 lines = [("", "(none)")]
221 maxlens = np.max([list(map(len, line)) for line in lines
222 if line[0] != ("newmodelline",)], axis=0)
223 dirs = [">", "<"] # we'll check lengths before using zip
224 assert len(list(dirs)) == len(list(maxlens))
225 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
226 for i, line in enumerate(lines):
227 if line[0] == ("newmodelline",):
228 linelist = [fmts[0].format(" | "), line[1]]
229 else:
230 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
231 lines[i] = "".join(linelist).rstrip()
232 return [title] + ["-"*len(title)] + lines + [""]
235def warnings_table(self, _, **kwargs):
236 "Makes a table for all warnings in the solution."
237 title = "WARNINGS"
238 lines = ["~"*len(title), title, "~"*len(title)]
239 if "warnings" not in self or not self["warnings"]:
240 return []
241 for wtype in sorted(self["warnings"]):
242 data_vec = self["warnings"][wtype]
243 if len(data_vec) == 0:
244 continue
245 if not hasattr(data_vec, "shape"):
246 data_vec = [data_vec]
247 for i, data in enumerate(data_vec):
248 if len(data) == 0:
249 continue
250 data = sorted(data, key=lambda l: l[0]) # sort by msg
251 title = wtype
252 if len(data_vec) > 1:
253 title += " in sweep %i" % i
254 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
255 data = [(-int(1e5*c.relax_sensitivity),
256 "%+6.2g" % c.relax_sensitivity, id(c), c)
257 for _, c in data]
258 lines += constraint_table(data, title, **kwargs)
259 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
260 data = [(-int(1e5*c.rel_diff),
261 "%.4g %s %.4g" % c.tightvalues, id(c), c)
262 for _, c in data]
263 lines += constraint_table(data, title, **kwargs)
264 else:
265 lines += [title] + ["-"*len(wtype)]
266 lines += [msg for msg, _ in data] + [""]
267 lines[-1] = "~~~~~~~~"
268 return lines + [""]
271TABLEFNS = {"sensitivities": senss_table,
272 "top sensitivities": topsenss_table,
273 "insensitivities": insenss_table,
274 "model sensitivities": msenss_table,
275 "tightest constraints": tight_table,
276 "loose constraints": loose_table,
277 "warnings": warnings_table,
278 }
280def unrolled_absmax(values):
281 "From an iterable of numbers and arrays, returns the largest magnitude"
282 finalval, absmaxest = None, 0
283 for val in values:
284 absmaxval = np.abs(val).max()
285 if absmaxval >= absmaxest:
286 absmaxest, finalval = absmaxval, val
287 if getattr(finalval, "shape", None):
288 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
289 finalval.shape)]
290 return finalval
293def cast(function, val1, val2):
294 "Relative difference between val1 and val2 (positive if val2 is larger)"
295 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
296 pywarnings.simplefilter("ignore")
297 if hasattr(val1, "shape") and hasattr(val2, "shape"):
298 if val1.ndim == val2.ndim:
299 return function(val1, val2)
300 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
301 dimdelta = dimmest.ndim - lessdim.ndim
302 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
303 if dimmest is val1:
304 return function(dimmest, lessdim[add_axes])
305 if dimmest is val2:
306 return function(lessdim[add_axes], dimmest)
307 return function(val1, val2)
310class SolutionArray(DictOfLists):
311 """A dictionary (of dictionaries) of lists, with convenience methods.
313 Items
314 -----
315 cost : array
316 variables: dict of arrays
317 sensitivities: dict containing:
318 monomials : array
319 posynomials : array
320 variables: dict of arrays
321 localmodels : NomialArray
322 Local power-law fits (small sensitivities are cut off)
324 Example
325 -------
326 >>> import gpkit
327 >>> import numpy as np
328 >>> x = gpkit.Variable("x")
329 >>> x_min = gpkit.Variable("x_{min}", 2)
330 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
331 >>>
332 >>> # VALUES
333 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
334 >>> assert all(np.array(values) == 2)
335 >>>
336 >>> # SENSITIVITIES
337 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
338 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
339 >>> assert all(np.array(senss) == 1)
340 """
341 modelstr = ""
342 _name_collision_varkeys = None
343 table_titles = {"sweepvariables": "Swept Variables",
344 "freevariables": "Free Variables",
345 "constants": "Fixed Variables", # TODO: change everywhere
346 "variables": "Variables"}
348 def name_collision_varkeys(self):
349 "Returns the set of contained varkeys whose names are not unique"
350 if self._name_collision_varkeys is None:
351 self["variables"].update_keymap()
352 keymap = self["variables"].keymap
353 self._name_collision_varkeys = set()
354 for key in list(keymap):
355 if hasattr(key, "key"):
356 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
357 self._name_collision_varkeys.add(key)
358 return self._name_collision_varkeys
360 def __len__(self):
361 try:
362 return len(self["cost"])
363 except TypeError:
364 return 1
365 except KeyError:
366 return 0
368 def __call__(self, posy):
369 posy_subbed = self.subinto(posy)
370 return getattr(posy_subbed, "c", posy_subbed)
372 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
373 "Checks for almost-equality between two solutions"
374 svars, ovars = self["variables"], other["variables"]
375 svks, ovks = set(svars), set(ovars)
376 if svks != ovks:
377 return False
378 for key in svks:
379 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
380 return False
381 if abs(self["sensitivities"]["variables"][key]
382 - other["sensitivities"]["variables"][key]) >= sens_abstol:
383 return False
384 return True
386 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
387 def diff(self, other, showvars=None, *,
388 constraintsdiff=True, senssdiff=False, sensstol=0.1,
389 absdiff=False, abstol=0, reldiff=True, reltol=1.0,
390 sortmodelsbysenss=True, **tableargs):
391 """Outputs differences between this solution and another
393 Arguments
394 ---------
395 other : solution or string
396 strings will be treated as paths to pickled solutions
397 senssdiff : boolean
398 if True, show sensitivity differences
399 sensstol : float
400 the smallest sensitivity difference worth showing
401 abssdiff : boolean
402 if True, show absolute differences
403 absstol : float
404 the smallest absolute difference worth showing
405 reldiff : boolean
406 if True, show relative differences
407 reltol : float
408 the smallest relative difference worth showing
410 Returns
411 -------
412 str
413 """
414 if sortmodelsbysenss:
415 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
416 else:
417 tableargs["sortmodelsbysenss"] = False
418 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
419 "skipifempty": False})
420 if isinstance(other, Strings):
421 if other[-4:] == ".pgz":
422 other = SolutionArray.decompress_file(other)
423 else:
424 other = pickle.load(open(other, "rb"))
425 svars, ovars = self["variables"], other["variables"]
426 lines = ["Solution Diff",
427 "=============",
428 "(argument is the baseline solution)", ""]
429 svks, ovks = set(svars), set(ovars)
430 if showvars:
431 lines[0] += " (for selected variables)"
432 lines[1] += "========================="
433 showvars = self._parse_showvars(showvars)
434 svks = {k for k in showvars if k in svars}
435 ovks = {k for k in showvars if k in ovars}
436 if constraintsdiff and other.modelstr and self.modelstr:
437 if self.modelstr == other.modelstr:
438 lines += ["** no constraint differences **", ""]
439 else:
440 cdiff = ["Constraint Differences",
441 "**********************"]
442 cdiff.extend(list(difflib.unified_diff(
443 other.modelstr.split("\n"), self.modelstr.split("\n"),
444 lineterm="", n=3))[2:])
445 cdiff += ["", "**********************", ""]
446 lines += cdiff
447 if svks - ovks:
448 lines.append("Variable(s) of this solution"
449 " which are not in the argument:")
450 lines.append("\n".join(" %s" % key for key in svks - ovks))
451 lines.append("")
452 if ovks - svks:
453 lines.append("Variable(s) of the argument"
454 " which are not in this solution:")
455 lines.append("\n".join(" %s" % key for key in ovks - svks))
456 lines.append("")
457 sharedvks = svks.intersection(ovks)
458 if reldiff:
459 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
460 for vk in sharedvks}
461 lines += var_table(rel_diff,
462 "Relative Differences |above %g%%|" % reltol,
463 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
464 minval=reltol, printunits=False, **tableargs)
465 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
466 lines.insert(-1, ("The largest is %+g%%."
467 % unrolled_absmax(rel_diff.values())))
468 if absdiff:
469 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
470 lines += var_table(abs_diff,
471 "Absolute Differences |above %g|" % abstol,
472 valfmt="%+.2g", vecfmt="%+8.2g",
473 minval=abstol, **tableargs)
474 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
475 lines.insert(-1, ("The largest is %+g."
476 % unrolled_absmax(abs_diff.values())))
477 if senssdiff:
478 ssenss = self["sensitivities"]["variables"]
479 osenss = other["sensitivities"]["variables"]
480 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
481 for vk in svks.intersection(ovks)}
482 lines += var_table(senss_delta,
483 "Sensitivity Differences |above %g|" % sensstol,
484 valfmt="%+-.2f ", vecfmt="%+-6.2f",
485 minval=sensstol, printunits=False, **tableargs)
486 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
487 lines.insert(-1, ("The largest is %+g."
488 % unrolled_absmax(senss_delta.values())))
489 return "\n".join(lines)
491 def save(self, filename="solution.pkl",
492 *, saveconstraints=True, **pickleargs):
493 """Pickles the solution and saves it to a file.
495 Solution can then be loaded with e.g.:
496 >>> import pickle
497 >>> pickle.load(open("solution.pkl"))
498 """
499 with SolSavingEnvironment(self, saveconstraints):
500 pickle.dump(self, open(filename, "wb"), **pickleargs)
502 def save_compressed(self, filename="solution.pgz",
503 *, saveconstraints=True, **cpickleargs):
504 "Pickle a file and then compress it into a file with extension."
505 with gzip.open(filename, "wb") as f:
506 with SolSavingEnvironment(self, saveconstraints):
507 pickled = pickle.dumps(self, **cpickleargs)
508 f.write(pickletools.optimize(pickled))
510 @staticmethod
511 def decompress_file(file):
512 "Load a gzip-compressed pickle file"
513 with gzip.open(file, "rb") as f:
514 return pickle.Unpickler(f).load()
516 def varnames(self, showvars, exclude):
517 "Returns list of variables, optionally with minimal unique names"
518 if showvars:
519 showvars = self._parse_showvars(showvars)
520 for key in self.name_collision_varkeys():
521 key.descr["necessarylineage"] = True
522 names = {}
523 for key in showvars or self["variables"]:
524 for k in self["variables"].keymap[key]:
525 names[k.str_without(exclude)] = k
526 for key in self.name_collision_varkeys():
527 del key.descr["necessarylineage"]
528 return names
530 def savemat(self, filename="solution.mat", showvars=None,
531 excluded=("unnecessary lineage", "vec")):
532 "Saves primal solution as matlab file"
533 from scipy.io import savemat
534 savemat(filename,
535 {name.replace(".", "_"): np.array(self["variables"][key], "f")
536 for name, key in self.varnames(showvars, excluded).items()})
538 def todataframe(self, showvars=None,
539 excluded=("unnecessary lineage", "vec")):
540 "Returns primal solution as pandas dataframe"
541 import pandas as pd # pylint:disable=import-error
542 rows = []
543 cols = ["Name", "Index", "Value", "Units", "Label",
544 "Lineage", "Other"]
545 for _, key in sorted(self.varnames(showvars, excluded).items(),
546 key=lambda k: k[0]):
547 value = self["variables"][key]
548 if key.shape:
549 idxs = []
550 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
551 while not it.finished:
552 idx = it.multi_index
553 idxs.append(idx[0] if len(idx) == 1 else idx)
554 it.iternext()
555 else:
556 idxs = [None]
557 for idx in idxs:
558 row = [
559 key.name,
560 "" if idx is None else idx,
561 value if idx is None else value[idx]]
562 rows.append(row)
563 row.extend([
564 key.unitstr(),
565 key.label or "",
566 key.lineage or "",
567 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
568 if k not in ["name", "units", "unitrepr",
569 "idx", "shape", "veckey",
570 "value", "original_fn",
571 "lineage", "label"])])
572 return pd.DataFrame(rows, columns=cols)
574 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs):
575 "Saves solution table as a text file"
576 with open(filename, "w") as f:
577 if printmodel:
578 f.write(self.modelstr + "\n")
579 f.write(self.table(**kwargs))
581 def savecsv(self, showvars=None, filename="solution.csv", valcols=5):
582 "Saves primal solution as a CSV sorted by modelname, like the tables."
583 data = self["variables"]
584 if showvars:
585 showvars = self._parse_showvars(showvars)
586 data = {k: data[k] for k in showvars if k in data}
587 # if the columns don't capture any dimensions, skip them
588 minspan, maxspan = None, 1
589 for v in data.values():
590 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
591 minspan_ = min((di for di in v.shape if di != 1))
592 maxspan_ = max((di for di in v.shape if di != 1))
593 if minspan is None or minspan_ < minspan:
594 minspan = minspan_
595 if maxspan is None or maxspan_ > maxspan:
596 maxspan = maxspan_
597 if minspan is not None and minspan > valcols:
598 valcols = 1
599 if maxspan < valcols:
600 valcols = maxspan
601 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
602 tables=("cost", "sweepvariables", "freevariables",
603 "constants", "sensitivities"))
604 with open(filename, "w") as f:
605 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
606 + "Units,Description\n")
607 for line in lines:
608 if line[0] == ("newmodelline",):
609 f.write(line[1])
610 elif not line[1]: # spacer line
611 f.write("\n")
612 else:
613 f.write("," + line[0].replace(" : ", "") + ",")
614 vals = line[1].replace("[", "").replace("]", "").strip()
615 for el in vals.split():
616 f.write(el + ",")
617 f.write(","*(valcols - len(vals.split())))
618 f.write((line[2].replace("[", "").replace("]", "").strip()
619 + ","))
620 f.write(line[3].strip() + "\n")
622 def subinto(self, posy):
623 "Returns NomialArray of each solution substituted into posy."
624 if posy in self["variables"]:
625 return self["variables"](posy)
627 if not hasattr(posy, "sub"):
628 raise ValueError("no variable '%s' found in the solution" % posy)
630 if len(self) > 1:
631 return NomialArray([self.atindex(i).subinto(posy)
632 for i in range(len(self))])
634 return posy.sub(self["variables"])
636 def _parse_showvars(self, showvars):
637 showvars_out = set()
638 for k in showvars:
639 k, _ = self["variables"].parse_and_index(k)
640 keys = self["variables"].keymap[k]
641 showvars_out.update(keys)
642 return showvars_out
644 def summary(self, showvars=(), ntopsenss=5, **kwargs):
645 "Print summary table, showing top sensitivities and no constants"
646 showvars = self._parse_showvars(showvars)
647 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
648 "freevariables"], **kwargs)
649 constants_in_showvars = showvars.intersection(self["constants"])
650 senss_tables = []
651 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
652 senss_tables.append("sensitivities")
653 if len(self["constants"]) >= ntopsenss+2:
654 senss_tables.append("top sensitivities")
655 senss_tables.append("tightest constraints")
656 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
657 **kwargs)
658 if senss_str:
659 out += "\n" + senss_str
660 return out
662 def table(self, showvars=(),
663 tables=("cost", "warnings", "model sensitivities",
664 "sweepvariables", "freevariables",
665 "constants", "sensitivities", "tightest constraints"),
666 sortmodelsbysenss=True, **kwargs):
667 """A table representation of this SolutionArray
669 Arguments
670 ---------
671 tables: Iterable
672 Which to print of ("cost", "sweepvariables", "freevariables",
673 "constants", "sensitivities")
674 fixedcols: If true, print vectors in fixed-width format
675 latex: int
676 If > 0, return latex format (options 1-3); otherwise plain text
677 included_models: Iterable of strings
678 If specified, the models (by name) to include
679 excluded_models: Iterable of strings
680 If specified, model names to exclude
682 Returns
683 -------
684 str
685 """
686 if sortmodelsbysenss:
687 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
688 else:
689 kwargs["sortmodelsbysenss"] = False
690 varlist = list(self["variables"])
691 has_only_one_model = True
692 for var in varlist[1:]:
693 if var.lineage != varlist[0].lineage:
694 has_only_one_model = False
695 break
696 if has_only_one_model:
697 kwargs["sortbymodel"] = False
698 for key in self.name_collision_varkeys():
699 key.descr["necessarylineage"] = True
700 showvars = self._parse_showvars(showvars)
701 strs = []
702 for table in tables:
703 if table == "cost":
704 cost = self["cost"] # pylint: disable=unsubscriptable-object
705 if kwargs.get("latex", None): # cost is not printed for latex
706 continue
707 strs += ["\n%s\n------------" % "Optimal Cost"]
708 if len(self) > 1:
709 costs = ["%-8.3g" % c for c in mag(cost[:4])]
710 strs += [" [ %s %s ]" % (" ".join(costs),
711 "..." if len(self) > 4 else "")]
712 else:
713 strs += [" %-.4g" % mag(cost)]
714 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
715 strs += [""]
716 elif table in TABLEFNS:
717 strs += TABLEFNS[table](self, showvars, **kwargs)
718 elif table in self:
719 data = self[table]
720 if showvars:
721 showvars = self._parse_showvars(showvars)
722 data = {k: data[k] for k in showvars if k in data}
723 strs += var_table(data, self.table_titles[table], **kwargs)
724 if kwargs.get("latex", None):
725 preamble = "\n".join(("% \\documentclass[12pt]{article}",
726 "% \\usepackage{booktabs}",
727 "% \\usepackage{longtable}",
728 "% \\usepackage{amsmath}",
729 "% \\begin{document}\n"))
730 strs = [preamble] + strs + ["% \\end{document}"]
731 for key in self.name_collision_varkeys():
732 del key.descr["necessarylineage"]
733 return "\n".join(strs)
735 def plot(self, posys=None, axes=None):
736 "Plots a sweep for each posy"
737 if len(self["sweepvariables"]) != 1:
738 print("SolutionArray.plot only supports 1-dimensional sweeps")
739 if not hasattr(posys, "__len__"):
740 posys = [posys]
741 import matplotlib.pyplot as plt
742 from .interactive.plot_sweep import assign_axes
743 from . import GPBLU
744 (swept, x), = self["sweepvariables"].items()
745 posys, axes = assign_axes(swept, posys, axes)
746 for posy, ax in zip(posys, axes):
747 y = self(posy) if posy not in [None, "cost"] else self["cost"]
748 ax.plot(x, y, color=GPBLU)
749 if len(axes) == 1:
750 axes, = axes
751 return plt.gcf(), axes
754# pylint: disable=too-many-branches,too-many-locals,too-many-statements
755def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
756 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
757 minval=0, sortbyvals=False, hidebelowminval=False,
758 included_models=None, excluded_models=None, sortbymodel=True,
759 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
760 """
761 Pretty string representation of a dict of VarKeys
762 Iterable values are handled specially (partial printing)
764 Arguments
765 ---------
766 data : dict whose keys are VarKey's
767 data to represent in table
768 title : string
769 printunits : bool
770 latex : int
771 If > 0, return latex format (options 1-3); otherwise plain text
772 varfmt : string
773 format for variable names
774 valfmt : string
775 format for scalar values
776 vecfmt : string
777 format for vector values
778 minval : float
779 skip values with all(abs(value)) < minval
780 sortbyvals : boolean
781 If true, rows are sorted by their average value instead of by name.
782 included_models : Iterable of strings
783 If specified, the models (by name) to include
784 excluded_models : Iterable of strings
785 If specified, model names to exclude
786 """
787 if not data:
788 return []
789 decorated, models = [], set()
790 for i, (k, v) in enumerate(data.items()):
791 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
792 continue # no values below minval
793 if minval and hidebelowminval and getattr(v, "shape", None):
794 v[np.abs(v) <= minval] = np.nan
795 model = lineagestr(k.lineage) if sortbymodel else ""
796 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0
797 if hasattr(msenss, "shape"):
798 msenss = np.mean(msenss)
799 models.add(model)
800 b = bool(getattr(v, "shape", None))
801 s = k.str_without(("lineage", "vec"))
802 if not sortbyvals:
803 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
804 else: # for consistent sorting, add small offset to negative vals
805 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
806 sort = (float("%.4g" % -val), k.name)
807 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
808 if not decorated and skipifempty:
809 return []
810 if included_models:
811 included_models = set(included_models)
812 included_models.add("")
813 models = models.intersection(included_models)
814 if excluded_models:
815 models = models.difference(excluded_models)
816 decorated.sort()
817 previous_model, lines = None, []
818 for varlist in decorated:
819 if sortbyvals:
820 model, _, msenss, isvector, varstr, _, var, val = varlist
821 else:
822 msenss, model, isvector, varstr, _, var, val = varlist
823 if model not in models:
824 continue
825 if model != previous_model:
826 if lines:
827 lines.append(["", "", "", ""])
828 if model:
829 if not latex:
830 lines.append([("newmodelline",), model, "", ""])
831 else:
832 lines.append(
833 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
834 previous_model = model
835 label = var.descr.get("label", "")
836 units = var.unitstr(" [%s] ") if printunits else ""
837 if not isvector:
838 valstr = valfmt % val
839 else:
840 last_dim_index = len(val.shape)-1
841 horiz_dim, ncols = last_dim_index, 1 # starting values
842 for dim_idx, dim_size in enumerate(val.shape):
843 if ncols <= dim_size <= maxcolumns:
844 horiz_dim, ncols = dim_idx, dim_size
845 # align the array with horiz_dim by making it the last one
846 dim_order = list(range(last_dim_index))
847 dim_order.insert(horiz_dim, last_dim_index)
848 flatval = val.transpose(dim_order).flatten()
849 vals = [vecfmt % v for v in flatval[:ncols]]
850 bracket = " ] " if len(flatval) <= ncols else ""
851 valstr = "[ %s%s" % (" ".join(vals), bracket)
852 for before, after in VALSTR_REPLACES:
853 valstr = valstr.replace(before, after)
854 if not latex:
855 lines.append([varstr, valstr, units, label])
856 if isvector and len(flatval) > ncols:
857 values_remaining = len(flatval) - ncols
858 while values_remaining > 0:
859 idx = len(flatval)-values_remaining
860 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
861 values_remaining -= ncols
862 valstr = " " + " ".join(vals)
863 for before, after in VALSTR_REPLACES:
864 valstr = valstr.replace(before, after)
865 if values_remaining <= 0:
866 spaces = (-values_remaining
867 * len(valstr)//(values_remaining + ncols))
868 valstr = valstr + " ]" + " "*spaces
869 lines.append(["", valstr, "", ""])
870 else:
871 varstr = "$%s$" % varstr.replace(" : ", "")
872 if latex == 1: # normal results table
873 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
874 label])
875 coltitles = [title, "Value", "Units", "Description"]
876 elif latex == 2: # no values
877 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
878 coltitles = [title, "Units", "Description"]
879 elif latex == 3: # no description
880 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
881 coltitles = [title, "Value", "Units"]
882 else:
883 raise ValueError("Unexpected latex option, %s." % latex)
884 if rawlines:
885 return lines
886 if not latex:
887 if lines:
888 maxlens = np.max([list(map(len, line)) for line in lines
889 if line[0] != ("newmodelline",)], axis=0)
890 dirs = [">", "<", "<", "<"]
891 # check lengths before using zip
892 assert len(list(dirs)) == len(list(maxlens))
893 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
894 for i, line in enumerate(lines):
895 if line[0] == ("newmodelline",):
896 line = [fmts[0].format(" | "), line[1]]
897 else:
898 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
899 lines[i] = "".join(line).rstrip()
900 lines = [title] + ["-"*len(title)] + lines + [""]
901 else:
902 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
903 lines = (["\n".join(["{\\footnotesize",
904 "\\begin{longtable}{%s}" % colfmt[latex],
905 "\\toprule",
906 " & ".join(coltitles) + " \\\\ \\midrule"])] +
907 [" & ".join(l) + " \\\\" for l in lines] +
908 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
909 return lines