Coverage for gpkit/solution_array.py: 81%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import sys
3import re
4import json
5import difflib
6from operator import sub
7import warnings as pywarnings
8import pickle
9import gzip
10import pickletools
11from collections import defaultdict
12import numpy as np
13from .nomials import NomialArray
14from .small_classes import DictOfLists, Strings, SolverLog
15from .small_scripts import mag, try_str_without
16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS
17from .breakdowns import Breakdowns
20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
22VALSTR_REPLACES = [
23 ("+nan", " nan"),
24 ("-nan", " nan"),
25 ("nan%", "nan "),
26 ("nan", " - "),
27]
30class SolSavingEnvironment:
31 """Temporarily removes construction/solve attributes from constraints.
33 This approximately halves the size of the pickled solution.
34 """
36 def __init__(self, solarray, saveconstraints):
37 self.solarray = solarray
38 self.attrstore = {}
39 self.saveconstraints = saveconstraints
40 self.constraintstore = None
43 def __enter__(self):
44 if self.saveconstraints:
45 for constraint_attr in ["bounded", "meq_bounded", "vks",
46 "v_ss", "unsubbed", "varkeys"]:
47 store = {}
48 for constraint in self.solarray["sensitivities"]["constraints"]:
49 if getattr(constraint, constraint_attr, None):
50 store[constraint] = getattr(constraint, constraint_attr)
51 delattr(constraint, constraint_attr)
52 self.attrstore[constraint_attr] = store
53 else:
54 self.constraintstore = \
55 self.solarray["sensitivities"].pop("constraints")
57 def __exit__(self, type_, val, traceback):
58 if self.saveconstraints:
59 for constraint_attr, store in self.attrstore.items():
60 for constraint, value in store.items():
61 setattr(constraint, constraint_attr, value)
62 else:
63 self.solarray["sensitivities"]["constraints"] = self.constraintstore
65def msenss_table(data, _, **kwargs):
66 "Returns model sensitivity table lines"
67 if "models" not in data.get("sensitivities", {}):
68 return ""
69 data = sorted(data["sensitivities"]["models"].items(),
70 key=lambda i: ((i[1] < 0.1).all(),
71 -np.max(i[1]) if (i[1] < 0.1).all()
72 else -round(np.mean(i[1]), 1), i[0]))
73 lines = ["Model Sensitivities", "-------------------"]
74 if kwargs["sortmodelsbysenss"]:
75 lines[0] += " (sorts models in sections below)"
76 previousmsenssstr = ""
77 for model, msenss in data:
78 if not model: # for now let's only do named models
79 continue
80 if (msenss < 0.1).all():
81 msenss = np.max(msenss)
82 if msenss:
83 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss)))
84 else:
85 msenssstr = " =0 "
86 else:
87 meansenss = round(np.mean(msenss), 1)
88 msenssstr = "%+6.1f" % meansenss
89 deltas = msenss - meansenss
90 if np.max(np.abs(deltas)) > 0.1:
91 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
92 for d in deltas]
93 msenssstr += " + [ %s ]" % " ".join(deltastrs)
94 if msenssstr == previousmsenssstr:
95 msenssstr = " "*len(msenssstr)
96 else:
97 previousmsenssstr = msenssstr
98 lines.append("%s : %s" % (msenssstr, model))
99 return lines + [""] if len(lines) > 3 else []
102def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
103 "Returns sensitivity table lines"
104 if "variables" in data.get("sensitivities", {}):
105 data = data["sensitivities"]["variables"]
106 if showvars:
107 data = {k: data[k] for k in showvars if k in data}
108 return var_table(data, title, sortbyvals=True, skipifempty=True,
109 valfmt="%+-.2g ", vecfmt="%+-8.2g",
110 printunits=False, minval=1e-3, **kwargs)
113def topsenss_table(data, showvars, nvars=5, **kwargs):
114 "Returns top sensitivity table lines"
115 data, filtered = topsenss_filter(data, showvars, nvars)
116 title = "Most Sensitive Variables"
117 if filtered:
118 title = "Next Most Sensitive Variables"
119 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
122def topsenss_filter(data, showvars, nvars=5):
123 "Filters sensitivities down to top N vars"
124 if "variables" in data.get("sensitivities", {}):
125 data = data["sensitivities"]["variables"]
126 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
127 if not np.isnan(s).any()}
128 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
129 filter_already_shown = showvars.intersection(topk)
130 for k in filter_already_shown:
131 topk.remove(k)
132 if nvars > 3: # always show at least 3
133 nvars -= 1
134 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
137def insenss_table(data, _, maxval=0.1, **kwargs):
138 "Returns insensitivity table lines"
139 if "constants" in data.get("sensitivities", {}):
140 data = data["sensitivities"]["variables"]
141 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
142 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
145def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
146 "Return constraint tightness lines"
147 title = "Most Sensitive Constraints"
148 if len(self) > 1:
149 title += " (in last sweep)"
150 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)),
151 "%+6.2g" % abs(s[-1]), id(c), c)
152 for c, s in self["sensitivities"]["constraints"].items()
153 if s[-1] >= tight_senss)[:ntightconstrs]
154 else:
155 data = sorted(((-float("%+6.2g" % abs(s)), str(c)),
156 "%+6.2g" % abs(s), id(c), c)
157 for c, s in self["sensitivities"]["constraints"].items()
158 if s >= tight_senss)[:ntightconstrs]
159 return constraint_table(data, title, **kwargs)
161def loose_table(self, _, min_senss=1e-5, **kwargs):
162 "Return constraint tightness lines"
163 title = "Insensitive Constraints |below %+g|" % min_senss
164 if len(self) > 1:
165 title += " (in last sweep)"
166 data = [(0, "", id(c), c)
167 for c, s in self["sensitivities"]["constraints"].items()
168 if s[-1] <= min_senss]
169 else:
170 data = [(0, "", id(c), c)
171 for c, s in self["sensitivities"]["constraints"].items()
172 if s <= min_senss]
173 return constraint_table(data, title, **kwargs)
176# pylint: disable=too-many-branches,too-many-locals,too-many-statements
177def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
178 "Creates lines for tables where the right side is a constraint."
179 # TODO: this should support 1D array inputs from sweeps
180 excluded = {"units"} if showmodels else {"units", "lineage"}
181 models, decorated = {}, []
182 for sortby, openingstr, _, constraint in sorted(data):
183 model = lineagestr(constraint) if sortbymodel else ""
184 if model not in models:
185 models[model] = len(models)
186 constrstr = try_str_without(
187 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded))
188 if " at 0x" in constrstr: # don't print memory addresses
189 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
190 decorated.append((models[model], model, sortby, constrstr, openingstr))
191 decorated.sort()
192 previous_model, lines = None, []
193 for varlist in decorated:
194 _, model, _, constrstr, openingstr = varlist
195 if model != previous_model:
196 if lines:
197 lines.append(["", ""])
198 if model or lines:
199 lines.append([("newmodelline",), model])
200 previous_model = model
201 minlen, maxlen = 25, 80
202 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
203 constraintlines = []
204 line = ""
205 next_idx = 0
206 while next_idx < len(segments):
207 segment = segments[next_idx]
208 next_idx += 1
209 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
210 segments[next_idx] = segment[1:] + segments[next_idx]
211 segment = segment[0]
212 elif len(line) + len(segment) > maxlen and len(line) > minlen:
213 constraintlines.append(line)
214 line = " " # start a new line
215 line += segment
216 while len(line) > maxlen:
217 constraintlines.append(line[:maxlen])
218 line = " " + line[maxlen:]
219 constraintlines.append(line)
220 lines += [(openingstr + " : ", constraintlines[0])]
221 lines += [("", l) for l in constraintlines[1:]]
222 if not lines:
223 lines = [("", "(none)")]
224 maxlens = np.max([list(map(len, line)) for line in lines
225 if line[0] != ("newmodelline",)], axis=0)
226 dirs = [">", "<"] # we'll check lengths before using zip
227 assert len(list(dirs)) == len(list(maxlens))
228 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
229 for i, line in enumerate(lines):
230 if line[0] == ("newmodelline",):
231 linelist = [fmts[0].format(" | "), line[1]]
232 else:
233 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
234 lines[i] = "".join(linelist).rstrip()
235 return [title] + ["-"*len(title)] + lines + [""]
238def warnings_table(self, _, **kwargs):
239 "Makes a table for all warnings in the solution."
240 title = "WARNINGS"
241 lines = ["~"*len(title), title, "~"*len(title)]
242 if "warnings" not in self or not self["warnings"]:
243 return []
244 for wtype in sorted(self["warnings"]):
245 data_vec = self["warnings"][wtype]
246 if len(data_vec) == 0:
247 continue
248 if not hasattr(data_vec, "shape"):
249 data_vec = [data_vec] # not a sweep
250 else:
251 all_equal = True
252 for data in data_vec[1:]:
253 eq_i = (data == data_vec[0])
254 if hasattr(eq_i, "all"):
255 eq_i = eq_i.all()
256 if not eq_i:
257 all_equal = False
258 break
259 if all_equal:
260 data_vec = [data_vec[0]] # warnings identical across sweeps
261 for i, data in enumerate(data_vec):
262 if len(data) == 0:
263 continue
264 data = sorted(data, key=lambda l: l[0]) # sort by msg
265 title = wtype
266 if len(data_vec) > 1:
267 title += " in sweep %i" % i
268 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
269 data = [(-int(1e5*relax_sensitivity),
270 "%+6.2g" % relax_sensitivity, id(c), c)
271 for _, (relax_sensitivity, c) in data]
272 lines += constraint_table(data, title, **kwargs)
273 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
274 data = [(-int(1e5*rel_diff),
275 "%.4g %s %.4g" % tightvalues, id(c), c)
276 for _, (rel_diff, tightvalues, c) in data]
277 lines += constraint_table(data, title, **kwargs)
278 else:
279 lines += [title] + ["-"*len(wtype)]
280 lines += [msg for msg, _ in data] + [""]
281 if len(lines) == 3: # just the header
282 return []
283 lines[-1] = "~~~~~~~~"
284 return lines + [""]
286def bdtable_gen(key):
287 "Generator for breakdown tablefns"
289 def bdtable(self, _showvars, **_):
290 "Cost breakdown plot"
291 bds = Breakdowns(self)
292 original_stdout = sys.stdout
293 try:
294 sys.stdout = SolverLog(original_stdout, verbosity=0)
295 bds.plot(key)
296 finally:
297 lines = sys.stdout.lines()
298 sys.stdout = original_stdout
299 return lines
301 return bdtable
304TABLEFNS = {"sensitivities": senss_table,
305 "top sensitivities": topsenss_table,
306 "insensitivities": insenss_table,
307 "model sensitivities": msenss_table,
308 "tightest constraints": tight_table,
309 "loose constraints": loose_table,
310 "warnings": warnings_table,
311 "model sensitivities breakdown": bdtable_gen("model sensitivities"),
312 "cost breakdown": bdtable_gen("cost")
313 }
315def unrolled_absmax(values):
316 "From an iterable of numbers and arrays, returns the largest magnitude"
317 finalval, absmaxest = None, 0
318 for val in values:
319 absmaxval = np.abs(val).max()
320 if absmaxval >= absmaxest:
321 absmaxest, finalval = absmaxval, val
322 if getattr(finalval, "shape", None):
323 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
324 finalval.shape)]
325 return finalval
328def cast(function, val1, val2):
329 "Relative difference between val1 and val2 (positive if val2 is larger)"
330 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
331 pywarnings.simplefilter("ignore")
332 if hasattr(val1, "shape") and hasattr(val2, "shape"):
333 if val1.ndim == val2.ndim:
334 return function(val1, val2)
335 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
336 dimdelta = dimmest.ndim - lessdim.ndim
337 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
338 if dimmest is val1:
339 return function(dimmest, lessdim[add_axes])
340 if dimmest is val2:
341 return function(lessdim[add_axes], dimmest)
342 return function(val1, val2)
345class SolutionArray(DictOfLists):
346 """A dictionary (of dictionaries) of lists, with convenience methods.
348 Items
349 -----
350 cost : array
351 variables: dict of arrays
352 sensitivities: dict containing:
353 monomials : array
354 posynomials : array
355 variables: dict of arrays
356 localmodels : NomialArray
357 Local power-law fits (small sensitivities are cut off)
359 Example
360 -------
361 >>> import gpkit
362 >>> import numpy as np
363 >>> x = gpkit.Variable("x")
364 >>> x_min = gpkit.Variable("x_{min}", 2)
365 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
366 >>>
367 >>> # VALUES
368 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
369 >>> assert all(np.array(values) == 2)
370 >>>
371 >>> # SENSITIVITIES
372 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
373 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
374 >>> assert all(np.array(senss) == 1)
375 """
376 modelstr = ""
377 _name_collision_varkeys = None
378 _lineageset = False
379 table_titles = {"choicevariables": "Choice Variables",
380 "sweepvariables": "Swept Variables",
381 "freevariables": "Free Variables",
382 "constants": "Fixed Variables", # TODO: change everywhere
383 "variables": "Variables"}
385 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches
386 "Returns the set of contained varkeys whose names are not unique"
387 if self._name_collision_varkeys is None:
388 self._name_collision_varkeys = {}
389 self["variables"].update_keymap()
390 keymap = self["variables"].keymap
391 name_collisions = defaultdict(set)
392 for key in keymap:
393 if hasattr(key, "key"):
394 if len(keymap[key.name]) == 1: # very unique
395 self._name_collision_varkeys[key] = 0
396 else:
397 shortname = key.str_without(["lineage", "vec"])
398 if len(keymap[shortname]) > 1:
399 name_collisions[shortname].add(key)
400 for varkeys in name_collisions.values():
401 min_namespaced = defaultdict(set)
402 for vk in varkeys:
403 *_, mineage = vk.lineagestr().split(".")
404 min_namespaced[(mineage, 1)].add(vk)
405 while any(len(vks) > 1 for vks in min_namespaced.values()):
406 for key, vks in list(min_namespaced.items()):
407 if len(vks) <= 1:
408 continue
409 del min_namespaced[key]
410 mineage, idx = key
411 idx += 1
412 for vk in vks:
413 lineages = vk.lineagestr().split(".")
414 submineage = lineages[-idx] + "." + mineage
415 min_namespaced[(submineage, idx)].add(vk)
416 for (_, idx), vks in min_namespaced.items():
417 vk, = vks
418 self._name_collision_varkeys[vk] = idx
419 if clear:
420 self._lineageset = False
421 for vk in self._name_collision_varkeys:
422 del vk.descr["necessarylineage"]
423 else:
424 self._lineageset = True
425 for vk, idx in self._name_collision_varkeys.items():
426 vk.descr["necessarylineage"] = idx
428 def __len__(self):
429 try:
430 return len(self["cost"])
431 except TypeError:
432 return 1
433 except KeyError:
434 return 0
436 def __call__(self, posy):
437 posy_subbed = self.subinto(posy)
438 return getattr(posy_subbed, "c", posy_subbed)
440 def almost_equal(self, other, reltol=1e-3):
441 "Checks for almost-equality between two solutions"
442 svars, ovars = self["variables"], other["variables"]
443 svks, ovks = set(svars), set(ovars)
444 if svks != ovks:
445 return False
446 for key in svks:
447 reldiff = np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1))
448 if reldiff >= reltol:
449 return False
450 return True
452 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
453 def diff(self, other, showvars=None, *,
454 constraintsdiff=True, senssdiff=False, sensstol=0.1,
455 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
456 sortmodelsbysenss=True, **tableargs):
457 """Outputs differences between this solution and another
459 Arguments
460 ---------
461 other : solution or string
462 strings will be treated as paths to pickled solutions
463 senssdiff : boolean
464 if True, show sensitivity differences
465 sensstol : float
466 the smallest sensitivity difference worth showing
467 absdiff : boolean
468 if True, show absolute differences
469 abstol : float
470 the smallest absolute difference worth showing
471 reldiff : boolean
472 if True, show relative differences
473 reltol : float
474 the smallest relative difference worth showing
476 Returns
477 -------
478 str
479 """
480 if sortmodelsbysenss:
481 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
482 else:
483 tableargs["sortmodelsbysenss"] = False
484 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
485 "skipifempty": False})
486 if isinstance(other, Strings):
487 if other[-4:] == ".pgz":
488 other = SolutionArray.decompress_file(other)
489 else:
490 other = pickle.load(open(other, "rb"))
491 svars, ovars = self["variables"], other["variables"]
492 lines = ["Solution Diff",
493 "=============",
494 "(argument is the baseline solution)", ""]
495 svks, ovks = set(svars), set(ovars)
496 if showvars:
497 lines[0] += " (for selected variables)"
498 lines[1] += "========================="
499 showvars = self._parse_showvars(showvars)
500 svks = {k for k in showvars if k in svars}
501 ovks = {k for k in showvars if k in ovars}
502 if constraintsdiff and other.modelstr and self.modelstr:
503 if self.modelstr == other.modelstr:
504 lines += ["** no constraint differences **", ""]
505 else:
506 cdiff = ["Constraint Differences",
507 "**********************"]
508 cdiff.extend(list(difflib.unified_diff(
509 other.modelstr.split("\n"), self.modelstr.split("\n"),
510 lineterm="", n=3))[2:])
511 cdiff += ["", "**********************", ""]
512 lines += cdiff
513 if svks - ovks:
514 lines.append("Variable(s) of this solution"
515 " which are not in the argument:")
516 lines.append("\n".join(" %s" % key for key in svks - ovks))
517 lines.append("")
518 if ovks - svks:
519 lines.append("Variable(s) of the argument"
520 " which are not in this solution:")
521 lines.append("\n".join(" %s" % key for key in ovks - svks))
522 lines.append("")
523 sharedvks = svks.intersection(ovks)
524 if reldiff:
525 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
526 for vk in sharedvks}
527 lines += var_table(rel_diff,
528 "Relative Differences |above %g%%|" % reltol,
529 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
530 minval=reltol, printunits=False, **tableargs)
531 if lines[-2][:10] == "-"*10: # nothing larger than reltol
532 lines.insert(-1, ("The largest is %+g%%."
533 % unrolled_absmax(rel_diff.values())))
534 if absdiff:
535 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
536 lines += var_table(abs_diff,
537 "Absolute Differences |above %g|" % abstol,
538 valfmt="%+.2g", vecfmt="%+8.2g",
539 minval=abstol, **tableargs)
540 if lines[-2][:10] == "-"*10: # nothing larger than abstol
541 lines.insert(-1, ("The largest is %+g."
542 % unrolled_absmax(abs_diff.values())))
543 if senssdiff:
544 ssenss = self["sensitivities"]["variables"]
545 osenss = other["sensitivities"]["variables"]
546 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
547 for vk in svks.intersection(ovks)}
548 lines += var_table(senss_delta,
549 "Sensitivity Differences |above %g|" % sensstol,
550 valfmt="%+-.2f ", vecfmt="%+-6.2f",
551 minval=sensstol, printunits=False, **tableargs)
552 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
553 lines.insert(-1, ("The largest is %+g."
554 % unrolled_absmax(senss_delta.values())))
555 return "\n".join(lines)
557 def save(self, filename="solution.pkl",
558 *, saveconstraints=True, **pickleargs):
559 """Pickles the solution and saves it to a file.
561 Solution can then be loaded with e.g.:
562 >>> import pickle
563 >>> pickle.load(open("solution.pkl"))
564 """
565 with SolSavingEnvironment(self, saveconstraints):
566 pickle.dump(self, open(filename, "wb"), **pickleargs)
568 def save_compressed(self, filename="solution.pgz",
569 *, saveconstraints=True, **cpickleargs):
570 "Pickle a file and then compress it into a file with extension."
571 with gzip.open(filename, "wb") as f:
572 with SolSavingEnvironment(self, saveconstraints):
573 pickled = pickle.dumps(self, **cpickleargs)
574 f.write(pickletools.optimize(pickled))
576 @staticmethod
577 def decompress_file(file):
578 "Load a gzip-compressed pickle file"
579 with gzip.open(file, "rb") as f:
580 return pickle.Unpickler(f).load()
582 def varnames(self, showvars, exclude):
583 "Returns list of variables, optionally with minimal unique names"
584 if showvars:
585 showvars = self._parse_showvars(showvars)
586 self.set_necessarylineage()
587 names = {}
588 for key in showvars or self["variables"]:
589 for k in self["variables"].keymap[key]:
590 names[k.str_without(exclude)] = k
591 self.set_necessarylineage(clear=True)
592 return names
594 def savemat(self, filename="solution.mat", *, showvars=None,
595 excluded=("vec")):
596 "Saves primal solution as matlab file"
597 from scipy.io import savemat
598 savemat(filename,
599 {name.replace(".", "_"): np.array(self["variables"][key], "f")
600 for name, key in self.varnames(showvars, excluded).items()})
602 def todataframe(self, showvars=None, excluded=("vec")):
603 "Returns primal solution as pandas dataframe"
604 import pandas as pd # pylint:disable=import-error
605 rows = []
606 cols = ["Name", "Index", "Value", "Units", "Label",
607 "Lineage", "Other"]
608 for _, key in sorted(self.varnames(showvars, excluded).items(),
609 key=lambda k: k[0]):
610 value = self["variables"][key]
611 if key.shape:
612 idxs = []
613 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
614 while not it.finished:
615 idx = it.multi_index
616 idxs.append(idx[0] if len(idx) == 1 else idx)
617 it.iternext()
618 else:
619 idxs = [None]
620 for idx in idxs:
621 row = [
622 key.name,
623 "" if idx is None else idx,
624 value if idx is None else value[idx]]
625 rows.append(row)
626 row.extend([
627 key.unitstr(),
628 key.label or "",
629 key.lineage or "",
630 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
631 if k not in ["name", "units", "unitrepr",
632 "idx", "shape", "veckey",
633 "value", "vecfn",
634 "lineage", "label"])])
635 return pd.DataFrame(rows, columns=cols)
637 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
638 "Saves solution table as a text file"
639 with open(filename, "w") as f:
640 if printmodel:
641 f.write(self.modelstr + "\n")
642 f.write(self.table(**kwargs))
644 def savejson(self, filename="solution.json", showvars=None):
645 "Saves solution table as a json file"
646 sol_dict = {}
647 if self._lineageset:
648 self.set_necessarylineage(clear=True)
649 data = self["variables"]
650 if showvars:
651 showvars = self._parse_showvars(showvars)
652 data = {k: data[k] for k in showvars if k in data}
653 # add appropriate data for each variable to the dictionary
654 for k, v in data.items():
655 key = str(k)
656 if isinstance(v, np.ndarray):
657 val = {"v": v.tolist(), "u": k.unitstr()}
658 else:
659 val = {"v": v, "u": k.unitstr()}
660 sol_dict[key] = val
661 with open(filename, "w") as f:
662 json.dump(sol_dict, f)
664 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
665 "Saves primal solution as a CSV sorted by modelname, like the tables."
666 data = self["variables"]
667 if showvars:
668 showvars = self._parse_showvars(showvars)
669 data = {k: data[k] for k in showvars if k in data}
670 # if the columns don't capture any dimensions, skip them
671 minspan, maxspan = None, 1
672 for v in data.values():
673 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
674 minspan_ = min((di for di in v.shape if di != 1))
675 maxspan_ = max((di for di in v.shape if di != 1))
676 if minspan is None or minspan_ < minspan:
677 minspan = minspan_
678 if maxspan is None or maxspan_ > maxspan:
679 maxspan = maxspan_
680 if minspan is not None and minspan > valcols:
681 valcols = 1
682 if maxspan < valcols:
683 valcols = maxspan
684 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
685 tables=("cost", "sweepvariables", "freevariables",
686 "constants", "sensitivities"))
687 with open(filename, "w") as f:
688 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
689 + "Units,Description\n")
690 for line in lines:
691 if line[0] == ("newmodelline",):
692 f.write(line[1])
693 elif not line[1]: # spacer line
694 f.write("\n")
695 else:
696 f.write("," + line[0].replace(" : ", "") + ",")
697 vals = line[1].replace("[", "").replace("]", "").strip()
698 for el in vals.split():
699 f.write(el + ",")
700 f.write(","*(valcols - len(vals.split())))
701 f.write((line[2].replace("[", "").replace("]", "").strip()
702 + ","))
703 f.write(line[3].strip() + "\n")
705 def subinto(self, posy):
706 "Returns NomialArray of each solution substituted into posy."
707 if posy in self["variables"]:
708 return self["variables"](posy)
710 if not hasattr(posy, "sub"):
711 raise ValueError("no variable '%s' found in the solution" % posy)
713 if len(self) > 1:
714 return NomialArray([self.atindex(i).subinto(posy)
715 for i in range(len(self))])
717 return posy.sub(self["variables"], require_positive=False)
719 def _parse_showvars(self, showvars):
720 showvars_out = set()
721 for k in showvars:
722 k, _ = self["variables"].parse_and_index(k)
723 keys = self["variables"].keymap[k]
724 showvars_out.update(keys)
725 return showvars_out
727 def summary(self, showvars=(), **kwargs):
728 "Print summary table, showing no sensitivities or constants"
729 return self.table(showvars,
730 ["cost breakdown", "model sensitivities breakdown",
731 "warnings", "sweepvariables", "freevariables"],
732 **kwargs)
734 def table(self, showvars=(),
735 tables=("cost breakdown", "model sensitivities breakdown",
736 "warnings", "sweepvariables", "freevariables",
737 "constants", "sensitivities", "tightest constraints"),
738 sortmodelsbysenss=False, **kwargs):
739 """A table representation of this SolutionArray
741 Arguments
742 ---------
743 tables: Iterable
744 Which to print of ("cost", "sweepvariables", "freevariables",
745 "constants", "sensitivities")
746 fixedcols: If true, print vectors in fixed-width format
747 latex: int
748 If > 0, return latex format (options 1-3); otherwise plain text
749 included_models: Iterable of strings
750 If specified, the models (by name) to include
751 excluded_models: Iterable of strings
752 If specified, model names to exclude
754 Returns
755 -------
756 str
757 """
758 if sortmodelsbysenss and "sensitivities" in self:
759 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
760 else:
761 kwargs["sortmodelsbysenss"] = False
762 varlist = list(self["variables"])
763 has_only_one_model = True
764 for var in varlist[1:]:
765 if var.lineage != varlist[0].lineage:
766 has_only_one_model = False
767 break
768 if has_only_one_model:
769 kwargs["sortbymodel"] = False
770 self.set_necessarylineage()
771 showvars = self._parse_showvars(showvars)
772 strs = []
773 for table in tables:
774 if "breakdown" in table:
775 if len(self) > 1 or not UNICODE_EXPONENTS:
776 # no breakdowns for sweeps or no-unicode environments
777 table = table.replace(" breakdown", "")
778 if "sensitivities" not in self and ("sensitivities" in table or
779 "constraints" in table):
780 continue
781 if table == "cost":
782 cost = self["cost"] # pylint: disable=unsubscriptable-object
783 if kwargs.get("latex", None): # cost is not printed for latex
784 continue
785 strs += ["\n%s\n------------" % "Optimal Cost"]
786 if len(self) > 1:
787 costs = ["%-8.3g" % c for c in mag(cost[:4])]
788 strs += [" [ %s %s ]" % (" ".join(costs),
789 "..." if len(self) > 4 else "")]
790 else:
791 strs += [" %-.4g" % mag(cost)]
792 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
793 strs += [""]
794 elif table in TABLEFNS:
795 strs += TABLEFNS[table](self, showvars, **kwargs)
796 elif table in self:
797 data = self[table]
798 if showvars:
799 showvars = self._parse_showvars(showvars)
800 data = {k: data[k] for k in showvars if k in data}
801 strs += var_table(data, self.table_titles[table], **kwargs)
802 if kwargs.get("latex", None):
803 preamble = "\n".join(("% \\documentclass[12pt]{article}",
804 "% \\usepackage{booktabs}",
805 "% \\usepackage{longtable}",
806 "% \\usepackage{amsmath}",
807 "% \\begin{document}\n"))
808 strs = [preamble] + strs + ["% \\end{document}"]
809 self.set_necessarylineage(clear=True)
810 return "\n".join(strs)
812 def plot(self, posys=None, axes=None):
813 "Plots a sweep for each posy"
814 if len(self["sweepvariables"]) != 1:
815 print("SolutionArray.plot only supports 1-dimensional sweeps")
816 if not hasattr(posys, "__len__"):
817 posys = [posys]
818 import matplotlib.pyplot as plt
819 from .interactive.plot_sweep import assign_axes
820 from . import GPBLU
821 (swept, x), = self["sweepvariables"].items()
822 posys, axes = assign_axes(swept, posys, axes)
823 for posy, ax in zip(posys, axes):
824 y = self(posy) if posy not in [None, "cost"] else self["cost"]
825 ax.plot(x, y, color=GPBLU)
826 if len(axes) == 1:
827 axes, = axes
828 return plt.gcf(), axes
831# pylint: disable=too-many-branches,too-many-locals,too-many-statements
832def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
833 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
834 minval=0, sortbyvals=False, hidebelowminval=False,
835 included_models=None, excluded_models=None, sortbymodel=True,
836 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
837 """
838 Pretty string representation of a dict of VarKeys
839 Iterable values are handled specially (partial printing)
841 Arguments
842 ---------
843 data : dict whose keys are VarKey's
844 data to represent in table
845 title : string
846 printunits : bool
847 latex : int
848 If > 0, return latex format (options 1-3); otherwise plain text
849 varfmt : string
850 format for variable names
851 valfmt : string
852 format for scalar values
853 vecfmt : string
854 format for vector values
855 minval : float
856 skip values with all(abs(value)) < minval
857 sortbyvals : boolean
858 If true, rows are sorted by their average value instead of by name.
859 included_models : Iterable of strings
860 If specified, the models (by name) to include
861 excluded_models : Iterable of strings
862 If specified, model names to exclude
863 """
864 if not data:
865 return []
866 decorated, models = [], set()
867 for i, (k, v) in enumerate(data.items()):
868 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
869 continue # no values below minval
870 if minval and hidebelowminval and getattr(v, "shape", None):
871 v[np.abs(v) <= minval] = np.nan
872 model = lineagestr(k.lineage) if sortbymodel else ""
873 if not sortmodelsbysenss:
874 msenss = 0
875 else: # sort should match that in msenss_table above
876 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4)
877 models.add(model)
878 b = bool(getattr(v, "shape", None))
879 s = k.str_without(("lineage", "vec"))
880 if not sortbyvals:
881 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
882 else: # for consistent sorting, add small offset to negative vals
883 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
884 sort = (float("%.4g" % -val), k.name)
885 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
886 if not decorated and skipifempty:
887 return []
888 if included_models:
889 included_models = set(included_models)
890 included_models.add("")
891 models = models.intersection(included_models)
892 if excluded_models:
893 models = models.difference(excluded_models)
894 decorated.sort()
895 previous_model, lines = None, []
896 for varlist in decorated:
897 if sortbyvals:
898 model, _, msenss, isvector, varstr, _, var, val = varlist
899 else:
900 msenss, model, isvector, varstr, _, var, val = varlist
901 if model not in models:
902 continue
903 if model != previous_model:
904 if lines:
905 lines.append(["", "", "", ""])
906 if model:
907 if not latex:
908 lines.append([("newmodelline",), model, "", ""])
909 else:
910 lines.append(
911 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
912 previous_model = model
913 label = var.descr.get("label", "")
914 units = var.unitstr(" [%s] ") if printunits else ""
915 if not isvector:
916 valstr = valfmt % val
917 else:
918 last_dim_index = len(val.shape)-1
919 horiz_dim, ncols = last_dim_index, 1 # starting values
920 for dim_idx, dim_size in enumerate(val.shape):
921 if ncols <= dim_size <= maxcolumns:
922 horiz_dim, ncols = dim_idx, dim_size
923 # align the array with horiz_dim by making it the last one
924 dim_order = list(range(last_dim_index))
925 dim_order.insert(horiz_dim, last_dim_index)
926 flatval = val.transpose(dim_order).flatten()
927 vals = [vecfmt % v for v in flatval[:ncols]]
928 bracket = " ] " if len(flatval) <= ncols else ""
929 valstr = "[ %s%s" % (" ".join(vals), bracket)
930 for before, after in VALSTR_REPLACES:
931 valstr = valstr.replace(before, after)
932 if not latex:
933 lines.append([varstr, valstr, units, label])
934 if isvector and len(flatval) > ncols:
935 values_remaining = len(flatval) - ncols
936 while values_remaining > 0:
937 idx = len(flatval)-values_remaining
938 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
939 values_remaining -= ncols
940 valstr = " " + " ".join(vals)
941 for before, after in VALSTR_REPLACES:
942 valstr = valstr.replace(before, after)
943 if values_remaining <= 0:
944 spaces = (-values_remaining
945 * len(valstr)//(values_remaining + ncols))
946 valstr = valstr + " ]" + " "*spaces
947 lines.append(["", valstr, "", ""])
948 else:
949 varstr = "$%s$" % varstr.replace(" : ", "")
950 if latex == 1: # normal results table
951 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
952 label])
953 coltitles = [title, "Value", "Units", "Description"]
954 elif latex == 2: # no values
955 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
956 coltitles = [title, "Units", "Description"]
957 elif latex == 3: # no description
958 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
959 coltitles = [title, "Value", "Units"]
960 else:
961 raise ValueError("Unexpected latex option, %s." % latex)
962 if rawlines:
963 return lines
964 if not latex:
965 if lines:
966 maxlens = np.max([list(map(len, line)) for line in lines
967 if line[0] != ("newmodelline",)], axis=0)
968 dirs = [">", "<", "<", "<"]
969 # check lengths before using zip
970 assert len(list(dirs)) == len(list(maxlens))
971 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
972 for i, line in enumerate(lines):
973 if line[0] == ("newmodelline",):
974 line = [fmts[0].format(" | "), line[1]]
975 else:
976 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
977 lines[i] = "".join(line).rstrip()
978 lines = [title] + ["-"*len(title)] + lines + [""]
979 else:
980 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
981 lines = (["\n".join(["{\\footnotesize",
982 "\\begin{longtable}{%s}" % colfmt[latex],
983 "\\toprule",
984 " & ".join(coltitles) + " \\\\ \\midrule"])] +
985 [" & ".join(l) + " \\\\" for l in lines] +
986 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
987 return lines