Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from .nomials import NomialArray
11from .small_classes import DictOfLists, Strings
12from .small_scripts import mag, try_str_without
13from .repr_conventions import unitstr, lineagestr
16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
18VALSTR_REPLACES = [
19 ("+nan", " nan"),
20 ("-nan", " nan"),
21 ("nan%", "nan "),
22 ("nan", " - "),
23]
26class SolSavingEnvironment:
27 """Temporarily removes construction/solve attributes from constraints.
29 This approximately halves the size of the pickled solution.
30 """
32 def __init__(self, solarray, saveconstraints):
33 self.solarray = solarray
34 self.attrstore = {}
35 self.saveconstraints = saveconstraints
36 self.constraintstore = None
39 def __enter__(self):
40 if self.saveconstraints:
41 for constraint_attr in ["bounded", "meq_bounded", "vks",
42 "v_ss", "unsubbed", "varkeys"]:
43 store = {}
44 for constraint in self.solarray["sensitivities"]["constraints"]:
45 if getattr(constraint, constraint_attr, None):
46 store[constraint] = getattr(constraint, constraint_attr)
47 delattr(constraint, constraint_attr)
48 self.attrstore[constraint_attr] = store
49 else:
50 self.constraintstore = \
51 self.solarray["sensitivities"].pop("constraints")
53 def __exit__(self, type_, val, traceback):
54 if self.saveconstraints:
55 for constraint_attr, store in self.attrstore.items():
56 for constraint, value in store.items():
57 setattr(constraint, constraint_attr, value)
58 else:
59 self.solarray["sensitivities"]["constraints"] = self.constraintstore
61def msenss_table(data, _, **kwargs):
62 "Returns model sensitivity table lines"
63 if "models" not in data.get("sensitivities", {}):
64 return ""
65 data = sorted(data["sensitivities"]["models"].items(),
66 key=lambda i: -np.mean(i[1]))
67 lines = ["Model Sensitivities", "-------------------"]
68 if kwargs["sortmodelsbysenss"]:
69 lines[0] += " (sorts models in sections below)"
70 previousmsenssstr = ""
71 for model, msenss in data:
72 if not model: # for now let's only do named models
73 continue
74 if (msenss < 0.1).all():
75 msenss = np.max(msenss)
76 if msenss:
77 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss))
78 else:
79 msenssstr = " =0 "
80 elif not msenss.shape:
81 msenssstr = "%+6.1f" % msenss
82 else:
83 meansenss = np.mean(msenss)
84 msenssstr = "%+6.1f" % meansenss
85 deltas = msenss - meansenss
86 if np.max(np.abs(deltas)) > 0.1:
87 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
88 for d in deltas]
89 msenssstr += " + [ %s ]" % " ".join(deltastrs)
90 if msenssstr == previousmsenssstr:
91 msenssstr = " "
92 else:
93 previousmsenssstr = msenssstr
94 lines.append("%s : %s" % (msenssstr, model))
95 return lines + [""] if len(lines) > 3 else []
98def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
99 "Returns sensitivity table lines"
100 if "variables" in data.get("sensitivities", {}):
101 data = data["sensitivities"]["variables"]
102 if showvars:
103 data = {k: data[k] for k in showvars if k in data}
104 return var_table(data, title, sortbyvals=True, skipifempty=True,
105 valfmt="%+-.2g ", vecfmt="%+-8.2g",
106 printunits=False, minval=1e-3, **kwargs)
109def topsenss_table(data, showvars, nvars=5, **kwargs):
110 "Returns top sensitivity table lines"
111 data, filtered = topsenss_filter(data, showvars, nvars)
112 title = "Most Sensitive Variables"
113 if filtered:
114 title = "Next Most Sensitive Variables"
115 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
118def topsenss_filter(data, showvars, nvars=5):
119 "Filters sensitivities down to top N vars"
120 if "variables" in data.get("sensitivities", {}):
121 data = data["sensitivities"]["variables"]
122 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
123 if not np.isnan(s).any()}
124 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
125 filter_already_shown = showvars.intersection(topk)
126 for k in filter_already_shown:
127 topk.remove(k)
128 if nvars > 3: # always show at least 3
129 nvars -= 1
130 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
133def insenss_table(data, _, maxval=0.1, **kwargs):
134 "Returns insensitivity table lines"
135 if "constants" in data.get("sensitivities", {}):
136 data = data["sensitivities"]["variables"]
137 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
138 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
141def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
142 "Return constraint tightness lines"
143 title = "Most Sensitive Constraints"
144 if len(self) > 1:
145 title += " (in last sweep)"
146 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
147 "%+6.2g" % s[-1], id(c), c)
148 for c, s in self["sensitivities"]["constraints"].items()
149 if s[-1] >= tight_senss)[:ntightconstrs]
150 else:
151 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
152 for c, s in self["sensitivities"]["constraints"].items()
153 if s >= tight_senss)[:ntightconstrs]
154 return constraint_table(data, title, **kwargs)
156def loose_table(self, _, min_senss=1e-5, **kwargs):
157 "Return constraint tightness lines"
158 title = "Insensitive Constraints |below %+g|" % min_senss
159 if len(self) > 1:
160 title += " (in last sweep)"
161 data = [(0, "", id(c), c)
162 for c, s in self["sensitivities"]["constraints"].items()
163 if s[-1] <= min_senss]
164 else:
165 data = [(0, "", id(c), c)
166 for c, s in self["sensitivities"]["constraints"].items()
167 if s <= min_senss]
168 return constraint_table(data, title, **kwargs)
171# pylint: disable=too-many-branches,too-many-locals,too-many-statements
172def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
173 "Creates lines for tables where the right side is a constraint."
174 # TODO: this should support 1D array inputs from sweeps
175 excluded = ("units", "unnecessary lineage")
176 if not showmodels:
177 excluded = ("units", "lineage") # hide all of it
178 models, decorated = {}, []
179 for sortby, openingstr, _, constraint in sorted(data):
180 model = lineagestr(constraint) if sortbymodel else ""
181 if model not in models:
182 models[model] = len(models)
183 constrstr = try_str_without(constraint, excluded)
184 if " at 0x" in constrstr: # don't print memory addresses
185 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
186 decorated.append((models[model], model, sortby, constrstr, openingstr))
187 decorated.sort()
188 previous_model, lines = None, []
189 for varlist in decorated:
190 _, model, _, constrstr, openingstr = varlist
191 if model != previous_model:
192 if lines:
193 lines.append(["", ""])
194 if model or lines:
195 lines.append([("newmodelline",), model])
196 previous_model = model
197 constrstr = constrstr.replace(model, "")
198 minlen, maxlen = 25, 80
199 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
200 constraintlines = []
201 line = ""
202 next_idx = 0
203 while next_idx < len(segments):
204 segment = segments[next_idx]
205 next_idx += 1
206 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
207 segments[next_idx] = segment[1:] + segments[next_idx]
208 segment = segment[0]
209 elif len(line) + len(segment) > maxlen and len(line) > minlen:
210 constraintlines.append(line)
211 line = " " # start a new line
212 line += segment
213 while len(line) > maxlen:
214 constraintlines.append(line[:maxlen])
215 line = " " + line[maxlen:]
216 constraintlines.append(line)
217 lines += [(openingstr + " : ", constraintlines[0])]
218 lines += [("", l) for l in constraintlines[1:]]
219 if not lines:
220 lines = [("", "(none)")]
221 maxlens = np.max([list(map(len, line)) for line in lines
222 if line[0] != ("newmodelline",)], axis=0)
223 dirs = [">", "<"] # we'll check lengths before using zip
224 assert len(list(dirs)) == len(list(maxlens))
225 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
226 for i, line in enumerate(lines):
227 if line[0] == ("newmodelline",):
228 linelist = [fmts[0].format(" | "), line[1]]
229 else:
230 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
231 lines[i] = "".join(linelist).rstrip()
232 return [title] + ["-"*len(title)] + lines + [""]
235def warnings_table(self, _, **kwargs):
236 "Makes a table for all warnings in the solution."
237 title = "WARNINGS"
238 lines = ["~"*len(title), title, "~"*len(title)]
239 if "warnings" not in self or not self["warnings"]:
240 return []
241 for wtype in sorted(self["warnings"]):
242 data_vec = self["warnings"][wtype]
243 if len(data_vec) == 0:
244 continue
245 if not hasattr(data_vec, "shape"):
246 data_vec = [data_vec] # not a sweep
247 if all((data == data_vec[0]).all() for data in data_vec[1:]):
248 data_vec = [data_vec[0]] # warnings identical across all sweeps
249 for i, data in enumerate(data_vec):
250 if len(data) == 0:
251 continue
252 data = sorted(data, key=lambda l: l[0]) # sort by msg
253 title = wtype
254 if len(data_vec) > 1:
255 title += " in sweep %i" % i
256 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
257 data = [(-int(1e5*c.relax_sensitivity),
258 "%+6.2g" % c.relax_sensitivity, id(c), c)
259 for _, c in data]
260 lines += constraint_table(data, title, **kwargs)
261 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
262 data = [(-int(1e5*c.rel_diff),
263 "%.4g %s %.4g" % c.tightvalues, id(c), c)
264 for _, c in data]
265 lines += constraint_table(data, title, **kwargs)
266 else:
267 lines += [title] + ["-"*len(wtype)]
268 lines += [msg for msg, _ in data] + [""]
269 lines[-1] = "~~~~~~~~"
270 return lines + [""]
273TABLEFNS = {"sensitivities": senss_table,
274 "top sensitivities": topsenss_table,
275 "insensitivities": insenss_table,
276 "model sensitivities": msenss_table,
277 "tightest constraints": tight_table,
278 "loose constraints": loose_table,
279 "warnings": warnings_table,
280 }
282def unrolled_absmax(values):
283 "From an iterable of numbers and arrays, returns the largest magnitude"
284 finalval, absmaxest = None, 0
285 for val in values:
286 absmaxval = np.abs(val).max()
287 if absmaxval >= absmaxest:
288 absmaxest, finalval = absmaxval, val
289 if getattr(finalval, "shape", None):
290 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
291 finalval.shape)]
292 return finalval
295def cast(function, val1, val2):
296 "Relative difference between val1 and val2 (positive if val2 is larger)"
297 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
298 pywarnings.simplefilter("ignore")
299 if hasattr(val1, "shape") and hasattr(val2, "shape"):
300 if val1.ndim == val2.ndim:
301 return function(val1, val2)
302 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
303 dimdelta = dimmest.ndim - lessdim.ndim
304 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
305 if dimmest is val1:
306 return function(dimmest, lessdim[add_axes])
307 if dimmest is val2:
308 return function(lessdim[add_axes], dimmest)
309 return function(val1, val2)
312class SolutionArray(DictOfLists):
313 """A dictionary (of dictionaries) of lists, with convenience methods.
315 Items
316 -----
317 cost : array
318 variables: dict of arrays
319 sensitivities: dict containing:
320 monomials : array
321 posynomials : array
322 variables: dict of arrays
323 localmodels : NomialArray
324 Local power-law fits (small sensitivities are cut off)
326 Example
327 -------
328 >>> import gpkit
329 >>> import numpy as np
330 >>> x = gpkit.Variable("x")
331 >>> x_min = gpkit.Variable("x_{min}", 2)
332 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
333 >>>
334 >>> # VALUES
335 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
336 >>> assert all(np.array(values) == 2)
337 >>>
338 >>> # SENSITIVITIES
339 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
340 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
341 >>> assert all(np.array(senss) == 1)
342 """
343 modelstr = ""
344 _name_collision_varkeys = None
345 table_titles = {"choicevariables": "Choice Variables",
346 "sweepvariables": "Swept Variables",
347 "freevariables": "Free Variables",
348 "constants": "Fixed Variables", # TODO: change everywhere
349 "variables": "Variables"}
351 def name_collision_varkeys(self):
352 "Returns the set of contained varkeys whose names are not unique"
353 if self._name_collision_varkeys is None:
354 self["variables"].update_keymap()
355 keymap = self["variables"].keymap
356 self._name_collision_varkeys = set()
357 for key in list(keymap):
358 if hasattr(key, "key"):
359 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
360 self._name_collision_varkeys.add(key)
361 return self._name_collision_varkeys
363 def __len__(self):
364 try:
365 return len(self["cost"])
366 except TypeError:
367 return 1
368 except KeyError:
369 return 0
371 def __call__(self, posy):
372 posy_subbed = self.subinto(posy)
373 return getattr(posy_subbed, "c", posy_subbed)
375 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
376 "Checks for almost-equality between two solutions"
377 svars, ovars = self["variables"], other["variables"]
378 svks, ovks = set(svars), set(ovars)
379 if svks != ovks:
380 return False
381 for key in svks:
382 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
383 return False
384 if abs(self["sensitivities"]["variables"][key]
385 - other["sensitivities"]["variables"][key]) >= sens_abstol:
386 return False
387 return True
389 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
390 def diff(self, other, showvars=None, *,
391 constraintsdiff=True, senssdiff=False, sensstol=0.1,
392 absdiff=False, abstol=0, reldiff=True, reltol=1.0,
393 sortmodelsbysenss=True, **tableargs):
394 """Outputs differences between this solution and another
396 Arguments
397 ---------
398 other : solution or string
399 strings will be treated as paths to pickled solutions
400 senssdiff : boolean
401 if True, show sensitivity differences
402 sensstol : float
403 the smallest sensitivity difference worth showing
404 abssdiff : boolean
405 if True, show absolute differences
406 absstol : float
407 the smallest absolute difference worth showing
408 reldiff : boolean
409 if True, show relative differences
410 reltol : float
411 the smallest relative difference worth showing
413 Returns
414 -------
415 str
416 """
417 if sortmodelsbysenss:
418 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
419 else:
420 tableargs["sortmodelsbysenss"] = False
421 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
422 "skipifempty": False})
423 if isinstance(other, Strings):
424 if other[-4:] == ".pgz":
425 other = SolutionArray.decompress_file(other)
426 else:
427 other = pickle.load(open(other, "rb"))
428 svars, ovars = self["variables"], other["variables"]
429 lines = ["Solution Diff",
430 "=============",
431 "(argument is the baseline solution)", ""]
432 svks, ovks = set(svars), set(ovars)
433 if showvars:
434 lines[0] += " (for selected variables)"
435 lines[1] += "========================="
436 showvars = self._parse_showvars(showvars)
437 svks = {k for k in showvars if k in svars}
438 ovks = {k for k in showvars if k in ovars}
439 if constraintsdiff and other.modelstr and self.modelstr:
440 if self.modelstr == other.modelstr:
441 lines += ["** no constraint differences **", ""]
442 else:
443 cdiff = ["Constraint Differences",
444 "**********************"]
445 cdiff.extend(list(difflib.unified_diff(
446 other.modelstr.split("\n"), self.modelstr.split("\n"),
447 lineterm="", n=3))[2:])
448 cdiff += ["", "**********************", ""]
449 lines += cdiff
450 if svks - ovks:
451 lines.append("Variable(s) of this solution"
452 " which are not in the argument:")
453 lines.append("\n".join(" %s" % key for key in svks - ovks))
454 lines.append("")
455 if ovks - svks:
456 lines.append("Variable(s) of the argument"
457 " which are not in this solution:")
458 lines.append("\n".join(" %s" % key for key in ovks - svks))
459 lines.append("")
460 sharedvks = svks.intersection(ovks)
461 if reldiff:
462 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
463 for vk in sharedvks}
464 lines += var_table(rel_diff,
465 "Relative Differences |above %g%%|" % reltol,
466 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
467 minval=reltol, printunits=False, **tableargs)
468 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
469 lines.insert(-1, ("The largest is %+g%%."
470 % unrolled_absmax(rel_diff.values())))
471 if absdiff:
472 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
473 lines += var_table(abs_diff,
474 "Absolute Differences |above %g|" % abstol,
475 valfmt="%+.2g", vecfmt="%+8.2g",
476 minval=abstol, **tableargs)
477 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
478 lines.insert(-1, ("The largest is %+g."
479 % unrolled_absmax(abs_diff.values())))
480 if senssdiff:
481 ssenss = self["sensitivities"]["variables"]
482 osenss = other["sensitivities"]["variables"]
483 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
484 for vk in svks.intersection(ovks)}
485 lines += var_table(senss_delta,
486 "Sensitivity Differences |above %g|" % sensstol,
487 valfmt="%+-.2f ", vecfmt="%+-6.2f",
488 minval=sensstol, printunits=False, **tableargs)
489 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
490 lines.insert(-1, ("The largest is %+g."
491 % unrolled_absmax(senss_delta.values())))
492 return "\n".join(lines)
494 def save(self, filename="solution.pkl",
495 *, saveconstraints=True, **pickleargs):
496 """Pickles the solution and saves it to a file.
498 Solution can then be loaded with e.g.:
499 >>> import pickle
500 >>> pickle.load(open("solution.pkl"))
501 """
502 with SolSavingEnvironment(self, saveconstraints):
503 pickle.dump(self, open(filename, "wb"), **pickleargs)
505 def save_compressed(self, filename="solution.pgz",
506 *, saveconstraints=True, **cpickleargs):
507 "Pickle a file and then compress it into a file with extension."
508 with gzip.open(filename, "wb") as f:
509 with SolSavingEnvironment(self, saveconstraints):
510 pickled = pickle.dumps(self, **cpickleargs)
511 f.write(pickletools.optimize(pickled))
513 @staticmethod
514 def decompress_file(file):
515 "Load a gzip-compressed pickle file"
516 with gzip.open(file, "rb") as f:
517 return pickle.Unpickler(f).load()
519 def varnames(self, showvars, exclude):
520 "Returns list of variables, optionally with minimal unique names"
521 if showvars:
522 showvars = self._parse_showvars(showvars)
523 for key in self.name_collision_varkeys():
524 key.descr["necessarylineage"] = True
525 names = {}
526 for key in showvars or self["variables"]:
527 for k in self["variables"].keymap[key]:
528 names[k.str_without(exclude)] = k
529 for key in self.name_collision_varkeys():
530 del key.descr["necessarylineage"]
531 return names
533 def savemat(self, filename="solution.mat", showvars=None,
534 excluded=("unnecessary lineage", "vec")):
535 "Saves primal solution as matlab file"
536 from scipy.io import savemat
537 savemat(filename,
538 {name.replace(".", "_"): np.array(self["variables"][key], "f")
539 for name, key in self.varnames(showvars, excluded).items()})
541 def todataframe(self, showvars=None,
542 excluded=("unnecessary lineage", "vec")):
543 "Returns primal solution as pandas dataframe"
544 import pandas as pd # pylint:disable=import-error
545 rows = []
546 cols = ["Name", "Index", "Value", "Units", "Label",
547 "Lineage", "Other"]
548 for _, key in sorted(self.varnames(showvars, excluded).items(),
549 key=lambda k: k[0]):
550 value = self["variables"][key]
551 if key.shape:
552 idxs = []
553 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
554 while not it.finished:
555 idx = it.multi_index
556 idxs.append(idx[0] if len(idx) == 1 else idx)
557 it.iternext()
558 else:
559 idxs = [None]
560 for idx in idxs:
561 row = [
562 key.name,
563 "" if idx is None else idx,
564 value if idx is None else value[idx]]
565 rows.append(row)
566 row.extend([
567 key.unitstr(),
568 key.label or "",
569 key.lineage or "",
570 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
571 if k not in ["name", "units", "unitrepr",
572 "idx", "shape", "veckey",
573 "value", "original_fn",
574 "lineage", "label"])])
575 return pd.DataFrame(rows, columns=cols)
577 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs):
578 "Saves solution table as a text file"
579 with open(filename, "w") as f:
580 if printmodel:
581 f.write(self.modelstr + "\n")
582 f.write(self.table(**kwargs))
584 def savecsv(self, showvars=None, filename="solution.csv", valcols=5):
585 "Saves primal solution as a CSV sorted by modelname, like the tables."
586 data = self["variables"]
587 if showvars:
588 showvars = self._parse_showvars(showvars)
589 data = {k: data[k] for k in showvars if k in data}
590 # if the columns don't capture any dimensions, skip them
591 minspan, maxspan = None, 1
592 for v in data.values():
593 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
594 minspan_ = min((di for di in v.shape if di != 1))
595 maxspan_ = max((di for di in v.shape if di != 1))
596 if minspan is None or minspan_ < minspan:
597 minspan = minspan_
598 if maxspan is None or maxspan_ > maxspan:
599 maxspan = maxspan_
600 if minspan is not None and minspan > valcols:
601 valcols = 1
602 if maxspan < valcols:
603 valcols = maxspan
604 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
605 tables=("cost", "sweepvariables", "freevariables",
606 "constants", "sensitivities"))
607 with open(filename, "w") as f:
608 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
609 + "Units,Description\n")
610 for line in lines:
611 if line[0] == ("newmodelline",):
612 f.write(line[1])
613 elif not line[1]: # spacer line
614 f.write("\n")
615 else:
616 f.write("," + line[0].replace(" : ", "") + ",")
617 vals = line[1].replace("[", "").replace("]", "").strip()
618 for el in vals.split():
619 f.write(el + ",")
620 f.write(","*(valcols - len(vals.split())))
621 f.write((line[2].replace("[", "").replace("]", "").strip()
622 + ","))
623 f.write(line[3].strip() + "\n")
625 def subinto(self, posy):
626 "Returns NomialArray of each solution substituted into posy."
627 if posy in self["variables"]:
628 return self["variables"](posy)
630 if not hasattr(posy, "sub"):
631 raise ValueError("no variable '%s' found in the solution" % posy)
633 if len(self) > 1:
634 return NomialArray([self.atindex(i).subinto(posy)
635 for i in range(len(self))])
637 return posy.sub(self["variables"])
639 def _parse_showvars(self, showvars):
640 showvars_out = set()
641 for k in showvars:
642 k, _ = self["variables"].parse_and_index(k)
643 keys = self["variables"].keymap[k]
644 showvars_out.update(keys)
645 return showvars_out
647 def summary(self, showvars=(), ntopsenss=5, **kwargs):
648 "Print summary table, showing top sensitivities and no constants"
649 showvars = self._parse_showvars(showvars)
650 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
651 "freevariables"], **kwargs)
652 constants_in_showvars = showvars.intersection(self["constants"])
653 senss_tables = []
654 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
655 senss_tables.append("sensitivities")
656 if len(self["constants"]) >= ntopsenss+2:
657 senss_tables.append("top sensitivities")
658 senss_tables.append("tightest constraints")
659 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
660 **kwargs)
661 if senss_str:
662 out += "\n" + senss_str
663 return out
665 def table(self, showvars=(),
666 tables=("cost", "warnings", "model sensitivities",
667 "sweepvariables", "freevariables",
668 "constants", "sensitivities", "tightest constraints"),
669 sortmodelsbysenss=True, **kwargs):
670 """A table representation of this SolutionArray
672 Arguments
673 ---------
674 tables: Iterable
675 Which to print of ("cost", "sweepvariables", "freevariables",
676 "constants", "sensitivities")
677 fixedcols: If true, print vectors in fixed-width format
678 latex: int
679 If > 0, return latex format (options 1-3); otherwise plain text
680 included_models: Iterable of strings
681 If specified, the models (by name) to include
682 excluded_models: Iterable of strings
683 If specified, model names to exclude
685 Returns
686 -------
687 str
688 """
689 if sortmodelsbysenss and "sensitivities" in self:
690 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
691 else:
692 kwargs["sortmodelsbysenss"] = False
693 varlist = list(self["variables"])
694 has_only_one_model = True
695 for var in varlist[1:]:
696 if var.lineage != varlist[0].lineage:
697 has_only_one_model = False
698 break
699 if has_only_one_model:
700 kwargs["sortbymodel"] = False
701 for key in self.name_collision_varkeys():
702 key.descr["necessarylineage"] = True
703 showvars = self._parse_showvars(showvars)
704 strs = []
705 for table in tables:
706 if "sensitivities" not in self and ("sensitivities" in table or
707 "constraints" in table):
708 continue
709 elif table == "cost":
710 cost = self["cost"] # pylint: disable=unsubscriptable-object
711 if kwargs.get("latex", None): # cost is not printed for latex
712 continue
713 strs += ["\n%s\n------------" % "Optimal Cost"]
714 if len(self) > 1:
715 costs = ["%-8.3g" % c for c in mag(cost[:4])]
716 strs += [" [ %s %s ]" % (" ".join(costs),
717 "..." if len(self) > 4 else "")]
718 else:
719 strs += [" %-.4g" % mag(cost)]
720 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
721 strs += [""]
722 elif table in TABLEFNS:
723 strs += TABLEFNS[table](self, showvars, **kwargs)
724 elif table in self:
725 data = self[table]
726 if showvars:
727 showvars = self._parse_showvars(showvars)
728 data = {k: data[k] for k in showvars if k in data}
729 strs += var_table(data, self.table_titles[table], **kwargs)
730 if kwargs.get("latex", None):
731 preamble = "\n".join(("% \\documentclass[12pt]{article}",
732 "% \\usepackage{booktabs}",
733 "% \\usepackage{longtable}",
734 "% \\usepackage{amsmath}",
735 "% \\begin{document}\n"))
736 strs = [preamble] + strs + ["% \\end{document}"]
737 for key in self.name_collision_varkeys():
738 del key.descr["necessarylineage"]
739 return "\n".join(strs)
741 def plot(self, posys=None, axes=None):
742 "Plots a sweep for each posy"
743 if len(self["sweepvariables"]) != 1:
744 print("SolutionArray.plot only supports 1-dimensional sweeps")
745 if not hasattr(posys, "__len__"):
746 posys = [posys]
747 import matplotlib.pyplot as plt
748 from .interactive.plot_sweep import assign_axes
749 from . import GPBLU
750 (swept, x), = self["sweepvariables"].items()
751 posys, axes = assign_axes(swept, posys, axes)
752 for posy, ax in zip(posys, axes):
753 y = self(posy) if posy not in [None, "cost"] else self["cost"]
754 ax.plot(x, y, color=GPBLU)
755 if len(axes) == 1:
756 axes, = axes
757 return plt.gcf(), axes
760# pylint: disable=too-many-branches,too-many-locals,too-many-statements
761def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
762 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
763 minval=0, sortbyvals=False, hidebelowminval=False,
764 included_models=None, excluded_models=None, sortbymodel=True,
765 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
766 """
767 Pretty string representation of a dict of VarKeys
768 Iterable values are handled specially (partial printing)
770 Arguments
771 ---------
772 data : dict whose keys are VarKey's
773 data to represent in table
774 title : string
775 printunits : bool
776 latex : int
777 If > 0, return latex format (options 1-3); otherwise plain text
778 varfmt : string
779 format for variable names
780 valfmt : string
781 format for scalar values
782 vecfmt : string
783 format for vector values
784 minval : float
785 skip values with all(abs(value)) < minval
786 sortbyvals : boolean
787 If true, rows are sorted by their average value instead of by name.
788 included_models : Iterable of strings
789 If specified, the models (by name) to include
790 excluded_models : Iterable of strings
791 If specified, model names to exclude
792 """
793 if not data:
794 return []
795 decorated, models = [], set()
796 for i, (k, v) in enumerate(data.items()):
797 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
798 continue # no values below minval
799 if minval and hidebelowminval and getattr(v, "shape", None):
800 v[np.abs(v) <= minval] = np.nan
801 model = lineagestr(k.lineage) if sortbymodel else ""
802 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0
803 if hasattr(msenss, "shape"):
804 msenss = np.mean(msenss)
805 models.add(model)
806 b = bool(getattr(v, "shape", None))
807 s = k.str_without(("lineage", "vec"))
808 if not sortbyvals:
809 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
810 else: # for consistent sorting, add small offset to negative vals
811 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
812 sort = (float("%.4g" % -val), k.name)
813 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
814 if not decorated and skipifempty:
815 return []
816 if included_models:
817 included_models = set(included_models)
818 included_models.add("")
819 models = models.intersection(included_models)
820 if excluded_models:
821 models = models.difference(excluded_models)
822 decorated.sort()
823 previous_model, lines = None, []
824 for varlist in decorated:
825 if sortbyvals:
826 model, _, msenss, isvector, varstr, _, var, val = varlist
827 else:
828 msenss, model, isvector, varstr, _, var, val = varlist
829 if model not in models:
830 continue
831 if model != previous_model:
832 if lines:
833 lines.append(["", "", "", ""])
834 if model:
835 if not latex:
836 lines.append([("newmodelline",), model, "", ""])
837 else:
838 lines.append(
839 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
840 previous_model = model
841 label = var.descr.get("label", "")
842 units = var.unitstr(" [%s] ") if printunits else ""
843 if not isvector:
844 valstr = valfmt % val
845 else:
846 last_dim_index = len(val.shape)-1
847 horiz_dim, ncols = last_dim_index, 1 # starting values
848 for dim_idx, dim_size in enumerate(val.shape):
849 if ncols <= dim_size <= maxcolumns:
850 horiz_dim, ncols = dim_idx, dim_size
851 # align the array with horiz_dim by making it the last one
852 dim_order = list(range(last_dim_index))
853 dim_order.insert(horiz_dim, last_dim_index)
854 flatval = val.transpose(dim_order).flatten()
855 vals = [vecfmt % v for v in flatval[:ncols]]
856 bracket = " ] " if len(flatval) <= ncols else ""
857 valstr = "[ %s%s" % (" ".join(vals), bracket)
858 for before, after in VALSTR_REPLACES:
859 valstr = valstr.replace(before, after)
860 if not latex:
861 lines.append([varstr, valstr, units, label])
862 if isvector and len(flatval) > ncols:
863 values_remaining = len(flatval) - ncols
864 while values_remaining > 0:
865 idx = len(flatval)-values_remaining
866 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
867 values_remaining -= ncols
868 valstr = " " + " ".join(vals)
869 for before, after in VALSTR_REPLACES:
870 valstr = valstr.replace(before, after)
871 if values_remaining <= 0:
872 spaces = (-values_remaining
873 * len(valstr)//(values_remaining + ncols))
874 valstr = valstr + " ]" + " "*spaces
875 lines.append(["", valstr, "", ""])
876 else:
877 varstr = "$%s$" % varstr.replace(" : ", "")
878 if latex == 1: # normal results table
879 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
880 label])
881 coltitles = [title, "Value", "Units", "Description"]
882 elif latex == 2: # no values
883 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
884 coltitles = [title, "Units", "Description"]
885 elif latex == 3: # no description
886 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
887 coltitles = [title, "Value", "Units"]
888 else:
889 raise ValueError("Unexpected latex option, %s." % latex)
890 if rawlines:
891 return lines
892 if not latex:
893 if lines:
894 maxlens = np.max([list(map(len, line)) for line in lines
895 if line[0] != ("newmodelline",)], axis=0)
896 dirs = [">", "<", "<", "<"]
897 # check lengths before using zip
898 assert len(list(dirs)) == len(list(maxlens))
899 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
900 for i, line in enumerate(lines):
901 if line[0] == ("newmodelline",):
902 line = [fmts[0].format(" | "), line[1]]
903 else:
904 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
905 lines[i] = "".join(line).rstrip()
906 lines = [title] + ["-"*len(title)] + lines + [""]
907 else:
908 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
909 lines = (["\n".join(["{\\footnotesize",
910 "\\begin{longtable}{%s}" % colfmt[latex],
911 "\\toprule",
912 " & ".join(coltitles) + " \\\\ \\midrule"])] +
913 [" & ".join(l) + " \\\\" for l in lines] +
914 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
915 return lines