Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 

36 def __enter__(self): 

37 for constraint_attr in ["mfm", "pmap", "bounded", "meq_bounded", 

38 "v_ss", "unsubbed", "varkeys"]: 

39 store = {} 

40 for constraint in self.solarray["sensitivities"]["constraints"]: 

41 if getattr(constraint, constraint_attr, None): 

42 store[constraint] = getattr(constraint, constraint_attr) 

43 delattr(constraint, constraint_attr) 

44 self.attrstore[constraint_attr] = store 

45 

46 def __exit__(self, type_, val, traceback): 

47 for constraint_attr, store in self.attrstore.items(): 

48 for constraint, value in store.items(): 

49 setattr(constraint, constraint_attr, value) 

50 

51 

52def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

53 "Returns sensitivity table lines" 

54 if "variables" in data.get("sensitivities", {}): 

55 data = data["sensitivities"]["variables"] 

56 if showvars: 

57 data = {k: data[k] for k in showvars if k in data} 

58 return var_table(data, title, sortbyvals=True, skipifempty=True, 

59 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

60 printunits=False, minval=1e-3, **kwargs) 

61 

62 

63def topsenss_table(data, showvars, nvars=5, **kwargs): 

64 "Returns top sensitivity table lines" 

65 data, filtered = topsenss_filter(data, showvars, nvars) 

66 title = "Most Sensitive Variables" 

67 if filtered: 

68 title = "Next Most Sensitive Variables" 

69 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

70 

71 

72def topsenss_filter(data, showvars, nvars=5): 

73 "Filters sensitivities down to top N vars" 

74 if "variables" in data.get("sensitivities", {}): 

75 data = data["sensitivities"]["variables"] 

76 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

77 if not np.isnan(s).any()} 

78 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

79 filter_already_shown = showvars.intersection(topk) 

80 for k in filter_already_shown: 

81 topk.remove(k) 

82 if nvars > 3: # always show at least 3 

83 nvars -= 1 

84 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

85 

86 

87def insenss_table(data, _, maxval=0.1, **kwargs): 

88 "Returns insensitivity table lines" 

89 if "constants" in data.get("sensitivities", {}): 

90 data = data["sensitivities"]["variables"] 

91 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

92 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

93 

94 

95def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

96 "Return constraint tightness lines" 

97 title = "Most Sensitive Constraints" 

98 if len(self) > 1: 

99 title += " (in last sweep)" 

100 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

101 "%+6.2g" % s[-1], id(c), c) 

102 for c, s in self["sensitivities"]["constraints"].items() 

103 if s[-1] >= tight_senss)[:ntightconstrs] 

104 else: 

105 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

106 for c, s in self["sensitivities"]["constraints"].items() 

107 if s >= tight_senss)[:ntightconstrs] 

108 return constraint_table(data, title, **kwargs) 

109 

110def loose_table(self, _, min_senss=1e-5, **kwargs): 

111 "Return constraint tightness lines" 

112 title = "Insensitive Constraints |below %+g|" % min_senss 

113 if len(self) > 1: 

114 title += " (in last sweep)" 

115 data = [(0, "", id(c), c) 

116 for c, s in self["sensitivities"]["constraints"].items() 

117 if s[-1] <= min_senss] 

118 else: 

119 data = [(0, "", id(c), c) 

120 for c, s in self["sensitivities"]["constraints"].items() 

121 if s <= min_senss] 

122 return constraint_table(data, title, **kwargs) 

123 

124 

125# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

126def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

127 "Creates lines for tables where the right side is a constraint." 

128 # TODO: this should support 1D array inputs from sweeps 

129 excluded = ("units", "unnecessary lineage") 

130 if not showmodels: 

131 excluded = ("units", "lineage") # hide all of it 

132 models, decorated = {}, [] 

133 for sortby, openingstr, _, constraint in sorted(data): 

134 model = lineagestr(constraint) if sortbymodel else "" 

135 if model not in models: 

136 models[model] = len(models) 

137 constrstr = try_str_without(constraint, excluded) 

138 if " at 0x" in constrstr: # don't print memory addresses 

139 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

140 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

141 decorated.sort() 

142 previous_model, lines = None, [] 

143 for varlist in decorated: 

144 _, model, _, constrstr, openingstr = varlist 

145 if model != previous_model: 

146 if lines: 

147 lines.append(["", ""]) 

