Coverage for gpkit/solution_array.py : 83%
![Show keyboard shortcuts](keybd_closed.png)
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from .nomials import NomialArray
11from .small_classes import DictOfLists, Strings
12from .small_scripts import mag, try_str_without
13from .repr_conventions import unitstr, lineagestr
16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
18VALSTR_REPLACES = [
19 ("+nan", " nan"),
20 ("-nan", " nan"),
21 ("nan%", "nan "),
22 ("nan", " - "),
23]
26class SolSavingEnvironment:
27 """Temporarily removes construction/solve attributes from constraints.
29 This approximately halves the size of the pickled solution.
30 """
32 def __init__(self, solarray, saveconstraints):
33 self.solarray = solarray
34 self.attrstore = {}
35 self.saveconstraints = saveconstraints
36 self.constraintstore = None
39 def __enter__(self):
40 if self.saveconstraints:
41 for constraint_attr in ["bounded", "meq_bounded", "vks",
42 "v_ss", "unsubbed", "varkeys"]:
43 store = {}
44 for constraint in self.solarray["sensitivities"]["constraints"]:
45 if getattr(constraint, constraint_attr, None):
46 store[constraint] = getattr(constraint, constraint_attr)
47 delattr(constraint, constraint_attr)
48 self.attrstore[constraint_attr] = store
49 else:
50 self.constraintstore = \
51 self.solarray["sensitivities"].pop("constraints")
53 def __exit__(self, type_, val, traceback):
54 if self.saveconstraints:
55 for constraint_attr, store in self.attrstore.items():
56 for constraint, value in store.items():
57 setattr(constraint, constraint_attr, value)
58 else:
59 self.solarray["sensitivities"]["constraints"] = self.constraintstore
61def msenss_table(data, _, **kwargs):
62 "Returns model sensitivity table lines"
63 if "models" not in data.get("sensitivities", {}):
64 return ""
65 data = sorted(data["sensitivities"]["models"].items(),
66 key=lambda i: (-round(np.mean(i[1]), 1), i[0]))
67 lines = ["Model Sensitivities", "-------------------"]
68 if kwargs["sortmodelsbysenss"]:
69 lines[0] += " (sorts models in sections below)"
70 previousmsenssstr = ""
71 for model, msenss in data:
72 if not model: # for now let's only do named models
73 continue
74 if (msenss < 0.1).all():
75 msenss = np.max(msenss)
76 if msenss:
77 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss))
78 else:
79 msenssstr = " =0 "
80 elif not msenss.shape:
81 msenssstr = "%+6.1f" % msenss
82 else:
83 meansenss = np.mean(msenss)
84 msenssstr = "%+6.1f" % meansenss
85 deltas = msenss - meansenss
86 if np.max(np.abs(deltas)) > 0.1:
87 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
88 for d in deltas]
89 msenssstr += " + [ %s ]" % " ".join(deltastrs)
90 if msenssstr == previousmsenssstr:
91 msenssstr = " "
92 else:
93 previousmsenssstr = msenssstr
94 lines.append("%s : %s" % (msenssstr, model))
95 return lines + [""] if len(lines) > 3 else []
98def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
99 "Returns sensitivity table lines"
100 if "variables" in data.get("sensitivities", {}):
101 data = data["sensitivities"]["variables"]
102 if showvars:
103 data = {k: data[k] for k in showvars if k in data}
104 return var_table(data, title, sortbyvals=True, skipifempty=True,
105 valfmt="%+-.2g ", vecfmt="%+-8.2g",
106 printunits=False, minval=1e-3, **kwargs)
109def topsenss_table(data, showvars, nvars=5, **kwargs):
110 "Returns top sensitivity table lines"
111 data, filtered = topsenss_filter(data, showvars, nvars)
112 title = "Most Sensitive Variables"
113 if filtered:
114 title = "Next Most Sensitive Variables"
115 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
118def topsenss_filter(data, showvars, nvars=5):
119 "Filters sensitivities down to top N vars"
120 if "variables" in data.get("sensitivities", {}):
121 data = data["sensitivities"]["variables"]
122 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
123 if not np.isnan(s).any()}
124 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
125 filter_already_shown = showvars.intersection(topk)
126 for k in filter_already_shown:
127 topk.remove(k)
128 if nvars > 3: # always show at least 3
129 nvars -= 1
130 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
133def insenss_table(data, _, maxval=0.1, **kwargs):
134 "Returns insensitivity table lines"
135 if "constants" in data.get("sensitivities", {}):
136 data = data["sensitivities"]["variables"]
137 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
138 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
141def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
142 "Return constraint tightness lines"
143 title = "Most Sensitive Constraints"
144 if len(self) > 1:
145 title += " (in last sweep)"
146 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)),
147 "%+6.2g" % abs(s[-1]), id(c), c)
148 for c, s in self["sensitivities"]["constraints"].items()
149 if s[-1] >= tight_senss)[:ntightconstrs]
150 else:
151 data = sorted(((-float("%+6.2g" % abs(s)), str(c)),
152 "%+6.2g" % abs(s), id(c), c)
153 for c, s in self["sensitivities"]["constraints"].items()
154 if s >= tight_senss)[:ntightconstrs]
155 return constraint_table(data, title, **kwargs)
157def loose_table(self, _, min_senss=1e-5, **kwargs):
158 "Return constraint tightness lines"
159 title = "Insensitive Constraints |below %+g|" % min_senss
160 if len(self) > 1:
161 title += " (in last sweep)"
162 data = [(0, "", id(c), c)
163 for c, s in self["sensitivities"]["constraints"].items()
164 if s[-1] <= min_senss]
165 else:
166 data = [(0, "", id(c), c)
167 for c, s in self["sensitivities"]["constraints"].items()
168 if s <= min_senss]
169 return constraint_table(data, title, **kwargs)
172# pylint: disable=too-many-branches,too-many-locals,too-many-statements
173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
174 "Creates lines for tables where the right side is a constraint."
175 # TODO: this should support 1D array inputs from sweeps
176 excluded = ("units", "unnecessary lineage")
177 if not showmodels:
178 excluded = ("units", "lineage") # hide all of it
179 models, decorated = {}, []
180 for sortby, openingstr, _, constraint in sorted(data):
181 model = lineagestr(constraint) if sortbymodel else ""
182 if model not in models:
183 models[model] = len(models)
184 constrstr = try_str_without(constraint, excluded)
185 if " at 0x" in constrstr: # don't print memory addresses
186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
187 decorated.append((models[model], model, sortby, constrstr, openingstr))
188 decorated.sort()
189 previous_model, lines = None, []
190 for varlist in decorated:
191 _, model, _, constrstr, openingstr = varlist
192 if model != previous_model:
193 if lines:
194 lines.append(["", ""])
195 if model or lines:
196 lines.append([("newmodelline",), model])
197 previous_model = model
198 constrstr = constrstr.replace(model, "")
199 minlen, maxlen = 25, 80
200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
201 constraintlines = []
202 line = ""
203 next_idx = 0
204 while next_idx < len(segments):
205 segment = segments[next_idx]
206 next_idx += 1
207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
208 segments[next_idx] = segment[1:] + segments[next_idx]
209 segment = segment[0]
210 elif len(line) + len(segment) > maxlen and len(line) > minlen:
211 constraintlines.append(line)
212 line = " " # start a new line
213 line += segment
214 while len(line) > maxlen:
215 constraintlines.append(line[:maxlen])
216 line = " " + line[maxlen:]
217 constraintlines.append(line)
218 lines += [(openingstr + " : ", constraintlines[0])]
219 lines += [("", l) for l in constraintlines[1:]]
220 if not lines:
221 lines = [("", "(none)")]
222 maxlens = np.max([list(map(len, line)) for line in lines
223 if line[0] != ("newmodelline",)], axis=0)
224 dirs = [">", "<"] # we'll check lengths before using zip
225 assert len(list(dirs)) == len(list(maxlens))
226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
227 for i, line in enumerate(lines):
228 if line[0] == ("newmodelline",):
229 linelist = [fmts[0].format(" | "), line[1]]
230 else:
231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
232 lines[i] = "".join(linelist).rstrip()
233 return [title] + ["-"*len(title)] + lines + [""]
236def warnings_table(self, _, **kwargs):
237 "Makes a table for all warnings in the solution."
238 title = "WARNINGS"
239 lines = ["~"*len(title), title, "~"*len(title)]
240 if "warnings" not in self or not self["warnings"]:
241 return []
242 for wtype in sorted(self["warnings"]):
243 data_vec = self["warnings"][wtype]
244 if len(data_vec) == 0:
245 continue
246 if not hasattr(data_vec, "shape"):
247 data_vec = [data_vec] # not a sweep
248 if all((data == data_vec[0]).all() for data in data_vec[1:]):
249 data_vec = [data_vec[0]] # warnings identical across all sweeps
250 for i, data in enumerate(data_vec):
251 if len(data) == 0:
252 continue
253 data = sorted(data, key=lambda l: l[0]) # sort by msg
254 title = wtype
255 if len(data_vec) > 1:
256 title += " in sweep %i" % i
257 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
258 data = [(-int(1e5*c.relax_sensitivity),
259 "%+6.2g" % c.relax_sensitivity, id(c), c)
260 for _, c in data]
261 lines += constraint_table(data, title, **kwargs)
262 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
263 data = [(-int(1e5*c.rel_diff),
264 "%.4g %s %.4g" % c.tightvalues, id(c), c)
265 for _, c in data]
266 lines += constraint_table(data, title, **kwargs)
267 else:
268 lines += [title] + ["-"*len(wtype)]
269 lines += [msg for msg, _ in data] + [""]
270 lines[-1] = "~~~~~~~~"
271 return lines + [""]
274TABLEFNS = {"sensitivities": senss_table,
275 "top sensitivities": topsenss_table,
276 "insensitivities": insenss_table,
277 "model sensitivities": msenss_table,
278 "tightest constraints": tight_table,
279 "loose constraints": loose_table,
280 "warnings": warnings_table,
281 }
283def unrolled_absmax(values):
284 "From an iterable of numbers and arrays, returns the largest magnitude"
285 finalval, absmaxest = None, 0
286 for val in values:
287 absmaxval = np.abs(val).max()
288 if absmaxval >= absmaxest:
289 absmaxest, finalval = absmaxval, val
290 if getattr(finalval, "shape", None):
291 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
292 finalval.shape)]
293 return finalval
296def cast(function, val1, val2):
297 "Relative difference between val1 and val2 (positive if val2 is larger)"
298 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
299 pywarnings.simplefilter("ignore")
300 if hasattr(val1, "shape") and hasattr(val2, "shape"):
301 if val1.ndim == val2.ndim:
302 return function(val1, val2)
303 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
304 dimdelta = dimmest.ndim - lessdim.ndim
305 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
306 if dimmest is val1:
307 return function(dimmest, lessdim[add_axes])
308 if dimmest is val2:
309 return function(lessdim[add_axes], dimmest)
310 return function(val1, val2)
313class SolutionArray(DictOfLists):
314 """A dictionary (of dictionaries) of lists, with convenience methods.
316 Items
317 -----
318 cost : array
319 variables: dict of arrays
320 sensitivities: dict containing:
321 monomials : array
322 posynomials : array
323 variables: dict of arrays
324 localmodels : NomialArray
325 Local power-law fits (small sensitivities are cut off)
327 Example
328 -------
329 >>> import gpkit
330 >>> import numpy as np
331 >>> x = gpkit.Variable("x")
332 >>> x_min = gpkit.Variable("x_{min}", 2)
333 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
334 >>>
335 >>> # VALUES
336 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
337 >>> assert all(np.array(values) == 2)
338 >>>
339 >>> # SENSITIVITIES
340 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
341 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
342 >>> assert all(np.array(senss) == 1)
343 """
344 modelstr = ""
345 _name_collision_varkeys = None
346 table_titles = {"choicevariables": "Choice Variables",
347 "sweepvariables": "Swept Variables",
348 "freevariables": "Free Variables",
349 "constants": "Fixed Variables", # TODO: change everywhere
350 "variables": "Variables"}
352 def name_collision_varkeys(self):
353 "Returns the set of contained varkeys whose names are not unique"
354 if self._name_collision_varkeys is None:
355 self["variables"].update_keymap()
356 keymap = self["variables"].keymap
357 self._name_collision_varkeys = set()
358 for key in list(keymap):
359 if hasattr(key, "key"):
360 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
361 self._name_collision_varkeys.add(key)
362 return self._name_collision_varkeys
364 def __len__(self):
365 try:
366 return len(self["cost"])
367 except TypeError:
368 return 1
369 except KeyError:
370 return 0
372 def __call__(self, posy):
373 posy_subbed = self.subinto(posy)
374 return getattr(posy_subbed, "c", posy_subbed)
376 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
377 "Checks for almost-equality between two solutions"
378 svars, ovars = self["variables"], other["variables"]
379 svks, ovks = set(svars), set(ovars)
380 if svks != ovks:
381 return False
382 for key in svks:
383 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
384 return False
385 if abs(self["sensitivities"]["variables"][key]
386 - other["sensitivities"]["variables"][key]) >= sens_abstol:
387 return False
388 return True
390 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
391 def diff(self, other, showvars=None, *,
392 constraintsdiff=True, senssdiff=False, sensstol=0.1,
393 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
394 sortmodelsbysenss=True, **tableargs):
395 """Outputs differences between this solution and another
397 Arguments
398 ---------
399 other : solution or string
400 strings will be treated as paths to pickled solutions
401 senssdiff : boolean
402 if True, show sensitivity differences
403 sensstol : float
404 the smallest sensitivity difference worth showing
405 absdiff : boolean
406 if True, show absolute differences
407 abstol : float
408 the smallest absolute difference worth showing
409 reldiff : boolean
410 if True, show relative differences
411 reltol : float
412 the smallest relative difference worth showing
414 Returns
415 -------
416 str
417 """
418 if sortmodelsbysenss:
419 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
420 else:
421 tableargs["sortmodelsbysenss"] = False
422 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
423 "skipifempty": False})
424 if isinstance(other, Strings):
425 if other[-4:] == ".pgz":
426 other = SolutionArray.decompress_file(other)
427 else:
428 other = pickle.load(open(other, "rb"))
429 svars, ovars = self["variables"], other["variables"]
430 lines = ["Solution Diff",
431 "=============",
432 "(argument is the baseline solution)", ""]
433 svks, ovks = set(svars), set(ovars)
434 if showvars:
435 lines[0] += " (for selected variables)"
436 lines[1] += "========================="
437 showvars = self._parse_showvars(showvars)
438 svks = {k for k in showvars if k in svars}
439 ovks = {k for k in showvars if k in ovars}
440 if constraintsdiff and other.modelstr and self.modelstr:
441 if self.modelstr == other.modelstr:
442 lines += ["** no constraint differences **", ""]
443 else:
444 cdiff = ["Constraint Differences",
445 "**********************"]
446 cdiff.extend(list(difflib.unified_diff(
447 other.modelstr.split("\n"), self.modelstr.split("\n"),
448 lineterm="", n=3))[2:])
449 cdiff += ["", "**********************", ""]
450 lines += cdiff
451 if svks - ovks:
452 lines.append("Variable(s) of this solution"
453 " which are not in the argument:")
454 lines.append("\n".join(" %s" % key for key in svks - ovks))
455 lines.append("")
456 if ovks - svks:
457 lines.append("Variable(s) of the argument"
458 " which are not in this solution:")
459 lines.append("\n".join(" %s" % key for key in ovks - svks))
460 lines.append("")
461 sharedvks = svks.intersection(ovks)
462 if reldiff:
463 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
464 for vk in sharedvks}
465 lines += var_table(rel_diff,
466 "Relative Differences |above %g%%|" % reltol,
467 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
468 minval=reltol, printunits=False, **tableargs)
469 if lines[-2][:10] == "-"*10: # nothing larger than reltol
470 lines.insert(-1, ("The largest is %+g%%."
471 % unrolled_absmax(rel_diff.values())))
472 if absdiff:
473 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
474 lines += var_table(abs_diff,
475 "Absolute Differences |above %g|" % abstol,
476 valfmt="%+.2g", vecfmt="%+8.2g",
477 minval=abstol, **tableargs)
478 if lines[-2][:10] == "-"*10: # nothing larger than abstol
479 lines.insert(-1, ("The largest is %+g."
480 % unrolled_absmax(abs_diff.values())))
481 if senssdiff:
482 ssenss = self["sensitivities"]["variables"]
483 osenss = other["sensitivities"]["variables"]
484 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
485 for vk in svks.intersection(ovks)}
486 lines += var_table(senss_delta,
487 "Sensitivity Differences |above %g|" % sensstol,
488 valfmt="%+-.2f ", vecfmt="%+-6.2f",
489 minval=sensstol, printunits=False, **tableargs)
490 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
491 lines.insert(-1, ("The largest is %+g."
492 % unrolled_absmax(senss_delta.values())))
493 return "\n".join(lines)
495 def save(self, filename="solution.pkl",
496 *, saveconstraints=True, **pickleargs):
497 """Pickles the solution and saves it to a file.
499 Solution can then be loaded with e.g.:
500 >>> import pickle
501 >>> pickle.load(open("solution.pkl"))
502 """
503 with SolSavingEnvironment(self, saveconstraints):
504 pickle.dump(self, open(filename, "wb"), **pickleargs)
506 def save_compressed(self, filename="solution.pgz",
507 *, saveconstraints=True, **cpickleargs):
508 "Pickle a file and then compress it into a file with extension."
509 with gzip.open(filename, "wb") as f:
510 with SolSavingEnvironment(self, saveconstraints):
511 pickled = pickle.dumps(self, **cpickleargs)
512 f.write(pickletools.optimize(pickled))
514 @staticmethod
515 def decompress_file(file):
516 "Load a gzip-compressed pickle file"
517 with gzip.open(file, "rb") as f:
518 return pickle.Unpickler(f).load()
520 def varnames(self, showvars, exclude):
521 "Returns list of variables, optionally with minimal unique names"
522 if showvars:
523 showvars = self._parse_showvars(showvars)
524 for key in self.name_collision_varkeys():
525 key.descr["necessarylineage"] = True
526 names = {}
527 for key in showvars or self["variables"]:
528 for k in self["variables"].keymap[key]:
529 names[k.str_without(exclude)] = k
530 for key in self.name_collision_varkeys():
531 del key.descr["necessarylineage"]
532 return names
534 def savemat(self, filename="solution.mat", *, showvars=None,
535 excluded=("unnecessary lineage", "vec")):
536 "Saves primal solution as matlab file"
537 from scipy.io import savemat
538 savemat(filename,
539 {name.replace(".", "_"): np.array(self["variables"][key], "f")
540 for name, key in self.varnames(showvars, excluded).items()})
542 def todataframe(self, showvars=None,
543 excluded=("unnecessary lineage", "vec")):
544 "Returns primal solution as pandas dataframe"
545 import pandas as pd # pylint:disable=import-error
546 rows = []
547 cols = ["Name", "Index", "Value", "Units", "Label",
548 "Lineage", "Other"]
549 for _, key in sorted(self.varnames(showvars, excluded).items(),
550 key=lambda k: k[0]):
551 value = self["variables"][key]
552 if key.shape:
553 idxs = []
554 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
555 while not it.finished:
556 idx = it.multi_index
557 idxs.append(idx[0] if len(idx) == 1 else idx)
558 it.iternext()
559 else:
560 idxs = [None]
561 for idx in idxs:
562 row = [
563 key.name,
564 "" if idx is None else idx,
565 value if idx is None else value[idx]]
566 rows.append(row)
567 row.extend([
568 key.unitstr(),
569 key.label or "",
570 key.lineage or "",
571 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
572 if k not in ["name", "units", "unitrepr",
573 "idx", "shape", "veckey",
574 "value", "vecfn",
575 "lineage", "label"])])
576 return pd.DataFrame(rows, columns=cols)
578 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
579 "Saves solution table as a text file"
580 with open(filename, "w") as f:
581 if printmodel:
582 f.write(self.modelstr + "\n")
583 f.write(self.table(**kwargs))
585 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
586 "Saves primal solution as a CSV sorted by modelname, like the tables."
587 data = self["variables"]
588 if showvars:
589 showvars = self._parse_showvars(showvars)
590 data = {k: data[k] for k in showvars if k in data}
591 # if the columns don't capture any dimensions, skip them
592 minspan, maxspan = None, 1
593 for v in data.values():
594 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
595 minspan_ = min((di for di in v.shape if di != 1))
596 maxspan_ = max((di for di in v.shape if di != 1))
597 if minspan is None or minspan_ < minspan:
598 minspan = minspan_
599 if maxspan is None or maxspan_ > maxspan:
600 maxspan = maxspan_
601 if minspan is not None and minspan > valcols:
602 valcols = 1
603 if maxspan < valcols:
604 valcols = maxspan
605 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
606 tables=("cost", "sweepvariables", "freevariables",
607 "constants", "sensitivities"))
608 with open(filename, "w") as f:
609 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
610 + "Units,Description\n")
611 for line in lines:
612 if line[0] == ("newmodelline",):
613 f.write(line[1])
614 elif not line[1]: # spacer line
615 f.write("\n")
616 else:
617 f.write("," + line[0].replace(" : ", "") + ",")
618 vals = line[1].replace("[", "").replace("]", "").strip()
619 for el in vals.split():
620 f.write(el + ",")
621 f.write(","*(valcols - len(vals.split())))
622 f.write((line[2].replace("[", "").replace("]", "").strip()
623 + ","))
624 f.write(line[3].strip() + "\n")
626 def subinto(self, posy):
627 "Returns NomialArray of each solution substituted into posy."
628 if posy in self["variables"]:
629 return self["variables"](posy)
631 if not hasattr(posy, "sub"):
632 raise ValueError("no variable '%s' found in the solution" % posy)
634 if len(self) > 1:
635 return NomialArray([self.atindex(i).subinto(posy)
636 for i in range(len(self))])
638 return posy.sub(self["variables"], require_positive=False)
640 def _parse_showvars(self, showvars):
641 showvars_out = set()
642 for k in showvars:
643 k, _ = self["variables"].parse_and_index(k)
644 keys = self["variables"].keymap[k]
645 showvars_out.update(keys)
646 return showvars_out
648 def summary(self, showvars=(), ntopsenss=5, **kwargs):
649 "Print summary table, showing top sensitivities and no constants"
650 showvars = self._parse_showvars(showvars)
651 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
652 "freevariables"], **kwargs)
653 constants_in_showvars = showvars.intersection(self["constants"])
654 senss_tables = []
655 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
656 senss_tables.append("sensitivities")
657 if len(self["constants"]) >= ntopsenss+2:
658 senss_tables.append("top sensitivities")
659 senss_tables.append("tightest constraints")
660 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
661 **kwargs)
662 if senss_str:
663 out += "\n" + senss_str
664 return out
666 def table(self, showvars=(),
667 tables=("cost", "warnings", "model sensitivities",
668 "sweepvariables", "freevariables",
669 "constants", "sensitivities", "tightest constraints"),
670 sortmodelsbysenss=True, **kwargs):
671 """A table representation of this SolutionArray
673 Arguments
674 ---------
675 tables: Iterable
676 Which to print of ("cost", "sweepvariables", "freevariables",
677 "constants", "sensitivities")
678 fixedcols: If true, print vectors in fixed-width format
679 latex: int
680 If > 0, return latex format (options 1-3); otherwise plain text
681 included_models: Iterable of strings
682 If specified, the models (by name) to include
683 excluded_models: Iterable of strings
684 If specified, model names to exclude
686 Returns
687 -------
688 str
689 """
690 if sortmodelsbysenss and "sensitivities" in self:
691 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
692 else:
693 kwargs["sortmodelsbysenss"] = False
694 varlist = list(self["variables"])
695 has_only_one_model = True
696 for var in varlist[1:]:
697 if var.lineage != varlist[0].lineage:
698 has_only_one_model = False
699 break
700 if has_only_one_model:
701 kwargs["sortbymodel"] = False
702 for key in self.name_collision_varkeys():
703 key.descr["necessarylineage"] = True
704 showvars = self._parse_showvars(showvars)
705 strs = []
706 for table in tables:
707 if "sensitivities" not in self and ("sensitivities" in table or
708 "constraints" in table):
709 continue
710 if table == "cost":
711 cost = self["cost"] # pylint: disable=unsubscriptable-object
712 if kwargs.get("latex", None): # cost is not printed for latex
713 continue
714 strs += ["\n%s\n------------" % "Optimal Cost"]
715 if len(self) > 1:
716 costs = ["%-8.3g" % c for c in mag(cost[:4])]
717 strs += [" [ %s %s ]" % (" ".join(costs),
718 "..." if len(self) > 4 else "")]
719 else:
720 strs += [" %-.4g" % mag(cost)]
721 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
722 strs += [""]
723 elif table in TABLEFNS:
724 strs += TABLEFNS[table](self, showvars, **kwargs)
725 elif table in self:
726 data = self[table]
727 if showvars:
728 showvars = self._parse_showvars(showvars)
729 data = {k: data[k] for k in showvars if k in data}
730 strs += var_table(data, self.table_titles[table], **kwargs)
731 if kwargs.get("latex", None):
732 preamble = "\n".join(("% \\documentclass[12pt]{article}",
733 "% \\usepackage{booktabs}",
734 "% \\usepackage{longtable}",
735 "% \\usepackage{amsmath}",
736 "% \\begin{document}\n"))
737 strs = [preamble] + strs + ["% \\end{document}"]
738 for key in self.name_collision_varkeys():
739 del key.descr["necessarylineage"]
740 return "\n".join(strs)
742 def plot(self, posys=None, axes=None):
743 "Plots a sweep for each posy"
744 if len(self["sweepvariables"]) != 1:
745 print("SolutionArray.plot only supports 1-dimensional sweeps")
746 if not hasattr(posys, "__len__"):
747 posys = [posys]
748 import matplotlib.pyplot as plt
749 from .interactive.plot_sweep import assign_axes
750 from . import GPBLU
751 (swept, x), = self["sweepvariables"].items()
752 posys, axes = assign_axes(swept, posys, axes)
753 for posy, ax in zip(posys, axes):
754 y = self(posy) if posy not in [None, "cost"] else self["cost"]
755 ax.plot(x, y, color=GPBLU)
756 if len(axes) == 1:
757 axes, = axes
758 return plt.gcf(), axes
761# pylint: disable=too-many-branches,too-many-locals,too-many-statements
762def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
763 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
764 minval=0, sortbyvals=False, hidebelowminval=False,
765 included_models=None, excluded_models=None, sortbymodel=True,
766 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
767 """
768 Pretty string representation of a dict of VarKeys
769 Iterable values are handled specially (partial printing)
771 Arguments
772 ---------
773 data : dict whose keys are VarKey's
774 data to represent in table
775 title : string
776 printunits : bool
777 latex : int
778 If > 0, return latex format (options 1-3); otherwise plain text
779 varfmt : string
780 format for variable names
781 valfmt : string
782 format for scalar values
783 vecfmt : string
784 format for vector values
785 minval : float
786 skip values with all(abs(value)) < minval
787 sortbyvals : boolean
788 If true, rows are sorted by their average value instead of by name.
789 included_models : Iterable of strings
790 If specified, the models (by name) to include
791 excluded_models : Iterable of strings
792 If specified, model names to exclude
793 """
794 if not data:
795 return []
796 decorated, models = [], set()
797 for i, (k, v) in enumerate(data.items()):
798 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
799 continue # no values below minval
800 if minval and hidebelowminval and getattr(v, "shape", None):
801 v[np.abs(v) <= minval] = np.nan
802 model = lineagestr(k.lineage) if sortbymodel else ""
803 if not sortmodelsbysenss:
804 msenss = 0
805 else: # sort should match that in msenss_table above
806 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 1)
807 models.add(model)
808 b = bool(getattr(v, "shape", None))
809 s = k.str_without(("lineage", "vec"))
810 if not sortbyvals:
811 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
812 else: # for consistent sorting, add small offset to negative vals
813 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
814 sort = (float("%.4g" % -val), k.name)
815 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
816 if not decorated and skipifempty:
817 return []
818 if included_models:
819 included_models = set(included_models)
820 included_models.add("")
821 models = models.intersection(included_models)
822 if excluded_models:
823 models = models.difference(excluded_models)
824 decorated.sort()
825 previous_model, lines = None, []
826 for varlist in decorated:
827 if sortbyvals:
828 model, _, msenss, isvector, varstr, _, var, val = varlist
829 else:
830 msenss, model, isvector, varstr, _, var, val = varlist
831 if model not in models:
832 continue
833 if model != previous_model:
834 if lines:
835 lines.append(["", "", "", ""])
836 if model:
837 if not latex:
838 lines.append([("newmodelline",), model, "", ""])
839 else:
840 lines.append(
841 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
842 previous_model = model
843 label = var.descr.get("label", "")
844 units = var.unitstr(" [%s] ") if printunits else ""
845 if not isvector:
846 valstr = valfmt % val
847 else:
848 last_dim_index = len(val.shape)-1
849 horiz_dim, ncols = last_dim_index, 1 # starting values
850 for dim_idx, dim_size in enumerate(val.shape):
851 if ncols <= dim_size <= maxcolumns:
852 horiz_dim, ncols = dim_idx, dim_size
853 # align the array with horiz_dim by making it the last one
854 dim_order = list(range(last_dim_index))
855 dim_order.insert(horiz_dim, last_dim_index)
856 flatval = val.transpose(dim_order).flatten()
857 vals = [vecfmt % v for v in flatval[:ncols]]
858 bracket = " ] " if len(flatval) <= ncols else ""
859 valstr = "[ %s%s" % (" ".join(vals), bracket)
860 for before, after in VALSTR_REPLACES:
861 valstr = valstr.replace(before, after)
862 if not latex:
863 lines.append([varstr, valstr, units, label])
864 if isvector and len(flatval) > ncols:
865 values_remaining = len(flatval) - ncols
866 while values_remaining > 0:
867 idx = len(flatval)-values_remaining
868 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
869 values_remaining -= ncols
870 valstr = " " + " ".join(vals)
871 for before, after in VALSTR_REPLACES:
872 valstr = valstr.replace(before, after)
873 if values_remaining <= 0:
874 spaces = (-values_remaining
875 * len(valstr)//(values_remaining + ncols))
876 valstr = valstr + " ]" + " "*spaces
877 lines.append(["", valstr, "", ""])
878 else:
879 varstr = "$%s$" % varstr.replace(" : ", "")
880 if latex == 1: # normal results table
881 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
882 label])
883 coltitles = [title, "Value", "Units", "Description"]
884 elif latex == 2: # no values
885 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
886 coltitles = [title, "Units", "Description"]
887 elif latex == 3: # no description
888 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
889 coltitles = [title, "Value", "Units"]
890 else:
891 raise ValueError("Unexpected latex option, %s." % latex)
892 if rawlines:
893 return lines
894 if not latex:
895 if lines:
896 maxlens = np.max([list(map(len, line)) for line in lines
897 if line[0] != ("newmodelline",)], axis=0)
898 dirs = [">", "<", "<", "<"]
899 # check lengths before using zip
900 assert len(list(dirs)) == len(list(maxlens))
901 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
902 for i, line in enumerate(lines):
903 if line[0] == ("newmodelline",):
904 line = [fmts[0].format(" | "), line[1]]
905 else:
906 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
907 lines[i] = "".join(line).rstrip()
908 lines = [title] + ["-"*len(title)] + lines + [""]
909 else:
910 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
911 lines = (["\n".join(["{\\footnotesize",
912 "\\begin{longtable}{%s}" % colfmt[latex],
913 "\\toprule",
914 " & ".join(coltitles) + " \\\\ \\midrule"])] +
915 [" & ".join(l) + " \\\\" for l in lines] +
916 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
917 return lines