Coverage for gpkit/solution_array.py: 73%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import sys
3import re
4import json
5import difflib
6from operator import sub
7import warnings as pywarnings
8import pickle
9import gzip
10import pickletools
11from collections import defaultdict
12import numpy as np
13from .nomials import NomialArray
14from .small_classes import DictOfLists, Strings, SolverLog
15from .small_scripts import mag, try_str_without
16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS
17from .breakdowns import Breakdowns
20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
22VALSTR_REPLACES = [
23 ("+nan", " nan"),
24 ("-nan", " nan"),
25 ("nan%", "nan "),
26 ("nan", " - "),
27]
30class SolSavingEnvironment:
31 """Temporarily removes construction/solve attributes from constraints.
33 This approximately halves the size of the pickled solution.
34 """
36 def __init__(self, solarray, saveconstraints):
37 self.solarray = solarray
38 self.attrstore = {}
39 self.saveconstraints = saveconstraints
40 self.constraintstore = None
43 def __enter__(self):
44 if "sensitivities" not in self.solarray:
45 pass
46 elif self.saveconstraints:
47 for constraint_attr in ["bounded", "meq_bounded", "vks",
48 "v_ss", "unsubbed", "varkeys"]:
49 store = {}
50 for constraint in self.solarray["sensitivities"]["constraints"]:
51 if getattr(constraint, constraint_attr, None):
52 store[constraint] = getattr(constraint, constraint_attr)
53 delattr(constraint, constraint_attr)
54 self.attrstore[constraint_attr] = store
55 else:
56 self.constraintstore = \
57 self.solarray["sensitivities"].pop("constraints")
59 def __exit__(self, type_, val, traceback):
60 if not self.constraintstore:
61 pass
62 elif self.saveconstraints:
63 for constraint_attr, store in self.attrstore.items():
64 for constraint, value in store.items():
65 setattr(constraint, constraint_attr, value)
66 else:
67 self.solarray["sensitivities"]["constraints"] = self.constraintstore
69def msenss_table(data, _, **kwargs):
70 "Returns model sensitivity table lines"
71 if "models" not in data.get("sensitivities", {}):
72 return ""
73 data = sorted(data["sensitivities"]["models"].items(),
74 key=lambda i: ((i[1] < 0.1).all(),
75 -np.max(i[1]) if (i[1] < 0.1).all()
76 else -round(np.mean(i[1]), 1), i[0]))
77 lines = ["Model Sensitivities", "-------------------"]
78 if kwargs["sortmodelsbysenss"]:
79 lines[0] += " (sorts models in sections below)"
80 previousmsenssstr = ""
81 for model, msenss in data:
82 if not model: # for now let's only do named models
83 continue
84 if (msenss < 0.1).all():
85 msenss = np.max(msenss)
86 if msenss:
87 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss)))
88 else:
89 msenssstr = " =0 "
90 else:
91 meansenss = round(np.mean(msenss), 1)
92 msenssstr = "%+6.1f" % meansenss
93 deltas = msenss - meansenss
94 if np.max(np.abs(deltas)) > 0.1:
95 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
96 for d in deltas]
97 msenssstr += " + [ %s ]" % " ".join(deltastrs)
98 if msenssstr == previousmsenssstr:
99 msenssstr = " "*len(msenssstr)
100 else:
101 previousmsenssstr = msenssstr
102 lines.append("%s : %s" % (msenssstr, model))
103 return lines + [""] if len(lines) > 3 else []
106def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
107 "Returns sensitivity table lines"
108 if "variables" in data.get("sensitivities", {}):
109 data = data["sensitivities"]["variables"]
110 if showvars:
111 data = {k: data[k] for k in showvars if k in data}
112 return var_table(data, title, sortbyvals=True, skipifempty=True,
113 valfmt="%+-.2g ", vecfmt="%+-8.2g",
114 printunits=False, minval=1e-3, **kwargs)
117def topsenss_table(data, showvars, nvars=5, **kwargs):
118 "Returns top sensitivity table lines"
119 data, filtered = topsenss_filter(data, showvars, nvars)
120 title = "Most Sensitive Variables"
121 if filtered:
122 title = "Next Most Sensitive Variables"
123 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
126def topsenss_filter(data, showvars, nvars=5):
127 "Filters sensitivities down to top N vars"
128 if "variables" in data.get("sensitivities", {}):
129 data = data["sensitivities"]["variables"]
130 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
131 if not np.isnan(s).any()}
132 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
133 filter_already_shown = showvars.intersection(topk)
134 for k in filter_already_shown:
135 topk.remove(k)
136 if nvars > 3: # always show at least 3
137 nvars -= 1
138 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
141def insenss_table(data, _, maxval=0.1, **kwargs):
142 "Returns insensitivity table lines"
143 if "constants" in data.get("sensitivities", {}):
144 data = data["sensitivities"]["variables"]
145 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
146 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
149def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
150 "Return constraint tightness lines"
151 title = "Most Sensitive Constraints"
152 if len(self) > 1:
153 title += " (in last sweep)"
154 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)),
155 "%+6.2g" % abs(s[-1]), id(c), c)
156 for c, s in self["sensitivities"]["constraints"].items()
157 if s[-1] >= tight_senss)[:ntightconstrs]
158 else:
159 data = sorted(((-float("%+6.2g" % abs(s)), str(c)),
160 "%+6.2g" % abs(s), id(c), c)
161 for c, s in self["sensitivities"]["constraints"].items()
162 if s >= tight_senss)[:ntightconstrs]
163 return constraint_table(data, title, **kwargs)
165def loose_table(self, _, min_senss=1e-5, **kwargs):
166 "Return constraint tightness lines"
167 title = "Insensitive Constraints |below %+g|" % min_senss
168 if len(self) > 1:
169 title += " (in last sweep)"
170 data = [(0, "", id(c), c)
171 for c, s in self["sensitivities"]["constraints"].items()
172 if s[-1] <= min_senss]
173 else:
174 data = [(0, "", id(c), c)
175 for c, s in self["sensitivities"]["constraints"].items()
176 if s <= min_senss]
177 return constraint_table(data, title, **kwargs)
180# pylint: disable=too-many-branches,too-many-locals,too-many-statements
181def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
182 "Creates lines for tables where the right side is a constraint."
183 # TODO: this should support 1D array inputs from sweeps
184 excluded = {"units"} if showmodels else {"units", "lineage"}
185 models, decorated = {}, []
186 for sortby, openingstr, _, constraint in sorted(data):
187 model = lineagestr(constraint) if sortbymodel else ""
188 if model not in models:
189 models[model] = len(models)
190 constrstr = try_str_without(
191 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded))
192 if " at 0x" in constrstr: # don't print memory addresses
193 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
194 decorated.append((models[model], model, sortby, constrstr, openingstr))
195 decorated.sort()
196 previous_model, lines = None, []
197 for varlist in decorated:
198 _, model, _, constrstr, openingstr = varlist
199 if model != previous_model:
200 if lines:
201 lines.append(["", ""])
202 if model or lines:
203 lines.append([("newmodelline",), model])
204 previous_model = model
205 minlen, maxlen = 25, 80
206 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
207 constraintlines = []
208 line = ""
209 next_idx = 0
210 while next_idx < len(segments):
211 segment = segments[next_idx]
212 next_idx += 1
213 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
214 segments[next_idx] = segment[1:] + segments[next_idx]
215 segment = segment[0]
216 elif len(line) + len(segment) > maxlen and len(line) > minlen:
217 constraintlines.append(line)
218 line = " " # start a new line
219 line += segment
220 while len(line) > maxlen:
221 constraintlines.append(line[:maxlen])
222 line = " " + line[maxlen:]
223 constraintlines.append(line)
224 lines += [(openingstr + " : ", constraintlines[0])]
225 lines += [("", l) for l in constraintlines[1:]]
226 if not lines:
227 lines = [("", "(none)")]
228 maxlens = np.max([list(map(len, line)) for line in lines
229 if line[0] != ("newmodelline",)], axis=0)
230 dirs = [">", "<"] # we'll check lengths before using zip
231 assert len(list(dirs)) == len(list(maxlens))
232 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
233 for i, line in enumerate(lines):
234 if line[0] == ("newmodelline",):
235 linelist = [fmts[0].format(" | "), line[1]]
236 else:
237 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
238 lines[i] = "".join(linelist).rstrip()
239 return [title] + ["-"*len(title)] + lines + [""]
242def warnings_table(self, _, **kwargs):
243 "Makes a table for all warnings in the solution."
244 title = "WARNINGS"
245 lines = ["~"*len(title), title, "~"*len(title)]
246 if "warnings" not in self or not self["warnings"]:
247 return []
248 for wtype in sorted(self["warnings"]):
249 data_vec = self["warnings"][wtype]
250 if len(data_vec) == 0:
251 continue
252 if not hasattr(data_vec, "shape"):
253 data_vec = [data_vec] # not a sweep
254 else:
255 all_equal = True
256 for data in data_vec[1:]:
257 eq_i = (data == data_vec[0])
258 if hasattr(eq_i, "all"):
259 eq_i = eq_i.all()
260 if not eq_i:
261 all_equal = False
262 break
263 if all_equal:
264 data_vec = [data_vec[0]] # warnings identical across sweeps
265 for i, data in enumerate(data_vec):
266 if len(data) == 0:
267 continue
268 data = sorted(data, key=lambda l: l[0]) # sort by msg
269 title = wtype
270 if len(data_vec) > 1:
271 title += " in sweep %i" % i
272 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
273 data = [(-int(1e5*relax_sensitivity),
274 "%+6.2g" % relax_sensitivity, id(c), c)
275 for _, (relax_sensitivity, c) in data]
276 lines += constraint_table(data, title, **kwargs)
277 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
278 data = [(-int(1e5*rel_diff),
279 "%.4g %s %.4g" % tightvalues, id(c), c)
280 for _, (rel_diff, tightvalues, c) in data]
281 lines += constraint_table(data, title, **kwargs)
282 else:
283 lines += [title] + ["-"*len(wtype)]
284 lines += [msg for msg, _ in data] + [""]
285 if len(lines) == 3: # just the header
286 return []
287 lines[-1] = "~~~~~~~~"
288 return lines + [""]
290def bdtable_gen(key):
291 "Generator for breakdown tablefns"
293 def bdtable(self, _showvars, **_):
294 "Cost breakdown plot"
295 bds = Breakdowns(self)
296 original_stdout = sys.stdout
297 try:
298 sys.stdout = SolverLog(original_stdout, verbosity=0)
299 bds.plot(key)
300 finally:
301 lines = sys.stdout.lines()
302 sys.stdout = original_stdout
303 return lines
305 return bdtable
308TABLEFNS = {"sensitivities": senss_table,
309 "top sensitivities": topsenss_table,
310 "insensitivities": insenss_table,
311 "model sensitivities": msenss_table,
312 "tightest constraints": tight_table,
313 "loose constraints": loose_table,
314 "warnings": warnings_table,
315 "model sensitivities breakdown": bdtable_gen("model sensitivities"),
316 "cost breakdown": bdtable_gen("cost")
317 }
319def unrolled_absmax(values):
320 "From an iterable of numbers and arrays, returns the largest magnitude"
321 finalval, absmaxest = None, 0
322 for val in values:
323 absmaxval = np.abs(val).max()
324 if absmaxval >= absmaxest:
325 absmaxest, finalval = absmaxval, val
326 if getattr(finalval, "shape", None):
327 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
328 finalval.shape)]
329 return finalval
332def cast(function, val1, val2):
333 "Relative difference between val1 and val2 (positive if val2 is larger)"
334 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
335 pywarnings.simplefilter("ignore")
336 if hasattr(val1, "shape") and hasattr(val2, "shape"):
337 if val1.ndim == val2.ndim:
338 return function(val1, val2)
339 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
340 dimdelta = dimmest.ndim - lessdim.ndim
341 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
342 if dimmest is val1:
343 return function(dimmest, lessdim[add_axes])
344 if dimmest is val2:
345 return function(lessdim[add_axes], dimmest)
346 return function(val1, val2)
349class SolutionArray(DictOfLists):
350 """A dictionary (of dictionaries) of lists, with convenience methods.
352 Items
353 -----
354 cost : array
355 variables: dict of arrays
356 sensitivities: dict containing:
357 monomials : array
358 posynomials : array
359 variables: dict of arrays
360 localmodels : NomialArray
361 Local power-law fits (small sensitivities are cut off)
363 Example
364 -------
365 >>> import gpkit
366 >>> import numpy as np
367 >>> x = gpkit.Variable("x")
368 >>> x_min = gpkit.Variable("x_{min}", 2)
369 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
370 >>>
371 >>> # VALUES
372 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
373 >>> assert all(np.array(values) == 2)
374 >>>
375 >>> # SENSITIVITIES
376 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
377 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
378 >>> assert all(np.array(senss) == 1)
379 """
380 modelstr = ""
381 _name_collision_varkeys = None
382 _lineageset = False
383 table_titles = {"choicevariables": "Choice Variables",
384 "sweepvariables": "Swept Variables",
385 "freevariables": "Free Variables",
386 "constants": "Fixed Variables", # TODO: change everywhere
387 "variables": "Variables"}
389 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches
390 "Returns the set of contained varkeys whose names are not unique"
391 if self._name_collision_varkeys is None:
392 self._name_collision_varkeys = {}
393 self["variables"].update_keymap()
394 keymap = self["variables"].keymap
395 name_collisions = defaultdict(set)
396 for key in keymap:
397 if hasattr(key, "key"):
398 if len(keymap[key.name]) == 1: # very unique
399 self._name_collision_varkeys[key] = 0
400 else:
401 shortname = key.str_without(["lineage", "vec"])
402 if len(keymap[shortname]) > 1:
403 name_collisions[shortname].add(key)
404 for varkeys in name_collisions.values():
405 min_namespaced = defaultdict(set)
406 for vk in varkeys:
407 *_, mineage = vk.lineagestr().split(".")
408 min_namespaced[(mineage, 1)].add(vk)
409 while any(len(vks) > 1 for vks in min_namespaced.values()):
410 for key, vks in list(min_namespaced.items()):
411 if len(vks) <= 1:
412 continue
413 del min_namespaced[key]
414 mineage, idx = key
415 idx += 1
416 for vk in vks:
417 lineages = vk.lineagestr().split(".")
418 submineage = lineages[-idx] + "." + mineage
419 min_namespaced[(submineage, idx)].add(vk)
420 for (_, idx), vks in min_namespaced.items():
421 vk, = vks
422 self._name_collision_varkeys[vk] = idx
423 if clear:
424 self._lineageset = False
425 for vk in self._name_collision_varkeys:
426 del vk.descr["necessarylineage"]
427 else:
428 self._lineageset = True
429 for vk, idx in self._name_collision_varkeys.items():
430 vk.descr["necessarylineage"] = idx
432 def __len__(self):
433 try:
434 return len(self["cost"])
435 except TypeError:
436 return 1
437 except KeyError:
438 return 0
440 def __call__(self, posy):
441 posy_subbed = self.subinto(posy)
442 return getattr(posy_subbed, "c", posy_subbed)
444 def almost_equal(self, other, reltol=1e-3):
445 "Checks for almost-equality between two solutions"
446 svars, ovars = self["variables"], other["variables"]
447 svks, ovks = set(svars), set(ovars)
448 if svks != ovks:
449 return False
450 for key in svks:
451 reldiff = np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1))
452 if reldiff >= reltol:
453 return False
454 return True
456 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
457 def diff(self, other, showvars=None, *,
458 constraintsdiff=True, senssdiff=False, sensstol=0.1,
459 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
460 sortmodelsbysenss=True, **tableargs):
461 """Outputs differences between this solution and another
463 Arguments
464 ---------
465 other : solution or string
466 strings will be treated as paths to pickled solutions
467 senssdiff : boolean
468 if True, show sensitivity differences
469 sensstol : float
470 the smallest sensitivity difference worth showing
471 absdiff : boolean
472 if True, show absolute differences
473 abstol : float
474 the smallest absolute difference worth showing
475 reldiff : boolean
476 if True, show relative differences
477 reltol : float
478 the smallest relative difference worth showing
480 Returns
481 -------
482 str
483 """
484 if sortmodelsbysenss:
485 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
486 else:
487 tableargs["sortmodelsbysenss"] = False
488 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
489 "skipifempty": False})
490 if isinstance(other, Strings):
491 if other[-4:] == ".pgz":
492 other = SolutionArray.decompress_file(other)
493 else:
494 other = pickle.load(open(other, "rb"))
495 svars, ovars = self["variables"], other["variables"]
496 lines = ["Solution Diff",
497 "=============",
498 "(argument is the baseline solution)", ""]
499 svks, ovks = set(svars), set(ovars)
500 if showvars:
501 lines[0] += " (for selected variables)"
502 lines[1] += "========================="
503 showvars = self._parse_showvars(showvars)
504 svks = {k for k in showvars if k in svars}
505 ovks = {k for k in showvars if k in ovars}
506 if constraintsdiff and other.modelstr and self.modelstr:
507 if self.modelstr == other.modelstr:
508 lines += ["** no constraint differences **", ""]
509 else:
510 cdiff = ["Constraint Differences",
511 "**********************"]
512 cdiff.extend(list(difflib.unified_diff(
513 other.modelstr.split("\n"), self.modelstr.split("\n"),
514 lineterm="", n=3))[2:])
515 cdiff += ["", "**********************", ""]
516 lines += cdiff
517 if svks - ovks:
518 lines.append("Variable(s) of this solution"
519 " which are not in the argument:")
520 lines.append("\n".join(" %s" % key for key in svks - ovks))
521 lines.append("")
522 if ovks - svks:
523 lines.append("Variable(s) of the argument"
524 " which are not in this solution:")
525 lines.append("\n".join(" %s" % key for key in ovks - svks))
526 lines.append("")
527 sharedvks = svks.intersection(ovks)
528 if reldiff:
529 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
530 for vk in sharedvks}
531 lines += var_table(rel_diff,
532 "Relative Differences |above %g%%|" % reltol,
533 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
534 minval=reltol, printunits=False, **tableargs)
535 if lines[-2][:10] == "-"*10: # nothing larger than reltol
536 lines.insert(-1, ("The largest is %+g%%."
537 % unrolled_absmax(rel_diff.values())))
538 if absdiff:
539 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
540 lines += var_table(abs_diff,
541 "Absolute Differences |above %g|" % abstol,
542 valfmt="%+.2g", vecfmt="%+8.2g",
543 minval=abstol, **tableargs)
544 if lines[-2][:10] == "-"*10: # nothing larger than abstol
545 lines.insert(-1, ("The largest is %+g."
546 % unrolled_absmax(abs_diff.values())))
547 if senssdiff:
548 ssenss = self["sensitivities"]["variables"]
549 osenss = other["sensitivities"]["variables"]
550 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
551 for vk in svks.intersection(ovks)}
552 lines += var_table(senss_delta,
553 "Sensitivity Differences |above %g|" % sensstol,
554 valfmt="%+-.2f ", vecfmt="%+-6.2f",
555 minval=sensstol, printunits=False, **tableargs)
556 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
557 lines.insert(-1, ("The largest is %+g."
558 % unrolled_absmax(senss_delta.values())))
559 return "\n".join(lines)
561 def save(self, filename="solution.pkl",
562 *, saveconstraints=True, **pickleargs):
563 """Pickles the solution and saves it to a file.
565 Solution can then be loaded with e.g.:
566 >>> import pickle
567 >>> pickle.load(open("solution.pkl"))
568 """
569 with SolSavingEnvironment(self, saveconstraints):
570 pickle.dump(self, open(filename, "wb"), **pickleargs)
572 def save_compressed(self, filename="solution.pgz",
573 *, saveconstraints=True, **cpickleargs):
574 "Pickle a file and then compress it into a file with extension."
575 with gzip.open(filename, "wb") as f:
576 with SolSavingEnvironment(self, saveconstraints):
577 pickled = pickle.dumps(self, **cpickleargs)
578 f.write(pickletools.optimize(pickled))
580 @staticmethod
581 def decompress_file(file):
582 "Load a gzip-compressed pickle file"
583 with gzip.open(file, "rb") as f:
584 return pickle.Unpickler(f).load()
586 def varnames(self, showvars, exclude):
587 "Returns list of variables, optionally with minimal unique names"
588 if showvars:
589 showvars = self._parse_showvars(showvars)
590 self.set_necessarylineage()
591 names = {}
592 for key in showvars or self["variables"]:
593 for k in self["variables"].keymap[key]:
594 names[k.str_without(exclude)] = k
595 self.set_necessarylineage(clear=True)
596 return names
598 def savemat(self, filename="solution.mat", *, showvars=None,
599 excluded=("vec")):
600 "Saves primal solution as matlab file"
601 from scipy.io import savemat
602 savemat(filename,
603 {name.replace(".", "_"): np.array(self["variables"][key], "f")
604 for name, key in self.varnames(showvars, excluded).items()})
606 def todataframe(self, showvars=None, excluded=("vec")):
607 "Returns primal solution as pandas dataframe"
608 import pandas as pd # pylint:disable=import-error
609 rows = []
610 cols = ["Name", "Index", "Value", "Units", "Label",
611 "Lineage", "Other"]
612 for _, key in sorted(self.varnames(showvars, excluded).items(),
613 key=lambda k: k[0]):
614 value = self["variables"][key]
615 if key.shape:
616 idxs = []
617 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
618 while not it.finished:
619 idx = it.multi_index
620 idxs.append(idx[0] if len(idx) == 1 else idx)
621 it.iternext()
622 else:
623 idxs = [None]
624 for idx in idxs:
625 row = [
626 key.name,
627 "" if idx is None else idx,
628 value if idx is None else value[idx]]
629 rows.append(row)
630 row.extend([
631 key.unitstr(),
632 key.label or "",
633 key.lineage or "",
634 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
635 if k not in ["name", "units", "unitrepr",
636 "idx", "shape", "veckey",
637 "value", "vecfn",
638 "lineage", "label"])])
639 return pd.DataFrame(rows, columns=cols)
641 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
642 "Saves solution table as a text file"
643 with open(filename, "w") as f:
644 if printmodel:
645 f.write(self.modelstr + "\n")
646 f.write(self.table(**kwargs))
648 def savejson(self, filename="solution.json", showvars=None):
649 "Saves solution table as a json file"
650 sol_dict = {}
651 if self._lineageset:
652 self.set_necessarylineage(clear=True)
653 data = self["variables"]
654 if showvars:
655 showvars = self._parse_showvars(showvars)
656 data = {k: data[k] for k in showvars if k in data}
657 # add appropriate data for each variable to the dictionary
658 for k, v in data.items():
659 key = str(k)
660 if isinstance(v, np.ndarray):
661 val = {"v": v.tolist(), "u": k.unitstr()}
662 else:
663 val = {"v": v, "u": k.unitstr()}
664 sol_dict[key] = val
665 with open(filename, "w") as f:
666 json.dump(sol_dict, f)
668 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
669 "Saves primal solution as a CSV sorted by modelname, like the tables."
670 data = self["variables"]
671 if showvars:
672 showvars = self._parse_showvars(showvars)
673 data = {k: data[k] for k in showvars if k in data}
674 # if the columns don't capture any dimensions, skip them
675 minspan, maxspan = None, 1
676 for v in data.values():
677 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
678 minspan_ = min((di for di in v.shape if di != 1))
679 maxspan_ = max((di for di in v.shape if di != 1))
680 if minspan is None or minspan_ < minspan:
681 minspan = minspan_
682 if maxspan is None or maxspan_ > maxspan:
683 maxspan = maxspan_
684 if minspan is not None and minspan > valcols:
685 valcols = 1
686 if maxspan < valcols:
687 valcols = maxspan
688 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
689 tables=("cost", "sweepvariables", "freevariables",
690 "constants", "sensitivities"))
691 with open(filename, "w") as f:
692 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
693 + "Units,Description\n")
694 for line in lines:
695 if line[0] == ("newmodelline",):
696 f.write(line[1])
697 elif not line[1]: # spacer line
698 f.write("\n")
699 else:
700 f.write("," + line[0].replace(" : ", "") + ",")
701 vals = line[1].replace("[", "").replace("]", "").strip()
702 for el in vals.split():
703 f.write(el + ",")
704 f.write(","*(valcols - len(vals.split())))
705 f.write((line[2].replace("[", "").replace("]", "").strip()
706 + ","))
707 f.write(line[3].strip() + "\n")
709 def subinto(self, posy):
710 "Returns NomialArray of each solution substituted into posy."
711 if posy in self["variables"]:
712 return self["variables"](posy)
714 if not hasattr(posy, "sub"):
715 raise ValueError("no variable '%s' found in the solution" % posy)
717 if len(self) > 1:
718 return NomialArray([self.atindex(i).subinto(posy)
719 for i in range(len(self))])
721 return posy.sub(self["variables"], require_positive=False)
723 def _parse_showvars(self, showvars):
724 showvars_out = set()
725 for k in showvars:
726 k, _ = self["variables"].parse_and_index(k)
727 keys = self["variables"].keymap[k]
728 showvars_out.update(keys)
729 return showvars_out
731 def summary(self, showvars=(), **kwargs):
732 "Print summary table, showing no sensitivities or constants"
733 return self.table(showvars,
734 ["cost breakdown", "model sensitivities breakdown",
735 "warnings", "sweepvariables", "freevariables"],
736 **kwargs)
738 def table(self, showvars=(),
739 tables=("cost breakdown", "model sensitivities breakdown",
740 "warnings", "sweepvariables", "freevariables",
741 "constants", "sensitivities", "tightest constraints"),
742 sortmodelsbysenss=False, **kwargs):
743 """A table representation of this SolutionArray
745 Arguments
746 ---------
747 tables: Iterable
748 Which to print of ("cost", "sweepvariables", "freevariables",
749 "constants", "sensitivities")
750 fixedcols: If true, print vectors in fixed-width format
751 latex: int
752 If > 0, return latex format (options 1-3); otherwise plain text
753 included_models: Iterable of strings
754 If specified, the models (by name) to include
755 excluded_models: Iterable of strings
756 If specified, model names to exclude
758 Returns
759 -------
760 str
761 """
762 if sortmodelsbysenss and "sensitivities" in self:
763 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
764 else:
765 kwargs["sortmodelsbysenss"] = False
766 varlist = list(self["variables"])
767 has_only_one_model = True
768 for var in varlist[1:]:
769 if var.lineage != varlist[0].lineage:
770 has_only_one_model = False
771 break
772 if has_only_one_model:
773 kwargs["sortbymodel"] = False
774 self.set_necessarylineage()
775 showvars = self._parse_showvars(showvars)
776 strs = []
777 for table in tables:
778 if "breakdown" in table:
779 if (len(self) > 1 or not UNICODE_EXPONENTS
780 or "sensitivities" not in self):
781 # no breakdowns for sweeps or no-unicode environments
782 table = table.replace(" breakdown", "")
783 if "sensitivities" not in self and ("sensitivities" in table or
784 "constraints" in table):
785 continue
786 if table == "cost":
787 cost = self["cost"] # pylint: disable=unsubscriptable-object
788 if kwargs.get("latex", None): # cost is not printed for latex
789 continue
790 strs += ["\n%s\n------------" % "Optimal Cost"]
791 if len(self) > 1:
792 costs = ["%-8.3g" % c for c in mag(cost[:4])]
793 strs += [" [ %s %s ]" % (" ".join(costs),
794 "..." if len(self) > 4 else "")]
795 else:
796 strs += [" %-.4g" % mag(cost)]
797 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
798 strs += [""]
799 elif table in TABLEFNS:
800 strs += TABLEFNS[table](self, showvars, **kwargs)
801 elif table in self:
802 data = self[table]
803 if showvars:
804 showvars = self._parse_showvars(showvars)
805 data = {k: data[k] for k in showvars if k in data}
806 strs += var_table(data, self.table_titles[table], **kwargs)
807 if kwargs.get("latex", None):
808 preamble = "\n".join(("% \\documentclass[12pt]{article}",
809 "% \\usepackage{booktabs}",
810 "% \\usepackage{longtable}",
811 "% \\usepackage{amsmath}",
812 "% \\begin{document}\n"))
813 strs = [preamble] + strs + ["% \\end{document}"]
814 self.set_necessarylineage(clear=True)
815 return "\n".join(strs)
817 def plot(self, posys=None, axes=None):
818 "Plots a sweep for each posy"
819 if len(self["sweepvariables"]) != 1:
820 print("SolutionArray.plot only supports 1-dimensional sweeps")
821 if not hasattr(posys, "__len__"):
822 posys = [posys]
823 import matplotlib.pyplot as plt
824 from .interactive.plot_sweep import assign_axes
825 from . import GPBLU
826 (swept, x), = self["sweepvariables"].items()
827 posys, axes = assign_axes(swept, posys, axes)
828 for posy, ax in zip(posys, axes):
829 y = self(posy) if posy not in [None, "cost"] else self["cost"]
830 ax.plot(x, y, color=GPBLU)
831 if len(axes) == 1:
832 axes, = axes
833 return plt.gcf(), axes
836# pylint: disable=too-many-branches,too-many-locals,too-many-statements
837def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
838 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
839 minval=0, sortbyvals=False, hidebelowminval=False,
840 included_models=None, excluded_models=None, sortbymodel=True,
841 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
842 """
843 Pretty string representation of a dict of VarKeys
844 Iterable values are handled specially (partial printing)
846 Arguments
847 ---------
848 data : dict whose keys are VarKey's
849 data to represent in table
850 title : string
851 printunits : bool
852 latex : int
853 If > 0, return latex format (options 1-3); otherwise plain text
854 varfmt : string
855 format for variable names
856 valfmt : string
857 format for scalar values
858 vecfmt : string
859 format for vector values
860 minval : float
861 skip values with all(abs(value)) < minval
862 sortbyvals : boolean
863 If true, rows are sorted by their average value instead of by name.
864 included_models : Iterable of strings
865 If specified, the models (by name) to include
866 excluded_models : Iterable of strings
867 If specified, model names to exclude
868 """
869 if not data:
870 return []
871 decorated, models = [], set()
872 for i, (k, v) in enumerate(data.items()):
873 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
874 continue # no values below minval
875 if minval and hidebelowminval and getattr(v, "shape", None):
876 v[np.abs(v) <= minval] = np.nan
877 model = lineagestr(k.lineage) if sortbymodel else ""
878 if not sortmodelsbysenss:
879 msenss = 0
880 else: # sort should match that in msenss_table above
881 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4)
882 models.add(model)
883 b = bool(getattr(v, "shape", None))
884 s = k.str_without(("lineage", "vec"))
885 if not sortbyvals:
886 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
887 else: # for consistent sorting, add small offset to negative vals
888 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
889 sort = (float("%.4g" % -val), k.name)
890 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
891 if not decorated and skipifempty:
892 return []
893 if included_models:
894 included_models = set(included_models)
895 included_models.add("")
896 models = models.intersection(included_models)
897 if excluded_models:
898 models = models.difference(excluded_models)
899 decorated.sort()
900 previous_model, lines = None, []
901 for varlist in decorated:
902 if sortbyvals:
903 model, _, msenss, isvector, varstr, _, var, val = varlist
904 else:
905 msenss, model, isvector, varstr, _, var, val = varlist
906 if model not in models:
907 continue
908 if model != previous_model:
909 if lines:
910 lines.append(["", "", "", ""])
911 if model:
912 if not latex:
913 lines.append([("newmodelline",), model, "", ""])
914 else:
915 lines.append(
916 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
917 previous_model = model
918 label = var.descr.get("label", "")
919 units = var.unitstr(" [%s] ") if printunits else ""
920 if not isvector:
921 valstr = valfmt % val
922 else:
923 last_dim_index = len(val.shape)-1
924 horiz_dim, ncols = last_dim_index, 1 # starting values
925 for dim_idx, dim_size in enumerate(val.shape):
926 if ncols <= dim_size <= maxcolumns:
927 horiz_dim, ncols = dim_idx, dim_size
928 # align the array with horiz_dim by making it the last one
929 dim_order = list(range(last_dim_index))
930 dim_order.insert(horiz_dim, last_dim_index)
931 flatval = val.transpose(dim_order).flatten()
932 vals = [vecfmt % v for v in flatval[:ncols]]
933 bracket = " ] " if len(flatval) <= ncols else ""
934 valstr = "[ %s%s" % (" ".join(vals), bracket)
935 for before, after in VALSTR_REPLACES:
936 valstr = valstr.replace(before, after)
937 if not latex:
938 lines.append([varstr, valstr, units, label])
939 if isvector and len(flatval) > ncols:
940 values_remaining = len(flatval) - ncols
941 while values_remaining > 0:
942 idx = len(flatval)-values_remaining
943 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
944 values_remaining -= ncols
945 valstr = " " + " ".join(vals)
946 for before, after in VALSTR_REPLACES:
947 valstr = valstr.replace(before, after)
948 if values_remaining <= 0:
949 spaces = (-values_remaining
950 * len(valstr)//(values_remaining + ncols))
951 valstr = valstr + " ]" + " "*spaces
952 lines.append(["", valstr, "", ""])
953 else:
954 varstr = "$%s$" % varstr.replace(" : ", "")
955 if latex == 1: # normal results table
956 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
957 label])
958 coltitles = [title, "Value", "Units", "Description"]
959 elif latex == 2: # no values
960 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
961 coltitles = [title, "Units", "Description"]
962 elif latex == 3: # no description
963 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
964 coltitles = [title, "Value", "Units"]
965 else:
966 raise ValueError("Unexpected latex option, %s." % latex)
967 if rawlines:
968 return lines
969 if not latex:
970 if lines:
971 maxlens = np.max([list(map(len, line)) for line in lines
972 if line[0] != ("newmodelline",)], axis=0)
973 dirs = [">", "<", "<", "<"]
974 # check lengths before using zip
975 assert len(list(dirs)) == len(list(maxlens))
976 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
977 for i, line in enumerate(lines):
978 if line[0] == ("newmodelline",):
979 line = [fmts[0].format(" | "), line[1]]
980 else:
981 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
982 lines[i] = "".join(line).rstrip()
983 lines = [title] + ["-"*len(title)] + lines + [""]
984 else:
985 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
986 lines = (["\n".join(["{\\footnotesize",
987 "\\begin{longtable}{%s}" % colfmt[latex],
988 "\\toprule",
989 " & ".join(coltitles) + " \\\\ \\midrule"])] +
990 [" & ".join(l) + " \\\\" for l in lines] +
991 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
992 return lines