148 if model or lines: 

149 lines.append([("modelname",), model]) 

150 previous_model = model 

151 constrstr = constrstr.replace(model, "") 

152 minlen, maxlen = 25, 80 

153 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

154 constraintlines = [] 

155 line = "" 

156 next_idx = 0 

157 while next_idx < len(segments): 

158 segment = segments[next_idx] 

159 next_idx += 1 

160 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

161 segments[next_idx] = segment[1:] + segments[next_idx] 

162 segment = segment[0] 

163 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

164 constraintlines.append(line) 

165 line = " " # start a new line 

166 line += segment 

167 while len(line) > maxlen: 

168 constraintlines.append(line[:maxlen]) 

169 line = " " + line[maxlen:] 

170 constraintlines.append(line) 

171 lines += [(openingstr + " : ", constraintlines[0])] 

172 lines += [("", l) for l in constraintlines[1:]] 

173 if not lines: 

174 lines = [("", "(none)")] 

175 maxlens = np.max([list(map(len, line)) for line in lines 

176 if line[0] != ("modelname",)], axis=0) 

177 dirs = [">", "<"] # we'll check lengths before using zip 

178 assert len(list(dirs)) == len(list(maxlens)) 

179 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

180 for i, line in enumerate(lines): 

181 if line[0] == ("modelname",): 

182 linelist = [fmts[0].format(" | "), line[1]] 

183 else: 

184 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

185 lines[i] = "".join(linelist).rstrip() 

186 return [title] + ["-"*len(title)] + lines + [""] 

187 

188 

189def warnings_table(self, _, **kwargs): 

190 "Makes a table for all warnings in the solution." 

191 title = "WARNINGS" 

192 lines = ["~"*len(title), title, "~"*len(title)] 

193 if "warnings" not in self or not self["warnings"]: 

194 return [] 

195 for wtype in sorted(self["warnings"]): 

196 data_vec = self["warnings"][wtype] 

197 if not hasattr(data_vec, "shape"): 

198 data_vec = [data_vec] 

199 for i, data in enumerate(data_vec): 

200 data = sorted(data, key=lambda l: l[0]) # sort by msg 

201 title = wtype 

202 if len(data_vec) > 1: 

203 title += " in sweep %i" % i 

204 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

205 data = [(-int(1e5*c.relax_sensitivity), 

206 "%+6.2g" % c.relax_sensitivity, id(c), c) 

207 for _, c in data] 

208 lines += constraint_table(data, title, **kwargs) 

209 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

210 data = [(-int(1e5*c.rel_diff), 

211 "%.4g %s %.4g" % c.tightvalues, id(c), c) 

212 for _, c in data] 

213 lines += constraint_table(data, title, **kwargs) 

214 else: 

215 lines += [title] + ["-"*len(wtype)] 

216 lines += [msg for msg, _ in data] + [""] 

217 lines[-1] = "~~~~~~~~" 

218 return lines + [""] 

219 

220 

221TABLEFNS = {"sensitivities": senss_table, 

222 "top sensitivities": topsenss_table, 

223 "insensitivities": insenss_table, 

224 "tightest constraints": tight_table, 

225 "loose constraints": loose_table, 

226 "warnings": warnings_table, 

227 } 

228 

229def unrolled_absmax(values): 

230 "From an iterable of numbers and arrays, returns the largest magnitude" 

231 finalval, absmaxest = None, 0 

232 for val in values: 

233 absmaxval = np.abs(val).max() 

234 if absmaxval >= absmaxest: 

235 absmaxest, finalval = absmaxval, val 

236 if getattr(finalval, "shape", None): 

237 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

238 finalval.shape)] 

239 return finalval 

240 

241 

242def cast(function, val1, val2): 

243 "Relative difference between val1 and val2 (positive if val2 is larger)" 

244 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

245 pywarnings.simplefilter("ignore") 

246 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

247 if val1.ndim == val2.ndim: 

248 return function(val1, val2) 

249 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

250 dimdelta = dimmest.ndim - lessdim.ndim 

251 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

252 if dimmest is val1: 

253 return function(dimmest, lessdim[add_axes]) 

254 if dimmest is val2: 

255 return function(lessdim[add_axes], dimmest) 

256 return function(val1, val2) 

257 

258 

259class SolutionArray(DictOfLists): 

