Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import pickletools
9import numpy as np
10from collections import defaultdict
11from .nomials import NomialArray
12from .small_classes import DictOfLists, Strings
13from .small_scripts import mag, try_str_without
14from .repr_conventions import unitstr, lineagestr
17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
19VALSTR_REPLACES = [
20 ("+nan", " nan"),
21 ("-nan", " nan"),
22 ("nan%", "nan "),
23 ("nan", " - "),
24]
27class SolSavingEnvironment:
28 """Temporarily removes construction/solve attributes from constraints.
30 This approximately halves the size of the pickled solution.
31 """
33 def __init__(self, solarray, saveconstraints):
34 self.solarray = solarray
35 self.attrstore = {}
36 self.saveconstraints = saveconstraints
37 self.constraintstore = None
40 def __enter__(self):
41 if self.saveconstraints:
42 for constraint_attr in ["bounded", "meq_bounded", "vks",
43 "v_ss", "unsubbed", "varkeys"]:
44 store = {}
45 for constraint in self.solarray["sensitivities"]["constraints"]:
46 if getattr(constraint, constraint_attr, None):
47 store[constraint] = getattr(constraint, constraint_attr)
48 delattr(constraint, constraint_attr)
49 self.attrstore[constraint_attr] = store
50 else:
51 self.constraintstore = \
52 self.solarray["sensitivities"].pop("constraints")
54 def __exit__(self, type_, val, traceback):
55 if self.saveconstraints:
56 for constraint_attr, store in self.attrstore.items():
57 for constraint, value in store.items():
58 setattr(constraint, constraint_attr, value)
59 else:
60 self.solarray["sensitivities"]["constraints"] = self.constraintstore
62def msenss_table(data, _, **kwargs):
63 "Returns model sensitivity table lines"
64 if "models" not in data.get("sensitivities", {}):
65 return ""
66 data = sorted(data["sensitivities"]["models"].items(),
67 key=lambda i: ((i[1] < 0.1).all(),
68 -np.max(i[1]) if (i[1] < 0.1).all()
69 else -round(np.mean(i[1]), 1), i[0]))
70 lines = ["Model Sensitivities", "-------------------"]
71 if kwargs["sortmodelsbysenss"]:
72 lines[0] += " (sorts models in sections below)"
73 previousmsenssstr = ""
74 for model, msenss in data:
75 if not model: # for now let's only do named models
76 continue
77 if (msenss < 0.1).all():
78 msenss = np.max(msenss)
79 if msenss:
80 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss)))
81 else:
82 msenssstr = " =0 "
83 else:
84 meansenss = round(np.mean(msenss), 1)
85 msenssstr = "%+6.1f" % meansenss
86 deltas = msenss - meansenss
87 if np.max(np.abs(deltas)) > 0.1:
88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
89 for d in deltas]
90 msenssstr += " + [ %s ]" % " ".join(deltastrs)
91 if msenssstr == previousmsenssstr:
92 msenssstr = " "*len(msenssstr)
93 else:
94 previousmsenssstr = msenssstr
95 lines.append("%s : %s" % (msenssstr, model))
96 return lines + [""] if len(lines) > 3 else []
99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
100 "Returns sensitivity table lines"
101 if "variables" in data.get("sensitivities", {}):
102 data = data["sensitivities"]["variables"]
103 if showvars:
104 data = {k: data[k] for k in showvars if k in data}
105 return var_table(data, title, sortbyvals=True, skipifempty=True,
106 valfmt="%+-.2g ", vecfmt="%+-8.2g",
107 printunits=False, minval=1e-3, **kwargs)
110def topsenss_table(data, showvars, nvars=5, **kwargs):
111 "Returns top sensitivity table lines"
112 data, filtered = topsenss_filter(data, showvars, nvars)
113 title = "Most Sensitive Variables"
114 if filtered:
115 title = "Next Most Sensitive Variables"
116 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
119def topsenss_filter(data, showvars, nvars=5):
120 "Filters sensitivities down to top N vars"
121 if "variables" in data.get("sensitivities", {}):
122 data = data["sensitivities"]["variables"]
123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
124 if not np.isnan(s).any()}
125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
126 filter_already_shown = showvars.intersection(topk)
127 for k in filter_already_shown:
128 topk.remove(k)
129 if nvars > 3: # always show at least 3
130 nvars -= 1
131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
134def insenss_table(data, _, maxval=0.1, **kwargs):
135 "Returns insensitivity table lines"
136 if "constants" in data.get("sensitivities", {}):
137 data = data["sensitivities"]["variables"]
138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
143 "Return constraint tightness lines"
144 title = "Most Sensitive Constraints"
145 if len(self) > 1:
146 title += " (in last sweep)"
147 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)),
148 "%+6.2g" % abs(s[-1]), id(c), c)
149 for c, s in self["sensitivities"]["constraints"].items()
150 if s[-1] >= tight_senss)[:ntightconstrs]
151 else:
152 data = sorted(((-float("%+6.2g" % abs(s)), str(c)),
153 "%+6.2g" % abs(s), id(c), c)
154 for c, s in self["sensitivities"]["constraints"].items()
155 if s >= tight_senss)[:ntightconstrs]
156 return constraint_table(data, title, **kwargs)
158def loose_table(self, _, min_senss=1e-5, **kwargs):
159 "Return constraint tightness lines"
160 title = "Insensitive Constraints |below %+g|" % min_senss
161 if len(self) > 1:
162 title += " (in last sweep)"
163 data = [(0, "", id(c), c)
164 for c, s in self["sensitivities"]["constraints"].items()
165 if s[-1] <= min_senss]
166 else:
167 data = [(0, "", id(c), c)
168 for c, s in self["sensitivities"]["constraints"].items()
169 if s <= min_senss]
170 return constraint_table(data, title, **kwargs)
173# pylint: disable=too-many-branches,too-many-locals,too-many-statements
174def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
175 "Creates lines for tables where the right side is a constraint."
176 # TODO: this should support 1D array inputs from sweeps
177 excluded = ("units", "unnecessary lineage")
178 if not showmodels:
179 excluded = ("units", "lineage") # hide all of it
180 models, decorated = {}, []
181 for sortby, openingstr, _, constraint in sorted(data):
182 model = lineagestr(constraint) if sortbymodel else ""
183 if model not in models:
184 models[model] = len(models)
185 constrstr = try_str_without(constraint, excluded + (":MAGIC:"+lineagestr(constraint),))
186 if " at 0x" in constrstr: # don't print memory addresses
187 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
188 decorated.append((models[model], model, sortby, constrstr, openingstr))
189 decorated.sort()
190 previous_model, lines = None, []
191 for varlist in decorated:
192 _, model, _, constrstr, openingstr = varlist
193 if model != previous_model:
194 if lines:
195 lines.append(["", ""])
196 if model or lines:
197 lines.append([("newmodelline",), model])
198 previous_model = model
199 # constrstr = constrstr.replace(model, "")
200 minlen, maxlen = 25, 80
201 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
202 constraintlines = []
203 line = ""
204 next_idx = 0
205 while next_idx < len(segments):
206 segment = segments[next_idx]
207 next_idx += 1
208 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
209 segments[next_idx] = segment[1:] + segments[next_idx]
210 segment = segment[0]
211 elif len(line) + len(segment) > maxlen and len(line) > minlen:
212 constraintlines.append(line)
213 line = " " # start a new line
214 line += segment
215 while len(line) > maxlen:
216 constraintlines.append(line[:maxlen])
217 line = " " + line[maxlen:]
218 constraintlines.append(line)
219 lines += [(openingstr + " : ", constraintlines[0])]
220 lines += [("", l) for l in constraintlines[1:]]
221 if not lines:
222 lines = [("", "(none)")]
223 maxlens = np.max([list(map(len, line)) for line in lines
224 if line[0] != ("newmodelline",)], axis=0)
225 dirs = [">", "<"] # we'll check lengths before using zip
226 assert len(list(dirs)) == len(list(maxlens))
227 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
228 for i, line in enumerate(lines):
229 if line[0] == ("newmodelline",):
230 linelist = [fmts[0].format(" | "), line[1]]
231 else:
232 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
233 lines[i] = "".join(linelist).rstrip()
234 return [title] + ["-"*len(title)] + lines + [""]
237def warnings_table(self, _, **kwargs):
238 "Makes a table for all warnings in the solution."
239 title = "WARNINGS"
240 lines = ["~"*len(title), title, "~"*len(title)]
241 if "warnings" not in self or not self["warnings"]:
242 return []
243 for wtype in sorted(self["warnings"]):
244 data_vec = self["warnings"][wtype]
245 if len(data_vec) == 0:
246 continue
247 if not hasattr(data_vec, "shape"):
248 data_vec = [data_vec] # not a sweep
249 else:
250 all_equal = True
251 for data in data_vec[1:]:
252 eq_i = (data == data_vec[0])
253 if hasattr(eq_i, "all"):
254 eq_i = eq_i.all()
255 if not eq_i:
256 all_equal = False
257 break
258 if all_equal:
259 data_vec = [data_vec[0]] # warnings identical across sweeps
260 for i, data in enumerate(data_vec):
261 if len(data) == 0:
262 continue
263 data = sorted(data, key=lambda l: l[0]) # sort by msg
264 title = wtype
265 if len(data_vec) > 1:
266 title += " in sweep %i" % i
267 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
268 data = [(-int(1e5*relax_sensitivity),
269 "%+6.2g" % relax_sensitivity, id(c), c)
270 for _, (relax_sensitivity, c) in data]
271 lines += constraint_table(data, title, **kwargs)
272 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
273 data = [(-int(1e5*rel_diff),
274 "%.4g %s %.4g" % tightvalues, id(c), c)
275 for _, (rel_diff, tightvalues, c) in data]
276 lines += constraint_table(data, title, **kwargs)
277 else:
278 lines += [title] + ["-"*len(wtype)]
279 lines += [msg for msg, _ in data] + [""]
280 if len(lines) == 3: # just the header
281 return []
282 lines[-1] = "~~~~~~~~"
283 return lines + [""]
286TABLEFNS = {"sensitivities": senss_table,
287 "top sensitivities": topsenss_table,
288 "insensitivities": insenss_table,
289 "model sensitivities": msenss_table,
290 "tightest constraints": tight_table,
291 "loose constraints": loose_table,
292 "warnings": warnings_table,
293 }
295def unrolled_absmax(values):
296 "From an iterable of numbers and arrays, returns the largest magnitude"
297 finalval, absmaxest = None, 0
298 for val in values:
299 absmaxval = np.abs(val).max()
300 if absmaxval >= absmaxest:
301 absmaxest, finalval = absmaxval, val
302 if getattr(finalval, "shape", None):
303 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
304 finalval.shape)]
305 return finalval
308def cast(function, val1, val2):
309 "Relative difference between val1 and val2 (positive if val2 is larger)"
310 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
311 pywarnings.simplefilter("ignore")
312 if hasattr(val1, "shape") and hasattr(val2, "shape"):
313 if val1.ndim == val2.ndim:
314 return function(val1, val2)
315 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
316 dimdelta = dimmest.ndim - lessdim.ndim
317 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
318 if dimmest is val1:
319 return function(dimmest, lessdim[add_axes])
320 if dimmest is val2:
321 return function(lessdim[add_axes], dimmest)
322 return function(val1, val2)
325class SolutionArray(DictOfLists):
326 """A dictionary (of dictionaries) of lists, with convenience methods.
328 Items
329 -----
330 cost : array
331 variables: dict of arrays
332 sensitivities: dict containing:
333 monomials : array
334 posynomials : array
335 variables: dict of arrays
336 localmodels : NomialArray
337 Local power-law fits (small sensitivities are cut off)
339 Example
340 -------
341 >>> import gpkit
342 >>> import numpy as np
343 >>> x = gpkit.Variable("x")
344 >>> x_min = gpkit.Variable("x_{min}", 2)
345 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
346 >>>
347 >>> # VALUES
348 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
349 >>> assert all(np.array(values) == 2)
350 >>>
351 >>> # SENSITIVITIES
352 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
353 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
354 >>> assert all(np.array(senss) == 1)
355 """
356 modelstr = ""
357 _name_collision_varkeys = None
358 table_titles = {"choicevariables": "Choice Variables",
359 "sweepvariables": "Swept Variables",
360 "freevariables": "Free Variables",
361 "constants": "Fixed Variables", # TODO: change everywhere
362 "variables": "Variables"}
364 def set_necessarylineage(self, clear=False):
365 "Returns the set of contained varkeys whose names are not unique"
366 if self._name_collision_varkeys is None:
367 self._name_collision_varkeys = {}
368 self["variables"].update_keymap()
369 keymap = self["variables"].keymap
370 name_collisions = defaultdict(set)
371 for key in list(keymap):
372 if hasattr(key, "key"):
373 shortname = key.str_without(["lineage", "vec"])
374 if len(keymap[shortname]) > 1:
375 name_collisions[shortname].add(key)
376 for vks in name_collisions.values():
377 min_namespaced = defaultdict(set)
378 for vk in vks:
379 *_, mineage = vk.lineagestr().split(".")
380 min_namespaced[(mineage, 1)].add(vk)
381 while any(len(vks) > 1 for vks in min_namespaced.values()):
382 for key, vks in list(min_namespaced.items()):
383 if len(vks) > 1:
384 del min_namespaced[key]
385 mineage, idx = key
386 idx += 1
387 for vk in vks:
388 lineages = vk.lineagestr().split(".")
389 submineage = lineages[-idx] + "." + mineage
390 min_namespaced[(submineage, idx)].add(vk)
391 for (_, idx), vks in min_namespaced.items():
392 vk, = vks
393 self._name_collision_varkeys[vk] = idx
394 if clear:
395 for vk in self._name_collision_varkeys:
396 del vk.descr["necessarylineage"]
397 else:
398 for vk, idx in self._name_collision_varkeys.items():
399 vk.descr["necessarylineage"] = idx
401 def __len__(self):
402 try:
403 return len(self["cost"])
404 except TypeError:
405 return 1
406 except KeyError:
407 return 0
409 def __call__(self, posy):
410 posy_subbed = self.subinto(posy)
411 return getattr(posy_subbed, "c", posy_subbed)
413 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
414 "Checks for almost-equality between two solutions"
415 svars, ovars = self["variables"], other["variables"]
416 svks, ovks = set(svars), set(ovars)
417 if svks != ovks:
418 return False
419 for key in svks:
420 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
421 return False
422 if abs(self["sensitivities"]["variables"][key]
423 - other["sensitivities"]["variables"][key]) >= sens_abstol:
424 return False
425 return True
427 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
428 def diff(self, other, showvars=None, *,
429 constraintsdiff=True, senssdiff=False, sensstol=0.1,
430 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
431 sortmodelsbysenss=True, **tableargs):
432 """Outputs differences between this solution and another
434 Arguments
435 ---------
436 other : solution or string
437 strings will be treated as paths to pickled solutions
438 senssdiff : boolean
439 if True, show sensitivity differences
440 sensstol : float
441 the smallest sensitivity difference worth showing
442 absdiff : boolean
443 if True, show absolute differences
444 abstol : float
445 the smallest absolute difference worth showing
446 reldiff : boolean
447 if True, show relative differences
448 reltol : float
449 the smallest relative difference worth showing
451 Returns
452 -------
453 str
454 """
455 if sortmodelsbysenss:
456 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
457 else:
458 tableargs["sortmodelsbysenss"] = False
459 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
460 "skipifempty": False})
461 if isinstance(other, Strings):
462 if other[-4:] == ".pgz":
463 other = SolutionArray.decompress_file(other)
464 else:
465 other = pickle.load(open(other, "rb"))
466 svars, ovars = self["variables"], other["variables"]
467 lines = ["Solution Diff",
468 "=============",
469 "(argument is the baseline solution)", ""]
470 svks, ovks = set(svars), set(ovars)
471 if showvars:
472 lines[0] += " (for selected variables)"
473 lines[1] += "========================="
474 showvars = self._parse_showvars(showvars)
475 svks = {k for k in showvars if k in svars}
476 ovks = {k for k in showvars if k in ovars}
477 if constraintsdiff and other.modelstr and self.modelstr:
478 if self.modelstr == other.modelstr:
479 lines += ["** no constraint differences **", ""]
480 else:
481 cdiff = ["Constraint Differences",
482 "**********************"]
483 cdiff.extend(list(difflib.unified_diff(
484 other.modelstr.split("\n"), self.modelstr.split("\n"),
485 lineterm="", n=3))[2:])
486 cdiff += ["", "**********************", ""]
487 lines += cdiff
488 if svks - ovks:
489 lines.append("Variable(s) of this solution"
490 " which are not in the argument:")
491 lines.append("\n".join(" %s" % key for key in svks - ovks))
492 lines.append("")
493 if ovks - svks:
494 lines.append("Variable(s) of the argument"
495 " which are not in this solution:")
496 lines.append("\n".join(" %s" % key for key in ovks - svks))
497 lines.append("")
498 sharedvks = svks.intersection(ovks)
499 if reldiff:
500 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
501 for vk in sharedvks}
502 lines += var_table(rel_diff,
503 "Relative Differences |above %g%%|" % reltol,
504 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
505 minval=reltol, printunits=False, **tableargs)
506 if lines[-2][:10] == "-"*10: # nothing larger than reltol
507 lines.insert(-1, ("The largest is %+g%%."
508 % unrolled_absmax(rel_diff.values())))
509 if absdiff:
510 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
511 lines += var_table(abs_diff,
512 "Absolute Differences |above %g|" % abstol,
513 valfmt="%+.2g", vecfmt="%+8.2g",
514 minval=abstol, **tableargs)
515 if lines[-2][:10] == "-"*10: # nothing larger than abstol
516 lines.insert(-1, ("The largest is %+g."
517 % unrolled_absmax(abs_diff.values())))
518 if senssdiff:
519 ssenss = self["sensitivities"]["variables"]
520 osenss = other["sensitivities"]["variables"]
521 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
522 for vk in svks.intersection(ovks)}
523 lines += var_table(senss_delta,
524 "Sensitivity Differences |above %g|" % sensstol,
525 valfmt="%+-.2f ", vecfmt="%+-6.2f",
526 minval=sensstol, printunits=False, **tableargs)
527 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
528 lines.insert(-1, ("The largest is %+g."
529 % unrolled_absmax(senss_delta.values())))
530 return "\n".join(lines)
532 def save(self, filename="solution.pkl",
533 *, saveconstraints=True, **pickleargs):
534 """Pickles the solution and saves it to a file.
536 Solution can then be loaded with e.g.:
537 >>> import pickle
538 >>> pickle.load(open("solution.pkl"))
539 """
540 with SolSavingEnvironment(self, saveconstraints):
541 pickle.dump(self, open(filename, "wb"), **pickleargs)
543 def save_compressed(self, filename="solution.pgz",
544 *, saveconstraints=True, **cpickleargs):
545 "Pickle a file and then compress it into a file with extension."
546 with gzip.open(filename, "wb") as f:
547 with SolSavingEnvironment(self, saveconstraints):
548 pickled = pickle.dumps(self, **cpickleargs)
549 f.write(pickletools.optimize(pickled))
551 @staticmethod
552 def decompress_file(file):
553 "Load a gzip-compressed pickle file"
554 with gzip.open(file, "rb") as f:
555 return pickle.Unpickler(f).load()
557 def varnames(self, showvars, exclude):
558 "Returns list of variables, optionally with minimal unique names"
559 if showvars:
560 showvars = self._parse_showvars(showvars)
561 self.set_necessarylineage()
562 names = {}
563 for key in showvars or self["variables"]:
564 for k in self["variables"].keymap[key]:
565 names[k.str_without(exclude)] = k
566 self.set_necessarylineage(clear=True)
567 return names
569 def savemat(self, filename="solution.mat", *, showvars=None,
570 excluded=("unnecessary lineage", "vec")):
571 "Saves primal solution as matlab file"
572 from scipy.io import savemat
573 savemat(filename,
574 {name.replace(".", "_"): np.array(self["variables"][key], "f")
575 for name, key in self.varnames(showvars, excluded).items()})
577 def todataframe(self, showvars=None,
578 excluded=("unnecessary lineage", "vec")):
579 "Returns primal solution as pandas dataframe"
580 import pandas as pd # pylint:disable=import-error
581 rows = []
582 cols = ["Name", "Index", "Value", "Units", "Label",
583 "Lineage", "Other"]
584 for _, key in sorted(self.varnames(showvars, excluded).items(),
585 key=lambda k: k[0]):
586 value = self["variables"][key]
587 if key.shape:
588 idxs = []
589 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
590 while not it.finished:
591 idx = it.multi_index
592 idxs.append(idx[0] if len(idx) == 1 else idx)
593 it.iternext()
594 else:
595 idxs = [None]
596 for idx in idxs:
597 row = [
598 key.name,
599 "" if idx is None else idx,
600 value if idx is None else value[idx]]
601 rows.append(row)
602 row.extend([
603 key.unitstr(),
604 key.label or "",
605 key.lineage or "",
606 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
607 if k not in ["name", "units", "unitrepr",
608 "idx", "shape", "veckey",
609 "value", "vecfn",
610 "lineage", "label"])])
611 return pd.DataFrame(rows, columns=cols)
613 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
614 "Saves solution table as a text file"
615 with open(filename, "w") as f:
616 if printmodel:
617 f.write(self.modelstr + "\n")
618 f.write(self.table(**kwargs))
620 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
621 "Saves primal solution as a CSV sorted by modelname, like the tables."
622 data = self["variables"]
623 if showvars:
624 showvars = self._parse_showvars(showvars)
625 data = {k: data[k] for k in showvars if k in data}
626 # if the columns don't capture any dimensions, skip them
627 minspan, maxspan = None, 1
628 for v in data.values():
629 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
630 minspan_ = min((di for di in v.shape if di != 1))
631 maxspan_ = max((di for di in v.shape if di != 1))
632 if minspan is None or minspan_ < minspan:
633 minspan = minspan_
634 if maxspan is None or maxspan_ > maxspan:
635 maxspan = maxspan_
636 if minspan is not None and minspan > valcols:
637 valcols = 1
638 if maxspan < valcols:
639 valcols = maxspan
640 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
641 tables=("cost", "sweepvariables", "freevariables",
642 "constants", "sensitivities"))
643 with open(filename, "w") as f:
644 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
645 + "Units,Description\n")
646 for line in lines:
647 if line[0] == ("newmodelline",):
648 f.write(line[1])
649 elif not line[1]: # spacer line
650 f.write("\n")
651 else:
652 f.write("," + line[0].replace(" : ", "") + ",")
653 vals = line[1].replace("[", "").replace("]", "").strip()
654 for el in vals.split():
655 f.write(el + ",")
656 f.write(","*(valcols - len(vals.split())))
657 f.write((line[2].replace("[", "").replace("]", "").strip()
658 + ","))
659 f.write(line[3].strip() + "\n")
661 def subinto(self, posy):
662 "Returns NomialArray of each solution substituted into posy."
663 if posy in self["variables"]:
664 return self["variables"](posy)
666 if not hasattr(posy, "sub"):
667 raise ValueError("no variable '%s' found in the solution" % posy)
669 if len(self) > 1:
670 return NomialArray([self.atindex(i).subinto(posy)
671 for i in range(len(self))])
673 return posy.sub(self["variables"], require_positive=False)
675 def _parse_showvars(self, showvars):
676 showvars_out = set()
677 for k in showvars:
678 k, _ = self["variables"].parse_and_index(k)
679 keys = self["variables"].keymap[k]
680 showvars_out.update(keys)
681 return showvars_out
683 def summary(self, showvars=(), ntopsenss=5, **kwargs):
684 "Print summary table, showing top sensitivities and no constants"
685 showvars = self._parse_showvars(showvars)
686 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
687 "freevariables"], **kwargs)
688 constants_in_showvars = showvars.intersection(self["constants"])
689 senss_tables = []
690 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
691 senss_tables.append("sensitivities")
692 if len(self["constants"]) >= ntopsenss+2:
693 senss_tables.append("top sensitivities")
694 senss_tables.append("tightest constraints")
695 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
696 **kwargs)
697 if senss_str:
698 out += "\n" + senss_str
699 return out
701 def table(self, showvars=(),
702 tables=("cost", "warnings", "model sensitivities",
703 "sweepvariables", "freevariables",
704 "constants", "sensitivities", "tightest constraints"),
705 sortmodelsbysenss=True, **kwargs):
706 """A table representation of this SolutionArray
708 Arguments
709 ---------
710 tables: Iterable
711 Which to print of ("cost", "sweepvariables", "freevariables",
712 "constants", "sensitivities")
713 fixedcols: If true, print vectors in fixed-width format
714 latex: int
715 If > 0, return latex format (options 1-3); otherwise plain text
716 included_models: Iterable of strings
717 If specified, the models (by name) to include
718 excluded_models: Iterable of strings
719 If specified, model names to exclude
721 Returns
722 -------
723 str
724 """
725 if sortmodelsbysenss and "sensitivities" in self:
726 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
727 else:
728 kwargs["sortmodelsbysenss"] = False
729 varlist = list(self["variables"])
730 has_only_one_model = True
731 for var in varlist[1:]:
732 if var.lineage != varlist[0].lineage:
733 has_only_one_model = False
734 break
735 if has_only_one_model:
736 kwargs["sortbymodel"] = False
737 self.set_necessarylineage()
738 showvars = self._parse_showvars(showvars)
739 strs = []
740 for table in tables:
741 if "sensitivities" not in self and ("sensitivities" in table or
742 "constraints" in table):
743 continue
744 if table == "cost":
745 cost = self["cost"] # pylint: disable=unsubscriptable-object
746 if kwargs.get("latex", None): # cost is not printed for latex
747 continue
748 strs += ["\n%s\n------------" % "Optimal Cost"]
749 if len(self) > 1:
750 costs = ["%-8.3g" % c for c in mag(cost[:4])]
751 strs += [" [ %s %s ]" % (" ".join(costs),
752 "..." if len(self) > 4 else "")]
753 else:
754 strs += [" %-.4g" % mag(cost)]
755 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
756 strs += [""]
757 elif table in TABLEFNS:
758 strs += TABLEFNS[table](self, showvars, **kwargs)
759 elif table in self:
760 data = self[table]
761 if showvars:
762 showvars = self._parse_showvars(showvars)
763 data = {k: data[k] for k in showvars if k in data}
764 strs += var_table(data, self.table_titles[table], **kwargs)
765 if kwargs.get("latex", None):
766 preamble = "\n".join(("% \\documentclass[12pt]{article}",
767 "% \\usepackage{booktabs}",
768 "% \\usepackage{longtable}",
769 "% \\usepackage{amsmath}",
770 "% \\begin{document}\n"))
771 strs = [preamble] + strs + ["% \\end{document}"]
772 self.set_necessarylineage(clear=True)
773 return "\n".join(strs)
775 def plot(self, posys=None, axes=None):
776 "Plots a sweep for each posy"
777 if len(self["sweepvariables"]) != 1:
778 print("SolutionArray.plot only supports 1-dimensional sweeps")
779 if not hasattr(posys, "__len__"):
780 posys = [posys]
781 import matplotlib.pyplot as plt
782 from .interactive.plot_sweep import assign_axes
783 from . import GPBLU
784 (swept, x), = self["sweepvariables"].items()
785 posys, axes = assign_axes(swept, posys, axes)
786 for posy, ax in zip(posys, axes):
787 y = self(posy) if posy not in [None, "cost"] else self["cost"]
788 ax.plot(x, y, color=GPBLU)
789 if len(axes) == 1:
790 axes, = axes
791 return plt.gcf(), axes
794# pylint: disable=too-many-branches,too-many-locals,too-many-statements
795def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
796 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
797 minval=0, sortbyvals=False, hidebelowminval=False,
798 included_models=None, excluded_models=None, sortbymodel=True,
799 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
800 """
801 Pretty string representation of a dict of VarKeys
802 Iterable values are handled specially (partial printing)
804 Arguments
805 ---------
806 data : dict whose keys are VarKey's
807 data to represent in table
808 title : string
809 printunits : bool
810 latex : int
811 If > 0, return latex format (options 1-3); otherwise plain text
812 varfmt : string
813 format for variable names
814 valfmt : string
815 format for scalar values
816 vecfmt : string
817 format for vector values
818 minval : float
819 skip values with all(abs(value)) < minval
820 sortbyvals : boolean
821 If true, rows are sorted by their average value instead of by name.
822 included_models : Iterable of strings
823 If specified, the models (by name) to include
824 excluded_models : Iterable of strings
825 If specified, model names to exclude
826 """
827 if not data:
828 return []
829 decorated, models = [], set()
830 for i, (k, v) in enumerate(data.items()):
831 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
832 continue # no values below minval
833 if minval and hidebelowminval and getattr(v, "shape", None):
834 v[np.abs(v) <= minval] = np.nan
835 model = lineagestr(k.lineage) if sortbymodel else ""
836 if not sortmodelsbysenss:
837 msenss = 0
838 else: # sort should match that in msenss_table above
839 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4)
840 models.add(model)
841 b = bool(getattr(v, "shape", None))
842 s = k.str_without(("lineage", "vec"))
843 if not sortbyvals:
844 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
845 else: # for consistent sorting, add small offset to negative vals
846 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
847 sort = (float("%.4g" % -val), k.name)
848 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
849 if not decorated and skipifempty:
850 return []
851 if included_models:
852 included_models = set(included_models)
853 included_models.add("")
854 models = models.intersection(included_models)
855 if excluded_models:
856 models = models.difference(excluded_models)
857 decorated.sort()
858 previous_model, lines = None, []
859 for varlist in decorated:
860 if sortbyvals:
861 model, _, msenss, isvector, varstr, _, var, val = varlist
862 else:
863 msenss, model, isvector, varstr, _, var, val = varlist
864 if model not in models:
865 continue
866 if model != previous_model:
867 if lines:
868 lines.append(["", "", "", ""])
869 if model:
870 if not latex:
871 lines.append([("newmodelline",), model, "", ""])
872 else:
873 lines.append(
874 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
875 previous_model = model
876 label = var.descr.get("label", "")
877 units = var.unitstr(" [%s] ") if printunits else ""
878 if not isvector:
879 valstr = valfmt % val
880 else:
881 last_dim_index = len(val.shape)-1
882 horiz_dim, ncols = last_dim_index, 1 # starting values
883 for dim_idx, dim_size in enumerate(val.shape):
884 if ncols <= dim_size <= maxcolumns:
885 horiz_dim, ncols = dim_idx, dim_size
886 # align the array with horiz_dim by making it the last one
887 dim_order = list(range(last_dim_index))
888 dim_order.insert(horiz_dim, last_dim_index)
889 flatval = val.transpose(dim_order).flatten()
890 vals = [vecfmt % v for v in flatval[:ncols]]
891 bracket = " ] " if len(flatval) <= ncols else ""
892 valstr = "[ %s%s" % (" ".join(vals), bracket)
893 for before, after in VALSTR_REPLACES:
894 valstr = valstr.replace(before, after)
895 if not latex:
896 lines.append([varstr, valstr, units, label])
897 if isvector and len(flatval) > ncols:
898 values_remaining = len(flatval) - ncols
899 while values_remaining > 0:
900 idx = len(flatval)-values_remaining
901 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
902 values_remaining -= ncols
903 valstr = " " + " ".join(vals)
904 for before, after in VALSTR_REPLACES:
905 valstr = valstr.replace(before, after)
906 if values_remaining <= 0:
907 spaces = (-values_remaining
908 * len(valstr)//(values_remaining + ncols))
909 valstr = valstr + " ]" + " "*spaces
910 lines.append(["", valstr, "", ""])
911 else:
912 varstr = "$%s$" % varstr.replace(" : ", "")
913 if latex == 1: # normal results table
914 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
915 label])
916 coltitles = [title, "Value", "Units", "Description"]
917 elif latex == 2: # no values
918 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
919 coltitles = [title, "Units", "Description"]
920 elif latex == 3: # no description
921 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
922 coltitles = [title, "Value", "Units"]
923 else:
924 raise ValueError("Unexpected latex option, %s." % latex)
925 if rawlines:
926 return lines
927 if not latex:
928 if lines:
929 maxlens = np.max([list(map(len, line)) for line in lines
930 if line[0] != ("newmodelline",)], axis=0)
931 dirs = [">", "<", "<", "<"]
932 # check lengths before using zip
933 assert len(list(dirs)) == len(list(maxlens))
934 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
935 for i, line in enumerate(lines):
936 if line[0] == ("newmodelline",):
937 line = [fmts[0].format(" | "), line[1]]
938 else:
939 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
940 lines[i] = "".join(line).rstrip()
941 lines = [title] + ["-"*len(title)] + lines + [""]
942 else:
943 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
944 lines = (["\n".join(["{\\footnotesize",
945 "\\begin{longtable}{%s}" % colfmt[latex],
946 "\\toprule",
947 " & ".join(coltitles) + " \\\\ \\midrule"])] +
948 [" & ".join(l) + " \\\\" for l in lines] +
949 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
950 return lines