Coverage for gpkit/solution_array.py: 81%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import sys
3import re
4import json
5import difflib
6from operator import sub
7import warnings as pywarnings
8import pickle
9import gzip
10import pickletools
11from collections import defaultdict
12import numpy as np
13from .nomials import NomialArray
14from .small_classes import DictOfLists, Strings, SolverLog
15from .small_scripts import mag, try_str_without
16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS
17from .breakdowns import Breakdowns
20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
22VALSTR_REPLACES = [
23 ("+nan", " nan"),
24 ("-nan", " nan"),
25 ("nan%", "nan "),
26 ("nan", " - "),
27]
30class SolSavingEnvironment:
31 """Temporarily removes construction/solve attributes from constraints.
33 This approximately halves the size of the pickled solution.
34 """
36 def __init__(self, solarray, saveconstraints):
37 self.solarray = solarray
38 self.attrstore = {}
39 self.saveconstraints = saveconstraints
40 self.constraintstore = None
43 def __enter__(self):
44 if self.saveconstraints:
45 for constraint_attr in ["bounded", "meq_bounded", "vks",
46 "v_ss", "unsubbed", "varkeys"]:
47 store = {}
48 for constraint in self.solarray["sensitivities"]["constraints"]:
49 if getattr(constraint, constraint_attr, None):
50 store[constraint] = getattr(constraint, constraint_attr)
51 delattr(constraint, constraint_attr)
52 self.attrstore[constraint_attr] = store
53 else:
54 self.constraintstore = \
55 self.solarray["sensitivities"].pop("constraints")
57 def __exit__(self, type_, val, traceback):
58 if self.saveconstraints:
59 for constraint_attr, store in self.attrstore.items():
60 for constraint, value in store.items():
61 setattr(constraint, constraint_attr, value)
62 else:
63 self.solarray["sensitivities"]["constraints"] = self.constraintstore
65def msenss_table(data, _, **kwargs):
66 "Returns model sensitivity table lines"
67 if "models" not in data.get("sensitivities", {}):
68 return ""
69 data = sorted(data["sensitivities"]["models"].items(),
70 key=lambda i: ((i[1] < 0.1).all(),
71 -np.max(i[1]) if (i[1] < 0.1).all()
72 else -round(np.mean(i[1]), 1), i[0]))
73 lines = ["Model Sensitivities", "-------------------"]
74 if kwargs["sortmodelsbysenss"]:
75 lines[0] += " (sorts models in sections below)"
76 previousmsenssstr = ""
77 for model, msenss in data:
78 if not model: # for now let's only do named models
79 continue
80 if (msenss < 0.1).all():
81 msenss = np.max(msenss)
82 if msenss:
83 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss)))
84 else:
85 msenssstr = " =0 "
86 else:
87 meansenss = round(np.mean(msenss), 1)
88 msenssstr = "%+6.1f" % meansenss
89 deltas = msenss - meansenss
90 if np.max(np.abs(deltas)) > 0.1:
91 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
92 for d in deltas]
93 msenssstr += " + [ %s ]" % " ".join(deltastrs)
94 if msenssstr == previousmsenssstr:
95 msenssstr = " "*len(msenssstr)
96 else:
97 previousmsenssstr = msenssstr
98 lines.append("%s : %s" % (msenssstr, model))
99 return lines + [""] if len(lines) > 3 else []
102def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
103 "Returns sensitivity table lines"
104 if "variables" in data.get("sensitivities", {}):
105 data = data["sensitivities"]["variables"]
106 if showvars:
107 data = {k: data[k] for k in showvars if k in data}
108 return var_table(data, title, sortbyvals=True, skipifempty=True,
109 valfmt="%+-.2g ", vecfmt="%+-8.2g",
110 printunits=False, minval=1e-3, **kwargs)
113def topsenss_table(data, showvars, nvars=5, **kwargs):
114 "Returns top sensitivity table lines"
115 data, filtered = topsenss_filter(data, showvars, nvars)
116 title = "Most Sensitive Variables"
117 if filtered:
118 title = "Next Most Sensitive Variables"
119 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
122def topsenss_filter(data, showvars, nvars=5):
123 "Filters sensitivities down to top N vars"
124 if "variables" in data.get("sensitivities", {}):
125 data = data["sensitivities"]["variables"]
126 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
127 if not np.isnan(s).any()}
128 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
129 filter_already_shown = showvars.intersection(topk)
130 for k in filter_already_shown:
131 topk.remove(k)
132 if nvars > 3: # always show at least 3
133 nvars -= 1
134 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
137def insenss_table(data, _, maxval=0.1, **kwargs):
138 "Returns insensitivity table lines"
139 if "constants" in data.get("sensitivities", {}):
140 data = data["sensitivities"]["variables"]
141 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
142 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
145def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
146 "Return constraint tightness lines"
147 title = "Most Sensitive Constraints"
148 if len(self) > 1:
149 title += " (in last sweep)"
150 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)),
151 "%+6.2g" % abs(s[-1]), id(c), c)
152 for c, s in self["sensitivities"]["constraints"].items()
153 if s[-1] >= tight_senss)[:ntightconstrs]
154 else:
155 data = sorted(((-float("%+6.2g" % abs(s)), str(c)),
156 "%+6.2g" % abs(s), id(c), c)
157 for c, s in self["sensitivities"]["constraints"].items()
158 if s >= tight_senss)[:ntightconstrs]
159 return constraint_table(data, title, **kwargs)
161def loose_table(self, _, min_senss=1e-5, **kwargs):
162 "Return constraint tightness lines"
163 title = "Insensitive Constraints |below %+g|" % min_senss
164 if len(self) > 1:
165 title += " (in last sweep)"
166 data = [(0, "", id(c), c)
167 for c, s in self["sensitivities"]["constraints"].items()
168 if s[-1] <= min_senss]
169 else:
170 data = [(0, "", id(c), c)
171 for c, s in self["sensitivities"]["constraints"].items()
172 if s <= min_senss]
173 return constraint_table(data, title, **kwargs)
176# pylint: disable=too-many-branches,too-many-locals,too-many-statements
177def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
178 "Creates lines for tables where the right side is a constraint."
179 # TODO: this should support 1D array inputs from sweeps
180 excluded = {"units"} if showmodels else {"units", "lineage"}
181 models, decorated = {}, []
182 for sortby, openingstr, _, constraint in sorted(data):
183 model = lineagestr(constraint) if sortbymodel else ""
184 if model not in models:
185 models[model] = len(models)
186 constrstr = try_str_without(
187 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded))
188 if " at 0x" in constrstr: # don't print memory addresses
189 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
190 decorated.append((models[model], model, sortby, constrstr, openingstr))
191 decorated.sort()
192 previous_model, lines = None, []
193 for varlist in decorated:
194 _, model, _, constrstr, openingstr = varlist
195 if model != previous_model:
196 if lines:
197 lines.append(["", ""])
198 if model or lines:
199 lines.append([("newmodelline",), model])
200 previous_model = model
201 minlen, maxlen = 25, 80
202 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
203 constraintlines = []
204 line = ""
205 next_idx = 0
206 while next_idx < len(segments):
207 segment = segments[next_idx]
208 next_idx += 1
209 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
210 segments[next_idx] = segment[1:] + segments[next_idx]
211 segment = segment[0]
212 elif len(line) + len(segment) > maxlen and len(line) > minlen:
213 constraintlines.append(line)
214 line = " " # start a new line
215 line += segment
216 while len(line) > maxlen:
217 constraintlines.append(line[:maxlen])
218 line = " " + line[maxlen:]
219 constraintlines.append(line)
220 lines += [(openingstr + " : ", constraintlines[0])]
221 lines += [("", l) for l in constraintlines[1:]]
222 if not lines:
223 lines = [("", "(none)")]
224 maxlens = np.max([list(map(len, line)) for line in lines
225 if line[0] != ("newmodelline",)], axis=0)
226 dirs = [">", "<"] # we'll check lengths before using zip
227 assert len(list(dirs)) == len(list(maxlens))
228 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
229 for i, line in enumerate(lines):
230 if line[0] == ("newmodelline",):
231 linelist = [fmts[0].format(" | "), line[1]]
232 else:
233 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
234 lines[i] = "".join(linelist).rstrip()
235 return [title] + ["-"*len(title)] + lines + [""]
238def warnings_table(self, _, **kwargs):
239 "Makes a table for all warnings in the solution."
240 title = "WARNINGS"
241 lines = ["~"*len(title), title, "~"*len(title)]
242 if "warnings" not in self or not self["warnings"]:
243 return []
244 for wtype in sorted(self["warnings"]):
245 data_vec = self["warnings"][wtype]
246 if len(data_vec) == 0:
247 continue
248 if not hasattr(data_vec, "shape"):
249 data_vec = [data_vec] # not a sweep
250 else:
251 all_equal = True
252 for data in data_vec[1:]:
253 eq_i = (data == data_vec[0])
254 if hasattr(eq_i, "all"):
255 eq_i = eq_i.all()
256 if not eq_i:
257 all_equal = False
258 break
259 if all_equal:
260 data_vec = [data_vec[0]] # warnings identical across sweeps
261 for i, data in enumerate(data_vec):
262 if len(data) == 0:
263 continue
264 data = sorted(data, key=lambda l: l[0]) # sort by msg
265 title = wtype
266 if len(data_vec) > 1:
267 title += " in sweep %i" % i
268 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
269 data = [(-int(1e5*relax_sensitivity),
270 "%+6.2g" % relax_sensitivity, id(c), c)
271 for _, (relax_sensitivity, c) in data]
272 lines += constraint_table(data, title, **kwargs)
273 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
274 data = [(-int(1e5*rel_diff),
275 "%.4g %s %.4g" % tightvalues, id(c), c)
276 for _, (rel_diff, tightvalues, c) in data]
277 lines += constraint_table(data, title, **kwargs)
278 else:
279 lines += [title] + ["-"*len(wtype)]
280 lines += [msg for msg, _ in data] + [""]
281 if len(lines) == 3: # just the header
282 return []
283 lines[-1] = "~~~~~~~~"
284 return lines + [""]
286def bdtable_gen(key):
287 "Generator for breakdown tablefns"
289 def bdtable(self, _showvars, **_):
290 "Cost breakdown plot"
291 bds = Breakdowns(self)
292 original_stdout = sys.stdout
293 try:
294 sys.stdout = SolverLog(original_stdout, verbosity=0)
295 bds.plot(key)
296 finally:
297 lines = sys.stdout.lines()
298 sys.stdout = original_stdout
299 return lines
301 return bdtable
304TABLEFNS = {"sensitivities": senss_table,
305 "top sensitivities": topsenss_table,
306 "insensitivities": insenss_table,
307 "model sensitivities": msenss_table,
308 "tightest constraints": tight_table,
309 "loose constraints": loose_table,
310 "warnings": warnings_table,
311 "model sensitivities breakdown": bdtable_gen("model sensitivities"),
312 "cost breakdown": bdtable_gen("cost")
313 }
315def unrolled_absmax(values):
316 "From an iterable of numbers and arrays, returns the largest magnitude"
317 finalval, absmaxest = None, 0
318 for val in values:
319 absmaxval = np.abs(val).max()
320 if absmaxval >= absmaxest:
321 absmaxest, finalval = absmaxval, val
322 if getattr(finalval, "shape", None):
323 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
324 finalval.shape)]
325 return finalval
328def cast(function, val1, val2):
329 "Relative difference between val1 and val2 (positive if val2 is larger)"
330 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
331 pywarnings.simplefilter("ignore")
332 if hasattr(val1, "shape") and hasattr(val2, "shape"):
333 if val1.ndim == val2.ndim:
334 return function(val1, val2)
335 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
336 dimdelta = dimmest.ndim - lessdim.ndim
337 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
338 if dimmest is val1:
339 return function(dimmest, lessdim[add_axes])
340 if dimmest is val2:
341 return function(lessdim[add_axes], dimmest)
342 return function(val1, val2)
345class SolutionArray(DictOfLists):
346 """A dictionary (of dictionaries) of lists, with convenience methods.
348 Items
349 -----
350 cost : array
351 variables: dict of arrays
352 sensitivities: dict containing:
353 monomials : array
354 posynomials : array
355 variables: dict of arrays
356 localmodels : NomialArray
357 Local power-law fits (small sensitivities are cut off)
359 Example
360 -------
361 >>> import gpkit
362 >>> import numpy as np
363 >>> x = gpkit.Variable("x")
364 >>> x_min = gpkit.Variable("x_{min}", 2)
365 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
366 >>>
367 >>> # VALUES
368 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
369 >>> assert all(np.array(values) == 2)
370 >>>
371 >>> # SENSITIVITIES
372 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
373 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
374 >>> assert all(np.array(senss) == 1)
375 """
376 modelstr = ""
377 _name_collision_varkeys = None
378 _lineageset = False
379 table_titles = {"choicevariables": "Choice Variables",
380 "sweepvariables": "Swept Variables",
381 "freevariables": "Free Variables",
382 "constants": "Fixed Variables", # TODO: change everywhere
383 "variables": "Variables"}
385 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches
386 "Returns the set of contained varkeys whose names are not unique"
387 if self._name_collision_varkeys is None:
388 self._name_collision_varkeys = {}
389 self["variables"].update_keymap()
390 keymap = self["variables"].keymap
391 name_collisions = defaultdict(set)
392 for key in keymap:
393 if hasattr(key, "key"):
394 if len(keymap[key.name]) == 1: # very unique
395 self._name_collision_varkeys[key] = 0
396 else:
397 shortname = key.str_without(["lineage", "vec"])
398 if len(keymap[shortname]) > 1:
399 name_collisions[shortname].add(key)
400 for varkeys in name_collisions.values():
401 min_namespaced = defaultdict(set)
402 for vk in varkeys:
403 *_, mineage = vk.lineagestr().split(".")
404 min_namespaced[(mineage, 1)].add(vk)
405 while any(len(vks) > 1 for vks in min_namespaced.values()):
406 for key, vks in list(min_namespaced.items()):
407 if len(vks) <= 1:
408 continue
409 del min_namespaced[key]
410 mineage, idx = key
411 idx += 1
412 for vk in vks:
413 lineages = vk.lineagestr().split(".")
414 submineage = lineages[-idx] + "." + mineage
415 min_namespaced[(submineage, idx)].add(vk)
416 for (_, idx), vks in min_namespaced.items():
417 vk, = vks
418 self._name_collision_varkeys[vk] = idx
419 if clear:
420 self._lineageset = False
421 for vk in self._name_collision_varkeys:
422 del vk.descr["necessarylineage"]
423 else:
424 self._lineageset = True
425 for vk, idx in self._name_collision_varkeys.items():
426 vk.descr["necessarylineage"] = idx
428 def __len__(self):
429 try:
430 return len(self["cost"])
431 except TypeError:
432 return 1
433 except KeyError:
434 return 0
436 def __call__(self, posy):
437 posy_subbed = self.subinto(posy)
438 return getattr(posy_subbed, "c", posy_subbed)
440 def almost_equal(self, other, reltol=1e-3):
441 "Checks for almost-equality between two solutions"
442 svars, ovars = self["variables"], other["variables"]
443 svks, ovks = set(svars), set(ovars)
444 if svks != ovks:
445 return False
446 for key in svks:
447 if np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1)) >= reltol:
448 return False
449 return True
451 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
452 def diff(self, other, showvars=None, *,
453 constraintsdiff=True, senssdiff=False, sensstol=0.1,
454 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
455 sortmodelsbysenss=True, **tableargs):
456 """Outputs differences between this solution and another
458 Arguments
459 ---------
460 other : solution or string
461 strings will be treated as paths to pickled solutions
462 senssdiff : boolean
463 if True, show sensitivity differences
464 sensstol : float
465 the smallest sensitivity difference worth showing
466 absdiff : boolean
467 if True, show absolute differences
468 abstol : float
469 the smallest absolute difference worth showing
470 reldiff : boolean
471 if True, show relative differences
472 reltol : float
473 the smallest relative difference worth showing
475 Returns
476 -------
477 str
478 """
479 if sortmodelsbysenss:
480 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
481 else:
482 tableargs["sortmodelsbysenss"] = False
483 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
484 "skipifempty": False})
485 if isinstance(other, Strings):
486 if other[-4:] == ".pgz":
487 other = SolutionArray.decompress_file(other)
488 else:
489 other = pickle.load(open(other, "rb"))
490 svars, ovars = self["variables"], other["variables"]
491 lines = ["Solution Diff",
492 "=============",
493 "(argument is the baseline solution)", ""]
494 svks, ovks = set(svars), set(ovars)
495 if showvars:
496 lines[0] += " (for selected variables)"
497 lines[1] += "========================="
498 showvars = self._parse_showvars(showvars)
499 svks = {k for k in showvars if k in svars}
500 ovks = {k for k in showvars if k in ovars}
501 if constraintsdiff and other.modelstr and self.modelstr:
502 if self.modelstr == other.modelstr:
503 lines += ["** no constraint differences **", ""]
504 else:
505 cdiff = ["Constraint Differences",
506 "**********************"]
507 cdiff.extend(list(difflib.unified_diff(
508 other.modelstr.split("\n"), self.modelstr.split("\n"),
509 lineterm="", n=3))[2:])
510 cdiff += ["", "**********************", ""]
511 lines += cdiff
512 if svks - ovks:
513 lines.append("Variable(s) of this solution"
514 " which are not in the argument:")
515 lines.append("\n".join(" %s" % key for key in svks - ovks))
516 lines.append("")
517 if ovks - svks:
518 lines.append("Variable(s) of the argument"
519 " which are not in this solution:")
520 lines.append("\n".join(" %s" % key for key in ovks - svks))
521 lines.append("")
522 sharedvks = svks.intersection(ovks)
523 if reldiff:
524 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
525 for vk in sharedvks}
526 lines += var_table(rel_diff,
527 "Relative Differences |above %g%%|" % reltol,
528 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
529 minval=reltol, printunits=False, **tableargs)
530 if lines[-2][:10] == "-"*10: # nothing larger than reltol
531 lines.insert(-1, ("The largest is %+g%%."
532 % unrolled_absmax(rel_diff.values())))
533 if absdiff:
534 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
535 lines += var_table(abs_diff,
536 "Absolute Differences |above %g|" % abstol,
537 valfmt="%+.2g", vecfmt="%+8.2g",
538 minval=abstol, **tableargs)
539 if lines[-2][:10] == "-"*10: # nothing larger than abstol
540 lines.insert(-1, ("The largest is %+g."
541 % unrolled_absmax(abs_diff.values())))
542 if senssdiff:
543 ssenss = self["sensitivities"]["variables"]
544 osenss = other["sensitivities"]["variables"]
545 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
546 for vk in svks.intersection(ovks)}
547 lines += var_table(senss_delta,
548 "Sensitivity Differences |above %g|" % sensstol,
549 valfmt="%+-.2f ", vecfmt="%+-6.2f",
550 minval=sensstol, printunits=False, **tableargs)
551 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
552 lines.insert(-1, ("The largest is %+g."
553 % unrolled_absmax(senss_delta.values())))
554 return "\n".join(lines)
556 def save(self, filename="solution.pkl",
557 *, saveconstraints=True, **pickleargs):
558 """Pickles the solution and saves it to a file.
560 Solution can then be loaded with e.g.:
561 >>> import pickle
562 >>> pickle.load(open("solution.pkl"))
563 """
564 with SolSavingEnvironment(self, saveconstraints):
565 pickle.dump(self, open(filename, "wb"), **pickleargs)
567 def save_compressed(self, filename="solution.pgz",
568 *, saveconstraints=True, **cpickleargs):
569 "Pickle a file and then compress it into a file with extension."
570 with gzip.open(filename, "wb") as f:
571 with SolSavingEnvironment(self, saveconstraints):
572 pickled = pickle.dumps(self, **cpickleargs)
573 f.write(pickletools.optimize(pickled))
575 @staticmethod
576 def decompress_file(file):
577 "Load a gzip-compressed pickle file"
578 with gzip.open(file, "rb") as f:
579 return pickle.Unpickler(f).load()
581 def varnames(self, showvars, exclude):
582 "Returns list of variables, optionally with minimal unique names"
583 if showvars:
584 showvars = self._parse_showvars(showvars)
585 self.set_necessarylineage()
586 names = {}
587 for key in showvars or self["variables"]:
588 for k in self["variables"].keymap[key]:
589 names[k.str_without(exclude)] = k
590 self.set_necessarylineage(clear=True)
591 return names
593 def savemat(self, filename="solution.mat", *, showvars=None,
594 excluded=("vec")):
595 "Saves primal solution as matlab file"
596 from scipy.io import savemat
597 savemat(filename,
598 {name.replace(".", "_"): np.array(self["variables"][key], "f")
599 for name, key in self.varnames(showvars, excluded).items()})
601 def todataframe(self, showvars=None, excluded=("vec")):
602 "Returns primal solution as pandas dataframe"
603 import pandas as pd # pylint:disable=import-error
604 rows = []
605 cols = ["Name", "Index", "Value", "Units", "Label",
606 "Lineage", "Other"]
607 for _, key in sorted(self.varnames(showvars, excluded).items(),
608 key=lambda k: k[0]):
609 value = self["variables"][key]
610 if key.shape:
611 idxs = []
612 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
613 while not it.finished:
614 idx = it.multi_index
615 idxs.append(idx[0] if len(idx) == 1 else idx)
616 it.iternext()
617 else:
618 idxs = [None]
619 for idx in idxs:
620 row = [
621 key.name,
622 "" if idx is None else idx,
623 value if idx is None else value[idx]]
624 rows.append(row)
625 row.extend([
626 key.unitstr(),
627 key.label or "",
628 key.lineage or "",
629 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
630 if k not in ["name", "units", "unitrepr",
631 "idx", "shape", "veckey",
632 "value", "vecfn",
633 "lineage", "label"])])
634 return pd.DataFrame(rows, columns=cols)
636 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
637 "Saves solution table as a text file"
638 with open(filename, "w") as f:
639 if printmodel:
640 f.write(self.modelstr + "\n")
641 f.write(self.table(**kwargs))
643 def savejson(self, filename="solution.json", showvars=None):
644 "Saves solution table as a json file"
645 sol_dict = {}
646 if self._lineageset:
647 self.set_necessarylineage(clear=True)
648 data = self["variables"]
649 if showvars:
650 showvars = self._parse_showvars(showvars)
651 data = {k: data[k] for k in showvars if k in data}
652 # add appropriate data for each variable to the dictionary
653 for k, v in data.items():
654 key = str(k)
655 if isinstance(v, np.ndarray):
656 val = {"v": v.tolist(), "u": k.unitstr()}
657 else:
658 val = {"v": v, "u": k.unitstr()}
659 sol_dict[key] = val
660 with open(filename, "w") as f:
661 json.dump(sol_dict, f)
663 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
664 "Saves primal solution as a CSV sorted by modelname, like the tables."
665 data = self["variables"]
666 if showvars:
667 showvars = self._parse_showvars(showvars)
668 data = {k: data[k] for k in showvars if k in data}
669 # if the columns don't capture any dimensions, skip them
670 minspan, maxspan = None, 1
671 for v in data.values():
672 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
673 minspan_ = min((di for di in v.shape if di != 1))
674 maxspan_ = max((di for di in v.shape if di != 1))
675 if minspan is None or minspan_ < minspan:
676 minspan = minspan_
677 if maxspan is None or maxspan_ > maxspan:
678 maxspan = maxspan_
679 if minspan is not None and minspan > valcols:
680 valcols = 1
681 if maxspan < valcols:
682 valcols = maxspan
683 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
684 tables=("cost", "sweepvariables", "freevariables",
685 "constants", "sensitivities"))
686 with open(filename, "w") as f:
687 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
688 + "Units,Description\n")
689 for line in lines:
690 if line[0] == ("newmodelline",):
691 f.write(line[1])
692 elif not line[1]: # spacer line
693 f.write("\n")
694 else:
695 f.write("," + line[0].replace(" : ", "") + ",")
696 vals = line[1].replace("[", "").replace("]", "").strip()
697 for el in vals.split():
698 f.write(el + ",")
699 f.write(","*(valcols - len(vals.split())))
700 f.write((line[2].replace("[", "").replace("]", "").strip()
701 + ","))
702 f.write(line[3].strip() + "\n")
704 def subinto(self, posy):
705 "Returns NomialArray of each solution substituted into posy."
706 if posy in self["variables"]:
707 return self["variables"](posy)
709 if not hasattr(posy, "sub"):
710 raise ValueError("no variable '%s' found in the solution" % posy)
712 if len(self) > 1:
713 return NomialArray([self.atindex(i).subinto(posy)
714 for i in range(len(self))])
716 return posy.sub(self["variables"], require_positive=False)
718 def _parse_showvars(self, showvars):
719 showvars_out = set()
720 for k in showvars:
721 k, _ = self["variables"].parse_and_index(k)
722 keys = self["variables"].keymap[k]
723 showvars_out.update(keys)
724 return showvars_out
726 def summary(self, showvars=(), **kwargs):
727 "Print summary table, showing no sensitivities or constants"
728 return self.table(showvars,
729 ["cost breakdown", "model sensitivities breakdown",
730 "warnings", "sweepvariables", "freevariables"],
731 **kwargs)
733 def table(self, showvars=(),
734 tables=("cost breakdown", "model sensitivities breakdown",
735 "warnings", "sweepvariables", "freevariables",
736 "constants", "sensitivities", "tightest constraints"),
737 sortmodelsbysenss=False, **kwargs):
738 """A table representation of this SolutionArray
740 Arguments
741 ---------
742 tables: Iterable
743 Which to print of ("cost", "sweepvariables", "freevariables",
744 "constants", "sensitivities")
745 fixedcols: If true, print vectors in fixed-width format
746 latex: int
747 If > 0, return latex format (options 1-3); otherwise plain text
748 included_models: Iterable of strings
749 If specified, the models (by name) to include
750 excluded_models: Iterable of strings
751 If specified, model names to exclude
753 Returns
754 -------
755 str
756 """
757 if sortmodelsbysenss and "sensitivities" in self:
758 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
759 else:
760 kwargs["sortmodelsbysenss"] = False
761 varlist = list(self["variables"])
762 has_only_one_model = True
763 for var in varlist[1:]:
764 if var.lineage != varlist[0].lineage:
765 has_only_one_model = False
766 break
767 if has_only_one_model:
768 kwargs["sortbymodel"] = False
769 self.set_necessarylineage()
770 showvars = self._parse_showvars(showvars)
771 strs = []
772 for table in tables:
773 if "breakdown" in table:
774 if len(self) > 1 or not UNICODE_EXPONENTS:
775 # no breakdowns for sweeps or no-unicode environments
776 table = table.replace(" breakdown", "")
777 if "sensitivities" not in self and ("sensitivities" in table or
778 "constraints" in table):
779 continue
780 if table == "cost":
781 cost = self["cost"] # pylint: disable=unsubscriptable-object
782 if kwargs.get("latex", None): # cost is not printed for latex
783 continue
784 strs += ["\n%s\n------------" % "Optimal Cost"]
785 if len(self) > 1:
786 costs = ["%-8.3g" % c for c in mag(cost[:4])]
787 strs += [" [ %s %s ]" % (" ".join(costs),
788 "..." if len(self) > 4 else "")]
789 else:
790 strs += [" %-.4g" % mag(cost)]
791 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
792 strs += [""]
793 elif table in TABLEFNS:
794 strs += TABLEFNS[table](self, showvars, **kwargs)
795 elif table in self:
796 data = self[table]
797 if showvars:
798 showvars = self._parse_showvars(showvars)
799 data = {k: data[k] for k in showvars if k in data}
800 strs += var_table(data, self.table_titles[table], **kwargs)
801 if kwargs.get("latex", None):
802 preamble = "\n".join(("% \\documentclass[12pt]{article}",
803 "% \\usepackage{booktabs}",
804 "% \\usepackage{longtable}",
805 "% \\usepackage{amsmath}",
806 "% \\begin{document}\n"))
807 strs = [preamble] + strs + ["% \\end{document}"]
808 self.set_necessarylineage(clear=True)
809 return "\n".join(strs)
811 def plot(self, posys=None, axes=None):
812 "Plots a sweep for each posy"
813 if len(self["sweepvariables"]) != 1:
814 print("SolutionArray.plot only supports 1-dimensional sweeps")
815 if not hasattr(posys, "__len__"):
816 posys = [posys]
817 import matplotlib.pyplot as plt
818 from .interactive.plot_sweep import assign_axes
819 from . import GPBLU
820 (swept, x), = self["sweepvariables"].items()
821 posys, axes = assign_axes(swept, posys, axes)
822 for posy, ax in zip(posys, axes):
823 y = self(posy) if posy not in [None, "cost"] else self["cost"]
824 ax.plot(x, y, color=GPBLU)
825 if len(axes) == 1:
826 axes, = axes
827 return plt.gcf(), axes
830# pylint: disable=too-many-branches,too-many-locals,too-many-statements
831def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
832 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
833 minval=0, sortbyvals=False, hidebelowminval=False,
834 included_models=None, excluded_models=None, sortbymodel=True,
835 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
836 """
837 Pretty string representation of a dict of VarKeys
838 Iterable values are handled specially (partial printing)
840 Arguments
841 ---------
842 data : dict whose keys are VarKey's
843 data to represent in table
844 title : string
845 printunits : bool
846 latex : int
847 If > 0, return latex format (options 1-3); otherwise plain text
848 varfmt : string
849 format for variable names
850 valfmt : string
851 format for scalar values
852 vecfmt : string
853 format for vector values
854 minval : float
855 skip values with all(abs(value)) < minval
856 sortbyvals : boolean
857 If true, rows are sorted by their average value instead of by name.
858 included_models : Iterable of strings
859 If specified, the models (by name) to include
860 excluded_models : Iterable of strings
861 If specified, model names to exclude
862 """
863 if not data:
864 return []
865 decorated, models = [], set()
866 for i, (k, v) in enumerate(data.items()):
867 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
868 continue # no values below minval
869 if minval and hidebelowminval and getattr(v, "shape", None):
870 v[np.abs(v) <= minval] = np.nan
871 model = lineagestr(k.lineage) if sortbymodel else ""
872 if not sortmodelsbysenss:
873 msenss = 0
874 else: # sort should match that in msenss_table above
875 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4)
876 models.add(model)
877 b = bool(getattr(v, "shape", None))
878 s = k.str_without(("lineage", "vec"))
879 if not sortbyvals:
880 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
881 else: # for consistent sorting, add small offset to negative vals
882 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
883 sort = (float("%.4g" % -val), k.name)
884 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
885 if not decorated and skipifempty:
886 return []
887 if included_models:
888 included_models = set(included_models)
889 included_models.add("")
890 models = models.intersection(included_models)
891 if excluded_models:
892 models = models.difference(excluded_models)
893 decorated.sort()
894 previous_model, lines = None, []
895 for varlist in decorated:
896 if sortbyvals:
897 model, _, msenss, isvector, varstr, _, var, val = varlist
898 else:
899 msenss, model, isvector, varstr, _, var, val = varlist
900 if model not in models:
901 continue
902 if model != previous_model:
903 if lines:
904 lines.append(["", "", "", ""])
905 if model:
906 if not latex:
907 lines.append([("newmodelline",), model, "", ""])
908 else:
909 lines.append(
910 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
911 previous_model = model
912 label = var.descr.get("label", "")
913 units = var.unitstr(" [%s] ") if printunits else ""
914 if not isvector:
915 valstr = valfmt % val
916 else:
917 last_dim_index = len(val.shape)-1
918 horiz_dim, ncols = last_dim_index, 1 # starting values
919 for dim_idx, dim_size in enumerate(val.shape):
920 if ncols <= dim_size <= maxcolumns:
921 horiz_dim, ncols = dim_idx, dim_size
922 # align the array with horiz_dim by making it the last one
923 dim_order = list(range(last_dim_index))
924 dim_order.insert(horiz_dim, last_dim_index)
925 flatval = val.transpose(dim_order).flatten()
926 vals = [vecfmt % v for v in flatval[:ncols]]
927 bracket = " ] " if len(flatval) <= ncols else ""
928 valstr = "[ %s%s" % (" ".join(vals), bracket)
929 for before, after in VALSTR_REPLACES:
930 valstr = valstr.replace(before, after)
931 if not latex:
932 lines.append([varstr, valstr, units, label])
933 if isvector and len(flatval) > ncols:
934 values_remaining = len(flatval) - ncols
935 while values_remaining > 0:
936 idx = len(flatval)-values_remaining
937 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
938 values_remaining -= ncols
939 valstr = " " + " ".join(vals)
940 for before, after in VALSTR_REPLACES:
941 valstr = valstr.replace(before, after)
942 if values_remaining <= 0:
943 spaces = (-values_remaining
944 * len(valstr)//(values_remaining + ncols))
945 valstr = valstr + " ]" + " "*spaces
946 lines.append(["", valstr, "", ""])
947 else:
948 varstr = "$%s$" % varstr.replace(" : ", "")
949 if latex == 1: # normal results table
950 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
951 label])
952 coltitles = [title, "Value", "Units", "Description"]
953 elif latex == 2: # no values
954 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
955 coltitles = [title, "Units", "Description"]
956 elif latex == 3: # no description
957 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
958 coltitles = [title, "Value", "Units"]
959 else:
960 raise ValueError("Unexpected latex option, %s." % latex)
961 if rawlines:
962 return lines
963 if not latex:
964 if lines:
965 maxlens = np.max([list(map(len, line)) for line in lines
966 if line[0] != ("newmodelline",)], axis=0)
967 dirs = [">", "<", "<", "<"]
968 # check lengths before using zip
969 assert len(list(dirs)) == len(list(maxlens))
970 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
971 for i, line in enumerate(lines):
972 if line[0] == ("newmodelline",):
973 line = [fmts[0].format(" | "), line[1]]
974 else:
975 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
976 lines[i] = "".join(line).rstrip()
977 lines = [title] + ["-"*len(title)] + lines + [""]
978 else:
979 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
980 lines = (["\n".join(["{\\footnotesize",
981 "\\begin{longtable}{%s}" % colfmt[latex],
982 "\\toprule",
983 " & ".join(coltitles) + " \\\\ \\midrule"])] +
984 [" & ".join(l) + " \\\\" for l in lines] +
985 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
986 return lines