260 """A dictionary (of dictionaries) of lists, with convenience methods. 

261 

262 Items 

263 ----- 

264 cost : array 

265 variables: dict of arrays 

266 sensitivities: dict containing: 

267 monomials : array 

268 posynomials : array 

269 variables: dict of arrays 

270 localmodels : NomialArray 

271 Local power-law fits (small sensitivities are cut off) 

272 

273 Example 

274 ------- 

275 >>> import gpkit 

276 >>> import numpy as np 

277 >>> x = gpkit.Variable("x") 

278 >>> x_min = gpkit.Variable("x_{min}", 2) 

279 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

280 >>> 

281 >>> # VALUES 

282 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

283 >>> assert all(np.array(values) == 2) 

284 >>> 

285 >>> # SENSITIVITIES 

286 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

287 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

288 >>> assert all(np.array(senss) == 1) 

289 """ 

290 modelstr = "" 

291 _name_collision_varkeys = None 

292 table_titles = {"sweepvariables": "Swept Variables", 

293 "freevariables": "Free Variables", 

294 "constants": "Fixed Variables", # TODO: change everywhere 

295 "variables": "Variables"} 

296 

297 def name_collision_varkeys(self): 

298 "Returns the set of contained varkeys whose names are not unique" 

299 if self._name_collision_varkeys is None: 

300 self["variables"].update_keymap() 

301 keymap = self["variables"].keymap 

302 self._name_collision_varkeys = set() 

303 for key in list(keymap): 

304 if hasattr(key, "key"): 

305 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

306 self._name_collision_varkeys.add(key) 

307 return self._name_collision_varkeys 

308 

309 def __len__(self): 

310 try: 

311 return len(self["cost"]) 

312 except TypeError: 

313 return 1 

314 except KeyError: 

315 return 0 

316 

317 def __call__(self, posy): 

318 posy_subbed = self.subinto(posy) 

319 return getattr(posy_subbed, "c", posy_subbed) 

320 

321 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

322 "Checks for almost-equality between two solutions" 

323 svars, ovars = self["variables"], other["variables"] 

324 svks, ovks = set(svars), set(ovars) 

325 if svks != ovks: 

326 return False 

327 for key in svks: 

328 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

329 return False 

330 if abs(self["sensitivities"]["variables"][key] 

331 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

332 return False 

333 return True 

334 

335 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

336 def diff(self, other, showvars=None, *, 

337 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

338 absdiff=False, abstol=0, reldiff=True, reltol=1.0, **tableargs): 

339 """Outputs differences between this solution and another 

340 

341 Arguments 

342 --------- 

343 other : solution or string 

344 strings will be treated as paths to pickled solutions 

345 senssdiff : boolean 

346 if True, show sensitivity differences 

347 sensstol : float 

348 the smallest sensitivity difference worth showing 

349 abssdiff : boolean 

350 if True, show absolute differences 

351 absstol : float 

352 the smallest absolute difference worth showing 

353 reldiff : boolean 

354 if True, show relative differences 

355 reltol : float 

356 the smallest relative difference worth showing 

357 

358 Returns 

359 ------- 

360 str 

361 """ 

362 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

363 "skipifempty": False}) 

364 if isinstance(other, Strings): 

365 if other[-4:] == ".pgz": 

366 other = SolutionArray.decompress_file(other) 

367 else: 

368 other = pickle.load(open(other, "rb")) 

369 svars, ovars = self["variables"], other["variables"] 

370 lines = ["Solution Diff", 

371 "=============", 

372 "(argument is the baseline solution)", ""] 

373 svks, ovks = set(svars), set(ovars) 

374 if showvars: 

375 lines[0] += " (for selected variables)" 

376 lines[1] += "=========================" 

377 showvars = self._parse_showvars(showvars) 

378 svks = {k for k in showvars if k in svars} 

379 ovks = {k for k in showvars if k in ovars} 

380 if constraintsdiff and other.modelstr and self.modelstr: 

381 if self.modelstr == other.modelstr: 

382 lines += ["** no constraint differences **", ""] 

383 else: 

384 cdiff = ["Constraint Differences", 

385 "**********************"] 

386 cdiff.extend(list(difflib.unified_diff( 

387 other.modelstr.split("\n"), self.modelstr.split("\n"), 

388 lineterm="", n=3))[2:]) 

389 cdiff += ["", "**********************", ""] 

390 lines += cdiff 

391 if svks - ovks: 

