Coverage for gpkit\solution_array.py: 0%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import sys
3import re
4import json
5import difflib
6from operator import sub
7import warnings as pywarnings
8import pickle
9import gzip
10import pickletools
11from collections import defaultdict
12import numpy as np
13from .nomials import NomialArray
14from .small_classes import DictOfLists, Strings, SolverLog
15from .small_scripts import mag, try_str_without
16from .repr_conventions import unitstr, lineagestr
17from .breakdowns import Breakdowns
20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
22VALSTR_REPLACES = [
23 ("+nan", " nan"),
24 ("-nan", " nan"),
25 ("nan%", "nan "),
26 ("nan", " - "),
27]
30class SolSavingEnvironment:
31 """Temporarily removes construction/solve attributes from constraints.
33 This approximately halves the size of the pickled solution.
34 """
36 def __init__(self, solarray, saveconstraints):
37 self.solarray = solarray
38 self.attrstore = {}
39 self.saveconstraints = saveconstraints
40 self.constraintstore = None
43 def __enter__(self):
44 if self.saveconstraints:
45 for constraint_attr in ["bounded", "meq_bounded", "vks",
46 "v_ss", "unsubbed", "varkeys"]:
47 store = {}
48 for constraint in self.solarray["sensitivities"]["constraints"]:
49 if getattr(constraint, constraint_attr, None):
50 store[constraint] = getattr(constraint, constraint_attr)
51 delattr(constraint, constraint_attr)
52 self.attrstore[constraint_attr] = store
53 else:
54 self.constraintstore = \
55 self.solarray["sensitivities"].pop("constraints")
57 def __exit__(self, type_, val, traceback):
58 if self.saveconstraints:
59 for constraint_attr, store in self.attrstore.items():
60 for constraint, value in store.items():
61 setattr(constraint, constraint_attr, value)
62 else:
63 self.solarray["sensitivities"]["constraints"] = self.constraintstore
65def msenss_table(data, _, **kwargs):
66 "Returns model sensitivity table lines"
67 if "models" not in data.get("sensitivities", {}):
68 return ""
69 data = sorted(data["sensitivities"]["models"].items(),
70 key=lambda i: ((i[1] < 0.1).all(),
71 -np.max(i[1]) if (i[1] < 0.1).all()
72 else -round(np.mean(i[1]), 1), i[0]))
73 lines = ["Model Sensitivities", "-------------------"]
74 if kwargs["sortmodelsbysenss"]:
75 lines[0] += " (sorts models in sections below)"
76 previousmsenssstr = ""
77 for model, msenss in data:
78 if not model: # for now let's only do named models
79 continue
80 if (msenss < 0.1).all():
81 msenss = np.max(msenss)
82 if msenss:
83 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss)))
84 else:
85 msenssstr = " =0 "
86 else:
87 meansenss = round(np.mean(msenss), 1)
88 msenssstr = "%+6.1f" % meansenss
89 deltas = msenss - meansenss
90 if np.max(np.abs(deltas)) > 0.1:
91 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
92 for d in deltas]
93 msenssstr += " + [ %s ]" % " ".join(deltastrs)
94 if msenssstr == previousmsenssstr:
95 msenssstr = " "*len(msenssstr)
96 else:
97 previousmsenssstr = msenssstr
98 lines.append("%s : %s" % (msenssstr, model))
99 return lines + [""] if len(lines) > 3 else []
102def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
103 "Returns sensitivity table lines"
104 if "variables" in data.get("sensitivities", {}):
105 data = data["sensitivities"]["variables"]
106 if showvars:
107 data = {k: data[k] for k in showvars if k in data}
108 return var_table(data, title, sortbyvals=True, skipifempty=True,
109 valfmt="%+-.2g ", vecfmt="%+-8.2g",
110 printunits=False, minval=1e-3, **kwargs)
113def topsenss_table(data, showvars, nvars=5, **kwargs):
114 "Returns top sensitivity table lines"
115 data, filtered = topsenss_filter(data, showvars, nvars)
116 title = "Most Sensitive Variables"
117 if filtered:
118 title = "Next Most Sensitive Variables"
119 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
122def topsenss_filter(data, showvars, nvars=5):
123 "Filters sensitivities down to top N vars"
124 if "variables" in data.get("sensitivities", {}):
125 data = data["sensitivities"]["variables"]
126 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
127 if not np.isnan(s).any()}
128 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
129 filter_already_shown = showvars.intersection(topk)
130 for k in filter_already_shown:
131 topk.remove(k)
132 if nvars > 3: # always show at least 3
133 nvars -= 1
134 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
137def insenss_table(data, _, maxval=0.1, **kwargs):
138 "Returns insensitivity table lines"
139 if "constants" in data.get("sensitivities", {}):
140 data = data["sensitivities"]["variables"]
141 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
142 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
145def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
146 "Return constraint tightness lines"
147 title = "Most Sensitive Constraints"
148 if len(self) > 1:
149 title += " (in last sweep)"
150 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)),
151 "%+6.2g" % abs(s[-1]), id(c), c)
152 for c, s in self["sensitivities"]["constraints"].items()
153 if s[-1] >= tight_senss)[:ntightconstrs]
154 else:
155 data = sorted(((-float("%+6.2g" % abs(s)), str(c)),
156 "%+6.2g" % abs(s), id(c), c)
157 for c, s in self["sensitivities"]["constraints"].items()
158 if s >= tight_senss)[:ntightconstrs]
159 return constraint_table(data, title, **kwargs)
161def loose_table(self, _, min_senss=1e-5, **kwargs):
162 "Return constraint tightness lines"
163 title = "Insensitive Constraints |below %+g|" % min_senss
164 if len(self) > 1:
165 title += " (in last sweep)"
166 data = [(0, "", id(c), c)
167 for c, s in self["sensitivities"]["constraints"].items()
168 if s[-1] <= min_senss]
169 else:
170 data = [(0, "", id(c), c)
171 for c, s in self["sensitivities"]["constraints"].items()
172 if s <= min_senss]
173 return constraint_table(data, title, **kwargs)
176# pylint: disable=too-many-branches,too-many-locals,too-many-statements
177def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
178 "Creates lines for tables where the right side is a constraint."
179 # TODO: this should support 1D array inputs from sweeps
180 excluded = {"units"} if showmodels else {"units", "lineage"}
181 models, decorated = {}, []
182 for sortby, openingstr, _, constraint in sorted(data):
183 model = lineagestr(constraint) if sortbymodel else ""
184 if model not in models:
185 models[model] = len(models)
186 constrstr = try_str_without(
187 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded))
188 if " at 0x" in constrstr: # don't print memory addresses
189 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
190 decorated.append((models[model], model, sortby, constrstr, openingstr))
191 decorated.sort()
192 previous_model, lines = None, []
193 for varlist in decorated:
194 _, model, _, constrstr, openingstr = varlist
195 if model != previous_model:
196 if lines:
197 lines.append(["", ""])
198 if model or lines:
199 lines.append([("newmodelline",), model])
200 previous_model = model
201 minlen, maxlen = 25, 80
202 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
203 constraintlines = []
204 line = ""
205 next_idx = 0
206 while next_idx < len(segments):
207 segment = segments[next_idx]
208 next_idx += 1
209 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
210 segments[next_idx] = segment[1:] + segments[next_idx]
211 segment = segment[0]
212 elif len(line) + len(segment) > maxlen and len(line) > minlen:
213 constraintlines.append(line)
214 line = " " # start a new line
215 line += segment
216 while len(line) > maxlen:
217 constraintlines.append(line[:maxlen])
218 line = " " + line[maxlen:]
219 constraintlines.append(line)
220 lines += [(openingstr + " : ", constraintlines[0])]
221 lines += [("", l) for l in constraintlines[1:]]
222 if not lines:
223 lines = [("", "(none)")]
224 maxlens = np.max([list(map(len, line)) for line in lines
225 if line[0] != ("newmodelline",)], axis=0)
226 dirs = [">", "<"] # we'll check lengths before using zip
227 assert len(list(dirs)) == len(list(maxlens))
228 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
229 for i, line in enumerate(lines):
230 if line[0] == ("newmodelline",):
231 linelist = [fmts[0].format(" | "), line[1]]
232 else:
233 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
234 lines[i] = "".join(linelist).rstrip()
235 return [title] + ["-"*len(title)] + lines + [""]
238def warnings_table(self, _, **kwargs):
239 "Makes a table for all warnings in the solution."
240 title = "WARNINGS"
241 lines = ["~"*len(title), title, "~"*len(title)]
242 if "warnings" not in self or not self["warnings"]:
243 return []
244 for wtype in sorted(self["warnings"]):
245 data_vec = self["warnings"][wtype]
246 if len(data_vec) == 0:
247 continue
248 if not hasattr(data_vec, "shape"):
249 data_vec = [data_vec] # not a sweep
250 else:
251 all_equal = True
252 for data in data_vec[1:]:
253 eq_i = (data == data_vec[0])
254 if hasattr(eq_i, "all"):
255 eq_i = eq_i.all()
256 if not eq_i:
257 all_equal = False
258 break
259 if all_equal:
260 data_vec = [data_vec[0]] # warnings identical across sweeps
261 for i, data in enumerate(data_vec):
262 if len(data) == 0:
263 continue
264 data = sorted(data, key=lambda l: l[0]) # sort by msg
265 title = wtype
266 if len(data_vec) > 1:
267 title += " in sweep %i" % i
268 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
269 data = [(-int(1e5*relax_sensitivity),
270 "%+6.2g" % relax_sensitivity, id(c), c)
271 for _, (relax_sensitivity, c) in data]
272 lines += constraint_table(data, title, **kwargs)
273 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
274 data = [(-int(1e5*rel_diff),
275 "%.4g %s %.4g" % tightvalues, id(c), c)
276 for _, (rel_diff, tightvalues, c) in data]
277 lines += constraint_table(data, title, **kwargs)
278 else:
279 lines += [title] + ["-"*len(wtype)]
280 lines += [msg for msg, _ in data] + [""]
281 if len(lines) == 3: # just the header
282 return []
283 lines[-1] = "~~~~~~~~"
284 return lines + [""]
286def bdtable_gen(key):
287 "Generator for breakdown tablefns"
289 def bdtable(self, _showvars, **_):
290 "Cost breakdown plot"
291 bds = Breakdowns(self)
292 original_stdout = sys.stdout
293 try:
294 sys.stdout = SolverLog(original_stdout, verbosity=0)
295 bds.plot(key)
296 finally:
297 lines = sys.stdout.lines()
298 sys.stdout = original_stdout
299 return lines
301 return bdtable
304TABLEFNS = {"sensitivities": senss_table,
305 "top sensitivities": topsenss_table,
306 "insensitivities": insenss_table,
307 "model sensitivities": msenss_table,
308 "tightest constraints": tight_table,
309 "loose constraints": loose_table,
310 "warnings": warnings_table,
311 "model sensitivities breakdown": bdtable_gen("model sensitivities"),
312 "cost breakdown": bdtable_gen("cost")
313 }
315def unrolled_absmax(values):
316 "From an iterable of numbers and arrays, returns the largest magnitude"
317 finalval, absmaxest = None, 0
318 for val in values:
319 absmaxval = np.abs(val).max()
320 if absmaxval >= absmaxest:
321 absmaxest, finalval = absmaxval, val
322 if getattr(finalval, "shape", None):
323 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
324 finalval.shape)]
325 return finalval
328def cast(function, val1, val2):
329 "Relative difference between val1 and val2 (positive if val2 is larger)"
330 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
331 pywarnings.simplefilter("ignore")
332 if hasattr(val1, "shape") and hasattr(val2, "shape"):
333 if val1.ndim == val2.ndim:
334 return function(val1, val2)
335 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
336 dimdelta = dimmest.ndim - lessdim.ndim
337 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
338 if dimmest is val1:
339 return function(dimmest, lessdim[add_axes])
340 if dimmest is val2:
341 return function(lessdim[add_axes], dimmest)
342 return function(val1, val2)
345class SolutionArray(DictOfLists):
346 """A dictionary (of dictionaries) of lists, with convenience methods.
348 Items
349 -----
350 cost : array
351 variables: dict of arrays
352 sensitivities: dict containing:
353 monomials : array
354 posynomials : array
355 variables: dict of arrays
356 localmodels : NomialArray
357 Local power-law fits (small sensitivities are cut off)
359 Example
360 -------
361 >>> import gpkit
362 >>> import numpy as np
363 >>> x = gpkit.Variable("x")
364 >>> x_min = gpkit.Variable("x_{min}", 2)
365 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
366 >>>
367 >>> # VALUES
368 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
369 >>> assert all(np.array(values) == 2)
370 >>>
371 >>> # SENSITIVITIES
372 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
373 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
374 >>> assert all(np.array(senss) == 1)
375 """
376 modelstr = ""
377 _name_collision_varkeys = None
378 _lineageset = False
379 table_titles = {"choicevariables": "Choice Variables",
380 "sweepvariables": "Swept Variables",
381 "freevariables": "Free Variables",
382 "constants": "Fixed Variables", # TODO: change everywhere
383 "variables": "Variables"}
385 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches
386 "Returns the set of contained varkeys whose names are not unique"
387 if self._name_collision_varkeys is None:
388 self._name_collision_varkeys = {}
389 self["variables"].update_keymap()
390 keymap = self["variables"].keymap
391 name_collisions = defaultdict(set)
392 for key in keymap:
393 if hasattr(key, "key"):
394 if len(keymap[key.name]) == 1: # very unique
395 self._name_collision_varkeys[key] = 0
396 else:
397 shortname = key.str_without(["lineage", "vec"])
398 if len(keymap[shortname]) > 1:
399 name_collisions[shortname].add(key)
400 for varkeys in name_collisions.values():
401 min_namespaced = defaultdict(set)
402 for vk in varkeys:
403 *_, mineage = vk.lineagestr().split(".")
404 min_namespaced[(mineage, 1)].add(vk)
405 while any(len(vks) > 1 for vks in min_namespaced.values()):
406 for key, vks in list(min_namespaced.items()):
407 if len(vks) <= 1:
408 continue
409 del min_namespaced[key]
410 mineage, idx = key
411 idx += 1
412 for vk in vks:
413 lineages = vk.lineagestr().split(".")
414 submineage = lineages[-idx] + "." + mineage
415 min_namespaced[(submineage, idx)].add(vk)
416 for (_, idx), vks in min_namespaced.items():
417 vk, = vks
418 self._name_collision_varkeys[vk] = idx
419 if clear:
420 self._lineageset = False
421 for vk in self._name_collision_varkeys:
422 del vk.descr["necessarylineage"]
423 else:
424 self._lineageset = True
425 for vk, idx in self._name_collision_varkeys.items():
426 vk.descr["necessarylineage"] = idx
428 def __len__(self):
429 try:
430 return len(self["cost"])
431 except TypeError:
432 return 1
433 except KeyError:
434 return 0
436 def __call__(self, posy):
437 posy_subbed = self.subinto(posy)
438 return getattr(posy_subbed, "c", posy_subbed)
440 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
441 "Checks for almost-equality between two solutions"
442 svars, ovars = self["variables"], other["variables"]
443 svks, ovks = set(svars), set(ovars)
444 if svks != ovks:
445 return False
446 for key in svks:
447 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
448 return False
449 if abs(self["sensitivities"]["variables"][key]
450 - other["sensitivities"]["variables"][key]) >= sens_abstol:
451 return False
452 return True
454 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
455 def diff(self, other, showvars=None, *,
456 constraintsdiff=True, senssdiff=False, sensstol=0.1,
457 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
458 sortmodelsbysenss=True, **tableargs):
459 """Outputs differences between this solution and another
461 Arguments
462 ---------
463 other : solution or string
464 strings will be treated as paths to pickled solutions
465 senssdiff : boolean
466 if True, show sensitivity differences
467 sensstol : float
468 the smallest sensitivity difference worth showing
469 absdiff : boolean
470 if True, show absolute differences
471 abstol : float
472 the smallest absolute difference worth showing
473 reldiff : boolean
474 if True, show relative differences
475 reltol : float
476 the smallest relative difference worth showing
478 Returns
479 -------
480 str
481 """
482 if sortmodelsbysenss:
483 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
484 else:
485 tableargs["sortmodelsbysenss"] = False
486 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
487 "skipifempty": False})
488 if isinstance(other, Strings):
489 if other[-4:] == ".pgz":
490 other = SolutionArray.decompress_file(other)
491 else:
492 other = pickle.load(open(other, "rb"))
493 svars, ovars = self["variables"], other["variables"]
494 lines = ["Solution Diff",
495 "=============",
496 "(argument is the baseline solution)", ""]
497 svks, ovks = set(svars), set(ovars)
498 if showvars:
499 lines[0] += " (for selected variables)"
500 lines[1] += "========================="
501 showvars = self._parse_showvars(showvars)
502 svks = {k for k in showvars if k in svars}
503 ovks = {k for k in showvars if k in ovars}
504 if constraintsdiff and other.modelstr and self.modelstr:
505 if self.modelstr == other.modelstr:
506 lines += ["** no constraint differences **", ""]
507 else:
508 cdiff = ["Constraint Differences",
509 "**********************"]
510 cdiff.extend(list(difflib.unified_diff(
511 other.modelstr.split("\n"), self.modelstr.split("\n"),
512 lineterm="", n=3))[2:])
513 cdiff += ["", "**********************", ""]
514 lines += cdiff
515 if svks - ovks:
516 lines.append("Variable(s) of this solution"
517 " which are not in the argument:")
518 lines.append("\n".join(" %s" % key for key in svks - ovks))
519 lines.append("")
520 if ovks - svks:
521 lines.append("Variable(s) of the argument"
522 " which are not in this solution:")
523 lines.append("\n".join(" %s" % key for key in ovks - svks))
524 lines.append("")
525 sharedvks = svks.intersection(ovks)
526 if reldiff:
527 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
528 for vk in sharedvks}
529 lines += var_table(rel_diff,
530 "Relative Differences |above %g%%|" % reltol,
531 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
532 minval=reltol, printunits=False, **tableargs)
533 if lines[-2][:10] == "-"*10: # nothing larger than reltol
534 lines.insert(-1, ("The largest is %+g%%."
535 % unrolled_absmax(rel_diff.values())))
536 if absdiff:
537 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
538 lines += var_table(abs_diff,
539 "Absolute Differences |above %g|" % abstol,
540 valfmt="%+.2g", vecfmt="%+8.2g",
541 minval=abstol, **tableargs)
542 if lines[-2][:10] == "-"*10: # nothing larger than abstol
543 lines.insert(-1, ("The largest is %+g."
544 % unrolled_absmax(abs_diff.values())))
545 if senssdiff:
546 ssenss = self["sensitivities"]["variables"]
547 osenss = other["sensitivities"]["variables"]
548 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
549 for vk in svks.intersection(ovks)}
550 lines += var_table(senss_delta,
551 "Sensitivity Differences |above %g|" % sensstol,
552 valfmt="%+-.2f ", vecfmt="%+-6.2f",
553 minval=sensstol, printunits=False, **tableargs)
554 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
555 lines.insert(-1, ("The largest is %+g."
556 % unrolled_absmax(senss_delta.values())))
557 return "\n".join(lines)
559 def save(self, filename="solution.pkl",
560 *, saveconstraints=True, **pickleargs):
561 """Pickles the solution and saves it to a file.
563 Solution can then be loaded with e.g.:
564 >>> import pickle
565 >>> pickle.load(open("solution.pkl"))
566 """
567 with SolSavingEnvironment(self, saveconstraints):
568 pickle.dump(self, open(filename, "wb"), **pickleargs)
570 def save_compressed(self, filename="solution.pgz",
571 *, saveconstraints=True, **cpickleargs):
572 "Pickle a file and then compress it into a file with extension."
573 with gzip.open(filename, "wb") as f:
574 with SolSavingEnvironment(self, saveconstraints):
575 pickled = pickle.dumps(self, **cpickleargs)
576 f.write(pickletools.optimize(pickled))
578 @staticmethod
579 def decompress_file(file):
580 "Load a gzip-compressed pickle file"
581 with gzip.open(file, "rb") as f:
582 return pickle.Unpickler(f).load()
584 def varnames(self, showvars, exclude):
585 "Returns list of variables, optionally with minimal unique names"
586 if showvars:
587 showvars = self._parse_showvars(showvars)
588 self.set_necessarylineage()
589 names = {}
590 for key in showvars or self["variables"]:
591 for k in self["variables"].keymap[key]:
592 names[k.str_without(exclude)] = k
593 self.set_necessarylineage(clear=True)
594 return names
596 def savemat(self, filename="solution.mat", *, showvars=None,
597 excluded=("vec")):
598 "Saves primal solution as matlab file"
599 from scipy.io import savemat
600 savemat(filename,
601 {name.replace(".", "_"): np.array(self["variables"][key], "f")
602 for name, key in self.varnames(showvars, excluded).items()})
604 def todataframe(self, showvars=None, excluded=("vec")):
605 "Returns primal solution as pandas dataframe"
606 import pandas as pd # pylint:disable=import-error
607 rows = []
608 cols = ["Name", "Index", "Value", "Units", "Label",
609 "Lineage", "Other"]
610 for _, key in sorted(self.varnames(showvars, excluded).items(),
611 key=lambda k: k[0]):
612 value = self["variables"][key]
613 if key.shape:
614 idxs = []
615 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
616 while not it.finished:
617 idx = it.multi_index
618 idxs.append(idx[0] if len(idx) == 1 else idx)
619 it.iternext()
620 else:
621 idxs = [None]
622 for idx in idxs:
623 row = [
624 key.name,
625 "" if idx is None else idx,
626 value if idx is None else value[idx]]
627 rows.append(row)
628 row.extend([
629 key.unitstr(),
630 key.label or "",
631 key.lineage or "",
632 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
633 if k not in ["name", "units", "unitrepr",
634 "idx", "shape", "veckey",
635 "value", "vecfn",
636 "lineage", "label"])])
637 return pd.DataFrame(rows, columns=cols)
639 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
640 "Saves solution table as a text file"
641 with open(filename, "w") as f:
642 if printmodel:
643 f.write(self.modelstr + "\n")
644 f.write(self.table(**kwargs))
646 def savejson(self, filename="solution.json", showvars=None):
647 "Saves solution table as a json file"
648 sol_dict = {}
649 if self._lineageset:
650 self.set_necessarylineage(clear=True)
651 data = self["variables"]
652 if showvars:
653 showvars = self._parse_showvars(showvars)
654 data = {k: data[k] for k in showvars if k in data}
655 # add appropriate data for each variable to the dictionary
656 for k, v in data.items():
657 key = str(k)
658 if isinstance(v, np.ndarray):
659 val = {"v": v.tolist(), "u": k.unitstr()}
660 else:
661 val = {"v": v, "u": k.unitstr()}
662 sol_dict[key] = val
663 with open(filename, "w") as f:
664 json.dump(sol_dict, f)
666 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
667 "Saves primal solution as a CSV sorted by modelname, like the tables."
668 data = self["variables"]
669 if showvars:
670 showvars = self._parse_showvars(showvars)
671 data = {k: data[k] for k in showvars if k in data}
672 # if the columns don't capture any dimensions, skip them
673 minspan, maxspan = None, 1
674 for v in data.values():
675 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
676 minspan_ = min((di for di in v.shape if di != 1))
677 maxspan_ = max((di for di in v.shape if di != 1))
678 if minspan is None or minspan_ < minspan:
679 minspan = minspan_
680 if maxspan is None or maxspan_ > maxspan:
681 maxspan = maxspan_
682 if minspan is not None and minspan > valcols:
683 valcols = 1
684 if maxspan < valcols:
685 valcols = maxspan
686 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
687 tables=("cost", "sweepvariables", "freevariables",
688 "constants", "sensitivities"))
689 with open(filename, "w") as f:
690 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
691 + "Units,Description\n")
692 for line in lines:
693 if line[0] == ("newmodelline",):
694 f.write(line[1])
695 elif not line[1]: # spacer line
696 f.write("\n")
697 else:
698 f.write("," + line[0].replace(" : ", "") + ",")
699 vals = line[1].replace("[", "").replace("]", "").strip()
700 for el in vals.split():
701 f.write(el + ",")
702 f.write(","*(valcols - len(vals.split())))
703 f.write((line[2].replace("[", "").replace("]", "").strip()
704 + ","))
705 f.write(line[3].strip() + "\n")
707 def subinto(self, posy):
708 "Returns NomialArray of each solution substituted into posy."
709 if posy in self["variables"]:
710 return self["variables"](posy)
712 if not hasattr(posy, "sub"):
713 raise ValueError("no variable '%s' found in the solution" % posy)
715 if len(self) > 1:
716 return NomialArray([self.atindex(i).subinto(posy)
717 for i in range(len(self))])
719 return posy.sub(self["variables"], require_positive=False)
721 def _parse_showvars(self, showvars):
722 showvars_out = set()
723 for k in showvars:
724 k, _ = self["variables"].parse_and_index(k)
725 keys = self["variables"].keymap[k]
726 showvars_out.update(keys)
727 return showvars_out
729 def summary(self, showvars=(), **kwargs):
730 "Print summary table, showing no sensitivities or constants"
731 return self.table(showvars,
732 ["cost breakdown", "model sensitivities breakdown",
733 "warnings", "sweepvariables", "freevariables"],
734 **kwargs)
736 def table(self, showvars=(),
737 tables=("cost breakdown", "model sensitivities breakdown",
738 "warnings", "sweepvariables", "freevariables",
739 "constants", "sensitivities", "tightest constraints"),
740 sortmodelsbysenss=False, **kwargs):
741 """A table representation of this SolutionArray
743 Arguments
744 ---------
745 tables: Iterable
746 Which to print of ("cost", "sweepvariables", "freevariables",
747 "constants", "sensitivities")
748 fixedcols: If true, print vectors in fixed-width format
749 latex: int
750 If > 0, return latex format (options 1-3); otherwise plain text
751 included_models: Iterable of strings
752 If specified, the models (by name) to include
753 excluded_models: Iterable of strings
754 If specified, model names to exclude
756 Returns
757 -------
758 str
759 """
760 if sortmodelsbysenss and "sensitivities" in self:
761 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
762 else:
763 kwargs["sortmodelsbysenss"] = False
764 varlist = list(self["variables"])
765 has_only_one_model = True
766 for var in varlist[1:]:
767 if var.lineage != varlist[0].lineage:
768 has_only_one_model = False
769 break
770 if has_only_one_model:
771 kwargs["sortbymodel"] = False
772 self.set_necessarylineage()
773 showvars = self._parse_showvars(showvars)
774 strs = []
775 for table in tables:
776 if len(self) > 1 and "breakdown" in table:
777 # no breakdowns for sweeps
778 table = table.replace(" breakdown", "")
779 if "sensitivities" not in self and ("sensitivities" in table or
780 "constraints" in table):
781 continue
782 if table == "cost":
783 cost = self["cost"] # pylint: disable=unsubscriptable-object
784 if kwargs.get("latex", None): # cost is not printed for latex
785 continue
786 strs += ["\n%s\n------------" % "Optimal Cost"]
787 if len(self) > 1:
788 costs = ["%-8.3g" % c for c in mag(cost[:4])]
789 strs += [" [ %s %s ]" % (" ".join(costs),
790 "..." if len(self) > 4 else "")]
791 else:
792 strs += [" %-.4g" % mag(cost)]
793 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
794 strs += [""]
795 elif table in TABLEFNS:
796 strs += TABLEFNS[table](self, showvars, **kwargs)
797 elif table in self:
798 data = self[table]
799 if showvars:
800 showvars = self._parse_showvars(showvars)
801 data = {k: data[k] for k in showvars if k in data}
802 strs += var_table(data, self.table_titles[table], **kwargs)
803 if kwargs.get("latex", None):
804 preamble = "\n".join(("% \\documentclass[12pt]{article}",
805 "% \\usepackage{booktabs}",
806 "% \\usepackage{longtable}",
807 "% \\usepackage{amsmath}",
808 "% \\begin{document}\n"))
809 strs = [preamble] + strs + ["% \\end{document}"]
810 self.set_necessarylineage(clear=True)
811 return "\n".join(strs)
813 def plot(self, posys=None, axes=None):
814 "Plots a sweep for each posy"
815 if len(self["sweepvariables"]) != 1:
816 print("SolutionArray.plot only supports 1-dimensional sweeps")
817 if not hasattr(posys, "__len__"):
818 posys = [posys]
819 import matplotlib.pyplot as plt
820 from .interactive.plot_sweep import assign_axes
821 from . import GPBLU
822 (swept, x), = self["sweepvariables"].items()
823 posys, axes = assign_axes(swept, posys, axes)
824 for posy, ax in zip(posys, axes):
825 y = self(posy) if posy not in [None, "cost"] else self["cost"]
826 ax.plot(x, y, color=GPBLU)
827 if len(axes) == 1:
828 axes, = axes
829 return plt.gcf(), axes
832# pylint: disable=too-many-branches,too-many-locals,too-many-statements
833def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
834 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
835 minval=0, sortbyvals=False, hidebelowminval=False,
836 included_models=None, excluded_models=None, sortbymodel=True,
837 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
838 """
839 Pretty string representation of a dict of VarKeys
840 Iterable values are handled specially (partial printing)
842 Arguments
843 ---------
844 data : dict whose keys are VarKey's
845 data to represent in table
846 title : string
847 printunits : bool
848 latex : int
849 If > 0, return latex format (options 1-3); otherwise plain text
850 varfmt : string
851 format for variable names
852 valfmt : string
853 format for scalar values
854 vecfmt : string
855 format for vector values
856 minval : float
857 skip values with all(abs(value)) < minval
858 sortbyvals : boolean
859 If true, rows are sorted by their average value instead of by name.
860 included_models : Iterable of strings
861 If specified, the models (by name) to include
862 excluded_models : Iterable of strings
863 If specified, model names to exclude
864 """
865 if not data:
866 return []
867 decorated, models = [], set()
868 for i, (k, v) in enumerate(data.items()):
869 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
870 continue # no values below minval
871 if minval and hidebelowminval and getattr(v, "shape", None):
872 v[np.abs(v) <= minval] = np.nan
873 model = lineagestr(k.lineage) if sortbymodel else ""
874 if not sortmodelsbysenss:
875 msenss = 0
876 else: # sort should match that in msenss_table above
877 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4)
878 models.add(model)
879 b = bool(getattr(v, "shape", None))
880 s = k.str_without(("lineage", "vec"))
881 if not sortbyvals:
882 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
883 else: # for consistent sorting, add small offset to negative vals
884 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
885 sort = (float("%.4g" % -val), k.name)
886 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
887 if not decorated and skipifempty:
888 return []
889 if included_models:
890 included_models = set(included_models)
891 included_models.add("")
892 models = models.intersection(included_models)
893 if excluded_models:
894 models = models.difference(excluded_models)
895 decorated.sort()
896 previous_model, lines = None, []
897 for varlist in decorated:
898 if sortbyvals:
899 model, _, msenss, isvector, varstr, _, var, val = varlist
900 else:
901 msenss, model, isvector, varstr, _, var, val = varlist
902 if model not in models:
903 continue
904 if model != previous_model:
905 if lines:
906 lines.append(["", "", "", ""])
907 if model:
908 if not latex:
909 lines.append([("newmodelline",), model, "", ""])
910 else:
911 lines.append(
912 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
913 previous_model = model
914 label = var.descr.get("label", "")
915 units = var.unitstr(" [%s] ") if printunits else ""
916 if not isvector:
917 valstr = valfmt % val
918 else:
919 last_dim_index = len(val.shape)-1
920 horiz_dim, ncols = last_dim_index, 1 # starting values
921 for dim_idx, dim_size in enumerate(val.shape):
922 if ncols <= dim_size <= maxcolumns:
923 horiz_dim, ncols = dim_idx, dim_size
924 # align the array with horiz_dim by making it the last one
925 dim_order = list(range(last_dim_index))
926 dim_order.insert(horiz_dim, last_dim_index)
927 flatval = val.transpose(dim_order).flatten()
928 vals = [vecfmt % v for v in flatval[:ncols]]
929 bracket = " ] " if len(flatval) <= ncols else ""
930 valstr = "[ %s%s" % (" ".join(vals), bracket)
931 for before, after in VALSTR_REPLACES:
932 valstr = valstr.replace(before, after)
933 if not latex:
934 lines.append([varstr, valstr, units, label])
935 if isvector and len(flatval) > ncols:
936 values_remaining = len(flatval) - ncols
937 while values_remaining > 0:
938 idx = len(flatval)-values_remaining
939 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
940 values_remaining -= ncols
941 valstr = " " + " ".join(vals)
942 for before, after in VALSTR_REPLACES:
943 valstr = valstr.replace(before, after)
944 if values_remaining <= 0:
945 spaces = (-values_remaining
946 * len(valstr)//(values_remaining + ncols))
947 valstr = valstr + " ]" + " "*spaces
948 lines.append(["", valstr, "", ""])
949 else:
950 varstr = "$%s$" % varstr.replace(" : ", "")
951 if latex == 1: # normal results table
952 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
953 label])
954 coltitles = [title, "Value", "Units", "Description"]
955 elif latex == 2: # no values
956 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
957 coltitles = [title, "Units", "Description"]
958 elif latex == 3: # no description
959 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
960 coltitles = [title, "Value", "Units"]
961 else:
962 raise ValueError("Unexpected latex option, %s." % latex)
963 if rawlines:
964 return lines
965 if not latex:
966 if lines:
967 maxlens = np.max([list(map(len, line)) for line in lines
968 if line[0] != ("newmodelline",)], axis=0)
969 dirs = [">", "<", "<", "<"]
970 # check lengths before using zip
971 assert len(list(dirs)) == len(list(maxlens))
972 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
973 for i, line in enumerate(lines):
974 if line[0] == ("newmodelline",):
975 line = [fmts[0].format(" | "), line[1]]
976 else:
977 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
978 lines[i] = "".join(line).rstrip()
979 lines = [title] + ["-"*len(title)] + lines + [""]
980 else:
981 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
982 lines = (["\n".join(["{\\footnotesize",
983 "\\begin{longtable}{%s}" % colfmt[latex],
984 "\\toprule",
985 " & ".join(coltitles) + " \\\\ \\midrule"])] +
986 [" & ".join(l) + " \\\\" for l in lines] +
987 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
988 return lines