Coverage for gpkit/solution_array.py : 84%
![Show keyboard shortcuts](keybd_closed.png)
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Defines SolutionArray class"""
2import re
3import json
4import difflib
5from operator import sub
6import warnings as pywarnings
7import pickle
8import gzip
9import pickletools
10import numpy as np
11from .nomials import NomialArray
12from .small_classes import DictOfLists, Strings
13from .small_scripts import mag, try_str_without
14from .repr_conventions import unitstr, lineagestr
17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )")
19VALSTR_REPLACES = [
20 ("+nan", " nan"),
21 ("-nan", " nan"),
22 ("nan%", "nan "),
23 ("nan", " - "),
24]
27class SolSavingEnvironment:
28 """Temporarily removes construction/solve attributes from constraints.
30 This approximately halves the size of the pickled solution.
31 """
33 def __init__(self, solarray, saveconstraints):
34 self.solarray = solarray
35 self.attrstore = {}
36 self.saveconstraints = saveconstraints
37 self.constraintstore = None
40 def __enter__(self):
41 if self.saveconstraints:
42 for constraint_attr in ["bounded", "meq_bounded", "vks",
43 "v_ss", "unsubbed", "varkeys"]:
44 store = {}
45 for constraint in self.solarray["sensitivities"]["constraints"]:
46 if getattr(constraint, constraint_attr, None):
47 store[constraint] = getattr(constraint, constraint_attr)
48 delattr(constraint, constraint_attr)
49 self.attrstore[constraint_attr] = store
50 else:
51 self.constraintstore = \
52 self.solarray["sensitivities"].pop("constraints")
54 def __exit__(self, type_, val, traceback):
55 if self.saveconstraints:
56 for constraint_attr, store in self.attrstore.items():
57 for constraint, value in store.items():
58 setattr(constraint, constraint_attr, value)
59 else:
60 self.solarray["sensitivities"]["constraints"] = self.constraintstore
62def msenss_table(data, _, **kwargs):
63 "Returns model sensitivity table lines"
64 if "models" not in data.get("sensitivities", {}):
65 return ""
66 data = sorted(data["sensitivities"]["models"].items(),
67 key=lambda i: ((i[1] < 0.1).all(),
68 -np.max(i[1]) if (i[1] < 0.1).all()
69 else -round(np.mean(i[1]), 1), i[0]))
70 lines = ["Model Sensitivities", "-------------------"]
71 if kwargs["sortmodelsbysenss"]:
72 lines[0] += " (sorts models in sections below)"
73 previousmsenssstr = ""
74 for model, msenss in data:
75 if not model: # for now let's only do named models
76 continue
77 if (msenss < 0.1).all():
78 msenss = np.max(msenss)
79 if msenss:
80 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss))
81 else:
82 msenssstr = " =0 "
83 else:
84 meansenss = round(np.mean(msenss), 1)
85 msenssstr = "%+6.1f" % meansenss
86 deltas = msenss - meansenss
87 if np.max(np.abs(deltas)) > 0.1:
88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - "
89 for d in deltas]
90 msenssstr += " + [ %s ]" % " ".join(deltastrs)
91 if msenssstr == previousmsenssstr:
92 msenssstr = " "*len(msenssstr)
93 else:
94 previousmsenssstr = msenssstr
95 lines.append("%s : %s" % (msenssstr, model))
96 return lines + [""] if len(lines) > 3 else []
99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs):
100 "Returns sensitivity table lines"
101 if "variables" in data.get("sensitivities", {}):
102 data = data["sensitivities"]["variables"]
103 if showvars:
104 data = {k: data[k] for k in showvars if k in data}
105 return var_table(data, title, sortbyvals=True, skipifempty=True,
106 valfmt="%+-.2g ", vecfmt="%+-8.2g",
107 printunits=False, minval=1e-3, **kwargs)
110def topsenss_table(data, showvars, nvars=5, **kwargs):
111 "Returns top sensitivity table lines"
112 data, filtered = topsenss_filter(data, showvars, nvars)
113 title = "Most Sensitive Variables"
114 if filtered:
115 title = "Next Most Sensitive Variables"
116 return senss_table(data, title=title, hidebelowminval=True, **kwargs)
119def topsenss_filter(data, showvars, nvars=5):
120 "Filters sensitivities down to top N vars"
121 if "variables" in data.get("sensitivities", {}):
122 data = data["sensitivities"]["variables"]
123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items()
124 if not np.isnan(s).any()}
125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])]
126 filter_already_shown = showvars.intersection(topk)
127 for k in filter_already_shown:
128 topk.remove(k)
129 if nvars > 3: # always show at least 3
130 nvars -= 1
131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown
134def insenss_table(data, _, maxval=0.1, **kwargs):
135 "Returns insensitivity table lines"
136 if "constants" in data.get("sensitivities", {}):
137 data = data["sensitivities"]["variables"]
138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval}
139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs)
142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs):
143 "Return constraint tightness lines"
144 title = "Most Sensitive Constraints"
145 if len(self) > 1:
146 title += " (in last sweep)"
147 data = sorted(((-float("%+6.2g" % s[-1]), str(c)),
148 "%+6.2g" % s[-1], id(c), c)
149 for c, s in self["sensitivities"]["constraints"].items()
150 if s[-1] >= tight_senss)[:ntightconstrs]
151 else:
152 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c)
153 for c, s in self["sensitivities"]["constraints"].items()
154 if s >= tight_senss)[:ntightconstrs]
155 return constraint_table(data, title, **kwargs)
157def loose_table(self, _, min_senss=1e-5, **kwargs):
158 "Return constraint tightness lines"
159 title = "Insensitive Constraints |below %+g|" % min_senss
160 if len(self) > 1:
161 title += " (in last sweep)"
162 data = [(0, "", id(c), c)
163 for c, s in self["sensitivities"]["constraints"].items()
164 if s[-1] <= min_senss]
165 else:
166 data = [(0, "", id(c), c)
167 for c, s in self["sensitivities"]["constraints"].items()
168 if s <= min_senss]
169 return constraint_table(data, title, **kwargs)
172# pylint: disable=too-many-branches,too-many-locals,too-many-statements
173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_):
174 "Creates lines for tables where the right side is a constraint."
175 # TODO: this should support 1D array inputs from sweeps
176 excluded = ("units", "unnecessary lineage")
177 if not showmodels:
178 excluded = ("units", "lineage") # hide all of it
179 models, decorated = {}, []
180 for sortby, openingstr, _, constraint in sorted(data):
181 model = lineagestr(constraint) if sortbymodel else ""
182 if model not in models:
183 models[model] = len(models)
184 constrstr = try_str_without(constraint, excluded)
185 if " at 0x" in constrstr: # don't print memory addresses
186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">"
187 decorated.append((models[model], model, sortby, constrstr, openingstr))
188 decorated.sort()
189 previous_model, lines = None, []
190 for varlist in decorated:
191 _, model, _, constrstr, openingstr = varlist
192 if model != previous_model:
193 if lines:
194 lines.append(["", ""])
195 if model or lines:
196 lines.append([("newmodelline",), model])
197 previous_model = model
198 constrstr = constrstr.replace(model, "")
199 minlen, maxlen = 25, 80
200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s]
201 constraintlines = []
202 line = ""
203 next_idx = 0
204 while next_idx < len(segments):
205 segment = segments[next_idx]
206 next_idx += 1
207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments):
208 segments[next_idx] = segment[1:] + segments[next_idx]
209 segment = segment[0]
210 elif len(line) + len(segment) > maxlen and len(line) > minlen:
211 constraintlines.append(line)
212 line = " " # start a new line
213 line += segment
214 while len(line) > maxlen:
215 constraintlines.append(line[:maxlen])
216 line = " " + line[maxlen:]
217 constraintlines.append(line)
218 lines += [(openingstr + " : ", constraintlines[0])]
219 lines += [("", l) for l in constraintlines[1:]]
220 if not lines:
221 lines = [("", "(none)")]
222 maxlens = np.max([list(map(len, line)) for line in lines
223 if line[0] != ("newmodelline",)], axis=0)
224 dirs = [">", "<"] # we'll check lengths before using zip
225 assert len(list(dirs)) == len(list(maxlens))
226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
227 for i, line in enumerate(lines):
228 if line[0] == ("newmodelline",):
229 linelist = [fmts[0].format(" | "), line[1]]
230 else:
231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)]
232 lines[i] = "".join(linelist).rstrip()
233 return [title] + ["-"*len(title)] + lines + [""]
236def warnings_table(self, _, **kwargs):
237 "Makes a table for all warnings in the solution."
238 title = "WARNINGS"
239 lines = ["~"*len(title), title, "~"*len(title)]
240 if "warnings" not in self or not self["warnings"]:
241 return []
242 for wtype in sorted(self["warnings"]):
243 data_vec = self["warnings"][wtype]
244 if len(data_vec) == 0:
245 continue
246 if not hasattr(data_vec, "shape"):
247 data_vec = [data_vec] # not a sweep
248 else:
249 all_equal = True
250 for data in data_vec[1:]:
251 eq_i = (data == data_vec[0])
252 if hasattr(eq_i, "all"):
253 eq_i = eq_i.all()
254 if not eq_i:
255 all_equal = False
256 break
257 if all_equal:
258 data_vec = [data_vec[0]] # warnings identical across sweeps
259 for i, data in enumerate(data_vec):
260 if len(data) == 0:
261 continue
262 data = sorted(data, key=lambda l: l[0]) # sort by msg
263 title = wtype
264 if len(data_vec) > 1:
265 title += " in sweep %i" % i
266 if wtype == "Unexpectedly Tight Constraints" and data[0][1]:
267 data = [(-int(1e5*relax_sensitivity),
268 "%+6.2g" % relax_sensitivity, id(c), c)
269 for _, (relax_sensitivity, c) in data]
270 lines += constraint_table(data, title, **kwargs)
271 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]:
272 data = [(-int(1e5*rel_diff),
273 "%.4g %s %.4g" % tightvalues, id(c), c)
274 for _, (rel_diff, tightvalues, c) in data]
275 lines += constraint_table(data, title, **kwargs)
276 else:
277 lines += [title] + ["-"*len(wtype)]
278 lines += [msg for msg, _ in data] + [""]
279 if len(lines) == 3: # just the header
280 return []
281 lines[-1] = "~~~~~~~~"
282 return lines + [""]
285TABLEFNS = {"sensitivities": senss_table,
286 "top sensitivities": topsenss_table,
287 "insensitivities": insenss_table,
288 "model sensitivities": msenss_table,
289 "tightest constraints": tight_table,
290 "loose constraints": loose_table,
291 "warnings": warnings_table,
292 }
294def unrolled_absmax(values):
295 "From an iterable of numbers and arrays, returns the largest magnitude"
296 finalval, absmaxest = None, 0
297 for val in values:
298 absmaxval = np.abs(val).max()
299 if absmaxval >= absmaxest:
300 absmaxest, finalval = absmaxval, val
301 if getattr(finalval, "shape", None):
302 return finalval[np.unravel_index(np.argmax(np.abs(finalval)),
303 finalval.shape)]
304 return finalval
307def cast(function, val1, val2):
308 "Relative difference between val1 and val2 (positive if val2 is larger)"
309 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros
310 pywarnings.simplefilter("ignore")
311 if hasattr(val1, "shape") and hasattr(val2, "shape"):
312 if val1.ndim == val2.ndim:
313 return function(val1, val2)
314 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim)
315 dimdelta = dimmest.ndim - lessdim.ndim
316 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta
317 if dimmest is val1:
318 return function(dimmest, lessdim[add_axes])
319 if dimmest is val2:
320 return function(lessdim[add_axes], dimmest)
321 return function(val1, val2)
324class SolutionArray(DictOfLists):
325 """A dictionary (of dictionaries) of lists, with convenience methods.
327 Items
328 -----
329 cost : array
330 variables: dict of arrays
331 sensitivities: dict containing:
332 monomials : array
333 posynomials : array
334 variables: dict of arrays
335 localmodels : NomialArray
336 Local power-law fits (small sensitivities are cut off)
338 Example
339 -------
340 >>> import gpkit
341 >>> import numpy as np
342 >>> x = gpkit.Variable("x")
343 >>> x_min = gpkit.Variable("x_{min}", 2)
344 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0)
345 >>>
346 >>> # VALUES
347 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]]
348 >>> assert all(np.array(values) == 2)
349 >>>
350 >>> # SENSITIVITIES
351 >>> senss = [sol.sens(x_min), sol.sens(x_min)]
352 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"])
353 >>> assert all(np.array(senss) == 1)
354 """
355 modelstr = ""
356 _name_collision_varkeys = None
357 table_titles = {"choicevariables": "Choice Variables",
358 "sweepvariables": "Swept Variables",
359 "freevariables": "Free Variables",
360 "constants": "Fixed Variables", # TODO: change everywhere
361 "variables": "Variables"}
363 def name_collision_varkeys(self):
364 "Returns the set of contained varkeys whose names are not unique"
365 if self._name_collision_varkeys is None:
366 self["variables"].update_keymap()
367 keymap = self["variables"].keymap
368 self._name_collision_varkeys = set()
369 for key in list(keymap):
370 if hasattr(key, "key"):
371 if len(keymap[key.str_without(["lineage", "vec"])]) > 1:
372 self._name_collision_varkeys.add(key)
373 return self._name_collision_varkeys
375 def __len__(self):
376 try:
377 return len(self["cost"])
378 except TypeError:
379 return 1
380 except KeyError:
381 return 0
383 def __call__(self, posy):
384 posy_subbed = self.subinto(posy)
385 return getattr(posy_subbed, "c", posy_subbed)
387 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01):
388 "Checks for almost-equality between two solutions"
389 svars, ovars = self["variables"], other["variables"]
390 svks, ovks = set(svars), set(ovars)
391 if svks != ovks:
392 return False
393 for key in svks:
394 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol:
395 return False
396 if abs(self["sensitivities"]["variables"][key]
397 - other["sensitivities"]["variables"][key]) >= sens_abstol:
398 return False
399 return True
401 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
402 def diff(self, other, showvars=None, *,
403 constraintsdiff=True, senssdiff=False, sensstol=0.1,
404 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0,
405 sortmodelsbysenss=True, **tableargs):
406 """Outputs differences between this solution and another
408 Arguments
409 ---------
410 other : solution or string
411 strings will be treated as paths to pickled solutions
412 senssdiff : boolean
413 if True, show sensitivity differences
414 sensstol : float
415 the smallest sensitivity difference worth showing
416 absdiff : boolean
417 if True, show absolute differences
418 abstol : float
419 the smallest absolute difference worth showing
420 reldiff : boolean
421 if True, show relative differences
422 reltol : float
423 the smallest relative difference worth showing
425 Returns
426 -------
427 str
428 """
429 if sortmodelsbysenss:
430 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
431 else:
432 tableargs["sortmodelsbysenss"] = False
433 tableargs.update({"hidebelowminval": True, "sortbyvals": True,
434 "skipifempty": False})
435 if isinstance(other, Strings):
436 if other[-4:] == ".pgz":
437 other = SolutionArray.decompress_file(other)
438 else:
439 other = pickle.load(open(other, "rb"))
440 svars, ovars = self["variables"], other["variables"]
441 lines = ["Solution Diff",
442 "=============",
443 "(argument is the baseline solution)", ""]
444 svks, ovks = set(svars), set(ovars)
445 if showvars:
446 lines[0] += " (for selected variables)"
447 lines[1] += "========================="
448 showvars = self._parse_showvars(showvars)
449 svks = {k for k in showvars if k in svars}
450 ovks = {k for k in showvars if k in ovars}
451 if constraintsdiff and other.modelstr and self.modelstr:
452 if self.modelstr == other.modelstr:
453 lines += ["** no constraint differences **", ""]
454 else:
455 cdiff = ["Constraint Differences",
456 "**********************"]
457 cdiff.extend(list(difflib.unified_diff(
458 other.modelstr.split("\n"), self.modelstr.split("\n"),
459 lineterm="", n=3))[2:])
460 cdiff += ["", "**********************", ""]
461 lines += cdiff
462 if svks - ovks:
463 lines.append("Variable(s) of this solution"
464 " which are not in the argument:")
465 lines.append("\n".join(" %s" % key for key in svks - ovks))
466 lines.append("")
467 if ovks - svks:
468 lines.append("Variable(s) of the argument"
469 " which are not in this solution:")
470 lines.append("\n".join(" %s" % key for key in ovks - svks))
471 lines.append("")
472 sharedvks = svks.intersection(ovks)
473 if reldiff:
474 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1)
475 for vk in sharedvks}
476 lines += var_table(rel_diff,
477 "Relative Differences |above %g%%|" % reltol,
478 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ",
479 minval=reltol, printunits=False, **tableargs)
480 if lines[-2][:10] == "-"*10: # nothing larger than reltol
481 lines.insert(-1, ("The largest is %+g%%."
482 % unrolled_absmax(rel_diff.values())))
483 if absdiff:
484 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks}
485 lines += var_table(abs_diff,
486 "Absolute Differences |above %g|" % abstol,
487 valfmt="%+.2g", vecfmt="%+8.2g",
488 minval=abstol, **tableargs)
489 if lines[-2][:10] == "-"*10: # nothing larger than abstol
490 lines.insert(-1, ("The largest is %+g."
491 % unrolled_absmax(abs_diff.values())))
492 if senssdiff:
493 ssenss = self["sensitivities"]["variables"]
494 osenss = other["sensitivities"]["variables"]
495 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk])
496 for vk in svks.intersection(ovks)}
497 lines += var_table(senss_delta,
498 "Sensitivity Differences |above %g|" % sensstol,
499 valfmt="%+-.2f ", vecfmt="%+-6.2f",
500 minval=sensstol, printunits=False, **tableargs)
501 if lines[-2][:10] == "-"*10: # nothing larger than sensstol
502 lines.insert(-1, ("The largest is %+g."
503 % unrolled_absmax(senss_delta.values())))
504 return "\n".join(lines)
506 def save(self, filename="solution.pkl",
507 *, saveconstraints=True, **pickleargs):
508 """Pickles the solution and saves it to a file.
510 Solution can then be loaded with e.g.:
511 >>> import pickle
512 >>> pickle.load(open("solution.pkl"))
513 """
514 with SolSavingEnvironment(self, saveconstraints):
515 pickle.dump(self, open(filename, "wb"), **pickleargs)
517 def save_compressed(self, filename="solution.pgz",
518 *, saveconstraints=True, **cpickleargs):
519 "Pickle a file and then compress it into a file with extension."
520 with gzip.open(filename, "wb") as f:
521 with SolSavingEnvironment(self, saveconstraints):
522 pickled = pickle.dumps(self, **cpickleargs)
523 f.write(pickletools.optimize(pickled))
525 @staticmethod
526 def decompress_file(file):
527 "Load a gzip-compressed pickle file"
528 with gzip.open(file, "rb") as f:
529 return pickle.Unpickler(f).load()
531 def varnames(self, showvars, exclude):
532 "Returns list of variables, optionally with minimal unique names"
533 if showvars:
534 showvars = self._parse_showvars(showvars)
535 for key in self.name_collision_varkeys():
536 key.descr["necessarylineage"] = True
537 names = {}
538 for key in showvars or self["variables"]:
539 for k in self["variables"].keymap[key]:
540 names[k.str_without(exclude)] = k
541 for key in self.name_collision_varkeys():
542 del key.descr["necessarylineage"]
543 return names
545 def savemat(self, filename="solution.mat", *, showvars=None,
546 excluded=("unnecessary lineage", "vec")):
547 "Saves primal solution as matlab file"
548 from scipy.io import savemat
549 savemat(filename,
550 {name.replace(".", "_"): np.array(self["variables"][key], "f")
551 for name, key in self.varnames(showvars, excluded).items()})
553 def todataframe(self, showvars=None,
554 excluded=("unnecessary lineage", "vec")):
555 "Returns primal solution as pandas dataframe"
556 import pandas as pd # pylint:disable=import-error
557 rows = []
558 cols = ["Name", "Index", "Value", "Units", "Label",
559 "Lineage", "Other"]
560 for _, key in sorted(self.varnames(showvars, excluded).items(),
561 key=lambda k: k[0]):
562 value = self["variables"][key]
563 if key.shape:
564 idxs = []
565 it = np.nditer(np.empty(value.shape), flags=['multi_index'])
566 while not it.finished:
567 idx = it.multi_index
568 idxs.append(idx[0] if len(idx) == 1 else idx)
569 it.iternext()
570 else:
571 idxs = [None]
572 for idx in idxs:
573 row = [
574 key.name,
575 "" if idx is None else idx,
576 value if idx is None else value[idx]]
577 rows.append(row)
578 row.extend([
579 key.unitstr(),
580 key.label or "",
581 key.lineage or "",
582 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items()
583 if k not in ["name", "units", "unitrepr",
584 "idx", "shape", "veckey",
585 "value", "vecfn",
586 "lineage", "label"])])
587 return pd.DataFrame(rows, columns=cols)
589 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs):
590 "Saves solution table as a text file"
591 with open(filename, "w") as f:
592 if printmodel:
593 f.write(self.modelstr + "\n")
594 f.write(self.table(**kwargs))
596 def savejson(self, filename="solution.json", printjson=False, showvars=None):
597 "Saves solution table as a json file"
598 sol_dict = {}
599 for key in self.name_collision_varkeys():
600 key.descr["necessarylineage"] = True
601 data = self["variables"]
602 if showvars:
603 showvars = self._parse_showvars(showvars)
604 data = {k: data[k] for k in showvars if k in data}
605 # add appropriate data for each variable to the dictionary
606 for i, (k, v) in enumerate(data.items()):
607 key = str(k)
608 if isinstance(v, np.ndarray):
609 val = {"v": v.tolist(), "u": k.unitstr()}
610 else:
611 val = {"v": v, "u": k.unitstr()}
612 sol_dict[key] = val
613 for key in self.name_collision_varkeys():
614 del key.descr["necessarylineage"]
615 if printjson:
616 return str(sol_dict)
617 else:
618 with open(filename, "w") as f:
619 json.dump(sol_dict, f)
621 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None):
622 "Saves primal solution as a CSV sorted by modelname, like the tables."
623 data = self["variables"]
624 if showvars:
625 showvars = self._parse_showvars(showvars)
626 data = {k: data[k] for k in showvars if k in data}
627 # if the columns don't capture any dimensions, skip them
628 minspan, maxspan = None, 1
629 for v in data.values():
630 if getattr(v, "shape", None) and any(di != 1 for di in v.shape):
631 minspan_ = min((di for di in v.shape if di != 1))
632 maxspan_ = max((di for di in v.shape if di != 1))
633 if minspan is None or minspan_ < minspan:
634 minspan = minspan_
635 if maxspan is None or maxspan_ > maxspan:
636 maxspan = maxspan_
637 if minspan is not None and minspan > valcols:
638 valcols = 1
639 if maxspan < valcols:
640 valcols = maxspan
641 lines = var_table(data, "", rawlines=True, maxcolumns=valcols,
642 tables=("cost", "sweepvariables", "freevariables",
643 "constants", "sensitivities"))
644 with open(filename, "w") as f:
645 f.write("Model Name,Variable Name,Value(s)" + ","*valcols
646 + "Units,Description\n")
647 for line in lines:
648 if line[0] == ("newmodelline",):
649 f.write(line[1])
650 elif not line[1]: # spacer line
651 f.write("\n")
652 else:
653 f.write("," + line[0].replace(" : ", "") + ",")
654 vals = line[1].replace("[", "").replace("]", "").strip()
655 for el in vals.split():
656 f.write(el + ",")
657 f.write(","*(valcols - len(vals.split())))
658 f.write((line[2].replace("[", "").replace("]", "").strip()
659 + ","))
660 f.write(line[3].strip() + "\n")
662 def subinto(self, posy):
663 "Returns NomialArray of each solution substituted into posy."
664 if posy in self["variables"]:
665 return self["variables"](posy)
667 if not hasattr(posy, "sub"):
668 raise ValueError("no variable '%s' found in the solution" % posy)
670 if len(self) > 1:
671 return NomialArray([self.atindex(i).subinto(posy)
672 for i in range(len(self))])
674 return posy.sub(self["variables"])
676 def _parse_showvars(self, showvars):
677 showvars_out = set()
678 for k in showvars:
679 k, _ = self["variables"].parse_and_index(k)
680 keys = self["variables"].keymap[k]
681 showvars_out.update(keys)
682 return showvars_out
684 def summary(self, showvars=(), ntopsenss=5, **kwargs):
685 "Print summary table, showing top sensitivities and no constants"
686 showvars = self._parse_showvars(showvars)
687 out = self.table(showvars, ["cost", "warnings", "sweepvariables",
688 "freevariables"], **kwargs)
689 constants_in_showvars = showvars.intersection(self["constants"])
690 senss_tables = []
691 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars:
692 senss_tables.append("sensitivities")
693 if len(self["constants"]) >= ntopsenss+2:
694 senss_tables.append("top sensitivities")
695 senss_tables.append("tightest constraints")
696 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss,
697 **kwargs)
698 if senss_str:
699 out += "\n" + senss_str
700 return out
702 def table(self, showvars=(),
703 tables=("cost", "warnings", "model sensitivities",
704 "sweepvariables", "freevariables",
705 "constants", "sensitivities", "tightest constraints"),
706 sortmodelsbysenss=True, **kwargs):
707 """A table representation of this SolutionArray
709 Arguments
710 ---------
711 tables: Iterable
712 Which to print of ("cost", "sweepvariables", "freevariables",
713 "constants", "sensitivities")
714 fixedcols: If true, print vectors in fixed-width format
715 latex: int
716 If > 0, return latex format (options 1-3); otherwise plain text
717 included_models: Iterable of strings
718 If specified, the models (by name) to include
719 excluded_models: Iterable of strings
720 If specified, model names to exclude
722 Returns
723 -------
724 str
725 """
726 if sortmodelsbysenss and "sensitivities" in self:
727 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"]
728 else:
729 kwargs["sortmodelsbysenss"] = False
730 varlist = list(self["variables"])
731 has_only_one_model = True
732 for var in varlist[1:]:
733 if var.lineage != varlist[0].lineage:
734 has_only_one_model = False
735 break
736 if has_only_one_model:
737 kwargs["sortbymodel"] = False
738 for key in self.name_collision_varkeys():
739 key.descr["necessarylineage"] = True
740 showvars = self._parse_showvars(showvars)
741 strs = []
742 for table in tables:
743 if "sensitivities" not in self and ("sensitivities" in table or
744 "constraints" in table):
745 continue
746 if table == "cost":
747 cost = self["cost"] # pylint: disable=unsubscriptable-object
748 if kwargs.get("latex", None): # cost is not printed for latex
749 continue
750 strs += ["\n%s\n------------" % "Optimal Cost"]
751 if len(self) > 1:
752 costs = ["%-8.3g" % c for c in mag(cost[:4])]
753 strs += [" [ %s %s ]" % (" ".join(costs),
754 "..." if len(self) > 4 else "")]
755 else:
756 strs += [" %-.4g" % mag(cost)]
757 strs[-1] += unitstr(cost, into=" [%s]", dimless="")
758 strs += [""]
759 elif table in TABLEFNS:
760 strs += TABLEFNS[table](self, showvars, **kwargs)
761 elif table in self:
762 data = self[table]
763 if showvars:
764 showvars = self._parse_showvars(showvars)
765 data = {k: data[k] for k in showvars if k in data}
766 strs += var_table(data, self.table_titles[table], **kwargs)
767 if kwargs.get("latex", None):
768 preamble = "\n".join(("% \\documentclass[12pt]{article}",
769 "% \\usepackage{booktabs}",
770 "% \\usepackage{longtable}",
771 "% \\usepackage{amsmath}",
772 "% \\begin{document}\n"))
773 strs = [preamble] + strs + ["% \\end{document}"]
774 for key in self.name_collision_varkeys():
775 del key.descr["necessarylineage"]
776 return "\n".join(strs)
778 def plot(self, posys=None, axes=None):
779 "Plots a sweep for each posy"
780 if len(self["sweepvariables"]) != 1:
781 print("SolutionArray.plot only supports 1-dimensional sweeps")
782 if not hasattr(posys, "__len__"):
783 posys = [posys]
784 import matplotlib.pyplot as plt
785 from .interactive.plot_sweep import assign_axes
786 from . import GPBLU
787 (swept, x), = self["sweepvariables"].items()
788 posys, axes = assign_axes(swept, posys, axes)
789 for posy, ax in zip(posys, axes):
790 y = self(posy) if posy not in [None, "cost"] else self["cost"]
791 ax.plot(x, y, color=GPBLU)
792 if len(axes) == 1:
793 axes, = axes
794 return plt.gcf(), axes
797# pylint: disable=too-many-branches,too-many-locals,too-many-statements
798def var_table(data, title, *, printunits=True, latex=False, rawlines=False,
799 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g",
800 minval=0, sortbyvals=False, hidebelowminval=False,
801 included_models=None, excluded_models=None, sortbymodel=True,
802 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_):
803 """
804 Pretty string representation of a dict of VarKeys
805 Iterable values are handled specially (partial printing)
807 Arguments
808 ---------
809 data : dict whose keys are VarKey's
810 data to represent in table
811 title : string
812 printunits : bool
813 latex : int
814 If > 0, return latex format (options 1-3); otherwise plain text
815 varfmt : string
816 format for variable names
817 valfmt : string
818 format for scalar values
819 vecfmt : string
820 format for vector values
821 minval : float
822 skip values with all(abs(value)) < minval
823 sortbyvals : boolean
824 If true, rows are sorted by their average value instead of by name.
825 included_models : Iterable of strings
826 If specified, the models (by name) to include
827 excluded_models : Iterable of strings
828 If specified, model names to exclude
829 """
830 if not data:
831 return []
832 decorated, models = [], set()
833 for i, (k, v) in enumerate(data.items()):
834 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval:
835 continue # no values below minval
836 if minval and hidebelowminval and getattr(v, "shape", None):
837 v[np.abs(v) <= minval] = np.nan
838 model = lineagestr(k.lineage) if sortbymodel else ""
839 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0
840 if hasattr(msenss, "shape"):
841 msenss = np.mean(msenss)
842 models.add(model)
843 b = bool(getattr(v, "shape", None))
844 s = k.str_without(("lineage", "vec"))
845 if not sortbyvals:
846 decorated.append((msenss, model, b, (varfmt % s), i, k, v))
847 else: # for consistent sorting, add small offset to negative vals
848 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0)
849 sort = (float("%.4g" % -val), k.name)
850 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v))
851 if not decorated and skipifempty:
852 return []
853 if included_models:
854 included_models = set(included_models)
855 included_models.add("")
856 models = models.intersection(included_models)
857 if excluded_models:
858 models = models.difference(excluded_models)
859 decorated.sort()
860 previous_model, lines = None, []
861 for varlist in decorated:
862 if sortbyvals:
863 model, _, msenss, isvector, varstr, _, var, val = varlist
864 else:
865 msenss, model, isvector, varstr, _, var, val = varlist
866 if model not in models:
867 continue
868 if model != previous_model:
869 if lines:
870 lines.append(["", "", "", ""])
871 if model:
872 if not latex:
873 lines.append([("newmodelline",), model, "", ""])
874 else:
875 lines.append(
876 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"])
877 previous_model = model
878 label = var.descr.get("label", "")
879 units = var.unitstr(" [%s] ") if printunits else ""
880 if not isvector:
881 valstr = valfmt % val
882 else:
883 last_dim_index = len(val.shape)-1
884 horiz_dim, ncols = last_dim_index, 1 # starting values
885 for dim_idx, dim_size in enumerate(val.shape):
886 if ncols <= dim_size <= maxcolumns:
887 horiz_dim, ncols = dim_idx, dim_size
888 # align the array with horiz_dim by making it the last one
889 dim_order = list(range(last_dim_index))
890 dim_order.insert(horiz_dim, last_dim_index)
891 flatval = val.transpose(dim_order).flatten()
892 vals = [vecfmt % v for v in flatval[:ncols]]
893 bracket = " ] " if len(flatval) <= ncols else ""
894 valstr = "[ %s%s" % (" ".join(vals), bracket)
895 for before, after in VALSTR_REPLACES:
896 valstr = valstr.replace(before, after)
897 if not latex:
898 lines.append([varstr, valstr, units, label])
899 if isvector and len(flatval) > ncols:
900 values_remaining = len(flatval) - ncols
901 while values_remaining > 0:
902 idx = len(flatval)-values_remaining
903 vals = [vecfmt % v for v in flatval[idx:idx+ncols]]
904 values_remaining -= ncols
905 valstr = " " + " ".join(vals)
906 for before, after in VALSTR_REPLACES:
907 valstr = valstr.replace(before, after)
908 if values_remaining <= 0:
909 spaces = (-values_remaining
910 * len(valstr)//(values_remaining + ncols))
911 valstr = valstr + " ]" + " "*spaces
912 lines.append(["", valstr, "", ""])
913 else:
914 varstr = "$%s$" % varstr.replace(" : ", "")
915 if latex == 1: # normal results table
916 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(),
917 label])
918 coltitles = [title, "Value", "Units", "Description"]
919 elif latex == 2: # no values
920 lines.append([varstr, "$%s$" % var.latex_unitstr(), label])
921 coltitles = [title, "Units", "Description"]
922 elif latex == 3: # no description
923 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()])
924 coltitles = [title, "Value", "Units"]
925 else:
926 raise ValueError("Unexpected latex option, %s." % latex)
927 if rawlines:
928 return lines
929 if not latex:
930 if lines:
931 maxlens = np.max([list(map(len, line)) for line in lines
932 if line[0] != ("newmodelline",)], axis=0)
933 dirs = [">", "<", "<", "<"]
934 # check lengths before using zip
935 assert len(list(dirs)) == len(list(maxlens))
936 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)]
937 for i, line in enumerate(lines):
938 if line[0] == ("newmodelline",):
939 line = [fmts[0].format(" | "), line[1]]
940 else:
941 line = [fmt.format(s) for fmt, s in zip(fmts, line)]
942 lines[i] = "".join(line).rstrip()
943 lines = [title] + ["-"*len(title)] + lines + [""]
944 else:
945 colfmt = {1: "llcl", 2: "lcl", 3: "llc"}
946 lines = (["\n".join(["{\\footnotesize",
947 "\\begin{longtable}{%s}" % colfmt[latex],
948 "\\toprule",
949 " & ".join(coltitles) + " \\\\ \\midrule"])] +
950 [" & ".join(l) + " \\\\" for l in lines] +
951 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])])
952 return lines