392 lines.append("Variable(s) of this solution" 

393 " which are not in the argument:") 

394 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

395 lines.append("") 

396 if ovks - svks: 

397 lines.append("Variable(s) of the argument" 

398 " which are not in this solution:") 

399 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

400 lines.append("") 

401 sharedvks = svks.intersection(ovks) 

402 if reldiff: 

403 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

404 for vk in sharedvks} 

405 lines += var_table(rel_diff, 

406 "Relative Differences |above %g%%|" % reltol, 

407 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

408 minval=reltol, printunits=False, **tableargs) 

409 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

410 lines.insert(-1, ("The largest is %+g%%." 

411 % unrolled_absmax(rel_diff.values()))) 

412 if absdiff: 

413 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

414 lines += var_table(abs_diff, 

415 "Absolute Differences |above %g|" % abstol, 

416 valfmt="%+.2g", vecfmt="%+8.2g", 

417 minval=abstol, **tableargs) 

418 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

419 lines.insert(-1, ("The largest is %+g." 

420 % unrolled_absmax(abs_diff.values()))) 

421 if senssdiff: 

422 ssenss = self["sensitivities"]["variables"] 

423 osenss = other["sensitivities"]["variables"] 

424 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

425 for vk in svks.intersection(ovks)} 

426 lines += var_table(senss_delta, 

427 "Sensitivity Differences |above %g|" % sensstol, 

428 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

429 minval=sensstol, printunits=False, **tableargs) 

430 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

431 lines.insert(-1, ("The largest is %+g." 

432 % unrolled_absmax(senss_delta.values()))) 

433 return "\n".join(lines) 

434 

435 def save(self, filename="solution.pkl", **pickleargs): 

436 """Pickles the solution and saves it to a file. 

437 

438 Solution can then be loaded with e.g.: 

439 >>> import pickle 

440 >>> pickle.load(open("solution.pkl")) 

441 """ 

442 with SolSavingEnvironment(self): 

443 pickle.dump(self, open(filename, "wb"), **pickleargs) 

444 

445 def save_compressed(self, filename="solution.pgz", **cpickleargs): 

446 "Pickle a file and then compress it into a file with extension." 

447 with gzip.open(filename, "wb") as f: 

448 with SolSavingEnvironment(self): 

449 pickled = pickle.dumps(self, **cpickleargs) 

450 f.write(pickletools.optimize(pickled)) 

451 

452 @staticmethod 

453 def decompress_file(file): 

454 "Load a gzip-compressed pickle file" 

455 with gzip.open(file, "rb") as f: 

456 return pickle.Unpickler(f).load() 

457 

458 def varnames(self, showvars, exclude): 

459 "Returns list of variables, optionally with minimal unique names" 

460 if showvars: 

461 showvars = self._parse_showvars(showvars) 

462 for key in self.name_collision_varkeys(): 

463 key.descr["necessarylineage"] = True 

464 names = {} 

465 for key in showvars or self["variables"]: 

466 for k in self["variables"].keymap[key]: 

467 names[k.str_without(exclude)] = k 

468 for key in self.name_collision_varkeys(): 

469 del key.descr["necessarylineage"] 

470 return names 

471 

472 def savemat(self, filename="solution.mat", showvars=None, 

473 excluded=("unnecessary lineage", "vec")): 

474 "Saves primal solution as matlab file" 

475 from scipy.io import savemat 

476 savemat(filename, 

477 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

478 for name, key in self.varnames(showvars, excluded).items()}) 

479 

480 def todataframe(self, showvars=None, 

481 excluded=("unnecessary lineage", "vec")): 

482 "Returns primal solution as pandas dataframe" 

483 import pandas as pd # pylint:disable=import-error 

484 rows = [] 

485 cols = ["Name", "Index", "Value", "Units", "Label", 

486 "Lineage", "Other"] 

487 for _, key in sorted(self.varnames(showvars, excluded).items(), 

488 key=lambda k: k[0]): 

489 value = self["variables"][key] 

490 if key.shape: 

491 idxs = [] 

492 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

493 while not it.finished: 

494 idx = it.multi_index 

495 idxs.append(idx[0] if len(idx) == 1 else idx) 

496 it.iternext() 

497 else: 

498 idxs = [None] 

499 for idx in idxs: 

500 row = [ 

501 key.name, 

502 "" if idx is None else idx, 

503 value if idx is None else value[idx]] 

