Coverage for gpkit\solution_array.py: 0%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import sys
3import re
4import json
5import difflib
6from operator import sub
7import warnings as pywarnings
8import pickle
9import gzip
10import pickletools
11from collections import defaultdict
12import numpy as np
13from .nomials import NomialArray
14from .small_classes import DictOfLists, Strings, SolverLog
15from .small_scripts import mag, try_str_without
16from .repr_conventions import unitstr, lineagestr
17from .breakdowns import Breakdowns
20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
22VALSTR_REPLACES = [
23 ("+nan", " nan"),
24 ("-nan", " nan"),
25 ("nan%", "nan "),
26 ("nan", " - "),
27]
30class SolSavingEnvironment:
31 """Temporarily removes construction/solve attributes from constraints.
33 This approximately halves the size of the pickled solution.
34 """
36 def __init__(self, solarray, saveconstraints):
37 self.solarray = solarray
38 self.attrstore = {}
39 self.saveconstraints = saveconstraints
40 self.constraintstore = None
43 def __enter__(self):
44 if self.saveconstraints:
45 for constraint_attr in ["bounded", "meq_bounded", "vks",
46 "v_ss", "unsubbed", "varkeys"]:
47 store = {}
48 for constraint in self.solarray["sensitivities"]["constraints"]:
49 if getattr(constraint, constraint_attr, None):
50 store[constraint] = getattr(constraint, constraint_attr)
51 delattr(constraint, constraint_attr)
52 self.attrstore[constraint_attr] = store
53 else:
54 self.constraintstore = \
55 self.solarray["sensitivities"].pop("constraints")
57 def __exit__(self, type_, val, traceback):
58 if self.saveconstraints:
59 for constraint_attr, store in self.attrstore.items():
60 for constraint, value in store.items():
61 setattr(constraint, constraint_attr, value)
62 else:
63 self.solarray["sensitivities"]["constraints"] = self.constraintstore
65def msenss_table(data, _, **kwargs):
66 "Returns model sensitivity table lines"
67 if "models" not in data.get("sensitivities", {}):
68 return ""
69 data = sorted(data["sensitivities"]["models"].items(),
70 key=lambda i: ((i[1] < 0.1).all(),
71 -np.max(i[1]) if (i[1] < 0.1).all()
72 else -round(np.mean(i[1]), 1), i[0]))
73 lines = ["Model Sensitivities", "-------------------"]
74 if kwargs["sortmodelsbysenss"]:
75 lines[0] += " (sorts models in sections below)"
76 previousmsenssstr = ""
77 for model, msenss in data:
78 if not model: # for now let's only do named models
79 continue
80 if (msenss < 0.1).all():
81 msenss = np.max(msenss)
82 if msenss:
83 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss)))
84 else:
85 msenssstr = " =0 "
86 else:
87 meansenss = round(np.mean(msenss), 1)
88 msenssstr = "%+6.1f" % meansenss
89 deltas = msenss - meansenss
90 if np.max(np.abs(deltas)) > 0.1:
91 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
92 for d in deltas]
93 msenssstr += " + [ %s ]" % " ".join(deltastrs)
94 if msenssstr == previousmsenssstr:
95 msenssstr = " "*len(msenssstr)
96 else:
97 previousmsenssstr = msenssstr
98 lines.append("%s : %s" % (msenssstr, model))
99 return lines + [""] if len(lines) > 3 else []
102def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
103 "Returns sensitivity table lines"
104 if "variables" in data.get("sensitivities", {}):
105 data = data["sensitivities"]["variables"]
106 if showvars:
107 data = {k: data[k] for k in showvars if k in data}
108 return var_table(data, title, sortbyvals=True, skipifempty=True,
109 valfmt="%+-.2g ", vecfmt="%+-8.2g",
110 printunits=False, minval=1e-3, **kwargs)
113def topsenss_table(data, showvars, nvars=5, **kwargs):
114 "Returns top sensitivity table lines"
115 data, filtered = topsenss_filter(data, showvars, nvars)
116 title = "Most Sensitive Variables"
117 if filtered:
118 title = "Next Most Sensitive Variables"
119 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
122def topsenss_filter(data, showvars, nvars=5):
123 "Filters sensitivities down to top N vars"
124 if "variables" in data.get("sensitivities", {}):
125 data = data["sensitivities"]["variables"]
126 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
127 if not np.isnan(s).any()}
128 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
129 filter_already_shown = showvars.intersection(topk)
130 for k in filter_already_shown:
131 topk.remove(k)
132 if nvars > 3: # always show at least 3
133 nvars -= 1
134 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
137def insenss_table(data, _, maxval=0.1, **kwargs):
138 "Returns insensitivity table lines"
139 if "constants" in data.get("sensitivities", {}):
140 data = data["sensitivities"]["variables"]
141 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
142 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
145def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
146 "Return constraint tightness lines"
147 title = "Most Sensitive Constraints"
148 if len(self) > 1:
149 title += " (in last sweep)"
150 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)),
151 "%+6.2g" % abs(s[-1]), id(c), c)
152 for c, s in self["sensitivities"]["constraints"].items()
153 if s[-1] >= tight_senss)[:ntightconstrs]
154 else:
155 data = sorted(((-float("%+6.2g" % abs(s)), str(c)),
156 "%+6.2g" % abs(s), id(c), c)
157 for c, s in self["sensitivities"]["constraints"].items()
158 if s >= tight_senss)[:ntightconstrs]
159 return constraint_table(data, title, **kwargs)
161def loose_table(self, _, min_senss=1e-5, **kwargs):
162 "Return constraint tightness lines"
163 title = "Insensitive Constraints |below %+g|" % min_senss
164 if len(self) > 1:
165 title += " (in last sweep)"
166 data = [(0, "", id(c), c)
167 for c, s in self["sensitivities"]["constraints"].items()
168 if s[-1] <= min_senss]
169 else:
170 data = [(0, "", id(c), c)
171 for c, s in self["sensitivities"]["constraints"].items()
172 if s <= min_senss]
173 return constraint_table(data, title, **kwargs)
176# pylint: disable=too-many-branches,too-many-locals,too-many-statements
177def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
178 "Creates lines for tables where the right side is a constraint."
179 # TODO: this should support 1D array inputs from sweeps
180 excluded = {"units"} if showmodels else {"units", "lineage"}
181 models, decorated = {}, []
182 for sortby, openingstr, _, constraint in sorted(data):
183 model = lineagestr(constraint) if sortbymodel else ""
184 if model not in models:
185 models[model] = len(models)
186 constrstr = try_str_without(
187 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded))
188 if " at 0x" in constrstr: # don't print memory addresses
189 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
190 decorated.append((models[model], model, sortby, constrstr, openingstr))
191 decorated.sort()
192 previous_model, lines = None, []
193 for varlist in decorated:
194 _, model, _, constrstr, openingstr = varlist
195 if model != previous_model:
196 if lines:
197 lines.append(["", ""])
198 if model or lines:
199 lines.append([("newmodelline",), model])
200 previous_model = model
201 minlen, maxlen = 25, 80
202 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
203 constraintlines = []
204 line = ""
205 next_idx = 0
206 while next_idx < len(segments):
207 segment = segments[next_idx]
208 next_idx += 1
209 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
210 segments[next_idx] = segment[1:] + segments[next_idx]
211 segment = segment[0]
212 elif len(line) + len(segment) > maxlen and len(line) > minlen:
213 constraintlines.append(line)
214 line = " " # start a new line
215 line += segment
216 while len(line) > maxlen:
217 constraintlines.append(line[:maxlen])
218 line = " " + line[maxlen:]
219 constraintlines.append(line)
220 lines += [(openingstr + " : ", constraintlines[0])]
221 lines += [("", l) for l in constraintlines[1:]]
222 if not lines:
223 lines = [("", "(none)")]
224 maxlens = np.max([list(map(len, line)) for line in lines
225 if line[0] != ("newmodelline",)], axis=0)
226 dirs = [">", "<"] # we'll check lengths before using zip
227 assert len(list(dirs)) == len(list(maxlens))
228 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
229 for i, line in enumerate(lines):
230 if line[0] == ("newmodelline",):
231 linelist = [fmts[0].format(" | "), line[1]]
232 else:
233 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
234 lines[i] = "".join(linelist).rstrip()
235 return [title] + ["-"*len(title)] + lines + [""]
238def warnings_table(self, _, **kwargs):
239 "Makes a table for all warnings in the solution."
240 title = "WARNINGS"
241 lines = ["~"*len(title), title, "~"*len(title)]
242 if "warnings" not in self or not self["warnings"]:
243 return []
244 for wtype in sorted(self["warnings"]):
245 data_vec = self["warnings"][wtype]
246 if len(data_vec) == 0:
247 continue
248 if not hasattr(data_vec, "shape"):
249 data_vec = [data_vec] # not a sweep
250 else:
251 all_equal = True
252 for data in data_vec[1:]:
253 eq_i = (data == data_vec[0])
254 if hasattr(eq_i, "all"):
255 eq_i = eq_i.all()
256 if not eq_i:
257 all_equal = False
258 break
259 if all_equal:
260 data_vec = [data_vec[0]] # warnings identical across sweeps
261 for i, data in enumerate(data_vec):
262 if len(data) == 0:
263 continue
264 data = sorted(data, key=lambda l: l[0]) # sort by msg
265 title = wtype
266 if len(data_vec) > 1:
267 title += " in sweep %i" % i
268 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
269 data = [(-int(1e5*relax_sensitivity),
270 "%+6.2g" % relax_sensitivity, id(c), c)
271 for _, (relax_sensitivity, c) in data]
272 lines += constraint_table(data, title, **kwargs)
273 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
274 data = [(-int(1e5*rel_diff),
275 "%.4g %s %.4g" % tightvalues, id(c), c)
276 for _, (rel_diff, tightvalues, c) in data]
277 lines += constraint_table(data, title, **kwargs)
278 else:
279 lines += [title] + ["-"*len(wtype)]
280 lines += [msg for msg, _ in data] + [""]
281 if len(lines) == 3: # just the header
282 return []
283 lines[-1] = "~~~~~~~~"
284 return lines + [""]
286# TODO: deduplicate these two functions
287def bdtable_gen(key):
288 "Generator for breakdown tablefns"
290 def bdtable(self, _showvars, **_):
291 "Cost breakdown plot"
292 bds = Breakdowns(self)
293 original_stdout = sys.stdout
294 try:
295 sys.stdout = SolverLog(original_stdout, verbosity=0)
296 bds.plot(key)
297 finally:
298 lines = sys.stdout.lines()
299 sys.stdout = original_stdout
300 return lines
302 return bdtable
305TABLEFNS = {"sensitivities": senss_table,
306 "top sensitivities": topsenss_table,
307 "insensitivities": insenss_table,
308 "model sensitivities": msenss_table,
309 "tightest constraints": tight_table,
310 "loose constraints": loose_table,
311 "warnings": warnings_table,
312 "model sensitivities breakdown": bdtable_gen("model sensitivities"),
313 "cost breakdown": bdtable_gen("cost")
314 }
316def unrolled_absmax(values):
317 "From an iterable of numbers and arrays, returns the largest magnitude"
318 finalval, absmaxest = None, 0
319 for val in values:
320 absmaxval = np.abs(val).max()
321 if absmaxval >= absmaxest:
322 absmaxest, finalval = absmaxval, val
323 if getattr(finalval, "shape", None):
324 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
325 finalval.shape)]
326 return finalval
329def cast(function, val1, val2):
330 "Relative difference between val1 and val2 (positive if val2 is larger)"
331 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
332 pywarnings.simplefilter("ignore")
333 if hasattr(val1, "shape") and hasattr(val2, "shape"):
334 if val1.ndim == val2.ndim:
335 return function(val1, val2)
336 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
337 dimdelta = dimmest.ndim - lessdim.ndim
338 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
339 if dimmest is val1:
340 return function(dimmest, lessdim[add_axes])
341 if dimmest is val2:
342 return function(lessdim[add_axes], dimmest)
343 return function(val1, val2)
346class SolutionArray(DictOfLists):
347 """A dictionary (of dictionaries) of lists, with convenience methods.
349 Items
350 -----
351 cost : array
352 variables: dict of arrays
353 sensitivities: dict containing:
354 monomials : array
355 posynomials : array
356 variables: dict of arrays
357 localmodels : NomialArray
358 Local power-law fits (small sensitivities are cut off)
360 Example
361 -------
362 >>> import gpkit
363 >>> import numpy as np
364 >>> x = gpkit.Variable("x")
365 >>> x_min = gpkit.Variable("x_{min}", 2)
366 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
367 >>>
368 >>> # VALUES
369 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
370 >>> assert all(np.array(values) == 2)
371 >>>
372 >>> # SENSITIVITIES
373 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
374 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
375 >>> assert all(np.array(senss) == 1)
376 """
377 modelstr = ""
378 _name_collision_varkeys = None
379 _lineageset = False
380 table_titles = {"choicevariables": "Choice Variables",
381 "sweepvariables": "Swept Variables",
382 "freevariables": "Free Variables",
383 "constants": "Fixed Variables", # TODO: change everywhere
384 "variables": "Variables"}
386 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches
387 "Returns the set of contained varkeys whose names are not unique"
388 if self._name_collision_varkeys is None:
389 self._name_collision_varkeys = {}
390 self["variables"].update_keymap()
391 keymap = self["variables"].keymap
392 name_collisions = defaultdict(set)
393 for key in keymap:
394 if hasattr(key, "key"):
395 if len(keymap[key.name]) == 1: # very unique
396 self._name_collision_varkeys[key] = 0
397 else:
398 shortname = key.str_without(["lineage", "vec"])
399 if len(keymap[shortname]) > 1:
400 name_collisions[shortname].add(key)
401 for varkeys in name_collisions.values():
402 min_namespaced = defaultdict(set)
403 for vk in varkeys:
404 *_, mineage = vk.lineagestr().split(".")
405 min_namespaced[(mineage, 1)].add(vk)
406 while any(len(vks) > 1 for vks in min_namespaced.values()):
407 for key, vks in list(min_namespaced.items()):
408 if len(vks) <= 1:
409 continue
410 del min_namespaced[key]
411 mineage, idx = key
412 idx += 1
413 for vk in vks:
414 lineages = vk.lineagestr().split(".")
415 submineage = lineages[-idx] + "." + mineage
416 min_namespaced[(submineage, idx)].add(vk)
417 for (_, idx), vks in min_namespaced.items():
418 vk, = vks
419 self._name_collision_varkeys[vk] = idx
420 if clear:
421 self._lineageset = False
422 for vk in self._name_collision_varkeys:
423 del vk.descr["necessarylineage"]
424 else:
425 self._lineageset = True
426 for vk, idx in self._name_collision_varkeys.items():
427 vk.descr["necessarylineage"] = idx
429 def __len__(self):
430 try:
431 return len(self["cost"])
432 except TypeError:
433 return 1
434 except KeyError:
435 return 0
437 def __call__(self, posy):
438 posy_subbed = self.subinto(posy)
439 return getattr(posy_subbed, "c", posy_subbed)
441 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
442 "Checks for almost-equality between two solutions"
443 svars, ovars = self["variables"], other["variables"]
444 svks, ovks = set(svars), set(ovars)
445 if svks != ovks:
446 return False
447 for key in svks:
448 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
449 return False
450 if abs(self["sensitivities"]["variables"][key]
451 - other["sensitivities"]["variables"][key]) >= sens_abstol:
452 return False
453 return True
455 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
456 def diff(self, other, showvars=None, *,
457 constraintsdiff=True, senssdiff=False, sensstol=0.1,
458 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
459 sortmodelsbysenss=True, **tableargs):
460 """Outputs differences between this solution and another
462 Arguments
463 ---------
464 other : solution or string
465 strings will be treated as paths to pickled solutions
466 senssdiff : boolean
467 if True, show sensitivity differences
468 sensstol : float
469 the smallest sensitivity difference worth showing
470 absdiff : boolean
471 if True, show absolute differences
472 abstol : float
473 the smallest absolute difference worth showing
474 reldiff : boolean
475 if True, show relative differences
476 reltol : float
477 the smallest relative difference worth showing
479 Returns
480 -------
481 str
482 """
483 if sortmodelsbysenss:
484 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
485 else:
486 tableargs["sortmodelsbysenss"] = False
487 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
488 "skipifempty": False})
489 if isinstance(other, Strings):
490 if other[-4:] == ".pgz":
491 other = SolutionArray.decompress_file(other)
492 else:
493 other = pickle.load(open(other, "rb"))
494 svars, ovars = self["variables"], other["variables"]
495 lines = ["Solution Diff",
496 "=============",
497 "(argument is the baseline solution)", ""]
498 svks, ovks = set(svars), set(ovars)
499 if showvars:
500 lines[0] += " (for selected variables)"
501 lines[1] += "========================="
502 showvars = self._parse_showvars(showvars)
503 svks = {k for k in showvars if k in svars}
504 ovks = {k for k in showvars if k in ovars}
505 if constraintsdiff and other.modelstr and self.modelstr:
506 if self.modelstr == other.modelstr:
507 lines += ["** no constraint differences **", ""]
508 else:
509 cdiff = ["Constraint Differences",
510 "**********************"]
511 cdiff.extend(list(difflib.unified_diff(
512 other.modelstr.split("\n"), self.modelstr.split("\n"),
513 lineterm="", n=3))[2:])
514 cdiff += ["", "**********************", ""]
515 lines += cdiff
516 if svks - ovks:
517 lines.append("Variable(s) of this solution"
518 " which are not in the argument:")
519 lines.append("\n".join(" %s" % key for key in svks - ovks))
520 lines.append("")
521 if ovks - svks:
522 lines.append("Variable(s) of the argument"
523 " which are not in this solution:")
524 lines.append("\n".join(" %s" % key for key in ovks - svks))
525 lines.append("")
526 sharedvks = svks.intersection(ovks)
527 if reldiff:
528 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
529 for vk in sharedvks}
530 lines += var_table(rel_diff,
531 "Relative Differences |above %g%%|" % reltol,
532 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
533 minval=reltol, printunits=False, **tableargs)
534 if lines[-2][:10] == "-"*10: # nothing larger than reltol
535 lines.insert(-1, ("The largest is %+g%%."
536 % unrolled_absmax(rel_diff.values())))
537 if absdiff:
538 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
539 lines += var_table(abs_diff,
540 "Absolute Differences |above %g|" % abstol,
541 valfmt="%+.2g", vecfmt="%+8.2g",
542 minval=abstol, **tableargs)
543 if lines[-2][:10] == "-"*10: # nothing larger than abstol
544 lines.insert(-1, ("The largest is %+g."
545 % unrolled_absmax(abs_diff.values())))
546 if senssdiff:
547 ssenss = self["sensitivities"]["variables"]
548 osenss = other["sensitivities"]["variables"]
549 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
550 for vk in svks.intersection(ovks)}
551 lines += var_table(senss_delta,
552 "Sensitivity Differences |above %g|" % sensstol,
553 valfmt="%+-.2f ", vecfmt="%+-6.2f",
554 minval=sensstol, printunits=False, **tableargs)
555 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
556 lines.insert(-1, ("The largest is %+g."
557 % unrolled_absmax(senss_delta.values())))
558 return "\n".join(lines)
560 def save(self, filename="solution.pkl",
561 *, saveconstraints=True, **pickleargs):
562 """Pickles the solution and saves it to a file.
564 Solution can then be loaded with e.g.:
565 >>> import pickle
566 >>> pickle.load(open("solution.pkl"))
567 """
568 with SolSavingEnvironment(self, saveconstraints):
569 pickle.dump(self, open(filename, "wb"), **pickleargs)
571 def save_compressed(self, filename="solution.pgz",
572 *, saveconstraints=True, **cpickleargs):
573 "Pickle a file and then compress it into a file with extension."
574 with gzip.open(filename, "wb") as f:
575 with SolSavingEnvironment(self, saveconstraints):
576 pickled = pickle.dumps(self, **cpickleargs)
577 f.write(pickletools.optimize(pickled))
579 @staticmethod
580 def decompress_file(file):
581 "Load a gzip-compressed pickle file"
582 with gzip.open(file, "rb") as f:
583 return pickle.Unpickler(f).load()
585 def varnames(self, showvars, exclude):
586 "Returns list of variables, optionally with minimal unique names"
587 if showvars:
588 showvars = self._parse_showvars(showvars)
589 self.set_necessarylineage()
590 names = {}
591 for key in showvars or self["variables"]:
592 for k in self["variables"].keymap[key]:
593 names[k.str_without(exclude)] = k
594 self.set_necessarylineage(clear=True)
595 return names
597 def savemat(self, filename="solution.mat", *, showvars=None,
598 excluded=("vec")):
599 "Saves primal solution as matlab file"
600 from scipy.io import savemat
601 savemat(filename,
602 {name.replace(".", "_"): np.array(self["variables"][key], "f")
603 for name, key in self.varnames(showvars, excluded).items()})
605 def todataframe(self, showvars=None, excluded=("vec")):
606 "Returns primal solution as pandas dataframe"
607 import pandas as pd # pylint:disable=import-error
608 rows = []
609 cols = ["Name", "Index", "Value", "Units", "Label",
610 "Lineage", "Other"]
611 for _, key in sorted(self.varnames(showvars, excluded).items(),
612 key=lambda k: k[0]):
613 value = self["variables"][key]
614 if key.shape:
615 idxs = []
616 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
617 while not it.finished:
618 idx = it.multi_index
619 idxs.append(idx[0] if len(idx) == 1 else idx)
620 it.iternext()
621 else:
622 idxs = [None]
623 for idx in idxs:
624 row = [
625 key.name,
626 "" if idx is None else idx,
627 value if idx is None else value[idx]]
628 rows.append(row)
629 row.extend([
630 key.unitstr(),
631 key.label or "",
632 key.lineage or "",
633 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
634 if k not in ["name", "units", "unitrepr",
635 "idx", "shape", "veckey",
636 "value", "vecfn",
637 "lineage", "label"])])
638 return pd.DataFrame(rows, columns=cols)
640 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
641 "Saves solution table as a text file"
642 with open(filename, "w") as f:
643 if printmodel:
644 f.write(self.modelstr + "\n")
645 f.write(self.table(**kwargs))
647 def savejson(self, filename="solution.json", showvars=None):
648 "Saves solution table as a json file"
649 sol_dict = {}
650 if self._lineageset:
651 self.set_necessarylineage(clear=True)
652 data = self["variables"]
653 if showvars:
654 showvars = self._parse_showvars(showvars)
655 data = {k: data[k] for k in showvars if k in data}
656 # add appropriate data for each variable to the dictionary
657 for k, v in data.items():
658 key = str(k)
659 if isinstance(v, np.ndarray):
660 val = {"v": v.tolist(), "u": k.unitstr()}
661 else:
662 val = {"v": v, "u": k.unitstr()}
663 sol_dict[key] = val
664 with open(filename, "w") as f:
665 json.dump(sol_dict, f)
667 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
668 "Saves primal solution as a CSV sorted by modelname, like the tables."
669 data = self["variables"]
670 if showvars:
671 showvars = self._parse_showvars(showvars)
672 data = {k: data[k] for k in showvars if k in data}
673 # if the columns don't capture any dimensions, skip them
674 minspan, maxspan = None, 1
675 for v in data.values():
676 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
677 minspan_ = min((di for di in v.shape if di != 1))
678 maxspan_ = max((di for di in v.shape if di != 1))
679 if minspan is None or minspan_ < minspan:
680 minspan = minspan_
681 if maxspan is None or maxspan_ > maxspan:
682 maxspan = maxspan_
683 if minspan is not None and minspan > valcols:
684 valcols = 1
685 if maxspan < valcols:
686 valcols = maxspan
687 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
688 tables=("cost", "sweepvariables", "freevariables",
689 "constants", "sensitivities"))
690 with open(filename, "w") as f:
691 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
692 + "Units,Description\n")
693 for line in lines:
694 if line[0] == ("newmodelline",):
695 f.write(line[1])
696 elif not line[1]: # spacer line
697 f.write("\n")
698 else:
699 f.write("," + line[0].replace(" : ", "") + ",")
700 vals = line[1].replace("[", "").replace("]", "").strip()
701 for el in vals.split():
702 f.write(el + ",")
703 f.write(","*(valcols - len(vals.split())))
704 f.write((line[2].replace("[", "").replace("]", "").strip()
705 + ","))
706 f.write(line[3].strip() + "\n")
708 def subinto(self, posy):
709 "Returns NomialArray of each solution substituted into posy."
710 if posy in self["variables"]:
711 return self["variables"](posy)
713 if not hasattr(posy, "sub"):
714 raise ValueError("no variable '%s' found in the solution" % posy)
716 if len(self) > 1:
717 return NomialArray([self.atindex(i).subinto(posy)
718 for i in range(len(self))])
720 return posy.sub(self["variables"], require_positive=False)
722 def _parse_showvars(self, showvars):
723 showvars_out = set()
724 for k in showvars:
725 k, _ = self["variables"].parse_and_index(k)
726 keys = self["variables"].keymap[k]
727 showvars_out.update(keys)
728 return showvars_out
730 def summary(self, showvars=(), **kwargs):
731 "Print summary table, showing no sensitivities or constants"
732 return self.table(showvars,
733 ["cost breakdown", "model sensitivities breakdown",
734 "warnings", "sweepvariables", "freevariables"],
735 **kwargs)
737 def table(self, showvars=(),
738 tables=("cost breakdown", "model sensitivities breakdown",
739 "warnings", "sweepvariables", "freevariables",
740 "constants", "sensitivities", "tightest constraints"),
741 sortmodelsbysenss=False, **kwargs):
742 """A table representation of this SolutionArray
744 Arguments
745 ---------
746 tables: Iterable
747 Which to print of ("cost", "sweepvariables", "freevariables",
748 "constants", "sensitivities")
749 fixedcols: If true, print vectors in fixed-width format
750 latex: int
751 If > 0, return latex format (options 1-3); otherwise plain text
752 included_models: Iterable of strings
753 If specified, the models (by name) to include
754 excluded_models: Iterable of strings
755 If specified, model names to exclude
757 Returns
758 -------
759 str
760 """
761 if sortmodelsbysenss and "sensitivities" in self:
762 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
763 else:
764 kwargs["sortmodelsbysenss"] = False
765 varlist = list(self["variables"])
766 has_only_one_model = True
767 for var in varlist[1:]:
768 if var.lineage != varlist[0].lineage:
769 has_only_one_model = False
770 break
771 if has_only_one_model:
772 kwargs["sortbymodel"] = False
773 self.set_necessarylineage()
774 showvars = self._parse_showvars(showvars)
775 strs = []
776 for table in tables:
777 if len(self) > 1 and "breakdown" in table:
778 # no breakdowns for sweeps
779 table = table.replace(" breakdown", "")
780 if "sensitivities" not in self and ("sensitivities" in table or
781 "constraints" in table):
782 continue
783 if table == "cost":
784 cost = self["cost"] # pylint: disable=unsubscriptable-object
785 if kwargs.get("latex", None): # cost is not printed for latex
786 continue
787 strs += ["\n%s\n------------" % "Optimal Cost"]
788 if len(self) > 1:
789 costs = ["%-8.3g" % c for c in mag(cost[:4])]
790 strs += [" [ %s %s ]" % (" ".join(costs),
791 "..." if len(self) > 4 else "")]
792 else:
793 strs += [" %-.4g" % mag(cost)]
794 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
795 strs += [""]
796 elif table in TABLEFNS:
797 strs += TABLEFNS[table](self, showvars, **kwargs)
798 elif table in self:
799 data = self[table]
800 if showvars:
801 showvars = self._parse_showvars(showvars)
802 data = {k: data[k] for k in showvars if k in data}
803 strs += var_table(data, self.table_titles[table], **kwargs)
804 if kwargs.get("latex", None):
805 preamble = "\n".join(("% \\documentclass[12pt]{article}",
806 "% \\usepackage{booktabs}",
807 "% \\usepackage{longtable}",
808 "% \\usepackage{amsmath}",
809 "% \\begin{document}\n"))
810 strs = [preamble] + strs + ["% \\end{document}"]
811 self.set_necessarylineage(clear=True)
812 return "\n".join(strs)
814 def plot(self, posys=None, axes=None):
815 "Plots a sweep for each posy"
816 if len(self["sweepvariables"]) != 1:
817 print("SolutionArray.plot only supports 1-dimensional sweeps")
818 if not hasattr(posys, "__len__"):
819 posys = [posys]
820 import matplotlib.pyplot as plt
821 from .interactive.plot_sweep import assign_axes
822 from . import GPBLU
823 (swept, x), = self["sweepvariables"].items()
824 posys, axes = assign_axes(swept, posys, axes)
825 for posy, ax in zip(posys, axes):
826 y = self(posy) if posy not in [None, "cost"] else self["cost"]
827 ax.plot(x, y, color=GPBLU)
828 if len(axes) == 1:
829 axes, = axes
830 return plt.gcf(), axes
833# pylint: disable=too-many-branches,too-many-locals,too-many-statements
834def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
835 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
836 minval=0, sortbyvals=False, hidebelowminval=False,
837 included_models=None, excluded_models=None, sortbymodel=True,
838 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
839 """
840 Pretty string representation of a dict of VarKeys
841 Iterable values are handled specially (partial printing)
843 Arguments
844 ---------
845 data : dict whose keys are VarKey's
846 data to represent in table
847 title : string
848 printunits : bool
849 latex : int
850 If > 0, return latex format (options 1-3); otherwise plain text
851 varfmt : string
852 format for variable names
853 valfmt : string
854 format for scalar values
855 vecfmt : string
856 format for vector values
857 minval : float
858 skip values with all(abs(value)) < minval
859 sortbyvals : boolean
860 If true, rows are sorted by their average value instead of by name.
861 included_models : Iterable of strings
862 If specified, the models (by name) to include
863 excluded_models : Iterable of strings
864 If specified, model names to exclude
865 """
866 if not data:
867 return []
868 decorated, models = [], set()
869 for i, (k, v) in enumerate(data.items()):
870 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
871 continue # no values below minval
872 if minval and hidebelowminval and getattr(v, "shape", None):
873 v[np.abs(v) <= minval] = np.nan
874 model = lineagestr(k.lineage) if sortbymodel else ""
875 if not sortmodelsbysenss:
876 msenss = 0
877 else: # sort should match that in msenss_table above
878 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4)
879 models.add(model)
880 b = bool(getattr(v, "shape", None))
881 s = k.str_without(("lineage", "vec"))
882 if not sortbyvals:
883 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
884 else: # for consistent sorting, add small offset to negative vals
885 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
886 sort = (float("%.4g" % -val), k.name)
887 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
888 if not decorated and skipifempty:
889 return []
890 if included_models:
891 included_models = set(included_models)
892 included_models.add("")
893 models = models.intersection(included_models)
894 if excluded_models:
895 models = models.difference(excluded_models)
896 decorated.sort()
897 previous_model, lines = None, []
898 for varlist in decorated:
899 if sortbyvals:
900 model, _, msenss, isvector, varstr, _, var, val = varlist
901 else:
902 msenss, model, isvector, varstr, _, var, val = varlist
903 if model not in models:
904 continue
905 if model != previous_model:
906 if lines:
907 lines.append(["", "", "", ""])
908 if model:
909 if not latex:
910 lines.append([("newmodelline",), model, "", ""])
911 else:
912 lines.append(
913 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
914 previous_model = model
915 label = var.descr.get("label", "")
916 units = var.unitstr(" [%s] ") if printunits else ""
917 if not isvector:
918 valstr = valfmt % val
919 else:
920 last_dim_index = len(val.shape)-1
921 horiz_dim, ncols = last_dim_index, 1 # starting values
922 for dim_idx, dim_size in enumerate(val.shape):
923 if ncols <= dim_size <= maxcolumns:
924 horiz_dim, ncols = dim_idx, dim_size
925 # align the array with horiz_dim by making it the last one
926 dim_order = list(range(last_dim_index))
927 dim_order.insert(horiz_dim, last_dim_index)
928 flatval = val.transpose(dim_order).flatten()
929 vals = [vecfmt % v for v in flatval[:ncols]]
930 bracket = " ] " if len(flatval) <= ncols else ""
931 valstr = "[ %s%s" % (" ".join(vals), bracket)
932 for before, after in VALSTR_REPLACES:
933 valstr = valstr.replace(before, after)
934 if not latex:
935 lines.append([varstr, valstr, units, label])
936 if isvector and len(flatval) > ncols:
937 values_remaining = len(flatval) - ncols
938 while values_remaining > 0:
939 idx = len(flatval)-values_remaining
940 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
941 values_remaining -= ncols
942 valstr = " " + " ".join(vals)
943 for before, after in VALSTR_REPLACES:
944 valstr = valstr.replace(before, after)
945 if values_remaining <= 0:
946 spaces = (-values_remaining
947 * len(valstr)//(values_remaining + ncols))
948 valstr = valstr + " ]" + " "*spaces
949 lines.append(["", valstr, "", ""])
950 else:
951 varstr = "$%s$" % varstr.replace(" : ", "")
952 if latex == 1: # normal results table
953 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
954 label])
955 coltitles = [title, "Value", "Units", "Description"]
956 elif latex == 2: # no values
957 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
958 coltitles = [title, "Units", "Description"]
959 elif latex == 3: # no description
960 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
961 coltitles = [title, "Value", "Units"]
962 else:
963 raise ValueError("Unexpected latex option, %s." % latex)
964 if rawlines:
965 return lines
966 if not latex:
967 if lines:
968 maxlens = np.max([list(map(len, line)) for line in lines
969 if line[0] != ("newmodelline",)], axis=0)
970 dirs = [">", "<", "<", "<"]
971 # check lengths before using zip
972 assert len(list(dirs)) == len(list(maxlens))
973 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
974 for i, line in enumerate(lines):
975 if line[0] == ("newmodelline",):
976 line = [fmts[0].format(" | "), line[1]]
977 else:
978 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
979 lines[i] = "".join(line).rstrip()
980 lines = [title] + ["-"*len(title)] + lines + [""]
981 else:
982 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
983 lines = (["\n".join(["{\\footnotesize",
984 "\\begin{longtable}{%s}" % colfmt[latex],
985 "\\toprule",
986 " & ".join(coltitles) + " \\\\ \\midrule"])] +
987 [" & ".join(l) + " \\\\" for l in lines] +
988 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
989 return lines