Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9from collections import defaultdict
10import numpy as np
11from .nomials import NomialArray
12from .small_classes import DictOfLists, Strings
13from .small_scripts import mag, try_str_without
14from .repr_conventions import unitstr, lineagestr
17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
19VALSTR_REPLACES = [
20 ("+nan", " nan"),
21 ("-nan", " nan"),
22 ("nan%", "nan "),
23 ("nan", " - "),
24]
27class SolSavingEnvironment:
28 """Temporarily removes construction/solve attributes from constraints.
30 This approximately halves the size of the pickled solution.
31 """
33 def __init__(self, solarray, saveconstraints):
34 self.solarray = solarray
35 self.attrstore = {}
36 self.saveconstraints = saveconstraints
37 self.constraintstore = None
40 def __enter__(self):
41 if self.saveconstraints:
42 for constraint_attr in ["bounded", "meq_bounded", "vks",
43 "v_ss", "unsubbed", "varkeys"]:
44 store = {}
45 for constraint in self.solarray["sensitivities"]["constraints"]:
46 if getattr(constraint, constraint_attr, None):
47 store[constraint] = getattr(constraint, constraint_attr)
48 delattr(constraint, constraint_attr)
49 self.attrstore[constraint_attr] = store
50 else:
51 self.constraintstore = \
52 self.solarray["sensitivities"].pop("constraints")
54 def __exit__(self, type_, val, traceback):
55 if self.saveconstraints:
56 for constraint_attr, store in self.attrstore.items():
57 for constraint, value in store.items():
58 setattr(constraint, constraint_attr, value)
59 else:
60 self.solarray["sensitivities"]["constraints"] = self.constraintstore
62def msenss_table(data, _, **kwargs):
63 "Returns model sensitivity table lines"
64 if "models" not in data.get("sensitivities", {}):
65 return ""
66 data = sorted(data["sensitivities"]["models"].items(),
67 key=lambda i: ((i[1] < 0.1).all(),
68 -np.max(i[1]) if (i[1] < 0.1).all()
69 else -round(np.mean(i[1]), 1), i[0]))
70 lines = ["Model Sensitivities", "-------------------"]
71 if kwargs["sortmodelsbysenss"]:
72 lines[0] += " (sorts models in sections below)"
73 previousmsenssstr = ""
74 for model, msenss in data:
75 if not model: # for now let's only do named models
76 continue
77 if (msenss < 0.1).all():
78 msenss = np.max(msenss)
79 if msenss:
80 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss)))
81 else:
82 msenssstr = " =0 "
83 else:
84 meansenss = round(np.mean(msenss), 1)
85 msenssstr = "%+6.1f" % meansenss
86 deltas = msenss - meansenss
87 if np.max(np.abs(deltas)) > 0.1:
88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
89 for d in deltas]
90 msenssstr += " + [ %s ]" % " ".join(deltastrs)
91 if msenssstr == previousmsenssstr:
92 msenssstr = " "*len(msenssstr)
93 else:
94 previousmsenssstr = msenssstr
95 lines.append("%s : %s" % (msenssstr, model))
96 return lines + [""] if len(lines) > 3 else []
99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
100 "Returns sensitivity table lines"
101 if "variables" in data.get("sensitivities", {}):
102 data = data["sensitivities"]["variables"]
103 if showvars:
104 data = {k: data[k] for k in showvars if k in data}
105 return var_table(data, title, sortbyvals=True, skipifempty=True,
106 valfmt="%+-.2g ", vecfmt="%+-8.2g",
107 printunits=False, minval=1e-3, **kwargs)
110def topsenss_table(data, showvars, nvars=5, **kwargs):
111 "Returns top sensitivity table lines"
112 data, filtered = topsenss_filter(data, showvars, nvars)
113 title = "Most Sensitive Variables"
114 if filtered:
115 title = "Next Most Sensitive Variables"
116 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
119def topsenss_filter(data, showvars, nvars=5):
120 "Filters sensitivities down to top N vars"
121 if "variables" in data.get("sensitivities", {}):
122 data = data["sensitivities"]["variables"]
123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
124 if not np.isnan(s).any()}
125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
126 filter_already_shown = showvars.intersection(topk)
127 for k in filter_already_shown:
128 topk.remove(k)
129 if nvars > 3: # always show at least 3
130 nvars -= 1
131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
134def insenss_table(data, _, maxval=0.1, **kwargs):
135 "Returns insensitivity table lines"
136 if "constants" in data.get("sensitivities", {}):
137 data = data["sensitivities"]["variables"]
138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
143 "Return constraint tightness lines"
144 title = "Most Sensitive Constraints"
145 if len(self) > 1:
146 title += " (in last sweep)"
147 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)),
148 "%+6.2g" % abs(s[-1]), id(c), c)
149 for c, s in self["sensitivities"]["constraints"].items()
150 if s[-1] >= tight_senss)[:ntightconstrs]
151 else:
152 data = sorted(((-float("%+6.2g" % abs(s)), str(c)),
153 "%+6.2g" % abs(s), id(c), c)
154 for c, s in self["sensitivities"]["constraints"].items()
155 if s >= tight_senss)[:ntightconstrs]
156 return constraint_table(data, title, **kwargs)
158def loose_table(self, _, min_senss=1e-5, **kwargs):
159 "Return constraint tightness lines"
160 title = "Insensitive Constraints |below %+g|" % min_senss
161 if len(self) > 1:
162 title += " (in last sweep)"
163 data = [(0, "", id(c), c)
164 for c, s in self["sensitivities"]["constraints"].items()
165 if s[-1] <= min_senss]
166 else:
167 data = [(0, "", id(c), c)
168 for c, s in self["sensitivities"]["constraints"].items()
169 if s <= min_senss]
170 return constraint_table(data, title, **kwargs)
173# pylint: disable=too-many-branches,too-many-locals,too-many-statements
174def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
175 "Creates lines for tables where the right side is a constraint."
176 # TODO: this should support 1D array inputs from sweeps
177 excluded = ("units", "unnecessary lineage")
178 if not showmodels:
179 excluded = ("units", "lineage") # hide all of it
180 models, decorated = {}, []
181 for sortby, openingstr, _, constraint in sorted(data):
182 model = lineagestr(constraint) if sortbymodel else ""
183 if model not in models:
184 models[model] = len(models)
185 constrstr = try_str_without(
186 constraint, excluded + (":MAGIC:"+lineagestr(constraint),))
187 if " at 0x" in constrstr: # don't print memory addresses
188 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
189 decorated.append((models[model], model, sortby, constrstr, openingstr))
190 decorated.sort()
191 previous_model, lines = None, []
192 for varlist in decorated:
193 _, model, _, constrstr, openingstr = varlist
194 if model != previous_model:
195 if lines:
196 lines.append(["", ""])
197 if model or lines:
198 lines.append([("newmodelline",), model])
199 previous_model = model
200 minlen, maxlen = 25, 80
201 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
202 constraintlines = []
203 line = ""
204 next_idx = 0
205 while next_idx < len(segments):
206 segment = segments[next_idx]
207 next_idx += 1
208 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
209 segments[next_idx] = segment[1:] + segments[next_idx]
210 segment = segment[0]
211 elif len(line) + len(segment) > maxlen and len(line) > minlen:
212 constraintlines.append(line)
213 line = " " # start a new line
214 line += segment
215 while len(line) > maxlen:
216 constraintlines.append(line[:maxlen])
217 line = " " + line[maxlen:]
218 constraintlines.append(line)
219 lines += [(openingstr + " : ", constraintlines[0])]
220 lines += [("", l) for l in constraintlines[1:]]
221 if not lines:
222 lines = [("", "(none)")]
223 maxlens = np.max([list(map(len, line)) for line in lines
224 if line[0] != ("newmodelline",)], axis=0)
225 dirs = [">", "<"] # we'll check lengths before using zip
226 assert len(list(dirs)) == len(list(maxlens))
227 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
228 for i, line in enumerate(lines):
229 if line[0] == ("newmodelline",):
230 linelist = [fmts[0].format(" | "), line[1]]
231 else:
232 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
233 lines[i] = "".join(linelist).rstrip()
234 return [title] + ["-"*len(title)] + lines + [""]
237def warnings_table(self, _, **kwargs):
238 "Makes a table for all warnings in the solution."
239 title = "WARNINGS"
240 lines = ["~"*len(title), title, "~"*len(title)]
241 if "warnings" not in self or not self["warnings"]:
242 return []
243 for wtype in sorted(self["warnings"]):
244 data_vec = self["warnings"][wtype]
245 if len(data_vec) == 0:
246 continue
247 if not hasattr(data_vec, "shape"):
248 data_vec = [data_vec] # not a sweep
249 else:
250 all_equal = True
251 for data in data_vec[1:]:
252 eq_i = (data == data_vec[0])
253 if hasattr(eq_i, "all"):
254 eq_i = eq_i.all()
255 if not eq_i:
256 all_equal = False
257 break
258 if all_equal:
259 data_vec = [data_vec[0]] # warnings identical across sweeps
260 for i, data in enumerate(data_vec):
261 if len(data) == 0:
262 continue
263 data = sorted(data, key=lambda l: l[0]) # sort by msg
264 title = wtype
265 if len(data_vec) > 1:
266 title += " in sweep %i" % i
267 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
268 data = [(-int(1e5*relax_sensitivity),
269 "%+6.2g" % relax_sensitivity, id(c), c)
270 for _, (relax_sensitivity, c) in data]
271 lines += constraint_table(data, title, **kwargs)
272 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
273 data = [(-int(1e5*rel_diff),
274 "%.4g %s %.4g" % tightvalues, id(c), c)
275 for _, (rel_diff, tightvalues, c) in data]
276 lines += constraint_table(data, title, **kwargs)
277 else:
278 lines += [title] + ["-"*len(wtype)]
279 lines += [msg for msg, _ in data] + [""]
280 if len(lines) == 3: # just the header
281 return []
282 lines[-1] = "~~~~~~~~"
283 return lines + [""]
286TABLEFNS = {"sensitivities": senss_table,
287 "top sensitivities": topsenss_table,
288 "insensitivities": insenss_table,
289 "model sensitivities": msenss_table,
290 "tightest constraints": tight_table,
291 "loose constraints": loose_table,
292 "warnings": warnings_table,
293 }
295def unrolled_absmax(values):
296 "From an iterable of numbers and arrays, returns the largest magnitude"
297 finalval, absmaxest = None, 0
298 for val in values:
299 absmaxval = np.abs(val).max()
300 if absmaxval >= absmaxest:
301 absmaxest, finalval = absmaxval, val
302 if getattr(finalval, "shape", None):
303 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
304 finalval.shape)]
305 return finalval
308def cast(function, val1, val2):
309 "Relative difference between val1 and val2 (positive if val2 is larger)"
310 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
311 pywarnings.simplefilter("ignore")
312 if hasattr(val1, "shape") and hasattr(val2, "shape"):
313 if val1.ndim == val2.ndim:
314 return function(val1, val2)
315 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
316 dimdelta = dimmest.ndim - lessdim.ndim
317 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
318 if dimmest is val1:
319 return function(dimmest, lessdim[add_axes])
320 if dimmest is val2:
321 return function(lessdim[add_axes], dimmest)
322 return function(val1, val2)
325class SolutionArray(DictOfLists):
326 """A dictionary (of dictionaries) of lists, with convenience methods.
328 Items
329 -----
330 cost : array
331 variables: dict of arrays
332 sensitivities: dict containing:
333 monomials : array
334 posynomials : array
335 variables: dict of arrays
336 localmodels : NomialArray
337 Local power-law fits (small sensitivities are cut off)
339 Example
340 -------
341 >>> import gpkit
342 >>> import numpy as np
343 >>> x = gpkit.Variable("x")
344 >>> x_min = gpkit.Variable("x_{min}", 2)
345 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
346 >>>
347 >>> # VALUES
348 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
349 >>> assert all(np.array(values) == 2)
350 >>>
351 >>> # SENSITIVITIES
352 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
353 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
354 >>> assert all(np.array(senss) == 1)
355 """
356 modelstr = ""
357 _name_collision_varkeys = None
358 table_titles = {"choicevariables": "Choice Variables",
359 "sweepvariables": "Swept Variables",
360 "freevariables": "Free Variables",
361 "constants": "Fixed Variables", # TODO: change everywhere
362 "variables": "Variables"}
364 def set_necessarylineage(self, clear=False):
365 "Returns the set of contained varkeys whose names are not unique"
366 if self._name_collision_varkeys is None:
367 self._name_collision_varkeys = {}
368 self["variables"].update_keymap()
369 keymap = self["variables"].keymap
370 name_collisions = defaultdict(set)
371 for key in keymap:
372 if hasattr(key, "key"):
373 shortname = key.str_without(["lineage", "vec"])
374 if len(keymap[shortname]) > 1:
375 name_collisions[shortname].add(key)
376 for varkeys in name_collisions.values():
377 min_namespaced = defaultdict(set)
378 for vk in varkeys:
379 *_, mineage = vk.lineagestr().split(".")
380 min_namespaced[(mineage, 1)].add(vk)
381 while any(len(vks) > 1 for vks in min_namespaced.values()):
382 for key, vks in list(min_namespaced.items()):
383 if len(vks) <= 1:
384 continue
385 del min_namespaced[key]
386 mineage, idx = key
387 idx += 1
388 for vk in vks:
389 lineages = vk.lineagestr().split(".")
390 submineage = lineages[-idx] + "." + mineage
391 min_namespaced[(submineage, idx)].add(vk)
392 for (_, idx), vks in min_namespaced.items():
393 vk, = vks
394 self._name_collision_varkeys[vk] = idx
395 if clear:
396 for vk in self._name_collision_varkeys:
397 del vk.descr["necessarylineage"]
398 else:
399 for vk, idx in self._name_collision_varkeys.items():
400 vk.descr["necessarylineage"] = idx
402 def __len__(self):
403 try:
404 return len(self["cost"])
405 except TypeError:
406 return 1
407 except KeyError:
408 return 0
410 def __call__(self, posy):
411 posy_subbed = self.subinto(posy)
412 return getattr(posy_subbed, "c", posy_subbed)
414 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
415 "Checks for almost-equality between two solutions"
416 svars, ovars = self["variables"], other["variables"]
417 svks, ovks = set(svars), set(ovars)
418 if svks != ovks:
419 return False
420 for key in svks:
421 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
422 return False
423 if abs(self["sensitivities"]["variables"][key]
424 - other["sensitivities"]["variables"][key]) >= sens_abstol:
425 return False
426 return True
428 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
429 def diff(self, other, showvars=None, *,
430 constraintsdiff=True, senssdiff=False, sensstol=0.1,
431 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
432 sortmodelsbysenss=True, **tableargs):
433 """Outputs differences between this solution and another
435 Arguments
436 ---------
437 other : solution or string
438 strings will be treated as paths to pickled solutions
439 senssdiff : boolean
440 if True, show sensitivity differences
441 sensstol : float
442 the smallest sensitivity difference worth showing
443 absdiff : boolean
444 if True, show absolute differences
445 abstol : float
446 the smallest absolute difference worth showing
447 reldiff : boolean
448 if True, show relative differences
449 reltol : float
450 the smallest relative difference worth showing
452 Returns
453 -------
454 str
455 """
456 if sortmodelsbysenss:
457 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
458 else:
459 tableargs["sortmodelsbysenss"] = False
460 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
461 "skipifempty": False})
462 if isinstance(other, Strings):
463 if other[-4:] == ".pgz":
464 other = SolutionArray.decompress_file(other)
465 else:
466 other = pickle.load(open(other, "rb"))
467 svars, ovars = self["variables"], other["variables"]
468 lines = ["Solution Diff",
469 "=============",
470 "(argument is the baseline solution)", ""]
471 svks, ovks = set(svars), set(ovars)
472 if showvars:
473 lines[0] += " (for selected variables)"
474 lines[1] += "========================="
475 showvars = self._parse_showvars(showvars)
476 svks = {k for k in showvars if k in svars}
477 ovks = {k for k in showvars if k in ovars}
478 if constraintsdiff and other.modelstr and self.modelstr:
479 if self.modelstr == other.modelstr:
480 lines += ["** no constraint differences **", ""]
481 else:
482 cdiff = ["Constraint Differences",
483 "**********************"]
484 cdiff.extend(list(difflib.unified_diff(
485 other.modelstr.split("\n"), self.modelstr.split("\n"),
486 lineterm="", n=3))[2:])
487 cdiff += ["", "**********************", ""]
488 lines += cdiff
489 if svks - ovks:
490 lines.append("Variable(s) of this solution"
491 " which are not in the argument:")
492 lines.append("\n".join(" %s" % key for key in svks - ovks))
493 lines.append("")
494 if ovks - svks:
495 lines.append("Variable(s) of the argument"
496 " which are not in this solution:")
497 lines.append("\n".join(" %s" % key for key in ovks - svks))
498 lines.append("")
499 sharedvks = svks.intersection(ovks)
500 if reldiff:
501 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
502 for vk in sharedvks}
503 lines += var_table(rel_diff,
504 "Relative Differences |above %g%%|" % reltol,
505 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
506 minval=reltol, printunits=False, **tableargs)
507 if lines[-2][:10] == "-"*10: # nothing larger than reltol
508 lines.insert(-1, ("The largest is %+g%%."
509 % unrolled_absmax(rel_diff.values())))
510 if absdiff:
511 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
512 lines += var_table(abs_diff,
513 "Absolute Differences |above %g|" % abstol,
514 valfmt="%+.2g", vecfmt="%+8.2g",
515 minval=abstol, **tableargs)
516 if lines[-2][:10] == "-"*10: # nothing larger than abstol
517 lines.insert(-1, ("The largest is %+g."
518 % unrolled_absmax(abs_diff.values())))
519 if senssdiff:
520 ssenss = self["sensitivities"]["variables"]
521 osenss = other["sensitivities"]["variables"]
522 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
523 for vk in svks.intersection(ovks)}
524 lines += var_table(senss_delta,
525 "Sensitivity Differences |above %g|" % sensstol,
526 valfmt="%+-.2f ", vecfmt="%+-6.2f",
527 minval=sensstol, printunits=False, **tableargs)
528 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
529 lines.insert(-1, ("The largest is %+g."
530 % unrolled_absmax(senss_delta.values())))
531 return "\n".join(lines)
533 def save(self, filename="solution.pkl",
534 *, saveconstraints=True, **pickleargs):
535 """Pickles the solution and saves it to a file.
537 Solution can then be loaded with e.g.:
538 >>> import pickle
539 >>> pickle.load(open("solution.pkl"))
540 """
541 with SolSavingEnvironment(self, saveconstraints):
542 pickle.dump(self, open(filename, "wb"), **pickleargs)
544 def save_compressed(self, filename="solution.pgz",
545 *, saveconstraints=True, **cpickleargs):
546 "Pickle a file and then compress it into a file with extension."
547 with gzip.open(filename, "wb") as f:
548 with SolSavingEnvironment(self, saveconstraints):
549 pickled = pickle.dumps(self, **cpickleargs)
550 f.write(pickletools.optimize(pickled))
552 @staticmethod
553 def decompress_file(file):
554 "Load a gzip-compressed pickle file"
555 with gzip.open(file, "rb") as f:
556 return pickle.Unpickler(f).load()
558 def varnames(self, showvars, exclude):
559 "Returns list of variables, optionally with minimal unique names"
560 if showvars:
561 showvars = self._parse_showvars(showvars)
562 self.set_necessarylineage()
563 names = {}
564 for key in showvars or self["variables"]:
565 for k in self["variables"].keymap[key]:
566 names[k.str_without(exclude)] = k
567 self.set_necessarylineage(clear=True)
568 return names
570 def savemat(self, filename="solution.mat", *, showvars=None,
571 excluded=("unnecessary lineage", "vec")):
572 "Saves primal solution as matlab file"
573 from scipy.io import savemat
574 savemat(filename,
575 {name.replace(".", "_"): np.array(self["variables"][key], "f")
576 for name, key in self.varnames(showvars, excluded).items()})
578 def todataframe(self, showvars=None,
579 excluded=("unnecessary lineage", "vec")):
580 "Returns primal solution as pandas dataframe"
581 import pandas as pd # pylint:disable=import-error
582 rows = []
583 cols = ["Name", "Index", "Value", "Units", "Label",
584 "Lineage", "Other"]
585 for _, key in sorted(self.varnames(showvars, excluded).items(),
586 key=lambda k: k[0]):
587 value = self["variables"][key]
588 if key.shape:
589 idxs = []
590 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
591 while not it.finished:
592 idx = it.multi_index
593 idxs.append(idx[0] if len(idx) == 1 else idx)
594 it.iternext()
595 else:
596 idxs = [None]
597 for idx in idxs:
598 row = [
599 key.name,
600 "" if idx is None else idx,
601 value if idx is None else value[idx]]
602 rows.append(row)
603 row.extend([
604 key.unitstr(),
605 key.label or "",
606 key.lineage or "",
607 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
608 if k not in ["name", "units", "unitrepr",
609 "idx", "shape", "veckey",
610 "value", "vecfn",
611 "lineage", "label"])])
612 return pd.DataFrame(rows, columns=cols)
614 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
615 "Saves solution table as a text file"
616 with open(filename, "w") as f:
617 if printmodel:
618 f.write(self.modelstr + "\n")
619 f.write(self.table(**kwargs))
621 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
622 "Saves primal solution as a CSV sorted by modelname, like the tables."
623 data = self["variables"]
624 if showvars:
625 showvars = self._parse_showvars(showvars)
626 data = {k: data[k] for k in showvars if k in data}
627 # if the columns don't capture any dimensions, skip them
628 minspan, maxspan = None, 1
629 for v in data.values():
630 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
631 minspan_ = min((di for di in v.shape if di != 1))
632 maxspan_ = max((di for di in v.shape if di != 1))
633 if minspan is None or minspan_ < minspan:
634 minspan = minspan_
635 if maxspan is None or maxspan_ > maxspan:
636 maxspan = maxspan_
637 if minspan is not None and minspan > valcols:
638 valcols = 1
639 if maxspan < valcols:
640 valcols = maxspan
641 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
642 tables=("cost", "sweepvariables", "freevariables",
643 "constants", "sensitivities"))
644 with open(filename, "w") as f:
645 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
646 + "Units,Description\n")
647 for line in lines:
648 if line[0] == ("newmodelline",):
649 f.write(line[1])
650 elif not line[1]: # spacer line
651 f.write("\n")
652 else:
653 f.write("," + line[0].replace(" : ", "") + ",")
654 vals = line[1].replace("[", "").replace("]", "").strip()
655 for el in vals.split():
656 f.write(el + ",")
657 f.write(","*(valcols - len(vals.split())))
658 f.write((line[2].replace("[", "").replace("]", "").strip()
659 + ","))
660 f.write(line[3].strip() + "\n")
662 def subinto(self, posy):
663 "Returns NomialArray of each solution substituted into posy."
664 if posy in self["variables"]:
665 return self["variables"](posy)
667 if not hasattr(posy, "sub"):
668 raise ValueError("no variable '%s' found in the solution" % posy)
670 if len(self) > 1:
671 return NomialArray([self.atindex(i).subinto(posy)
672 for i in range(len(self))])
674 return posy.sub(self["variables"], require_positive=False)
676 def _parse_showvars(self, showvars):
677 showvars_out = set()
678 for k in showvars:
679 k, _ = self["variables"].parse_and_index(k)
680 keys = self["variables"].keymap[k]
681 showvars_out.update(keys)
682 return showvars_out
684 def summary(self, showvars=(), ntopsenss=5, **kwargs):
685 "Print summary table, showing top sensitivities and no constants"
686 showvars = self._parse_showvars(showvars)
687 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
688 "freevariables"], **kwargs)
689 constants_in_showvars = showvars.intersection(self["constants"])
690 senss_tables = []
691 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
692 senss_tables.append("sensitivities")
693 if len(self["constants"]) >= ntopsenss+2:
694 senss_tables.append("top sensitivities")
695 senss_tables.append("tightest constraints")
696 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
697 **kwargs)
698 if senss_str:
699 out += "\n" + senss_str
700 return out
702 def table(self, showvars=(),
703 tables=("cost", "warnings", "model sensitivities",
704 "sweepvariables", "freevariables",
705 "constants", "sensitivities", "tightest constraints"),
706 sortmodelsbysenss=True, **kwargs):
707 """A table representation of this SolutionArray
709 Arguments
710 ---------
711 tables: Iterable
712 Which to print of ("cost", "sweepvariables", "freevariables",
713 "constants", "sensitivities")
714 fixedcols: If true, print vectors in fixed-width format
715 latex: int
716 If > 0, return latex format (options 1-3); otherwise plain text
717 included_models: Iterable of strings
718 If specified, the models (by name) to include
719 excluded_models: Iterable of strings
720 If specified, model names to exclude
722 Returns
723 -------
724 str
725 """
726 if sortmodelsbysenss and "sensitivities" in self:
727 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
728 else:
729 kwargs["sortmodelsbysenss"] = False
730 varlist = list(self["variables"])
731 has_only_one_model = True
732 for var in varlist[1:]:
733 if var.lineage != varlist[0].lineage:
734 has_only_one_model = False
735 break
736 if has_only_one_model:
737 kwargs["sortbymodel"] = False
738 self.set_necessarylineage()
739 showvars = self._parse_showvars(showvars)
740 strs = []
741 for table in tables:
742 if "sensitivities" not in self and ("sensitivities" in table or
743 "constraints" in table):
744 continue
745 if table == "cost":
746 cost = self["cost"] # pylint: disable=unsubscriptable-object
747 if kwargs.get("latex", None): # cost is not printed for latex
748 continue
749 strs += ["\n%s\n------------" % "Optimal Cost"]
750 if len(self) > 1:
751 costs = ["%-8.3g" % c for c in mag(cost[:4])]
752 strs += [" [ %s %s ]" % (" ".join(costs),
753 "..." if len(self) > 4 else "")]
754 else:
755 strs += [" %-.4g" % mag(cost)]
756 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
757 strs += [""]
758 elif table in TABLEFNS:
759 strs += TABLEFNS[table](self, showvars, **kwargs)
760 elif table in self:
761 data = self[table]
762 if showvars:
763 showvars = self._parse_showvars(showvars)
764 data = {k: data[k] for k in showvars if k in data}
765 strs += var_table(data, self.table_titles[table], **kwargs)
766 if kwargs.get("latex", None):
767 preamble = "\n".join(("% \\documentclass[12pt]{article}",
768 "% \\usepackage{booktabs}",
769 "% \\usepackage{longtable}",
770 "% \\usepackage{amsmath}",
771 "% \\begin{document}\n"))
772 strs = [preamble] + strs + ["% \\end{document}"]
773 self.set_necessarylineage(clear=True)
774 return "\n".join(strs)
776 def plot(self, posys=None, axes=None):
777 "Plots a sweep for each posy"
778 if len(self["sweepvariables"]) != 1:
779 print("SolutionArray.plot only supports 1-dimensional sweeps")
780 if not hasattr(posys, "__len__"):
781 posys = [posys]
782 import matplotlib.pyplot as plt
783 from .interactive.plot_sweep import assign_axes
784 from . import GPBLU
785 (swept, x), = self["sweepvariables"].items()
786 posys, axes = assign_axes(swept, posys, axes)
787 for posy, ax in zip(posys, axes):
788 y = self(posy) if posy not in [None, "cost"] else self["cost"]
789 ax.plot(x, y, color=GPBLU)
790 if len(axes) == 1:
791 axes, = axes
792 return plt.gcf(), axes
795# pylint: disable=too-many-branches,too-many-locals,too-many-statements
796def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
797 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
798 minval=0, sortbyvals=False, hidebelowminval=False,
799 included_models=None, excluded_models=None, sortbymodel=True,
800 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
801 """
802 Pretty string representation of a dict of VarKeys
803 Iterable values are handled specially (partial printing)
805 Arguments
806 ---------
807 data : dict whose keys are VarKey's
808 data to represent in table
809 title : string
810 printunits : bool
811 latex : int
812 If > 0, return latex format (options 1-3); otherwise plain text
813 varfmt : string
814 format for variable names
815 valfmt : string
816 format for scalar values
817 vecfmt : string
818 format for vector values
819 minval : float
820 skip values with all(abs(value)) < minval
821 sortbyvals : boolean
822 If true, rows are sorted by their average value instead of by name.
823 included_models : Iterable of strings
824 If specified, the models (by name) to include
825 excluded_models : Iterable of strings
826 If specified, model names to exclude
827 """
828 if not data:
829 return []
830 decorated, models = [], set()
831 for i, (k, v) in enumerate(data.items()):
832 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
833 continue # no values below minval
834 if minval and hidebelowminval and getattr(v, "shape", None):
835 v[np.abs(v) <= minval] = np.nan
836 model = lineagestr(k.lineage) if sortbymodel else ""
837 if not sortmodelsbysenss:
838 msenss = 0
839 else: # sort should match that in msenss_table above
840 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4)
841 models.add(model)
842 b = bool(getattr(v, "shape", None))
843 s = k.str_without(("lineage", "vec"))
844 if not sortbyvals:
845 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
846 else: # for consistent sorting, add small offset to negative vals
847 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
848 sort = (float("%.4g" % -val), k.name)
849 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
850 if not decorated and skipifempty:
851 return []
852 if included_models:
853 included_models = set(included_models)
854 included_models.add("")
855 models = models.intersection(included_models)
856 if excluded_models:
857 models = models.difference(excluded_models)
858 decorated.sort()
859 previous_model, lines = None, []
860 for varlist in decorated:
861 if sortbyvals:
862 model, _, msenss, isvector, varstr, _, var, val = varlist
863 else:
864 msenss, model, isvector, varstr, _, var, val = varlist
865 if model not in models:
866 continue
867 if model != previous_model:
868 if lines:
869 lines.append(["", "", "", ""])
870 if model:
871 if not latex:
872 lines.append([("newmodelline",), model, "", ""])
873 else:
874 lines.append(
875 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
876 previous_model = model
877 label = var.descr.get("label", "")
878 units = var.unitstr(" [%s] ") if printunits else ""
879 if not isvector:
880 valstr = valfmt % val
881 else:
882 last_dim_index = len(val.shape)-1
883 horiz_dim, ncols = last_dim_index, 1 # starting values
884 for dim_idx, dim_size in enumerate(val.shape):
885 if ncols <= dim_size <= maxcolumns:
886 horiz_dim, ncols = dim_idx, dim_size
887 # align the array with horiz_dim by making it the last one
888 dim_order = list(range(last_dim_index))
889 dim_order.insert(horiz_dim, last_dim_index)
890 flatval = val.transpose(dim_order).flatten()
891 vals = [vecfmt % v for v in flatval[:ncols]]
892 bracket = " ] " if len(flatval) <= ncols else ""
893 valstr = "[ %s%s" % (" ".join(vals), bracket)
894 for before, after in VALSTR_REPLACES:
895 valstr = valstr.replace(before, after)
896 if not latex:
897 lines.append([varstr, valstr, units, label])
898 if isvector and len(flatval) > ncols:
899 values_remaining = len(flatval) - ncols
900 while values_remaining > 0:
901 idx = len(flatval)-values_remaining
902 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
903 values_remaining -= ncols
904 valstr = " " + " ".join(vals)
905 for before, after in VALSTR_REPLACES:
906 valstr = valstr.replace(before, after)
907 if values_remaining <= 0:
908 spaces = (-values_remaining
909 * len(valstr)//(values_remaining + ncols))
910 valstr = valstr + " ]" + " "*spaces
911 lines.append(["", valstr, "", ""])
912 else:
913 varstr = "$%s$" % varstr.replace(" : ", "")
914 if latex == 1: # normal results table
915 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
916 label])
917 coltitles = [title, "Value", "Units", "Description"]
918 elif latex == 2: # no values
919 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
920 coltitles = [title, "Units", "Description"]
921 elif latex == 3: # no description
922 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
923 coltitles = [title, "Value", "Units"]
924 else:
925 raise ValueError("Unexpected latex option, %s." % latex)
926 if rawlines:
927 return lines
928 if not latex:
929 if lines:
930 maxlens = np.max([list(map(len, line)) for line in lines
931 if line[0] != ("newmodelline",)], axis=0)
932 dirs = [">", "<", "<", "<"]
933 # check lengths before using zip
934 assert len(list(dirs)) == len(list(maxlens))
935 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
936 for i, line in enumerate(lines):
937 if line[0] == ("newmodelline",):
938 line = [fmts[0].format(" | "), line[1]]
939 else:
940 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
941 lines[i] = "".join(line).rstrip()
942 lines = [title] + ["-"*len(title)] + lines + [""]
943 else:
944 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
945 lines = (["\n".join(["{\\footnotesize",
946 "\\begin{longtable}{%s}" % colfmt[latex],
947 "\\toprule",
948 " & ".join(coltitles) + " \\\\ \\midrule"])] +
949 [" & ".join(l) + " \\\\" for l in lines] +
950 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
951 return lines