504 rows.append(row) 

505 row.extend([ 

506 key.unitstr(), 

507 key.label or "", 

508 key.lineage or "", 

509 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

510 if k not in ["name", "units", "unitrepr", 

511 "idx", "shape", "veckey", 

512 "value", "original_fn", 

513 "lineage", "label"])]) 

514 return pd.DataFrame(rows, columns=cols) 

515 

516 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs): 

517 "Saves solution table as a text file" 

518 with open(filename, "w") as f: 

519 if printmodel: 

520 f.write(self.modelstr) 

521 f.write(self.table(**kwargs)) 

522 

523 def savecsv(self, showvars=None, filename="solution.csv", valcols=5): 

524 "Saves primal solution as a CSV sorted by modelname, like the tables." 

525 data = self["variables"] 

526 if showvars: 

527 showvars = self._parse_showvars(showvars) 

528 data = {k: data[k] for k in showvars if k in data} 

529 # if the columns don't capture any dimensions, skip them 

530 minspan, maxspan = None, 1 

531 for v in data.values(): 

532 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

533 minspan_ = min((di for di in v.shape if di != 1)) 

534 maxspan_ = max((di for di in v.shape if di != 1)) 

535 if minspan is None or minspan_ < minspan: 

536 minspan = minspan_ 

537 if maxspan is None or maxspan_ > maxspan: 

538 maxspan = maxspan_ 

539 if minspan is not None and minspan > valcols: 

540 valcols = 1 

541 if maxspan < valcols: 

542 valcols = maxspan 

543 lines = var_table(data, "", rawlines=True, maxcolumns=valcols) 

544 with open(filename, "w") as f: 

545 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

546 + "Units,Description\n") 

547 for line in lines: 

548 if line[0] == ("modelname",): 

549 f.write(line[1]) 

550 elif not line[1]: # spacer line 

551 f.write("\n") 

552 else: 

553 f.write("," + line[0].replace(" : ", "") + ",") 

554 vals = line[1].replace("[", "").replace("]", "").strip() 

555 for el in vals.split(): 

556 f.write(el + ",") 

557 f.write(","*(valcols - len(vals.split()))) 

558 f.write((line[2].replace("[", "").replace("]", "").strip() 

559 + ",")) 

560 f.write(line[3].strip() + "\n") 

561 

562 def subinto(self, posy): 

563 "Returns NomialArray of each solution substituted into posy." 

564 if posy in self["variables"]: 

565 return self["variables"](posy) 

566 

567 if not hasattr(posy, "sub"): 

568 raise ValueError("no variable '%s' found in the solution" % posy) 

569 

570 if len(self) > 1: 

571 return NomialArray([self.atindex(i).subinto(posy) 

572 for i in range(len(self))]) 

573 

574 return posy.sub(self["variables"]) 

575 

576 def _parse_showvars(self, showvars): 

577 showvars_out = set() 

578 for k in showvars: 

579 k, _ = self["variables"].parse_and_index(k) 

580 keys = self["variables"].keymap[k] 

581 showvars_out.update(keys) 

582 return showvars_out 

583 

584 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

585 "Print summary table, showing top sensitivities and no constants" 

586 showvars = self._parse_showvars(showvars) 

587 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

588 "freevariables"], **kwargs) 

589 constants_in_showvars = showvars.intersection(self["constants"]) 

590 senss_tables = [] 

591 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

592 senss_tables.append("sensitivities") 

593 if len(self["constants"]) >= ntopsenss+2: 

594 senss_tables.append("top sensitivities") 

595 senss_tables.append("tightest constraints") 

596 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

597 **kwargs) 

598 if senss_str: 

599 out += "\n" + senss_str 

600 return out 

601 

602 def table(self, showvars=(), 

603 tables=("cost", "warnings", "sweepvariables", "freevariables", 

604 "constants", "sensitivities", "tightest constraints"), 

605 **kwargs): 

606 """A table representation of this SolutionArray 

607 

608 Arguments 

609 --------- 

610 tables: Iterable 

611 Which to print of ("cost", "sweepvariables", "freevariables", 

612 "constants", "sensitivities") 

613 fixedcols: If true, print vectors in fixed-width format 

614 latex: int 

615 If > 0, return latex format (options 1-3); otherwise plain text 

616 included_models: Iterable of strings 

617 If specified, the models (by name) to include 

618 excluded_models: Iterable of strings 

619 If specified, model names to exclude 

620 

621 Returns 

622 ------- 

623 str 

624 """ 

