Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import json
4import difflib
5from operator import sub
6import warnings as pywarnings
7import pickle
8import gzip
9import pickletools
10import numpy as np
11from .nomials import NomialArray
12from .small_classes import DictOfLists, Strings
13from .small_scripts import mag, try_str_without
14from .repr_conventions import unitstr, lineagestr
17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
19VALSTR_REPLACES = [
20 ("+nan", " nan"),
21 ("-nan", " nan"),
22 ("nan%", "nan "),
23 ("nan", " - "),
24]
27class SolSavingEnvironment:
28 """Temporarily removes construction/solve attributes from constraints.
30 This approximately halves the size of the pickled solution.
31 """
33 def __init__(self, solarray, saveconstraints):
34 self.solarray = solarray
35 self.attrstore = {}
36 self.saveconstraints = saveconstraints
37 self.constraintstore = None
40 def __enter__(self):
41 if self.saveconstraints:
42 for constraint_attr in ["bounded", "meq_bounded", "vks",
43 "v_ss", "unsubbed", "varkeys"]:
44 store = {}
45 for constraint in self.solarray["sensitivities"]["constraints"]:
46 if getattr(constraint, constraint_attr, None):
47 store[constraint] = getattr(constraint, constraint_attr)
48 delattr(constraint, constraint_attr)
49 self.attrstore[constraint_attr] = store
50 else:
51 self.constraintstore = \
52 self.solarray["sensitivities"].pop("constraints")
54 def __exit__(self, type_, val, traceback):
55 if self.saveconstraints:
56 for constraint_attr, store in self.attrstore.items():
57 for constraint, value in store.items():
58 setattr(constraint, constraint_attr, value)
59 else:
60 self.solarray["sensitivities"]["constraints"] = self.constraintstore
62def msenss_table(data, _, **kwargs):
63 "Returns model sensitivity table lines"
64 if "models" not in data.get("sensitivities", {}):
65 return ""
66 data = sorted(data["sensitivities"]["models"].items(),
67 key=lambda i: ((i[1] < 0.1).all(),
68 -np.max(i[1]) if (i[1] < 0.1).all()
69 else -round(np.mean(i[1]), 1), i[0]))
70 lines = ["Model Sensitivities", "-------------------"]
71 if kwargs["sortmodelsbysenss"]:
72 lines[0] += " (sorts models in sections below)"
73 previousmsenssstr = ""
74 for model, msenss in data:
75 if not model: # for now let's only do named models
76 continue
77 if (msenss < 0.1).all():
78 msenss = np.max(msenss)
79 if msenss:
80 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss))
81 else:
82 msenssstr = " =0 "
83 else:
84 meansenss = round(np.mean(msenss), 1)
85 msenssstr = "%+6.1f" % meansenss
86 deltas = msenss - meansenss
87 if np.max(np.abs(deltas)) > 0.1:
88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
89 for d in deltas]
90 msenssstr += " + [ %s ]" % " ".join(deltastrs)
91 if msenssstr == previousmsenssstr:
92 msenssstr = " "*len(msenssstr)
93 else:
94 previousmsenssstr = msenssstr
95 lines.append("%s : %s" % (msenssstr, model))
96 return lines + [""] if len(lines) > 3 else []
99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
100 "Returns sensitivity table lines"
101 if "variables" in data.get("sensitivities", {}):
102 data = data["sensitivities"]["variables"]
103 if showvars:
104 data = {k: data[k] for k in showvars if k in data}
105 return var_table(data, title, sortbyvals=True, skipifempty=True,
106 valfmt="%+-.2g ", vecfmt="%+-8.2g",
107 printunits=False, minval=1e-3, **kwargs)
110def topsenss_table(data, showvars, nvars=5, **kwargs):
111 "Returns top sensitivity table lines"
112 data, filtered = topsenss_filter(data, showvars, nvars)
113 title = "Most Sensitive Variables"
114 if filtered:
115 title = "Next Most Sensitive Variables"
116 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
119def topsenss_filter(data, showvars, nvars=5):
120 "Filters sensitivities down to top N vars"
121 if "variables" in data.get("sensitivities", {}):
122 data = data["sensitivities"]["variables"]
123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
124 if not np.isnan(s).any()}
125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
126 filter_already_shown = showvars.intersection(topk)
127 for k in filter_already_shown:
128 topk.remove(k)
129 if nvars > 3: # always show at least 3
130 nvars -= 1
131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
134def insenss_table(data, _, maxval=0.1, **kwargs):
135 "Returns insensitivity table lines"
136 if "constants" in data.get("sensitivities", {}):
137 data = data["sensitivities"]["variables"]
138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
143 "Return constraint tightness lines"
144 title = "Most Sensitive Constraints"
145 if len(self) > 1:
146 title += " (in last sweep)"
147 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
148 "%+6.2g" % s[-1], id(c), c)
149 for c, s in self["sensitivities"]["constraints"].items()
150 if s[-1] >= tight_senss)[:ntightconstrs]
151 else:
152 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
153 for c, s in self["sensitivities"]["constraints"].items()
154 if s >= tight_senss)[:ntightconstrs]
155 return constraint_table(data, title, **kwargs)
157def loose_table(self, _, min_senss=1e-5, **kwargs):
158 "Return constraint tightness lines"
159 title = "Insensitive Constraints |below %+g|" % min_senss
160 if len(self) > 1:
161 title += " (in last sweep)"
162 data = [(0, "", id(c), c)
163 for c, s in self["sensitivities"]["constraints"].items()
164 if s[-1] <= min_senss]
165 else:
166 data = [(0, "", id(c), c)
167 for c, s in self["sensitivities"]["constraints"].items()
168 if s <= min_senss]
169 return constraint_table(data, title, **kwargs)
172# pylint: disable=too-many-branches,too-many-locals,too-many-statements
173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
174 "Creates lines for tables where the right side is a constraint."
175 # TODO: this should support 1D array inputs from sweeps
176 excluded = ("units", "unnecessary lineage")
177 if not showmodels:
178 excluded = ("units", "lineage") # hide all of it
179 models, decorated = {}, []
180 for sortby, openingstr, _, constraint in sorted(data):
181 model = lineagestr(constraint) if sortbymodel else ""
182 if model not in models:
183 models[model] = len(models)
184 constrstr = try_str_without(constraint, excluded)
185 if " at 0x" in constrstr: # don't print memory addresses
186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
187 decorated.append((models[model], model, sortby, constrstr, openingstr))
188 decorated.sort()
189 previous_model, lines = None, []
190 for varlist in decorated:
191 _, model, _, constrstr, openingstr = varlist
192 if model != previous_model:
193 if lines:
194 lines.append(["", ""])
195 if model or lines:
196 lines.append([("newmodelline",), model])
197 previous_model = model
198 constrstr = constrstr.replace(model, "")
199 minlen, maxlen = 25, 80
200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
201 constraintlines = []
202 line = ""
203 next_idx = 0
204 while next_idx < len(segments):
205 segment = segments[next_idx]
206 next_idx += 1
207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
208 segments[next_idx] = segment[1:] + segments[next_idx]
209 segment = segment[0]
210 elif len(line) + len(segment) > maxlen and len(line) > minlen:
211 constraintlines.append(line)
212 line = " " # start a new line
213 line += segment
214 while len(line) > maxlen:
215 constraintlines.append(line[:maxlen])
216 line = " " + line[maxlen:]
217 constraintlines.append(line)
218 lines += [(openingstr + " : ", constraintlines[0])]
219 lines += [("", l) for l in constraintlines[1:]]
220 if not lines:
221 lines = [("", "(none)")]
222 maxlens = np.max([list(map(len, line)) for line in lines
223 if line[0] != ("newmodelline",)], axis=0)
224 dirs = [">", "<"] # we'll check lengths before using zip
225 assert len(list(dirs)) == len(list(maxlens))
226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
227 for i, line in enumerate(lines):
228 if line[0] == ("newmodelline",):
229 linelist = [fmts[0].format(" | "), line[1]]
230 else:
231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
232 lines[i] = "".join(linelist).rstrip()
233 return [title] + ["-"*len(title)] + lines + [""]
236def warnings_table(self, _, **kwargs):
237 "Makes a table for all warnings in the solution."
238 title = "WARNINGS"
239 lines = ["~"*len(title), title, "~"*len(title)]
240 if "warnings" not in self or not self["warnings"]:
241 return []
242 for wtype in sorted(self["warnings"]):
243 data_vec = self["warnings"][wtype]
244 if len(data_vec) == 0:
245 continue
246 if not hasattr(data_vec, "shape"):
247 data_vec = [data_vec] # not a sweep
248 else:
249 all_equal = True
250 for data in data_vec[1:]:
251 eq_i = (data == data_vec[0])
252 if hasattr(eq_i, "all"):
253 eq_i = eq_i.all()
254 if not eq_i:
255 all_equal = False
256 break
257 if all_equal:
258 data_vec = [data_vec[0]] # warnings identical across sweeps
259 for i, data in enumerate(data_vec):
260 if len(data) == 0:
261 continue
262 data = sorted(data, key=lambda l: l[0]) # sort by msg
263 title = wtype
264 if len(data_vec) > 1:
265 title += " in sweep %i" % i
266 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
267 data = [(-int(1e5*relax_sensitivity),
268 "%+6.2g" % relax_sensitivity, id(c), c)
269 for _, (relax_sensitivity, c) in data]
270 lines += constraint_table(data, title, **kwargs)
271 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
272 data = [(-int(1e5*rel_diff),
273 "%.4g %s %.4g" % tightvalues, id(c), c)
274 for _, (rel_diff, tightvalues, c) in data]
275 lines += constraint_table(data, title, **kwargs)
276 else:
277 lines += [title] + ["-"*len(wtype)]
278 lines += [msg for msg, _ in data] + [""]
279 if len(lines) == 3: # just the header
280 return []
281 lines[-1] = "~~~~~~~~"
282 return lines + [""]
285TABLEFNS = {"sensitivities": senss_table,
286 "top sensitivities": topsenss_table,
287 "insensitivities": insenss_table,
288 "model sensitivities": msenss_table,
289 "tightest constraints": tight_table,
290 "loose constraints": loose_table,
291 "warnings": warnings_table,
292 }
294def unrolled_absmax(values):
295 "From an iterable of numbers and arrays, returns the largest magnitude"
296 finalval, absmaxest = None, 0
297 for val in values:
298 absmaxval = np.abs(val).max()
299 if absmaxval >= absmaxest:
300 absmaxest, finalval = absmaxval, val
301 if getattr(finalval, "shape", None):
302 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
303 finalval.shape)]
304 return finalval
307def cast(function, val1, val2):
308 "Relative difference between val1 and val2 (positive if val2 is larger)"
309 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
310 pywarnings.simplefilter("ignore")
311 if hasattr(val1, "shape") and hasattr(val2, "shape"):
312 if val1.ndim == val2.ndim:
313 return function(val1, val2)
314 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
315 dimdelta = dimmest.ndim - lessdim.ndim
316 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
317 if dimmest is val1:
318 return function(dimmest, lessdim[add_axes])
319 if dimmest is val2:
320 return function(lessdim[add_axes], dimmest)
321 return function(val1, val2)
324class SolutionArray(DictOfLists):
325 """A dictionary (of dictionaries) of lists, with convenience methods.
327 Items
328 -----
329 cost : array
330 variables: dict of arrays
331 sensitivities: dict containing:
332 monomials : array
333 posynomials : array
334 variables: dict of arrays
335 localmodels : NomialArray
336 Local power-law fits (small sensitivities are cut off)
338 Example
339 -------
340 >>> import gpkit
341 >>> import numpy as np
342 >>> x = gpkit.Variable("x")
343 >>> x_min = gpkit.Variable("x_{min}", 2)
344 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
345 >>>
346 >>> # VALUES
347 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
348 >>> assert all(np.array(values) == 2)
349 >>>
350 >>> # SENSITIVITIES
351 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
352 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
353 >>> assert all(np.array(senss) == 1)
354 """
355 modelstr = ""
356 _name_collision_varkeys = None
357 table_titles = {"choicevariables": "Choice Variables",
358 "sweepvariables": "Swept Variables",
359 "freevariables": "Free Variables",
360 "constants": "Fixed Variables", # TODO: change everywhere
361 "variables": "Variables"}
363 def name_collision_varkeys(self):
364 "Returns the set of contained varkeys whose names are not unique"
365 if self._name_collision_varkeys is None:
366 self["variables"].update_keymap()
367 keymap = self["variables"].keymap
368 self._name_collision_varkeys = set()
369 for key in list(keymap):
370 if hasattr(key, "key"):
371 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
372 self._name_collision_varkeys.add(key)
373 return self._name_collision_varkeys
375 def __len__(self):
376 try:
377 return len(self["cost"])
378 except TypeError:
379 return 1
380 except KeyError:
381 return 0
383 def __call__(self, posy):
384 posy_subbed = self.subinto(posy)
385 return getattr(posy_subbed, "c", posy_subbed)
387 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
388 "Checks for almost-equality between two solutions"
389 svars, ovars = self["variables"], other["variables"]
390 svks, ovks = set(svars), set(ovars)
391 if svks != ovks:
392 return False
393 for key in svks:
394 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
395 return False
396 if abs(self["sensitivities"]["variables"][key]
397 - other["sensitivities"]["variables"][key]) >= sens_abstol:
398 return False
399 return True
401 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
402 def diff(self, other, showvars=None, *,
403 constraintsdiff=True, senssdiff=False, sensstol=0.1,
404 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
405 sortmodelsbysenss=True, **tableargs):
406 """Outputs differences between this solution and another
408 Arguments
409 ---------
410 other : solution or string
411 strings will be treated as paths to pickled solutions
412 senssdiff : boolean
413 if True, show sensitivity differences
414 sensstol : float
415 the smallest sensitivity difference worth showing
416 absdiff : boolean
417 if True, show absolute differences
418 abstol : float
419 the smallest absolute difference worth showing
420 reldiff : boolean
421 if True, show relative differences
422 reltol : float
423 the smallest relative difference worth showing
425 Returns
426 -------
427 str
428 """
429 if sortmodelsbysenss:
430 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
431 else:
432 tableargs["sortmodelsbysenss"] = False
433 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
434 "skipifempty": False})
435 if isinstance(other, Strings):
436 if other[-4:] == ".pgz":
437 other = SolutionArray.decompress_file(other)
438 else:
439 other = pickle.load(open(other, "rb"))
440 svars, ovars = self["variables"], other["variables"]
441 lines = ["Solution Diff",
442 "=============",
443 "(argument is the baseline solution)", ""]
444 svks, ovks = set(svars), set(ovars)
445 if showvars:
446 lines[0] += " (for selected variables)"
447 lines[1] += "========================="
448 showvars = self._parse_showvars(showvars)
449 svks = {k for k in showvars if k in svars}
450 ovks = {k for k in showvars if k in ovars}
451 if constraintsdiff and other.modelstr and self.modelstr:
452 if self.modelstr == other.modelstr:
453 lines += ["** no constraint differences **", ""]
454 else:
455 cdiff = ["Constraint Differences",
456 "**********************"]
457 cdiff.extend(list(difflib.unified_diff(
458 other.modelstr.split("\n"), self.modelstr.split("\n"),
459 lineterm="", n=3))[2:])
460 cdiff += ["", "**********************", ""]
461 lines += cdiff
462 if svks - ovks:
463 lines.append("Variable(s) of this solution"
464 " which are not in the argument:")
465 lines.append("\n".join(" %s" % key for key in svks - ovks))
466 lines.append("")
467 if ovks - svks:
468 lines.append("Variable(s) of the argument"
469 " which are not in this solution:")
470 lines.append("\n".join(" %s" % key for key in ovks - svks))
471 lines.append("")
472 sharedvks = svks.intersection(ovks)
473 if reldiff:
474 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
475 for vk in sharedvks}
476 lines += var_table(rel_diff,
477 "Relative Differences |above %g%%|" % reltol,
478 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
479 minval=reltol, printunits=False, **tableargs)
480 if lines[-2][:10] == "-"*10: # nothing larger than reltol
481 lines.insert(-1, ("The largest is %+g%%."
482 % unrolled_absmax(rel_diff.values())))
483 if absdiff:
484 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
485 lines += var_table(abs_diff,
486 "Absolute Differences |above %g|" % abstol,
487 valfmt="%+.2g", vecfmt="%+8.2g",
488 minval=abstol, **tableargs)
489 if lines[-2][:10] == "-"*10: # nothing larger than abstol
490 lines.insert(-1, ("The largest is %+g."
491 % unrolled_absmax(abs_diff.values())))
492 if senssdiff:
493 ssenss = self["sensitivities"]["variables"]
494 osenss = other["sensitivities"]["variables"]
495 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
496 for vk in svks.intersection(ovks)}
497 lines += var_table(senss_delta,
498 "Sensitivity Differences |above %g|" % sensstol,
499 valfmt="%+-.2f ", vecfmt="%+-6.2f",
500 minval=sensstol, printunits=False, **tableargs)
501 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
502 lines.insert(-1, ("The largest is %+g."
503 % unrolled_absmax(senss_delta.values())))
504 return "\n".join(lines)
506 def save(self, filename="solution.pkl",
507 *, saveconstraints=True, **pickleargs):
508 """Pickles the solution and saves it to a file.
510 Solution can then be loaded with e.g.:
511 >>> import pickle
512 >>> pickle.load(open("solution.pkl"))
513 """
514 with SolSavingEnvironment(self, saveconstraints):
515 pickle.dump(self, open(filename, "wb"), **pickleargs)
517 def save_compressed(self, filename="solution.pgz",
518 *, saveconstraints=True, **cpickleargs):
519 "Pickle a file and then compress it into a file with extension."
520 with gzip.open(filename, "wb") as f:
521 with SolSavingEnvironment(self, saveconstraints):
522 pickled = pickle.dumps(self, **cpickleargs)
523 f.write(pickletools.optimize(pickled))
525 @staticmethod
526 def decompress_file(file):
527 "Load a gzip-compressed pickle file"
528 with gzip.open(file, "rb") as f:
529 return pickle.Unpickler(f).load()
531 def varnames(self, showvars, exclude):
532 "Returns list of variables, optionally with minimal unique names"
533 if showvars:
534 showvars = self._parse_showvars(showvars)
535 for key in self.name_collision_varkeys():
536 key.descr["necessarylineage"] = True
537 names = {}
538 for key in showvars or self["variables"]:
539 for k in self["variables"].keymap[key]:
540 names[k.str_without(exclude)] = k
541 for key in self.name_collision_varkeys():
542 del key.descr["necessarylineage"]
543 return names
545 def savemat(self, filename="solution.mat", *, showvars=None,
546 excluded=("unnecessary lineage", "vec")):
547 "Saves primal solution as matlab file"
548 from scipy.io import savemat
549 savemat(filename,
550 {name.replace(".", "_"): np.array(self["variables"][key], "f")
551 for name, key in self.varnames(showvars, excluded).items()})
553 def todataframe(self, showvars=None,
554 excluded=("unnecessary lineage", "vec")):
555 "Returns primal solution as pandas dataframe"
556 import pandas as pd # pylint:disable=import-error
557 rows = []
558 cols = ["Name", "Index", "Value", "Units", "Label",
559 "Lineage", "Other"]
560 for _, key in sorted(self.varnames(showvars, excluded).items(),
561 key=lambda k: k[0]):
562 value = self["variables"][key]
563 if key.shape:
564 idxs = []
565 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
566 while not it.finished:
567 idx = it.multi_index
568 idxs.append(idx[0] if len(idx) == 1 else idx)
569 it.iternext()
570 else:
571 idxs = [None]
572 for idx in idxs:
573 row = [
574 key.name,
575 "" if idx is None else idx,
576 value if idx is None else value[idx]]
577 rows.append(row)
578 row.extend([
579 key.unitstr(),
580 key.label or "",
581 key.lineage or "",
582 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
583 if k not in ["name", "units", "unitrepr",
584 "idx", "shape", "veckey",
585 "value", "vecfn",
586 "lineage", "label"])])
587 return pd.DataFrame(rows, columns=cols)
589 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
590 "Saves solution table as a text file"
591 with open(filename, "w") as f:
592 if printmodel:
593 f.write(self.modelstr + "\n")
594 f.write(self.table(**kwargs))
596 def savejson(self, filename="solution.json", showvars=None):
597 "Saves solution table as a json file"
598 sol_dict = {}
599 for key in self.name_collision_varkeys():
600 key.descr["necessarylineage"] = True
601 data = self["variables"]
602 if showvars:
603 showvars = self._parse_showvars(showvars)
604 data = {k: data[k] for k in showvars if k in data}
605 # add appropriate data for each variable to the dictionary
606 for k, v in data.items():
607 key = str(k)
608 if isinstance(v, np.ndarray):
609 val = {"v": v.tolist(), "u": k.unitstr()}
610 else:
611 val = {"v": v, "u": k.unitstr()}
612 sol_dict[key] = val
613 for key in self.name_collision_varkeys():
614 del key.descr["necessarylineage"]
615 with open(filename, "w") as f:
616 json.dump(sol_dict, f)
618 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
619 "Saves primal solution as a CSV sorted by modelname, like the tables."
620 data = self["variables"]
621 if showvars:
622 showvars = self._parse_showvars(showvars)
623 data = {k: data[k] for k in showvars if k in data}
624 # if the columns don't capture any dimensions, skip them
625 minspan, maxspan = None, 1
626 for v in data.values():
627 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
628 minspan_ = min((di for di in v.shape if di != 1))
629 maxspan_ = max((di for di in v.shape if di != 1))
630 if minspan is None or minspan_ < minspan:
631 minspan = minspan_
632 if maxspan is None or maxspan_ > maxspan:
633 maxspan = maxspan_
634 if minspan is not None and minspan > valcols:
635 valcols = 1
636 if maxspan < valcols:
637 valcols = maxspan
638 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
639 tables=("cost", "sweepvariables", "freevariables",
640 "constants", "sensitivities"))
641 with open(filename, "w") as f:
642 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
643 + "Units,Description\n")
644 for line in lines:
645 if line[0] == ("newmodelline",):
646 f.write(line[1])
647 elif not line[1]: # spacer line
648 f.write("\n")
649 else:
650 f.write("," + line[0].replace(" : ", "") + ",")
651 vals = line[1].replace("[", "").replace("]", "").strip()
652 for el in vals.split():
653 f.write(el + ",")
654 f.write(","*(valcols - len(vals.split())))
655 f.write((line[2].replace("[", "").replace("]", "").strip()
656 + ","))
657 f.write(line[3].strip() + "\n")
659 def subinto(self, posy):
660 "Returns NomialArray of each solution substituted into posy."
661 if posy in self["variables"]:
662 return self["variables"](posy)
664 if not hasattr(posy, "sub"):
665 raise ValueError("no variable '%s' found in the solution" % posy)
667 if len(self) > 1:
668 return NomialArray([self.atindex(i).subinto(posy)
669 for i in range(len(self))])
671 return posy.sub(self["variables"])
673 def _parse_showvars(self, showvars):
674 showvars_out = set()
675 for k in showvars:
676 k, _ = self["variables"].parse_and_index(k)
677 keys = self["variables"].keymap[k]
678 showvars_out.update(keys)
679 return showvars_out
681 def summary(self, showvars=(), ntopsenss=5, **kwargs):
682 "Print summary table, showing top sensitivities and no constants"
683 showvars = self._parse_showvars(showvars)
684 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
685 "freevariables"], **kwargs)
686 constants_in_showvars = showvars.intersection(self["constants"])
687 senss_tables = []
688 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
689 senss_tables.append("sensitivities")
690 if len(self["constants"]) >= ntopsenss+2:
691 senss_tables.append("top sensitivities")
692 senss_tables.append("tightest constraints")
693 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
694 **kwargs)
695 if senss_str:
696 out += "\n" + senss_str
697 return out
699 def table(self, showvars=(),
700 tables=("cost", "warnings", "model sensitivities",
701 "sweepvariables", "freevariables",
702 "constants", "sensitivities", "tightest constraints"),
703 sortmodelsbysenss=True, **kwargs):
704 """A table representation of this SolutionArray
706 Arguments
707 ---------
708 tables: Iterable
709 Which to print of ("cost", "sweepvariables", "freevariables",
710 "constants", "sensitivities")
711 fixedcols: If true, print vectors in fixed-width format
712 latex: int
713 If > 0, return latex format (options 1-3); otherwise plain text
714 included_models: Iterable of strings
715 If specified, the models (by name) to include
716 excluded_models: Iterable of strings
717 If specified, model names to exclude
719 Returns
720 -------
721 str
722 """
723 if sortmodelsbysenss and "sensitivities" in self:
724 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
725 else:
726 kwargs["sortmodelsbysenss"] = False
727 varlist = list(self["variables"])
728 has_only_one_model = True
729 for var in varlist[1:]:
730 if var.lineage != varlist[0].lineage:
731 has_only_one_model = False
732 break
733 if has_only_one_model:
734 kwargs["sortbymodel"] = False
735 for key in self.name_collision_varkeys():
736 key.descr["necessarylineage"] = True
737 showvars = self._parse_showvars(showvars)
738 strs = []
739 for table in tables:
740 if "sensitivities" not in self and ("sensitivities" in table or
741 "constraints" in table):
742 continue
743 if table == "cost":
744 cost = self["cost"] # pylint: disable=unsubscriptable-object
745 if kwargs.get("latex", None): # cost is not printed for latex
746 continue
747 strs += ["\n%s\n------------" % "Optimal Cost"]
748 if len(self) > 1:
749 costs = ["%-8.3g" % c for c in mag(cost[:4])]
750 strs += [" [ %s %s ]" % (" ".join(costs),
751 "..." if len(self) > 4 else "")]
752 else:
753 strs += [" %-.4g" % mag(cost)]
754 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
755 strs += [""]
756 elif table in TABLEFNS:
757 strs += TABLEFNS[table](self, showvars, **kwargs)
758 elif table in self:
759 data = self[table]
760 if showvars:
761 showvars = self._parse_showvars(showvars)
762 data = {k: data[k] for k in showvars if k in data}
763 strs += var_table(data, self.table_titles[table], **kwargs)
764 if kwargs.get("latex", None):
765 preamble = "\n".join(("% \\documentclass[12pt]{article}",
766 "% \\usepackage{booktabs}",
767 "% \\usepackage{longtable}",
768 "% \\usepackage{amsmath}",
769 "% \\begin{document}\n"))
770 strs = [preamble] + strs + ["% \\end{document}"]
771 for key in self.name_collision_varkeys():
772 del key.descr["necessarylineage"]
773 return "\n".join(strs)
775 def plot(self, posys=None, axes=None):
776 "Plots a sweep for each posy"
777 if len(self["sweepvariables"]) != 1:
778 print("SolutionArray.plot only supports 1-dimensional sweeps")
779 if not hasattr(posys, "__len__"):
780 posys = [posys]
781 import matplotlib.pyplot as plt
782 from .interactive.plot_sweep import assign_axes
783 from . import GPBLU
784 (swept, x), = self["sweepvariables"].items()
785 posys, axes = assign_axes(swept, posys, axes)
786 for posy, ax in zip(posys, axes):
787 y = self(posy) if posy not in [None, "cost"] else self["cost"]
788 ax.plot(x, y, color=GPBLU)
789 if len(axes) == 1:
790 axes, = axes
791 return plt.gcf(), axes
794# pylint: disable=too-many-branches,too-many-locals,too-many-statements
795def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
796 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
797 minval=0, sortbyvals=False, hidebelowminval=False,
798 included_models=None, excluded_models=None, sortbymodel=True,
799 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
800 """
801 Pretty string representation of a dict of VarKeys
802 Iterable values are handled specially (partial printing)
804 Arguments
805 ---------
806 data : dict whose keys are VarKey's
807 data to represent in table
808 title : string
809 printunits : bool
810 latex : int
811 If > 0, return latex format (options 1-3); otherwise plain text
812 varfmt : string
813 format for variable names
814 valfmt : string
815 format for scalar values
816 vecfmt : string
817 format for vector values
818 minval : float
819 skip values with all(abs(value)) < minval
820 sortbyvals : boolean
821 If true, rows are sorted by their average value instead of by name.
822 included_models : Iterable of strings
823 If specified, the models (by name) to include
824 excluded_models : Iterable of strings
825 If specified, model names to exclude
826 """
827 if not data:
828 return []
829 decorated, models = [], set()
830 for i, (k, v) in enumerate(data.items()):
831 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
832 continue # no values below minval
833 if minval and hidebelowminval and getattr(v, "shape", None):
834 v[np.abs(v) <= minval] = np.nan
835 model = lineagestr(k.lineage) if sortbymodel else ""
836 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0
837 if hasattr(msenss, "shape"):
838 msenss = np.mean(msenss)
839 models.add(model)
840 b = bool(getattr(v, "shape", None))
841 s = k.str_without(("lineage", "vec"))
842 if not sortbyvals:
843 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
844 else: # for consistent sorting, add small offset to negative vals
845 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
846 sort = (float("%.4g" % -val), k.name)
847 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
848 if not decorated and skipifempty:
849 return []
850 if included_models:
851 included_models = set(included_models)
852 included_models.add("")
853 models = models.intersection(included_models)
854 if excluded_models:
855 models = models.difference(excluded_models)
856 decorated.sort()
857 previous_model, lines = None, []
858 for varlist in decorated:
859 if sortbyvals:
860 model, _, msenss, isvector, varstr, _, var, val = varlist
861 else:
862 msenss, model, isvector, varstr, _, var, val = varlist
863 if model not in models:
864 continue
865 if model != previous_model:
866 if lines:
867 lines.append(["", "", "", ""])
868 if model:
869 if not latex:
870 lines.append([("newmodelline",), model, "", ""])
871 else:
872 lines.append(
873 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
874 previous_model = model
875 label = var.descr.get("label", "")
876 units = var.unitstr(" [%s] ") if printunits else ""
877 if not isvector:
878 valstr = valfmt % val
879 else:
880 last_dim_index = len(val.shape)-1
881 horiz_dim, ncols = last_dim_index, 1 # starting values
882 for dim_idx, dim_size in enumerate(val.shape):
883 if ncols <= dim_size <= maxcolumns:
884 horiz_dim, ncols = dim_idx, dim_size
885 # align the array with horiz_dim by making it the last one
886 dim_order = list(range(last_dim_index))
887 dim_order.insert(horiz_dim, last_dim_index)
888 flatval = val.transpose(dim_order).flatten()
889 vals = [vecfmt % v for v in flatval[:ncols]]
890 bracket = " ] " if len(flatval) <= ncols else ""
891 valstr = "[ %s%s" % (" ".join(vals), bracket)
892 for before, after in VALSTR_REPLACES:
893 valstr = valstr.replace(before, after)
894 if not latex:
895 lines.append([varstr, valstr, units, label])
896 if isvector and len(flatval) > ncols:
897 values_remaining = len(flatval) - ncols
898 while values_remaining > 0:
899 idx = len(flatval)-values_remaining
900 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
901 values_remaining -= ncols
902 valstr = " " + " ".join(vals)
903 for before, after in VALSTR_REPLACES:
904 valstr = valstr.replace(before, after)
905 if values_remaining <= 0:
906 spaces = (-values_remaining
907 * len(valstr)//(values_remaining + ncols))
908 valstr = valstr + " ]" + " "*spaces
909 lines.append(["", valstr, "", ""])
910 else:
911 varstr = "$%s$" % varstr.replace(" : ", "")
912 if latex == 1: # normal results table
913 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
914 label])
915 coltitles = [title, "Value", "Units", "Description"]
916 elif latex == 2: # no values
917 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
918 coltitles = [title, "Units", "Description"]
919 elif latex == 3: # no description
920 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
921 coltitles = [title, "Value", "Units"]
922 else:
923 raise ValueError("Unexpected latex option, %s." % latex)
924 if rawlines:
925 return lines
926 if not latex:
927 if lines:
928 maxlens = np.max([list(map(len, line)) for line in lines
929 if line[0] != ("newmodelline",)], axis=0)
930 dirs = [">", "<", "<", "<"]
931 # check lengths before using zip
932 assert len(list(dirs)) == len(list(maxlens))
933 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
934 for i, line in enumerate(lines):
935 if line[0] == ("newmodelline",):
936 line = [fmts[0].format(" | "), line[1]]
937 else:
938 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
939 lines[i] = "".join(line).rstrip()
940 lines = [title] + ["-"*len(title)] + lines + [""]
941 else:
942 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
943 lines = (["\n".join(["{\\footnotesize",
944 "\\begin{longtable}{%s}" % colfmt[latex],
945 "\\toprule",
946 " & ".join(coltitles) + " \\\\ \\midrule"])] +
947 [" & ".join(l) + " \\\\" for l in lines] +
948 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
949 return lines