Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from .nomials import NomialArray
11from .small_classes import DictOfLists, Strings
12from .small_scripts import mag, try_str_without
13from .repr_conventions import unitstr, lineagestr
16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
18VALSTR_REPLACES = [
19 ("+nan", " nan"),
20 ("-nan", " nan"),
21 ("nan%", "nan "),
22 ("nan", " - "),
23]
26class SolSavingEnvironment:
27 """Temporarily removes construction/solve attributes from constraints.
29 This approximately halves the size of the pickled solution.
30 """
32 def __init__(self, solarray, saveconstraints):
33 self.solarray = solarray
34 self.attrstore = {}
35 self.saveconstraints = saveconstraints
36 self.constraintstore = None
39 def __enter__(self):
40 if self.saveconstraints:
41 for constraint_attr in ["bounded", "meq_bounded", "vks",
42 "v_ss", "unsubbed", "varkeys"]:
43 store = {}
44 for constraint in self.solarray["sensitivities"]["constraints"]:
45 if getattr(constraint, constraint_attr, None):
46 store[constraint] = getattr(constraint, constraint_attr)
47 delattr(constraint, constraint_attr)
48 self.attrstore[constraint_attr] = store
49 else:
50 self.constraintstore = \
51 self.solarray["sensitivities"].pop("constraints")
53 def __exit__(self, type_, val, traceback):
54 if self.saveconstraints:
55 for constraint_attr, store in self.attrstore.items():
56 for constraint, value in store.items():
57 setattr(constraint, constraint_attr, value)
58 else:
59 self.solarray["sensitivities"]["constraints"] = self.constraintstore
61def msenss_table(data, _, **kwargs):
62 "Returns model sensitivity table lines"
63 if "models" not in data.get("sensitivities", {}):
64 return ""
65 data = sorted(data["sensitivities"]["models"].items(),
66 key=lambda i: ((i[1] < 0.1).all(),
67 -np.max(i[1]) if (i[1] < 0.1).all()
68 else -round(np.mean(i[1]), 1), i[0]))
69 lines = ["Model Sensitivities", "-------------------"]
70 if kwargs["sortmodelsbysenss"]:
71 lines[0] += " (sorts models in sections below)"
72 previousmsenssstr = ""
73 for model, msenss in data:
74 if not model: # for now let's only do named models
75 continue
76 if (msenss < 0.1).all():
77 msenss = np.max(msenss)
78 if msenss:
79 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss))
80 else:
81 msenssstr = " =0 "
82 else:
83 meansenss = round(np.mean(msenss), 1)
84 msenssstr = "%+6.1f" % meansenss
85 deltas = msenss - meansenss
86 if np.max(np.abs(deltas)) > 0.1:
87 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
88 for d in deltas]
89 msenssstr += " + [ %s ]" % " ".join(deltastrs)
90 if msenssstr == previousmsenssstr:
91 msenssstr = " "*len(msenssstr)
92 else:
93 previousmsenssstr = msenssstr
94 lines.append("%s : %s" % (msenssstr, model))
95 return lines + [""] if len(lines) > 3 else []
98def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
99 "Returns sensitivity table lines"
100 if "variables" in data.get("sensitivities", {}):
101 data = data["sensitivities"]["variables"]
102 if showvars:
103 data = {k: data[k] for k in showvars if k in data}
104 return var_table(data, title, sortbyvals=True, skipifempty=True,
105 valfmt="%+-.2g ", vecfmt="%+-8.2g",
106 printunits=False, minval=1e-3, **kwargs)
109def topsenss_table(data, showvars, nvars=5, **kwargs):
110 "Returns top sensitivity table lines"
111 data, filtered = topsenss_filter(data, showvars, nvars)
112 title = "Most Sensitive Variables"
113 if filtered:
114 title = "Next Most Sensitive Variables"
115 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
118def topsenss_filter(data, showvars, nvars=5):
119 "Filters sensitivities down to top N vars"
120 if "variables" in data.get("sensitivities", {}):
121 data = data["sensitivities"]["variables"]
122 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
123 if not np.isnan(s).any()}
124 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
125 filter_already_shown = showvars.intersection(topk)
126 for k in filter_already_shown:
127 topk.remove(k)
128 if nvars > 3: # always show at least 3
129 nvars -= 1
130 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
133def insenss_table(data, _, maxval=0.1, **kwargs):
134 "Returns insensitivity table lines"
135 if "constants" in data.get("sensitivities", {}):
136 data = data["sensitivities"]["variables"]
137 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
138 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
141def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
142 "Return constraint tightness lines"
143 title = "Most Sensitive Constraints"
144 if len(self) > 1:
145 title += " (in last sweep)"
146 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
147 "%+6.2g" % s[-1], id(c), c)
148 for c, s in self["sensitivities"]["constraints"].items()
149 if s[-1] >= tight_senss)[:ntightconstrs]
150 else:
151 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
152 for c, s in self["sensitivities"]["constraints"].items()
153 if s >= tight_senss)[:ntightconstrs]
154 return constraint_table(data, title, **kwargs)
156def loose_table(self, _, min_senss=1e-5, **kwargs):
157 "Return constraint tightness lines"
158 title = "Insensitive Constraints |below %+g|" % min_senss
159 if len(self) > 1:
160 title += " (in last sweep)"
161 data = [(0, "", id(c), c)
162 for c, s in self["sensitivities"]["constraints"].items()
163 if s[-1] <= min_senss]
164 else:
165 data = [(0, "", id(c), c)
166 for c, s in self["sensitivities"]["constraints"].items()
167 if s <= min_senss]
168 return constraint_table(data, title, **kwargs)
171# pylint: disable=too-many-branches,too-many-locals,too-many-statements
172def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
173 "Creates lines for tables where the right side is a constraint."
174 # TODO: this should support 1D array inputs from sweeps
175 excluded = ("units", "unnecessary lineage")
176 if not showmodels:
177 excluded = ("units", "lineage") # hide all of it
178 models, decorated = {}, []
179 for sortby, openingstr, _, constraint in sorted(data):
180 model = lineagestr(constraint) if sortbymodel else ""
181 if model not in models:
182 models[model] = len(models)
183 constrstr = try_str_without(constraint, excluded)
184 if " at 0x" in constrstr: # don't print memory addresses
185 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
186 decorated.append((models[model], model, sortby, constrstr, openingstr))
187 decorated.sort()
188 previous_model, lines = None, []
189 for varlist in decorated:
190 _, model, _, constrstr, openingstr = varlist
191 if model != previous_model:
192 if lines:
193 lines.append(["", ""])
194 if model or lines:
195 lines.append([("newmodelline",), model])
196 previous_model = model
197 constrstr = constrstr.replace(model, "")
198 minlen, maxlen = 25, 80
199 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
200 constraintlines = []
201 line = ""
202 next_idx = 0
203 while next_idx < len(segments):
204 segment = segments[next_idx]
205 next_idx += 1
206 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
207 segments[next_idx] = segment[1:] + segments[next_idx]
208 segment = segment[0]
209 elif len(line) + len(segment) > maxlen and len(line) > minlen:
210 constraintlines.append(line)
211 line = " " # start a new line
212 line += segment
213 while len(line) > maxlen:
214 constraintlines.append(line[:maxlen])
215 line = " " + line[maxlen:]
216 constraintlines.append(line)
217 lines += [(openingstr + " : ", constraintlines[0])]
218 lines += [("", l) for l in constraintlines[1:]]
219 if not lines:
220 lines = [("", "(none)")]
221 maxlens = np.max([list(map(len, line)) for line in lines
222 if line[0] != ("newmodelline",)], axis=0)
223 dirs = [">", "<"] # we'll check lengths before using zip
224 assert len(list(dirs)) == len(list(maxlens))
225 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
226 for i, line in enumerate(lines):
227 if line[0] == ("newmodelline",):
228 linelist = [fmts[0].format(" | "), line[1]]
229 else:
230 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
231 lines[i] = "".join(linelist).rstrip()
232 return [title] + ["-"*len(title)] + lines + [""]
235def warnings_table(self, _, **kwargs):
236 "Makes a table for all warnings in the solution."
237 title = "WARNINGS"
238 lines = ["~"*len(title), title, "~"*len(title)]
239 if "warnings" not in self or not self["warnings"]:
240 return []
241 for wtype in sorted(self["warnings"]):
242 data_vec = self["warnings"][wtype]
243 if len(data_vec) == 0:
244 continue
245 if not hasattr(data_vec, "shape"):
246 data_vec = [data_vec] # not a sweep
247 else:
248 all_equal = True
249 for data in data_vec[1:]:
250 eq_i = (data == data_vec[0])
251 if hasattr(eq_i, "all"):
252 eq_i = eq_i.all()
253 if not eq_i:
254 all_equal = False
255 break
256 if all_equal:
257 data_vec = [data_vec[0]] # warnings identical across sweeps
258 for i, data in enumerate(data_vec):
259 if len(data) == 0:
260 continue
261 data = sorted(data, key=lambda l: l[0]) # sort by msg
262 title = wtype
263 if len(data_vec) > 1:
264 title += " in sweep %i" % i
265 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
266 data = [(-int(1e5*relax_sensitivity),
267 "%+6.2g" % relax_sensitivity, id(c), c)
268 for _, (relax_sensitivity, c) in data]
269 lines += constraint_table(data, title, **kwargs)
270 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
271 data = [(-int(1e5*rel_diff),
272 "%.4g %s %.4g" % tightvalues, id(c), c)
273 for _, (rel_diff, tightvalues, c) in data]
274 lines += constraint_table(data, title, **kwargs)
275 else:
276 lines += [title] + ["-"*len(wtype)]
277 lines += [msg for msg, _ in data] + [""]
278 if len(lines) == 3: # just the header
279 return []
280 lines[-1] = "~~~~~~~~"
281 return lines + [""]
284TABLEFNS = {"sensitivities": senss_table,
285 "top sensitivities": topsenss_table,
286 "insensitivities": insenss_table,
287 "model sensitivities": msenss_table,
288 "tightest constraints": tight_table,
289 "loose constraints": loose_table,
290 "warnings": warnings_table,
291 }
293def unrolled_absmax(values):
294 "From an iterable of numbers and arrays, returns the largest magnitude"
295 finalval, absmaxest = None, 0
296 for val in values:
297 absmaxval = np.abs(val).max()
298 if absmaxval >= absmaxest:
299 absmaxest, finalval = absmaxval, val
300 if getattr(finalval, "shape", None):
301 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
302 finalval.shape)]
303 return finalval
306def cast(function, val1, val2):
307 "Relative difference between val1 and val2 (positive if val2 is larger)"
308 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
309 pywarnings.simplefilter("ignore")
310 if hasattr(val1, "shape") and hasattr(val2, "shape"):
311 if val1.ndim == val2.ndim:
312 return function(val1, val2)
313 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
314 dimdelta = dimmest.ndim - lessdim.ndim
315 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
316 if dimmest is val1:
317 return function(dimmest, lessdim[add_axes])
318 if dimmest is val2:
319 return function(lessdim[add_axes], dimmest)
320 return function(val1, val2)
323class SolutionArray(DictOfLists):
324 """A dictionary (of dictionaries) of lists, with convenience methods.
326 Items
327 -----
328 cost : array
329 variables: dict of arrays
330 sensitivities: dict containing:
331 monomials : array
332 posynomials : array
333 variables: dict of arrays
334 localmodels : NomialArray
335 Local power-law fits (small sensitivities are cut off)
337 Example
338 -------
339 >>> import gpkit
340 >>> import numpy as np
341 >>> x = gpkit.Variable("x")
342 >>> x_min = gpkit.Variable("x_{min}", 2)
343 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
344 >>>
345 >>> # VALUES
346 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
347 >>> assert all(np.array(values) == 2)
348 >>>
349 >>> # SENSITIVITIES
350 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
351 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
352 >>> assert all(np.array(senss) == 1)
353 """
354 modelstr = ""
355 _name_collision_varkeys = None
356 table_titles = {"choicevariables": "Choice Variables",
357 "sweepvariables": "Swept Variables",
358 "freevariables": "Free Variables",
359 "constants": "Fixed Variables", # TODO: change everywhere
360 "variables": "Variables"}
362 def name_collision_varkeys(self):
363 "Returns the set of contained varkeys whose names are not unique"
364 if self._name_collision_varkeys is None:
365 self["variables"].update_keymap()
366 keymap = self["variables"].keymap
367 self._name_collision_varkeys = set()
368 for key in list(keymap):
369 if hasattr(key, "key"):
370 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
371 self._name_collision_varkeys.add(key)
372 return self._name_collision_varkeys
374 def __len__(self):
375 try:
376 return len(self["cost"])
377 except TypeError:
378 return 1
379 except KeyError:
380 return 0
382 def __call__(self, posy):
383 posy_subbed = self.subinto(posy)
384 return getattr(posy_subbed, "c", posy_subbed)
386 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
387 "Checks for almost-equality between two solutions"
388 svars, ovars = self["variables"], other["variables"]
389 svks, ovks = set(svars), set(ovars)
390 if svks != ovks:
391 return False
392 for key in svks:
393 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
394 return False
395 if abs(self["sensitivities"]["variables"][key]
396 - other["sensitivities"]["variables"][key]) >= sens_abstol:
397 return False
398 return True
400 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
401 def diff(self, other, showvars=None, *,
402 constraintsdiff=True, senssdiff=False, sensstol=0.1,
403 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
404 sortmodelsbysenss=True, **tableargs):
405 """Outputs differences between this solution and another
407 Arguments
408 ---------
409 other : solution or string
410 strings will be treated as paths to pickled solutions
411 senssdiff : boolean
412 if True, show sensitivity differences
413 sensstol : float
414 the smallest sensitivity difference worth showing
415 absdiff : boolean
416 if True, show absolute differences
417 abstol : float
418 the smallest absolute difference worth showing
419 reldiff : boolean
420 if True, show relative differences
421 reltol : float
422 the smallest relative difference worth showing
424 Returns
425 -------
426 str
427 """
428 if sortmodelsbysenss:
429 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
430 else:
431 tableargs["sortmodelsbysenss"] = False
432 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
433 "skipifempty": False})
434 if isinstance(other, Strings):
435 if other[-4:] == ".pgz":
436 other = SolutionArray.decompress_file(other)
437 else:
438 other = pickle.load(open(other, "rb"))
439 svars, ovars = self["variables"], other["variables"]
440 lines = ["Solution Diff",
441 "=============",
442 "(argument is the baseline solution)", ""]
443 svks, ovks = set(svars), set(ovars)
444 if showvars:
445 lines[0] += " (for selected variables)"
446 lines[1] += "========================="
447 showvars = self._parse_showvars(showvars)
448 svks = {k for k in showvars if k in svars}
449 ovks = {k for k in showvars if k in ovars}
450 if constraintsdiff and other.modelstr and self.modelstr:
451 if self.modelstr == other.modelstr:
452 lines += ["** no constraint differences **", ""]
453 else:
454 cdiff = ["Constraint Differences",
455 "**********************"]
456 cdiff.extend(list(difflib.unified_diff(
457 other.modelstr.split("\n"), self.modelstr.split("\n"),
458 lineterm="", n=3))[2:])
459 cdiff += ["", "**********************", ""]
460 lines += cdiff
461 if svks - ovks:
462 lines.append("Variable(s) of this solution"
463 " which are not in the argument:")
464 lines.append("\n".join(" %s" % key for key in svks - ovks))
465 lines.append("")
466 if ovks - svks:
467 lines.append("Variable(s) of the argument"
468 " which are not in this solution:")
469 lines.append("\n".join(" %s" % key for key in ovks - svks))
470 lines.append("")
471 sharedvks = svks.intersection(ovks)
472 if reldiff:
473 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
474 for vk in sharedvks}
475 lines += var_table(rel_diff,
476 "Relative Differences |above %g%%|" % reltol,
477 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
478 minval=reltol, printunits=False, **tableargs)
479 if lines[-2][:10] == "-"*10: # nothing larger than reltol
480 lines.insert(-1, ("The largest is %+g%%."
481 % unrolled_absmax(rel_diff.values())))
482 if absdiff:
483 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
484 lines += var_table(abs_diff,
485 "Absolute Differences |above %g|" % abstol,
486 valfmt="%+.2g", vecfmt="%+8.2g",
487 minval=abstol, **tableargs)
488 if lines[-2][:10] == "-"*10: # nothing larger than abstol
489 lines.insert(-1, ("The largest is %+g."
490 % unrolled_absmax(abs_diff.values())))
491 if senssdiff:
492 ssenss = self["sensitivities"]["variables"]
493 osenss = other["sensitivities"]["variables"]
494 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
495 for vk in svks.intersection(ovks)}
496 lines += var_table(senss_delta,
497 "Sensitivity Differences |above %g|" % sensstol,
498 valfmt="%+-.2f ", vecfmt="%+-6.2f",
499 minval=sensstol, printunits=False, **tableargs)
500 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
501 lines.insert(-1, ("The largest is %+g."
502 % unrolled_absmax(senss_delta.values())))
503 return "\n".join(lines)
505 def save(self, filename="solution.pkl",
506 *, saveconstraints=True, **pickleargs):
507 """Pickles the solution and saves it to a file.
509 Solution can then be loaded with e.g.:
510 >>> import pickle
511 >>> pickle.load(open("solution.pkl"))
512 """
513 with SolSavingEnvironment(self, saveconstraints):
514 pickle.dump(self, open(filename, "wb"), **pickleargs)
516 def save_compressed(self, filename="solution.pgz",
517 *, saveconstraints=True, **cpickleargs):
518 "Pickle a file and then compress it into a file with extension."
519 with gzip.open(filename, "wb") as f:
520 with SolSavingEnvironment(self, saveconstraints):
521 pickled = pickle.dumps(self, **cpickleargs)
522 f.write(pickletools.optimize(pickled))
524 @staticmethod
525 def decompress_file(file):
526 "Load a gzip-compressed pickle file"
527 with gzip.open(file, "rb") as f:
528 return pickle.Unpickler(f).load()
530 def varnames(self, showvars, exclude):
531 "Returns list of variables, optionally with minimal unique names"
532 if showvars:
533 showvars = self._parse_showvars(showvars)
534 for key in self.name_collision_varkeys():
535 key.descr["necessarylineage"] = True
536 names = {}
537 for key in showvars or self["variables"]:
538 for k in self["variables"].keymap[key]:
539 names[k.str_without(exclude)] = k
540 for key in self.name_collision_varkeys():
541 del key.descr["necessarylineage"]
542 return names
544 def savemat(self, filename="solution.mat", *, showvars=None,
545 excluded=("unnecessary lineage", "vec")):
546 "Saves primal solution as matlab file"
547 from scipy.io import savemat
548 savemat(filename,
549 {name.replace(".", "_"): np.array(self["variables"][key], "f")
550 for name, key in self.varnames(showvars, excluded).items()})
552 def todataframe(self, showvars=None,
553 excluded=("unnecessary lineage", "vec")):
554 "Returns primal solution as pandas dataframe"
555 import pandas as pd # pylint:disable=import-error
556 rows = []
557 cols = ["Name", "Index", "Value", "Units", "Label",
558 "Lineage", "Other"]
559 for _, key in sorted(self.varnames(showvars, excluded).items(),
560 key=lambda k: k[0]):
561 value = self["variables"][key]
562 if key.shape:
563 idxs = []
564 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
565 while not it.finished:
566 idx = it.multi_index
567 idxs.append(idx[0] if len(idx) == 1 else idx)
568 it.iternext()
569 else:
570 idxs = [None]
571 for idx in idxs:
572 row = [
573 key.name,
574 "" if idx is None else idx,
575 value if idx is None else value[idx]]
576 rows.append(row)
577 row.extend([
578 key.unitstr(),
579 key.label or "",
580 key.lineage or "",
581 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
582 if k not in ["name", "units", "unitrepr",
583 "idx", "shape", "veckey",
584 "value", "vecfn",
585 "lineage", "label"])])
586 return pd.DataFrame(rows, columns=cols)
588 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
589 "Saves solution table as a text file"
590 with open(filename, "w") as f:
591 if printmodel:
592 f.write(self.modelstr + "\n")
593 f.write(self.table(**kwargs))
595 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
596 "Saves primal solution as a CSV sorted by modelname, like the tables."
597 data = self["variables"]
598 if showvars:
599 showvars = self._parse_showvars(showvars)
600 data = {k: data[k] for k in showvars if k in data}
601 # if the columns don't capture any dimensions, skip them
602 minspan, maxspan = None, 1
603 for v in data.values():
604 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
605 minspan_ = min((di for di in v.shape if di != 1))
606 maxspan_ = max((di for di in v.shape if di != 1))
607 if minspan is None or minspan_ < minspan:
608 minspan = minspan_
609 if maxspan is None or maxspan_ > maxspan:
610 maxspan = maxspan_
611 if minspan is not None and minspan > valcols:
612 valcols = 1
613 if maxspan < valcols:
614 valcols = maxspan
615 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
616 tables=("cost", "sweepvariables", "freevariables",
617 "constants", "sensitivities"))
618 with open(filename, "w") as f:
619 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
620 + "Units,Description\n")
621 for line in lines:
622 if line[0] == ("newmodelline",):
623 f.write(line[1])
624 elif not line[1]: # spacer line
625 f.write("\n")
626 else:
627 f.write("," + line[0].replace(" : ", "") + ",")
628 vals = line[1].replace("[", "").replace("]", "").strip()
629 for el in vals.split():
630 f.write(el + ",")
631 f.write(","*(valcols - len(vals.split())))
632 f.write((line[2].replace("[", "").replace("]", "").strip()
633 + ","))
634 f.write(line[3].strip() + "\n")
636 def subinto(self, posy):
637 "Returns NomialArray of each solution substituted into posy."
638 if posy in self["variables"]:
639 return self["variables"](posy)
641 if not hasattr(posy, "sub"):
642 raise ValueError("no variable '%s' found in the solution" % posy)
644 if len(self) > 1:
645 return NomialArray([self.atindex(i).subinto(posy)
646 for i in range(len(self))])
648 return posy.sub(self["variables"])
650 def _parse_showvars(self, showvars):
651 showvars_out = set()
652 for k in showvars:
653 k, _ = self["variables"].parse_and_index(k)
654 keys = self["variables"].keymap[k]
655 showvars_out.update(keys)
656 return showvars_out
658 def summary(self, showvars=(), ntopsenss=5, **kwargs):
659 "Print summary table, showing top sensitivities and no constants"
660 showvars = self._parse_showvars(showvars)
661 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
662 "freevariables"], **kwargs)
663 constants_in_showvars = showvars.intersection(self["constants"])
664 senss_tables = []
665 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
666 senss_tables.append("sensitivities")
667 if len(self["constants"]) >= ntopsenss+2:
668 senss_tables.append("top sensitivities")
669 senss_tables.append("tightest constraints")
670 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
671 **kwargs)
672 if senss_str:
673 out += "\n" + senss_str
674 return out
676 def table(self, showvars=(),
677 tables=("cost", "warnings", "model sensitivities",
678 "sweepvariables", "freevariables",
679 "constants", "sensitivities", "tightest constraints"),
680 sortmodelsbysenss=True, **kwargs):
681 """A table representation of this SolutionArray
683 Arguments
684 ---------
685 tables: Iterable
686 Which to print of ("cost", "sweepvariables", "freevariables",
687 "constants", "sensitivities")
688 fixedcols: If true, print vectors in fixed-width format
689 latex: int
690 If > 0, return latex format (options 1-3); otherwise plain text
691 included_models: Iterable of strings
692 If specified, the models (by name) to include
693 excluded_models: Iterable of strings
694 If specified, model names to exclude
696 Returns
697 -------
698 str
699 """
700 if sortmodelsbysenss and "sensitivities" in self:
701 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
702 else:
703 kwargs["sortmodelsbysenss"] = False
704 varlist = list(self["variables"])
705 has_only_one_model = True
706 for var in varlist[1:]:
707 if var.lineage != varlist[0].lineage:
708 has_only_one_model = False
709 break
710 if has_only_one_model:
711 kwargs["sortbymodel"] = False
712 for key in self.name_collision_varkeys():
713 key.descr["necessarylineage"] = True
714 showvars = self._parse_showvars(showvars)
715 strs = []
716 for table in tables:
717 if "sensitivities" not in self and ("sensitivities" in table or
718 "constraints" in table):
719 continue
720 if table == "cost":
721 cost = self["cost"] # pylint: disable=unsubscriptable-object
722 if kwargs.get("latex", None): # cost is not printed for latex
723 continue
724 strs += ["\n%s\n------------" % "Optimal Cost"]
725 if len(self) > 1:
726 costs = ["%-8.3g" % c for c in mag(cost[:4])]
727 strs += [" [ %s %s ]" % (" ".join(costs),
728 "..." if len(self) > 4 else "")]
729 else:
730 strs += [" %-.4g" % mag(cost)]
731 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
732 strs += [""]
733 elif table in TABLEFNS:
734 strs += TABLEFNS[table](self, showvars, **kwargs)
735 elif table in self:
736 data = self[table]
737 if showvars:
738 showvars = self._parse_showvars(showvars)
739 data = {k: data[k] for k in showvars if k in data}
740 strs += var_table(data, self.table_titles[table], **kwargs)
741 if kwargs.get("latex", None):
742 preamble = "\n".join(("% \\documentclass[12pt]{article}",
743 "% \\usepackage{booktabs}",
744 "% \\usepackage{longtable}",
745 "% \\usepackage{amsmath}",
746 "% \\begin{document}\n"))
747 strs = [preamble] + strs + ["% \\end{document}"]
748 for key in self.name_collision_varkeys():
749 del key.descr["necessarylineage"]
750 return "\n".join(strs)
752 def plot(self, posys=None, axes=None):
753 "Plots a sweep for each posy"
754 if len(self["sweepvariables"]) != 1:
755 print("SolutionArray.plot only supports 1-dimensional sweeps")
756 if not hasattr(posys, "__len__"):
757 posys = [posys]
758 import matplotlib.pyplot as plt
759 from .interactive.plot_sweep import assign_axes
760 from . import GPBLU
761 (swept, x), = self["sweepvariables"].items()
762 posys, axes = assign_axes(swept, posys, axes)
763 for posy, ax in zip(posys, axes):
764 y = self(posy) if posy not in [None, "cost"] else self["cost"]
765 ax.plot(x, y, color=GPBLU)
766 if len(axes) == 1:
767 axes, = axes
768 return plt.gcf(), axes
771# pylint: disable=too-many-branches,too-many-locals,too-many-statements
772def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
773 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
774 minval=0, sortbyvals=False, hidebelowminval=False,
775 included_models=None, excluded_models=None, sortbymodel=True,
776 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
777 """
778 Pretty string representation of a dict of VarKeys
779 Iterable values are handled specially (partial printing)
781 Arguments
782 ---------
783 data : dict whose keys are VarKey's
784 data to represent in table
785 title : string
786 printunits : bool
787 latex : int
788 If > 0, return latex format (options 1-3); otherwise plain text
789 varfmt : string
790 format for variable names
791 valfmt : string
792 format for scalar values
793 vecfmt : string
794 format for vector values
795 minval : float
796 skip values with all(abs(value)) < minval
797 sortbyvals : boolean
798 If true, rows are sorted by their average value instead of by name.
799 included_models : Iterable of strings
800 If specified, the models (by name) to include
801 excluded_models : Iterable of strings
802 If specified, model names to exclude
803 """
804 if not data:
805 return []
806 decorated, models = [], set()
807 for i, (k, v) in enumerate(data.items()):
808 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
809 continue # no values below minval
810 if minval and hidebelowminval and getattr(v, "shape", None):
811 v[np.abs(v) <= minval] = np.nan
812 model = lineagestr(k.lineage) if sortbymodel else ""
813 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0
814 if hasattr(msenss, "shape"):
815 msenss = np.mean(msenss)
816 models.add(model)
817 b = bool(getattr(v, "shape", None))
818 s = k.str_without(("lineage", "vec"))
819 if not sortbyvals:
820 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
821 else: # for consistent sorting, add small offset to negative vals
822 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
823 sort = (float("%.4g" % -val), k.name)
824 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
825 if not decorated and skipifempty:
826 return []
827 if included_models:
828 included_models = set(included_models)
829 included_models.add("")
830 models = models.intersection(included_models)
831 if excluded_models:
832 models = models.difference(excluded_models)
833 decorated.sort()
834 previous_model, lines = None, []
835 for varlist in decorated:
836 if sortbyvals:
837 model, _, msenss, isvector, varstr, _, var, val = varlist
838 else:
839 msenss, model, isvector, varstr, _, var, val = varlist
840 if model not in models:
841 continue
842 if model != previous_model:
843 if lines:
844 lines.append(["", "", "", ""])
845 if model:
846 if not latex:
847 lines.append([("newmodelline",), model, "", ""])
848 else:
849 lines.append(
850 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
851 previous_model = model
852 label = var.descr.get("label", "")
853 units = var.unitstr(" [%s] ") if printunits else ""
854 if not isvector:
855 valstr = valfmt % val
856 else:
857 last_dim_index = len(val.shape)-1
858 horiz_dim, ncols = last_dim_index, 1 # starting values
859 for dim_idx, dim_size in enumerate(val.shape):
860 if ncols <= dim_size <= maxcolumns:
861 horiz_dim, ncols = dim_idx, dim_size
862 # align the array with horiz_dim by making it the last one
863 dim_order = list(range(last_dim_index))
864 dim_order.insert(horiz_dim, last_dim_index)
865 flatval = val.transpose(dim_order).flatten()
866 vals = [vecfmt % v for v in flatval[:ncols]]
867 bracket = " ] " if len(flatval) <= ncols else ""
868 valstr = "[ %s%s" % (" ".join(vals), bracket)
869 for before, after in VALSTR_REPLACES:
870 valstr = valstr.replace(before, after)
871 if not latex:
872 lines.append([varstr, valstr, units, label])
873 if isvector and len(flatval) > ncols:
874 values_remaining = len(flatval) - ncols
875 while values_remaining > 0:
876 idx = len(flatval)-values_remaining
877 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
878 values_remaining -= ncols
879 valstr = " " + " ".join(vals)
880 for before, after in VALSTR_REPLACES:
881 valstr = valstr.replace(before, after)
882 if values_remaining <= 0:
883 spaces = (-values_remaining
884 * len(valstr)//(values_remaining + ncols))
885 valstr = valstr + " ]" + " "*spaces
886 lines.append(["", valstr, "", ""])
887 else:
888 varstr = "$%s$" % varstr.replace(" : ", "")
889 if latex == 1: # normal results table
890 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
891 label])
892 coltitles = [title, "Value", "Units", "Description"]
893 elif latex == 2: # no values
894 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
895 coltitles = [title, "Units", "Description"]
896 elif latex == 3: # no description
897 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
898 coltitles = [title, "Value", "Units"]
899 else:
900 raise ValueError("Unexpected latex option, %s." % latex)
901 if rawlines:
902 return lines
903 if not latex:
904 if lines:
905 maxlens = np.max([list(map(len, line)) for line in lines
906 if line[0] != ("newmodelline",)], axis=0)
907 dirs = [">", "<", "<", "<"]
908 # check lengths before using zip
909 assert len(list(dirs)) == len(list(maxlens))
910 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
911 for i, line in enumerate(lines):
912 if line[0] == ("newmodelline",):
913 line = [fmts[0].format(" | "), line[1]]
914 else:
915 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
916 lines[i] = "".join(line).rstrip()
917 lines = [title] + ["-"*len(title)] + lines + [""]
918 else:
919 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
920 lines = (["\n".join(["{\\footnotesize",
921 "\\begin{longtable}{%s}" % colfmt[latex],
922 "\\toprule",
923 " & ".join(coltitles) + " \\\\ \\midrule"])] +
924 [" & ".join(l) + " \\\\" for l in lines] +
925 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
926 return lines