625 varlist = list(self["variables"]) 

626 has_only_one_model = True 

627 for var in varlist[1:]: 

628 if var.lineage != varlist[0].lineage: 

629 has_only_one_model = False 

630 break 

631 if has_only_one_model: 

632 kwargs["sortbymodel"] = False 

633 for key in self.name_collision_varkeys(): 

634 key.descr["necessarylineage"] = True 

635 showvars = self._parse_showvars(showvars) 

636 strs = [] 

637 for table in tables: 

638 if table == "cost": 

639 cost = self["cost"] # pylint: disable=unsubscriptable-object 

640 if kwargs.get("latex", None): # cost is not printed for latex 

641 continue 

642 strs += ["\n%s\n------------" % "Optimal Cost"] 

643 if len(self) > 1: 

644 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

645 strs += [" [ %s %s ]" % (" ".join(costs), 

646 "..." if len(self) > 4 else "")] 

647 else: 

648 strs += [" %-.4g" % mag(cost)] 

649 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

650 strs += [""] 

651 elif table in TABLEFNS: 

652 strs += TABLEFNS[table](self, showvars, **kwargs) 

653 elif table in self: 

654 data = self[table] 

655 if showvars: 

656 showvars = self._parse_showvars(showvars) 

657 data = {k: data[k] for k in showvars if k in data} 

658 strs += var_table(data, self.table_titles[table], **kwargs) 

659 if kwargs.get("latex", None): 

660 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

661 "% \\usepackage{booktabs}", 

662 "% \\usepackage{longtable}", 

663 "% \\usepackage{amsmath}", 

664 "% \\begin{document}\n")) 

665 strs = [preamble] + strs + ["% \\end{document}"] 

666 for key in self.name_collision_varkeys(): 

667 del key.descr["necessarylineage"] 

668 return "\n".join(strs) 

669 

670 def plot(self, posys=None, axes=None): 

671 "Plots a sweep for each posy" 

672 if len(self["sweepvariables"]) != 1: 

673 print("SolutionArray.plot only supports 1-dimensional sweeps") 

674 if not hasattr(posys, "__len__"): 

675 posys = [posys] 

676 import matplotlib.pyplot as plt 

677 from .interactive.plot_sweep import assign_axes 

678 from . import GPBLU 

679 (swept, x), = self["sweepvariables"].items() 

680 posys, axes = assign_axes(swept, posys, axes) 

681 for posy, ax in zip(posys, axes): 

682 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

683 ax.plot(x, y, color=GPBLU) 

684 if len(axes) == 1: 

685 axes, = axes 

686 return plt.gcf(), axes 

687 

688 

689# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

690def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

691 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

692 minval=0, sortbyvals=False, hidebelowminval=False, 

693 included_models=None, excluded_models=None, sortbymodel=True, 

694 maxcolumns=5, skipifempty=True, **_): 

695 """ 

696 Pretty string representation of a dict of VarKeys 

697 Iterable values are handled specially (partial printing) 

698 

699 Arguments 

700 --------- 

701 data : dict whose keys are VarKey's 

702 data to represent in table 

703 title : string 

704 printunits : bool 

705 latex : int 

706 If > 0, return latex format (options 1-3); otherwise plain text 

707 varfmt : string 

708 format for variable names 

709 valfmt : string 

710 format for scalar values 

711 vecfmt : string 

712 format for vector values 

713 minval : float 

714 skip values with all(abs(value)) < minval 

715 sortbyvals : boolean 

716 If true, rows are sorted by their average value instead of by name. 

717 included_models : Iterable of strings 

718 If specified, the models (by name) to include 

719 excluded_models : Iterable of strings 

720 If specified, model names to exclude 

721 """ 

722 if not data: 

723 return [] 

724 decorated, models = [], set() 

725 for i, (k, v) in enumerate(data.items()): 

726 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

727 continue # no values below minval 

728 if minval and hidebelowminval and getattr(v, "shape", None): 

729 v[np.abs(v) <= minval] = np.nan 

730 model = lineagestr(k.lineage) if sortbymodel else "" 

731 models.add(model) 

732 b = bool(getattr(v, "shape", None)) 

733 s = k.str_without(("lineage", "vec")) 

