Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import difflib
4from operator import sub
5import warnings as pywarnings
6import pickle
7import gzip
8import json
9import pickletools
10import numpy as np
11from .nomials import NomialArray
12from .small_classes import DictOfLists, Strings
13from .small_scripts import mag, try_str_without
14from .repr_conventions import unitstr, lineagestr
17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
19VALSTR_REPLACES = [
20 ("+nan", " nan"),
21 ("-nan", " nan"),
22 ("nan%", "nan "),
23 ("nan", " - "),
24]
27class SolSavingEnvironment:
28 """Temporarily removes construction/solve attributes from constraints.
30 This approximately halves the size of the pickled solution.
31 """
33 def __init__(self, solarray, saveconstraints):
34 self.solarray = solarray
35 self.attrstore = {}
36 self.saveconstraints = saveconstraints
37 self.constraintstore = None
40 def __enter__(self):
41 if self.saveconstraints:
42 for constraint_attr in ["bounded", "meq_bounded", "vks",
43 "v_ss", "unsubbed", "varkeys"]:
44 store = {}
45 for constraint in self.solarray["sensitivities"]["constraints"]:
46 if getattr(constraint, constraint_attr, None):
47 store[constraint] = getattr(constraint, constraint_attr)
48 delattr(constraint, constraint_attr)
49 self.attrstore[constraint_attr] = store
50 else:
51 self.constraintstore = \
52 self.solarray["sensitivities"].pop("constraints")
54 def __exit__(self, type_, val, traceback):
55 if self.saveconstraints:
56 for constraint_attr, store in self.attrstore.items():
57 for constraint, value in store.items():
58 setattr(constraint, constraint_attr, value)
59 else:
60 self.solarray["sensitivities"]["constraints"] = self.constraintstore
62def msenss_table(data, _, **kwargs):
63 "Returns model sensitivity table lines"
64 if "models" not in data.get("sensitivities", {}):
65 return ""
66 data = sorted(data["sensitivities"]["models"].items(),
67 key=lambda i: ((i[1] < 0.1).all(),
68 -np.max(i[1]) if (i[1] < 0.1).all()
69 else -round(np.mean(i[1]), 1), i[0]))
70 lines = ["Model Sensitivities", "-------------------"]
71 if kwargs["sortmodelsbysenss"]:
72 lines[0] += " (sorts models in sections below)"
73 previousmsenssstr = ""
74 for model, msenss in data:
75 if not model: # for now let's only do named models
76 continue
77 if (msenss < 0.1).all():
78 msenss = np.max(msenss)
79 if msenss:
80 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss))
81 else:
82 msenssstr = " =0 "
83 else:
84 meansenss = round(np.mean(msenss), 1)
85 msenssstr = "%+6.1f" % meansenss
86 deltas = msenss - meansenss
87 if np.max(np.abs(deltas)) > 0.1:
88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
89 for d in deltas]
90 msenssstr += " + [ %s ]" % " ".join(deltastrs)
91 if msenssstr == previousmsenssstr:
92 msenssstr = " "*len(msenssstr)
93 else:
94 previousmsenssstr = msenssstr
95 lines.append("%s : %s" % (msenssstr, model))
96 return lines + [""] if len(lines) > 3 else []
99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
100 "Returns sensitivity table lines"
101 if "variables" in data.get("sensitivities", {}):
102 data = data["sensitivities"]["variables"]
103 if showvars:
104 data = {k: data[k] for k in showvars if k in data}
105 return var_table(data, title, sortbyvals=True, skipifempty=True,
106 valfmt="%+-.2g ", vecfmt="%+-8.2g",
107 printunits=False, minval=1e-3, **kwargs)
110def topsenss_table(data, showvars, nvars=5, **kwargs):
111 "Returns top sensitivity table lines"
112 data, filtered = topsenss_filter(data, showvars, nvars)
113 title = "Most Sensitive Variables"
114 if filtered:
115 title = "Next Most Sensitive Variables"
116 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
119def topsenss_filter(data, showvars, nvars=5):
120 "Filters sensitivities down to top N vars"
121 if "variables" in data.get("sensitivities", {}):
122 data = data["sensitivities"]["variables"]
123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
124 if not np.isnan(s).any()}
125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
126 filter_already_shown = showvars.intersection(topk)
127 for k in filter_already_shown:
128 topk.remove(k)
129 if nvars > 3: # always show at least 3
130 nvars -= 1
131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
134def insenss_table(data, _, maxval=0.1, **kwargs):
135 "Returns insensitivity table lines"
136 if "constants" in data.get("sensitivities", {}):
137 data = data["sensitivities"]["variables"]
138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
143 "Return constraint tightness lines"
144 title = "Most Sensitive Constraints"
145 if len(self) > 1:
146 title += " (in last sweep)"
147 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
148 "%+6.2g" % s[-1], id(c), c)
149 for c, s in self["sensitivities"]["constraints"].items()
150 if s[-1] >= tight_senss)[:ntightconstrs]
151 else:
152 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
153 for c, s in self["sensitivities"]["constraints"].items()
154 if s >= tight_senss)[:ntightconstrs]
155 return constraint_table(data, title, **kwargs)
157def loose_table(self, _, min_senss=1e-5, **kwargs):
158 "Return constraint tightness lines"
159 title = "Insensitive Constraints |below %+g|" % min_senss
160 if len(self) > 1:
161 title += " (in last sweep)"
162 data = [(0, "", id(c), c)
163 for c, s in self["sensitivities"]["constraints"].items()
164 if s[-1] <= min_senss]
165 else:
166 data = [(0, "", id(c), c)
167 for c, s in self["sensitivities"]["constraints"].items()
168 if s <= min_senss]
169 return constraint_table(data, title, **kwargs)
172# pylint: disable=too-many-branches,too-many-locals,too-many-statements
173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
174 "Creates lines for tables where the right side is a constraint."
175 # TODO: this should support 1D array inputs from sweeps
176 excluded = ("units", "unnecessary lineage")
177 if not showmodels:
178 excluded = ("units", "lineage") # hide all of it
179 models, decorated = {}, []
180 for sortby, openingstr, _, constraint in sorted(data):
181 model = lineagestr(constraint) if sortbymodel else ""
182 if model not in models:
183 models[model] = len(models)
184 constrstr = try_str_without(constraint, excluded)
185 if " at 0x" in constrstr: # don't print memory addresses
186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
187 decorated.append((models[model], model, sortby, constrstr, openingstr))
188 decorated.sort()
189 previous_model, lines = None, []
190 for varlist in decorated:
191 _, model, _, constrstr, openingstr = varlist
192 if model != previous_model:
193 if lines:
194 lines.append(["", ""])
195 if model or lines:
196 lines.append([("newmodelline",), model])
197 previous_model = model
198 constrstr = constrstr.replace(model, "")
199 minlen, maxlen = 25, 80
200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
201 constraintlines = []
202 line = ""
203 next_idx = 0
204 while next_idx < len(segments):
205 segment = segments[next_idx]
206 next_idx += 1
207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
208 segments[next_idx] = segment[1:] + segments[next_idx]
209 segment = segment[0]
210 elif len(line) + len(segment) > maxlen and len(line) > minlen:
211 constraintlines.append(line)
212 line = " " # start a new line
213 line += segment
214 while len(line) > maxlen:
215 constraintlines.append(line[:maxlen])
216 line = " " + line[maxlen:]
217 constraintlines.append(line)
218 lines += [(openingstr + " : ", constraintlines[0])]
219 lines += [("", l) for l in constraintlines[1:]]
220 if not lines:
221 lines = [("", "(none)")]
222 maxlens = np.max([list(map(len, line)) for line in lines
223 if line[0] != ("newmodelline",)], axis=0)
224 dirs = [">", "<"] # we'll check lengths before using zip
225 assert len(list(dirs)) == len(list(maxlens))
226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
227 for i, line in enumerate(lines):
228 if line[0] == ("newmodelline",):
229 linelist = [fmts[0].format(" | "), line[1]]
230 else:
231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
232 lines[i] = "".join(linelist).rstrip()
233 return [title] + ["-"*len(title)] + lines + [""]
236def warnings_table(self, _, **kwargs):
237 "Makes a table for all warnings in the solution."
238 title = "WARNINGS"
239 lines = ["~"*len(title), title, "~"*len(title)]
240 if "warnings" not in self or not self["warnings"]:
241 return []
242 for wtype in sorted(self["warnings"]):
243 data_vec = self["warnings"][wtype]
244 if len(data_vec) == 0:
245 continue
246 if not hasattr(data_vec, "shape"):
247 data_vec = [data_vec] # not a sweep
248 else:
249 all_equal = True
250 for data in data_vec[1:]:
251 eq_i = (data == data_vec[0])
252 if hasattr(eq_i, "all"):
253 eq_i = eq_i.all()
254 if not eq_i:
255 all_equal = False
256 break
257 if all_equal:
258 data_vec = [data_vec[0]] # warnings identical across sweeps
259 for i, data in enumerate(data_vec):
260 if len(data) == 0:
261 continue
262 data = sorted(data, key=lambda l: l[0]) # sort by msg
263 title = wtype
264 if len(data_vec) > 1:
265 title += " in sweep %i" % i
266 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
267 data = [(-int(1e5*relax_sensitivity),
268 "%+6.2g" % relax_sensitivity, id(c), c)
269 for _, (relax_sensitivity, c) in data]
270 lines += constraint_table(data, title, **kwargs)
271 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
272 data = [(-int(1e5*rel_diff),
273 "%.4g %s %.4g" % tightvalues, id(c), c)
274 for _, (rel_diff, tightvalues, c) in data]
275 lines += constraint_table(data, title, **kwargs)
276 else:
277 lines += [title] + ["-"*len(wtype)]
278 lines += [msg for msg, _ in data] + [""]
279 if len(lines) == 3: # just the header
280 return []
281 lines[-1] = "~~~~~~~~"
282 return lines + [""]
285TABLEFNS = {"sensitivities": senss_table,
286 "top sensitivities": topsenss_table,
287 "insensitivities": insenss_table,
288 "model sensitivities": msenss_table,
289 "tightest constraints": tight_table,
290 "loose constraints": loose_table,
291 "warnings": warnings_table,
292 }
294def unrolled_absmax(values):
295 "From an iterable of numbers and arrays, returns the largest magnitude"
296 finalval, absmaxest = None, 0
297 for val in values:
298 absmaxval = np.abs(val).max()
299 if absmaxval >= absmaxest:
300 absmaxest, finalval = absmaxval, val
301 if getattr(finalval, "shape", None):
302 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
303 finalval.shape)]
304 return finalval
307def cast(function, val1, val2):
308 "Relative difference between val1 and val2 (positive if val2 is larger)"
309 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
310 pywarnings.simplefilter("ignore")
311 if hasattr(val1, "shape") and hasattr(val2, "shape"):
312 if val1.ndim == val2.ndim:
313 return function(val1, val2)
314 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
315 dimdelta = dimmest.ndim - lessdim.ndim
316 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
317 if dimmest is val1:
318 return function(dimmest, lessdim[add_axes])
319 if dimmest is val2:
320 return function(lessdim[add_axes], dimmest)
321 return function(val1, val2)
324def diff_retrieval(self, other, sharedvks, showvars=None, *, jsondiff=False,
325 senssdiff=False, absdiff=False, reldiff=False):
326 """A helper function for generalized diff method
327 - retreives svars and ovars,
328 """
329 svars = self["variables"]
330 ovars = other["variables"]
331 # get the type of diffs
332 diff_dict = {}
333 if jsondiff == False:
334 if reldiff:
335 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
336 for vk in sharedvks}
337 diff_dict['rel'] = rel_diff
338 if absdiff:
339 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
340 diff_dict['abs'] = abs_diff
341 if senssdiff:
342 ssenss = self["sensitivities"]["variables"]
343 osenss = other["sensitivities"]["variables"]
344 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
345 for vk in sharedvks}
346 diff_dict['sens'] = senss_delta
347 else:
348 if reldiff:
349 rel_diff = {}
350 for vk in sharedvks:
351 val = 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
352 if isinstance(val, np.ndarray):
353 val = val.tolist()
354 rel_diff[str(vk)] = val
355 diff_dict['rel'] = rel_diff
356 if absdiff:
357 abs_diff = {}
358 for vk in sharedvks:
359 val = cast(sub, svars[vk], ovars[vk])
360 if isinstance(val, np.ndarray):
361 val = val.tolist()
362 abs_diff[str(vk)] = val
363 diff_dict['abs'] = abs_diff
364 if senssdiff:
365 sense_delta = {}
366 ssenss = self["sensitivities"]["variables"]
367 osenss = other["sensitivities"]["variables"]
368 for vk in sharedvks:
369 val = cast(sub, ssenss[vk], osenss[vk])
370 sense_delta[str(vk)] = val
371 diff_dict['sens'] = senss_delta
372 return diff_dict
375class SolutionArray(DictOfLists):
376 """A dictionary (of dictionaries) of lists, with convenience methods.
378 Items
379 -----
380 cost : array
381 variables: dict of arrays
382 sensitivities: dict containing:
383 monomials : array
384 posynomials : array
385 variables: dict of arrays
386 localmodels : NomialArray
387 Local power-law fits (small sensitivities are cut off)
389 Example
390 -------
391 >>> import gpkit
392 >>> import numpy as np
393 >>> x = gpkit.Variable("x")
394 >>> x_min = gpkit.Variable("x_{min}", 2)
395 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
396 >>>
397 >>> # VALUES
398 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
399 >>> assert all(np.array(values) == 2)
400 >>>
401 >>> # SENSITIVITIES
402 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
403 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
404 >>> assert all(np.array(senss) == 1)
405 """
406 modelstr = ""
407 _name_collision_varkeys = None
408 table_titles = {"choicevariables": "Choice Variables",
409 "sweepvariables": "Swept Variables",
410 "freevariables": "Free Variables",
411 "constants": "Fixed Variables", # TODO: change everywhere
412 "variables": "Variables"}
414 def name_collision_varkeys(self):
415 "Returns the set of contained varkeys whose names are not unique"
416 if self._name_collision_varkeys is None:
417 self["variables"].update_keymap()
418 keymap = self["variables"].keymap
419 self._name_collision_varkeys = set()
420 for key in list(keymap):
421 if hasattr(key, "key"):
422 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
423 self._name_collision_varkeys.add(key)
424 return self._name_collision_varkeys
426 def __len__(self):
427 try:
428 return len(self["cost"])
429 except TypeError:
430 return 1
431 except KeyError:
432 return 0
434 def __call__(self, posy):
435 posy_subbed = self.subinto(posy)
436 return getattr(posy_subbed, "c", posy_subbed)
438 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
439 "Checks for almost-equality between two solutions"
440 svars, ovars = self["variables"], other["variables"]
441 svks, ovks = set(svars), set(ovars)
442 if svks != ovks:
443 return False
444 for key in svks:
445 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
446 return False
447 if abs(self["sensitivities"]["variables"][key]
448 - other["sensitivities"]["variables"][key]) >= sens_abstol:
449 return False
450 return True
452 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
453 def diff(self, other, showvars=None, *, jsondiff=False, filename="solution.json"
454 ,constraintsdiff=True, senssdiff=False, sensstol=0.1,
455 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
456 sortmodelsbysenss=True, **tableargs):
457 """Outputs differences between this solution and another
459 Arguments
460 ---------
461 other : solution or string
462 strings will be treated as paths to pickled solutions
463 senssdiff : boolean
464 if True, show sensitivity differences
465 sensstol : float
466 the smallest sensitivity difference worth showing
467 absdiff : boolean
468 if True, show absolute differences
469 abstol : float
470 the smallest absolute difference worth showing
471 reldiff : boolean
472 if True, show relative differences
473 reltol : float
474 the smallest relative difference worth showing
476 Returns
477 -------
478 str
479 """
480 if sortmodelsbysenss:
481 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
482 else:
483 tableargs["sortmodelsbysenss"] = False
484 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
485 "skipifempty": False})
486 if isinstance(other, Strings):
487 if other[-4:] == ".pgz":
488 other = SolutionArray.decompress_file(other)
489 else:
490 other = pickle.load(open(other, "rb"))
491 svars, ovars = self["variables"], other["variables"]
492 lines = ["Solution Diff",
493 "=============",
494 "(argument is the baseline solution)", ""]
495 svks, ovks = set(svars), set(ovars)
496 if showvars:
497 lines[0] += " (for selected variables)"
498 lines[1] += "========================="
499 showvars = self._parse_showvars(showvars)
500 svks = {k for k in showvars if k in svars}
501 ovks = {k for k in showvars if k in ovars}
502 if constraintsdiff and other.modelstr and self.modelstr:
503 if self.modelstr == other.modelstr:
504 lines += ["** no constraint differences **", ""]
505 else:
506 cdiff = ["Constraint Differences",
507 "**********************"]
508 cdiff.extend(list(difflib.unified_diff(
509 other.modelstr.split("\n"), self.modelstr.split("\n"),
510 lineterm="", n=3))[2:])
511 cdiff += ["", "**********************", ""]
512 lines += cdiff
513 if svks - ovks:
514 lines.append("Variable(s) of this solution"
515 " which are not in the argument:")
516 lines.append("\n".join(" %s" % key for key in svks - ovks))
517 lines.append("")
518 if ovks - svks:
519 lines.append("Variable(s) of the argument"
520 " which are not in this solution:")
521 lines.append("\n".join(" %s" % key for key in ovks - svks))
522 lines.append("")
523 sharedvks = svks.intersection(ovks)
525 # retrieve diff data
526 diff_dict = diff_retreval(self, other, sharedvks, showvars=showvars,
527 jsondiff=jsondiff, senssdiff=senssdiff,
528 absdiff=absdiff, reldiff=reldiff)
529 if jsondiff:
530 with open(filename, "w") as f:
531 json.dump(diff_dict, f)
532 return diff_dict
534 if reldiff:
535 rel_diff = diff_dict['rel']
536 lines += var_table(rel_diff,
537 "Relative Differences |above %g%%|" % reltol,
538 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
539 minval=reltol, printunits=False, **tableargs)
540 if lines[-2][:10] == "-"*10: # nothing larger than reltol
541 lines.insert(-1, ("The largest is %+g%%."
542 % unrolled_absmax(rel_diff.values())))
543 if absdiff:
544 abs_diff = diff_dict['abs']
545 lines += var_table(abs_diff,
546 "Absolute Differences |above %g|" % abstol,
547 valfmt="%+.2g", vecfmt="%+8.2g",
548 minval=abstol, **tableargs)
549 if lines[-2][:10] == "-"*10: # nothing larger than abstol
550 lines.insert(-1, ("The largest is %+g."
551 % unrolled_absmax(abs_diff.values())))
552 if senssdiff:
553 senss_delta = diff_dict['sens']
554 lines += var_table(senss_delta,
555 "Sensitivity Differences |above %g|" % sensstol,
556 valfmt="%+-.2f ", vecfmt="%+-6.2f",
557 minval=sensstol, printunits=False, **tableargs)
558 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
559 lines.insert(-1, ("The largest is %+g."
560 % unrolled_absmax(senss_delta.values())))
561 return "\n".join(lines)
563 def save(self, filename="solution.pkl",
564 *, saveconstraints=True, **pickleargs):
565 """Pickles the solution and saves it to a file.
567 Solution can then be loaded with e.g.:
568 >>> import pickle
569 >>> pickle.load(open("solution.pkl"))
570 """
571 with SolSavingEnvironment(self, saveconstraints):
572 pickle.dump(self, open(filename, "wb"), **pickleargs)
574 def save_compressed(self, filename="solution.pgz",
575 *, saveconstraints=True, **cpickleargs):
576 "Pickle a file and then compress it into a file with extension."
577 with gzip.open(filename, "wb") as f:
578 with SolSavingEnvironment(self, saveconstraints):
579 pickled = pickle.dumps(self, **cpickleargs)
580 f.write(pickletools.optimize(pickled))
582 @staticmethod
583 def decompress_file(file):
584 "Load a gzip-compressed pickle file"
585 with gzip.open(file, "rb") as f:
586 return pickle.Unpickler(f).load()
588 def varnames(self, showvars, exclude):
589 "Returns list of variables, optionally with minimal unique names"
590 if showvars:
591 showvars = self._parse_showvars(showvars)
592 for key in self.name_collision_varkeys():
593 key.descr["necessarylineage"] = True
594 names = {}
595 for key in showvars or self["variables"]:
596 for k in self["variables"].keymap[key]:
597 names[k.str_without(exclude)] = k
598 for key in self.name_collision_varkeys():
599 del key.descr["necessarylineage"]
600 return names
602 def savemat(self, filename="solution.mat", *, showvars=None,
603 excluded=("unnecessary lineage", "vec")):
604 "Saves primal solution as matlab file"
605 from scipy.io import savemat
606 savemat(filename,
607 {name.replace(".", "_"): np.array(self["variables"][key], "f")
608 for name, key in self.varnames(showvars, excluded).items()})
610 def todataframe(self, showvars=None,
611 excluded=("unnecessary lineage", "vec")):
612 "Returns primal solution as pandas dataframe"
613 import pandas as pd # pylint:disable=import-error
614 rows = []
615 cols = ["Name", "Index", "Value", "Units", "Label",
616 "Lineage", "Other"]
617 for _, key in sorted(self.varnames(showvars, excluded).items(),
618 key=lambda k: k[0]):
619 value = self["variables"][key]
620 if key.shape:
621 idxs = []
622 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
623 while not it.finished:
624 idx = it.multi_index
625 idxs.append(idx[0] if len(idx) == 1 else idx)
626 it.iternext()
627 else:
628 idxs = [None]
629 for idx in idxs:
630 row = [
631 key.name,
632 "" if idx is None else idx,
633 value if idx is None else value[idx]]
634 rows.append(row)
635 row.extend([
636 key.unitstr(),
637 key.label or "",
638 key.lineage or "",
639 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
640 if k not in ["name", "units", "unitrepr",
641 "idx", "shape", "veckey",
642 "value", "vecfn",
643 "lineage", "label"])])
644 return pd.DataFrame(rows, columns=cols)
646 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
647 "Saves solution table as a text file"
648 with open(filename, "w") as f:
649 if printmodel:
650 f.write(self.modelstr + "\n")
651 f.write(self.table(**kwargs))
653 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
654 "Saves primal solution as a CSV sorted by modelname, like the tables."
655 data = self["variables"]
656 if showvars:
657 showvars = self._parse_showvars(showvars)
658 data = {k: data[k] for k in showvars if k in data}
659 # if the columns don't capture any dimensions, skip them
660 minspan, maxspan = None, 1
661 for v in data.values():
662 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
663 minspan_ = min((di for di in v.shape if di != 1))
664 maxspan_ = max((di for di in v.shape if di != 1))
665 if minspan is None or minspan_ < minspan:
666 minspan = minspan_
667 if maxspan is None or maxspan_ > maxspan:
668 maxspan = maxspan_
669 if minspan is not None and minspan > valcols:
670 valcols = 1
671 if maxspan < valcols:
672 valcols = maxspan
673 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
674 tables=("cost", "sweepvariables", "freevariables",
675 "constants", "sensitivities"))
676 with open(filename, "w") as f:
677 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
678 + "Units,Description\n")
679 for line in lines:
680 if line[0] == ("newmodelline",):
681 f.write(line[1])
682 elif not line[1]: # spacer line
683 f.write("\n")
684 else:
685 f.write("," + line[0].replace(" : ", "") + ",")
686 vals = line[1].replace("[", "").replace("]", "").strip()
687 for el in vals.split():
688 f.write(el + ",")
689 f.write(","*(valcols - len(vals.split())))
690 f.write((line[2].replace("[", "").replace("]", "").strip()
691 + ","))
692 f.write(line[3].strip() + "\n")
694 def subinto(self, posy):
695 "Returns NomialArray of each solution substituted into posy."
696 if posy in self["variables"]:
697 return self["variables"](posy)
699 if not hasattr(posy, "sub"):
700 raise ValueError("no variable '%s' found in the solution" % posy)
702 if len(self) > 1:
703 return NomialArray([self.atindex(i).subinto(posy)
704 for i in range(len(self))])
706 return posy.sub(self["variables"])
708 def _parse_showvars(self, showvars):
709 showvars_out = set()
710 for k in showvars:
711 k, _ = self["variables"].parse_and_index(k)
712 keys = self["variables"].keymap[k]
713 showvars_out.update(keys)
714 return showvars_out
716 def summary(self, showvars=(), ntopsenss=5, **kwargs):
717 "Print summary table, showing top sensitivities and no constants"
718 showvars = self._parse_showvars(showvars)
719 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
720 "freevariables"], **kwargs)
721 constants_in_showvars = showvars.intersection(self["constants"])
722 senss_tables = []
723 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
724 senss_tables.append("sensitivities")
725 if len(self["constants"]) >= ntopsenss+2:
726 senss_tables.append("top sensitivities")
727 senss_tables.append("tightest constraints")
728 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
729 **kwargs)
730 if senss_str:
731 out += "\n" + senss_str
732 return out
734 def table(self, showvars=(),
735 tables=("cost", "warnings", "model sensitivities",
736 "sweepvariables", "freevariables",
737 "constants", "sensitivities", "tightest constraints"),
738 sortmodelsbysenss=True, **kwargs):
739 """A table representation of this SolutionArray
741 Arguments
742 ---------
743 tables: Iterable
744 Which to print of ("cost", "sweepvariables", "freevariables",
745 "constants", "sensitivities")
746 fixedcols: If true, print vectors in fixed-width format
747 latex: int
748 If > 0, return latex format (options 1-3); otherwise plain text
749 included_models: Iterable of strings
750 If specified, the models (by name) to include
751 excluded_models: Iterable of strings
752 If specified, model names to exclude
754 Returns
755 -------
756 str
757 """
758 if sortmodelsbysenss and "sensitivities" in self:
759 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
760 else:
761 kwargs["sortmodelsbysenss"] = False
762 varlist = list(self["variables"])
763 has_only_one_model = True
764 for var in varlist[1:]:
765 if var.lineage != varlist[0].lineage:
766 has_only_one_model = False
767 break
768 if has_only_one_model:
769 kwargs["sortbymodel"] = False
770 for key in self.name_collision_varkeys():
771 key.descr["necessarylineage"] = True
772 showvars = self._parse_showvars(showvars)
773 strs = []
774 for table in tables:
775 if "sensitivities" not in self and ("sensitivities" in table or
776 "constraints" in table):
777 continue
778 if table == "cost":
779 cost = self["cost"] # pylint: disable=unsubscriptable-object
780 if kwargs.get("latex", None): # cost is not printed for latex
781 continue
782 strs += ["\n%s\n------------" % "Optimal Cost"]
783 if len(self) > 1:
784 costs = ["%-8.3g" % c for c in mag(cost[:4])]
785 strs += [" [ %s %s ]" % (" ".join(costs),
786 "..." if len(self) > 4 else "")]
787 else:
788 strs += [" %-.4g" % mag(cost)]
789 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
790 strs += [""]
791 elif table in TABLEFNS:
792 strs += TABLEFNS[table](self, showvars, **kwargs)
793 elif table in self:
794 data = self[table]
795 if showvars:
796 showvars = self._parse_showvars(showvars)
797 data = {k: data[k] for k in showvars if k in data}
798 strs += var_table(data, self.table_titles[table], **kwargs)
799 if kwargs.get("latex", None):
800 preamble = "\n".join(("% \\documentclass[12pt]{article}",
801 "% \\usepackage{booktabs}",
802 "% \\usepackage{longtable}",
803 "% \\usepackage{amsmath}",
804 "% \\begin{document}\n"))
805 strs = [preamble] + strs + ["% \\end{document}"]
806 for key in self.name_collision_varkeys():
807 del key.descr["necessarylineage"]
808 return "\n".join(strs)
810 def plot(self, posys=None, axes=None):
811 "Plots a sweep for each posy"
812 if len(self["sweepvariables"]) != 1:
813 print("SolutionArray.plot only supports 1-dimensional sweeps")
814 if not hasattr(posys, "__len__"):
815 posys = [posys]
816 import matplotlib.pyplot as plt
817 from .interactive.plot_sweep import assign_axes
818 from . import GPBLU
819 (swept, x), = self["sweepvariables"].items()
820 posys, axes = assign_axes(swept, posys, axes)
821 for posy, ax in zip(posys, axes):
822 y = self(posy) if posy not in [None, "cost"] else self["cost"]
823 ax.plot(x, y, color=GPBLU)
824 if len(axes) == 1:
825 axes, = axes
826 return plt.gcf(), axes
829# pylint: disable=too-many-branches,too-many-locals,too-many-statements
830def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
831 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
832 minval=0, sortbyvals=False, hidebelowminval=False,
833 included_models=None, excluded_models=None, sortbymodel=True,
834 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
835 """
836 Pretty string representation of a dict of VarKeys
837 Iterable values are handled specially (partial printing)
839 Arguments
840 ---------
841 data : dict whose keys are VarKey's
842 data to represent in table
843 title : string
844 printunits : bool
845 latex : int
846 If > 0, return latex format (options 1-3); otherwise plain text
847 varfmt : string
848 format for variable names
849 valfmt : string
850 format for scalar values
851 vecfmt : string
852 format for vector values
853 minval : float
854 skip values with all(abs(value)) < minval
855 sortbyvals : boolean
856 If true, rows are sorted by their average value instead of by name.
857 included_models : Iterable of strings
858 If specified, the models (by name) to include
859 excluded_models : Iterable of strings
860 If specified, model names to exclude
861 """
862 if not data:
863 return []
864 decorated, models = [], set()
865 for i, (k, v) in enumerate(data.items()):
866 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
867 continue # no values below minval
868 if minval and hidebelowminval and getattr(v, "shape", None):
869 v[np.abs(v) <= minval] = np.nan
870 model = lineagestr(k.lineage) if sortbymodel else ""
871 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0
872 if hasattr(msenss, "shape"):
873 msenss = np.mean(msenss)
874 models.add(model)
875 b = bool(getattr(v, "shape", None))
876 s = k.str_without(("lineage", "vec"))
877 if not sortbyvals:
878 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
879 else: # for consistent sorting, add small offset to negative vals
880 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
881 sort = (float("%.4g" % -val), k.name)
882 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
883 if not decorated and skipifempty:
884 return []
885 if included_models:
886 included_models = set(included_models)
887 included_models.add("")
888 models = models.intersection(included_models)
889 if excluded_models:
890 models = models.difference(excluded_models)
891 decorated.sort()
892 previous_model, lines = None, []
893 for varlist in decorated:
894 if sortbyvals:
895 model, _, msenss, isvector, varstr, _, var, val = varlist
896 else:
897 msenss, model, isvector, varstr, _, var, val = varlist
898 if model not in models:
899 continue
900 if model != previous_model:
901 if lines:
902 lines.append(["", "", "", ""])
903 if model:
904 if not latex:
905 lines.append([("newmodelline",), model, "", ""])
906 else:
907 lines.append(
908 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
909 previous_model = model
910 label = var.descr.get("label", "")
911 units = var.unitstr(" [%s] ") if printunits else ""
912 if not isvector:
913 valstr = valfmt % val
914 else:
915 last_dim_index = len(val.shape)-1
916 horiz_dim, ncols = last_dim_index, 1 # starting values
917 for dim_idx, dim_size in enumerate(val.shape):
918 if ncols <= dim_size <= maxcolumns:
919 horiz_dim, ncols = dim_idx, dim_size
920 # align the array with horiz_dim by making it the last one
921 dim_order = list(range(last_dim_index))
922 dim_order.insert(horiz_dim, last_dim_index)
923 flatval = val.transpose(dim_order).flatten()
924 vals = [vecfmt % v for v in flatval[:ncols]]
925 bracket = " ] " if len(flatval) <= ncols else ""
926 valstr = "[ %s%s" % (" ".join(vals), bracket)
927 for before, after in VALSTR_REPLACES:
928 valstr = valstr.replace(before, after)
929 if not latex:
930 lines.append([varstr, valstr, units, label])
931 if isvector and len(flatval) > ncols:
932 values_remaining = len(flatval) - ncols
933 while values_remaining > 0:
934 idx = len(flatval)-values_remaining
935 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
936 values_remaining -= ncols
937 valstr = " " + " ".join(vals)
938 for before, after in VALSTR_REPLACES:
939 valstr = valstr.replace(before, after)
940 if values_remaining <= 0:
941 spaces = (-values_remaining
942 * len(valstr)//(values_remaining + ncols))
943 valstr = valstr + " ]" + " "*spaces
944 lines.append(["", valstr, "", ""])
945 else:
946 varstr = "$%s$" % varstr.replace(" : ", "")
947 if latex == 1: # normal results table
948 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
949 label])
950 coltitles = [title, "Value", "Units", "Description"]
951 elif latex == 2: # no values
952 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
953 coltitles = [title, "Units", "Description"]
954 elif latex == 3: # no description
955 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
956 coltitles = [title, "Value", "Units"]
957 else:
958 raise ValueError("Unexpected latex option, %s." % latex)
959 if rawlines:
960 return lines
961 if not latex:
962 if lines:
963 maxlens = np.max([list(map(len, line)) for line in lines
964 if line[0] != ("newmodelline",)], axis=0)
965 dirs = [">", "<", "<", "<"]
966 # check lengths before using zip
967 assert len(list(dirs)) == len(list(maxlens))
968 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
969 for i, line in enumerate(lines):
970 if line[0] == ("newmodelline",):
971 line = [fmts[0].format(" | "), line[1]]
972 else:
973 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
974 lines[i] = "".join(line).rstrip()
975 lines = [title] + ["-"*len(title)] + lines + [""]
976 else:
977 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
978 lines = (["\n".join(["{\\footnotesize",
979 "\\begin{longtable}{%s}" % colfmt[latex],
980 "\\toprule",
981 " & ".join(coltitles) + " \\\\ \\midrule"])] +
982 [" & ".join(l) + " \\\\" for l in lines] +
983 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
984 return lines