734 if not sortbyvals: 

735 decorated.append((model, b, (varfmt % s), i, k, v)) 

736 else: # for consistent sorting, add small offset to negative vals 

737 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

738 sort = (float("%.4g" % -val), k.name) 

739 decorated.append((model, sort, b, (varfmt % s), i, k, v)) 

740 if not decorated and skipifempty: 

741 return [] 

742 if included_models: 

743 included_models = set(included_models) 

744 included_models.add("") 

745 models = models.intersection(included_models) 

746 if excluded_models: 

747 models = models.difference(excluded_models) 

748 decorated.sort() 

749 previous_model, lines = None, [] 

750 for varlist in decorated: 

751 if not sortbyvals: 

752 model, isvector, varstr, _, var, val = varlist 

753 else: 

754 model, _, isvector, varstr, _, var, val = varlist 

755 if model not in models: 

756 continue 

757 if model != previous_model: 

758 if lines: 

759 lines.append(["", "", "", ""]) 

760 if model: 

761 if not latex: 

762 lines.append([("modelname",), model, "", ""]) 

763 else: 

764 lines.append( 

765 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

766 previous_model = model 

767 label = var.descr.get("label", "") 

768 units = var.unitstr(" [%s] ") if printunits else "" 

769 if not isvector: 

770 valstr = valfmt % val 

771 else: 

772 last_dim_index = len(val.shape)-1 

773 horiz_dim, ncols = last_dim_index, 1 # starting values 

774 for dim_idx, dim_size in enumerate(val.shape): 

775 if ncols <= dim_size <= maxcolumns: 

776 horiz_dim, ncols = dim_idx, dim_size 

777 # align the array with horiz_dim by making it the last one 

778 dim_order = list(range(last_dim_index)) 

779 dim_order.insert(horiz_dim, last_dim_index) 

780 flatval = val.transpose(dim_order).flatten() 

781 vals = [vecfmt % v for v in flatval[:ncols]] 

782 bracket = " ] " if len(flatval) <= ncols else "" 

783 valstr = "[ %s%s" % (" ".join(vals), bracket) 

784 for before, after in VALSTR_REPLACES: 

785 valstr = valstr.replace(before, after) 

786 if not latex: 

787 lines.append([varstr, valstr, units, label]) 

788 if isvector and len(flatval) > ncols: 

789 values_remaining = len(flatval) - ncols 

790 while values_remaining > 0: 

791 idx = len(flatval)-values_remaining 

792 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

793 values_remaining -= ncols 

794 valstr = " " + " ".join(vals) 

795 for before, after in VALSTR_REPLACES: 

796 valstr = valstr.replace(before, after) 

797 if values_remaining <= 0: 

798 spaces = (-values_remaining 

799 * len(valstr)//(values_remaining + ncols)) 

800 valstr = valstr + " ]" + " "*spaces 

801 lines.append(["", valstr, "", ""]) 

802 else: 

803 varstr = "$%s$" % varstr.replace(" : ", "") 

804 if latex == 1: # normal results table 

805 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

806 label]) 

807 coltitles = [title, "Value", "Units", "Description"] 

808 elif latex == 2: # no values 

809 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

810 coltitles = [title, "Units", "Description"] 

811 elif latex == 3: # no description 

812 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

813 coltitles = [title, "Value", "Units"] 

814 else: 

815 raise ValueError("Unexpected latex option, %s." % latex) 

816 if rawlines: 

817 return lines 

818 if not latex: 

819 if lines: 

820 maxlens = np.max([list(map(len, line)) for line in lines 

821 if line[0] != ("modelname",)], axis=0) 

822 dirs = [">", "<", "<", "<"] 

823 # check lengths before using zip 

824 assert len(list(dirs)) == len(list(maxlens)) 

825 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

826 for i, line in enumerate(lines): 

827 if line[0] == ("modelname",): 

828 line = [fmts[0].format(" | "), line[1]] 

829 else: 

830 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

831 lines[i] = "".join(line).rstrip() 

832 lines = [title] + ["-"*len(title)] + lines + [""] 

833 else: 

834 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

835 lines = (["\n".join(["{\\footnotesize", 

836 "\\begin{longtable}{%s}" % colfmt[latex], 

837 "\\toprule", 

838 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

839 [" & ".join(l) + " \\\\" for l in lines] + 

840 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

841 return lines