Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray, saveconstraints): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 self.saveconstraints = saveconstraints 

36 self.constraintstore = None 

37 

38 

39 def __enter__(self): 

40 if self.saveconstraints: 

41 for constraint_attr in ["bounded", "meq_bounded", "vks", 

42 "v_ss", "unsubbed", "varkeys"]: 

43 store = {} 

44 for constraint in self.solarray["sensitivities"]["constraints"]: 

45 if getattr(constraint, constraint_attr, None): 

46 store[constraint] = getattr(constraint, constraint_attr) 

47 delattr(constraint, constraint_attr) 

48 self.attrstore[constraint_attr] = store 

49 else: 

50 self.constraintstore = \ 

51 self.solarray["sensitivities"].pop("constraints") 

52 

53 def __exit__(self, type_, val, traceback): 

54 if self.saveconstraints: 

55 for constraint_attr, store in self.attrstore.items(): 

56 for constraint, value in store.items(): 

57 setattr(constraint, constraint_attr, value) 

58 else: 

59 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

60 

61 

62def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

63 "Returns sensitivity table lines" 

64 if "variables" in data.get("sensitivities", {}): 

65 data = data["sensitivities"]["variables"] 

66 if showvars: 

67 data = {k: data[k] for k in showvars if k in data} 

68 return var_table(data, title, sortbyvals=True, skipifempty=True, 

69 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

70 printunits=False, minval=1e-3, **kwargs) 

71 

72 

73def topsenss_table(data, showvars, nvars=5, **kwargs): 

74 "Returns top sensitivity table lines" 

75 data, filtered = topsenss_filter(data, showvars, nvars) 

76 title = "Most Sensitive Variables" 

77 if filtered: 

78 title = "Next Most Sensitive Variables" 

79 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

80 

81 

82def topsenss_filter(data, showvars, nvars=5): 

83 "Filters sensitivities down to top N vars" 

84 if "variables" in data.get("sensitivities", {}): 

85 data = data["sensitivities"]["variables"] 

86 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

87 if not np.isnan(s).any()} 

88 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

89 filter_already_shown = showvars.intersection(topk) 

90 for k in filter_already_shown: 

91 topk.remove(k) 

92 if nvars > 3: # always show at least 3 

93 nvars -= 1 

94 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

95 

96 

97def insenss_table(data, _, maxval=0.1, **kwargs): 

98 "Returns insensitivity table lines" 

99 if "constants" in data.get("sensitivities", {}): 

100 data = data["sensitivities"]["variables"] 

101 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

102 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

103 

104 

105def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

106 "Return constraint tightness lines" 

107 title = "Most Sensitive Constraints" 

108 if len(self) > 1: 

109 title += " (in last sweep)" 

110 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

111 "%+6.2g" % s[-1], id(c), c) 

112 for c, s in self["sensitivities"]["constraints"].items() 

113 if s[-1] >= tight_senss)[:ntightconstrs] 

114 else: 

115 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

116 for c, s in self["sensitivities"]["constraints"].items() 

117 if s >= tight_senss)[:ntightconstrs] 

118 return constraint_table(data, title, **kwargs) 

119 

120def loose_table(self, _, min_senss=1e-5, **kwargs): 

121 "Return constraint tightness lines" 

122 title = "Insensitive Constraints |below %+g|" % min_senss 

123 if len(self) > 1: 

124 title += " (in last sweep)" 

125 data = [(0, "", id(c), c) 

126 for c, s in self["sensitivities"]["constraints"].items() 

127 if s[-1] <= min_senss] 

128 else: 

129 data = [(0, "", id(c), c) 

130 for c, s in self["sensitivities"]["constraints"].items() 

131 if s <= min_senss] 

132 return constraint_table(data, title, **kwargs) 

133 

134 

135# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

136def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

137 "Creates lines for tables where the right side is a constraint." 

138 # TODO: this should support 1D array inputs from sweeps 

139 excluded = ("units", "unnecessary lineage") 

140 if not showmodels: 

141 excluded = ("units", "lineage") # hide all of it 

142 models, decorated = {}, [] 

143 for sortby, openingstr, _, constraint in sorted(data): 

144 model = lineagestr(constraint) if sortbymodel else "" 

145 if model not in models: 

146 models[model] = len(models) 

147 constrstr = try_str_without(constraint, excluded) 

148 if " at 0x" in constrstr: # don't print memory addresses 

149 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

150 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

151 decorated.sort() 

152 previous_model, lines = None, [] 

153 for varlist in decorated: 

154 _, model, _, constrstr, openingstr = varlist 

155 if model != previous_model: 

156 if lines: 

157 lines.append(["", ""]) 

158 if model or lines: 

159 lines.append([("modelname",), model]) 

160 previous_model = model 

161 constrstr = constrstr.replace(model, "") 

162 minlen, maxlen = 25, 80 

163 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

164 constraintlines = [] 

165 line = "" 

166 next_idx = 0 

167 while next_idx < len(segments): 

168 segment = segments[next_idx] 

169 next_idx += 1 

170 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

171 segments[next_idx] = segment[1:] + segments[next_idx] 

172 segment = segment[0] 

173 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

174 constraintlines.append(line) 

175 line = " " # start a new line 

176 line += segment 

177 while len(line) > maxlen: 

178 constraintlines.append(line[:maxlen]) 

179 line = " " + line[maxlen:] 

180 constraintlines.append(line) 

181 lines += [(openingstr + " : ", constraintlines[0])] 

182 lines += [("", l) for l in constraintlines[1:]] 

183 if not lines: 

184 lines = [("", "(none)")] 

185 maxlens = np.max([list(map(len, line)) for line in lines 

186 if line[0] != ("modelname",)], axis=0) 

187 dirs = [">", "<"] # we'll check lengths before using zip 

188 assert len(list(dirs)) == len(list(maxlens)) 

189 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

190 for i, line in enumerate(lines): 

191 if line[0] == ("modelname",): 

192 linelist = [fmts[0].format(" | "), line[1]] 

193 else: 

194 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

195 lines[i] = "".join(linelist).rstrip() 

196 return [title] + ["-"*len(title)] + lines + [""] 

197 

198 

199def warnings_table(self, _, **kwargs): 

200 "Makes a table for all warnings in the solution." 

201 title = "WARNINGS" 

202 lines = ["~"*len(title), title, "~"*len(title)] 

203 if "warnings" not in self or not self["warnings"]: 

204 return [] 

205 for wtype in sorted(self["warnings"]): 

206 data_vec = self["warnings"][wtype] 

207 if not hasattr(data_vec, "shape"): 

208 data_vec = [data_vec] 

209 for i, data in enumerate(data_vec): 

210 data = sorted(data, key=lambda l: l[0]) # sort by msg 

211 title = wtype 

212 if len(data_vec) > 1: 

213 title += " in sweep %i" % i 

214 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

215 data = [(-int(1e5*c.relax_sensitivity), 

216 "%+6.2g" % c.relax_sensitivity, id(c), c) 

217 for _, c in data] 

218 lines += constraint_table(data, title, **kwargs) 

219 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

220 data = [(-int(1e5*c.rel_diff), 

221 "%.4g %s %.4g" % c.tightvalues, id(c), c) 

222 for _, c in data] 

223 lines += constraint_table(data, title, **kwargs) 

224 else: 

225 lines += [title] + ["-"*len(wtype)] 

226 lines += [msg for msg, _ in data] + [""] 

227 lines[-1] = "~~~~~~~~" 

228 return lines + [""] 

229 

230 

231TABLEFNS = {"sensitivities": senss_table, 

232 "top sensitivities": topsenss_table, 

233 "insensitivities": insenss_table, 

234 "tightest constraints": tight_table, 

235 "loose constraints": loose_table, 

236 "warnings": warnings_table, 

237 } 

238 

239def unrolled_absmax(values): 

240 "From an iterable of numbers and arrays, returns the largest magnitude" 

241 finalval, absmaxest = None, 0 

242 for val in values: 

243 absmaxval = np.abs(val).max() 

244 if absmaxval >= absmaxest: 

245 absmaxest, finalval = absmaxval, val 

246 if getattr(finalval, "shape", None): 

247 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

248 finalval.shape)] 

249 return finalval 

250 

251 

252def cast(function, val1, val2): 

253 "Relative difference between val1 and val2 (positive if val2 is larger)" 

254 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

255 pywarnings.simplefilter("ignore") 

256 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

257 if val1.ndim == val2.ndim: 

258 return function(val1, val2) 

259 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

260 dimdelta = dimmest.ndim - lessdim.ndim 

261 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

262 if dimmest is val1: 

263 return function(dimmest, lessdim[add_axes]) 

264 if dimmest is val2: 

265 return function(lessdim[add_axes], dimmest) 

266 return function(val1, val2) 

267 

268 

269class SolutionArray(DictOfLists): 

270 """A dictionary (of dictionaries) of lists, with convenience methods. 

271 

272 Items 

273 ----- 

274 cost : array 

275 variables: dict of arrays 

276 sensitivities: dict containing: 

277 monomials : array 

278 posynomials : array 

279 variables: dict of arrays 

280 localmodels : NomialArray 

281 Local power-law fits (small sensitivities are cut off) 

282 

283 Example 

284 ------- 

285 >>> import gpkit 

286 >>> import numpy as np 

287 >>> x = gpkit.Variable("x") 

288 >>> x_min = gpkit.Variable("x_{min}", 2) 

289 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

290 >>> 

291 >>> # VALUES 

292 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

293 >>> assert all(np.array(values) == 2) 

294 >>> 

295 >>> # SENSITIVITIES 

296 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

297 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

298 >>> assert all(np.array(senss) == 1) 

299 """ 

300 modelstr = "" 

301 _name_collision_varkeys = None 

302 table_titles = {"sweepvariables": "Swept Variables", 

303 "freevariables": "Free Variables", 

304 "constants": "Fixed Variables", # TODO: change everywhere 

305 "variables": "Variables"} 

306 

307 def name_collision_varkeys(self): 

308 "Returns the set of contained varkeys whose names are not unique" 

309 if self._name_collision_varkeys is None: 

310 self["variables"].update_keymap() 

311 keymap = self["variables"].keymap 

312 self._name_collision_varkeys = set() 

313 for key in list(keymap): 

314 if hasattr(key, "key"): 

315 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

316 self._name_collision_varkeys.add(key) 

317 return self._name_collision_varkeys 

318 

319 def __len__(self): 

320 try: 

321 return len(self["cost"]) 

322 except TypeError: 

323 return 1 

324 except KeyError: 

325 return 0 

326 

327 def __call__(self, posy): 

328 posy_subbed = self.subinto(posy) 

329 return getattr(posy_subbed, "c", posy_subbed) 

330 

331 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

332 "Checks for almost-equality between two solutions" 

333 svars, ovars = self["variables"], other["variables"] 

334 svks, ovks = set(svars), set(ovars) 

335 if svks != ovks: 

336 return False 

337 for key in svks: 

338 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

339 return False 

340 if abs(self["sensitivities"]["variables"][key] 

341 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

342 return False 

343 return True 

344 

345 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

346 def diff(self, other, showvars=None, *, 

347 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

348 absdiff=False, abstol=0, reldiff=True, reltol=1.0, **tableargs): 

349 """Outputs differences between this solution and another 

350 

351 Arguments 

352 --------- 

353 other : solution or string 

354 strings will be treated as paths to pickled solutions 

355 senssdiff : boolean 

356 if True, show sensitivity differences 

357 sensstol : float 

358 the smallest sensitivity difference worth showing 

359 abssdiff : boolean 

360 if True, show absolute differences 

361 absstol : float 

362 the smallest absolute difference worth showing 

363 reldiff : boolean 

364 if True, show relative differences 

365 reltol : float 

366 the smallest relative difference worth showing 

367 

368 Returns 

369 ------- 

370 str 

371 """ 

372 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

373 "skipifempty": False}) 

374 if isinstance(other, Strings): 

375 if other[-4:] == ".pgz": 

376 other = SolutionArray.decompress_file(other) 

377 else: 

378 other = pickle.load(open(other, "rb")) 

379 svars, ovars = self["variables"], other["variables"] 

380 lines = ["Solution Diff", 

381 "=============", 

382 "(argument is the baseline solution)", ""] 

383 svks, ovks = set(svars), set(ovars) 

384 if showvars: 

385 lines[0] += " (for selected variables)" 

386 lines[1] += "=========================" 

387 showvars = self._parse_showvars(showvars) 

388 svks = {k for k in showvars if k in svars} 

389 ovks = {k for k in showvars if k in ovars} 

390 if constraintsdiff and other.modelstr and self.modelstr: 

391 if self.modelstr == other.modelstr: 

392 lines += ["** no constraint differences **", ""] 

393 else: 

394 cdiff = ["Constraint Differences", 

395 "**********************"] 

396 cdiff.extend(list(difflib.unified_diff( 

397 other.modelstr.split("\n"), self.modelstr.split("\n"), 

398 lineterm="", n=3))[2:]) 

399 cdiff += ["", "**********************", ""] 

400 lines += cdiff 

401 if svks - ovks: 

402 lines.append("Variable(s) of this solution" 

403 " which are not in the argument:") 

404 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

405 lines.append("") 

406 if ovks - svks: 

407 lines.append("Variable(s) of the argument" 

408 " which are not in this solution:") 

409 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

410 lines.append("") 

411 sharedvks = svks.intersection(ovks) 

412 if reldiff: 

413 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

414 for vk in sharedvks} 

415 lines += var_table(rel_diff, 

416 "Relative Differences |above %g%%|" % reltol, 

417 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

418 minval=reltol, printunits=False, **tableargs) 

419 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

420 lines.insert(-1, ("The largest is %+g%%." 

421 % unrolled_absmax(rel_diff.values()))) 

422 if absdiff: 

423 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

424 lines += var_table(abs_diff, 

425 "Absolute Differences |above %g|" % abstol, 

426 valfmt="%+.2g", vecfmt="%+8.2g", 

427 minval=abstol, **tableargs) 

428 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

429 lines.insert(-1, ("The largest is %+g." 

430 % unrolled_absmax(abs_diff.values()))) 

431 if senssdiff: 

432 ssenss = self["sensitivities"]["variables"] 

433 osenss = other["sensitivities"]["variables"] 

434 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

435 for vk in svks.intersection(ovks)} 

436 lines += var_table(senss_delta, 

437 "Sensitivity Differences |above %g|" % sensstol, 

438 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

439 minval=sensstol, printunits=False, **tableargs) 

440 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

441 lines.insert(-1, ("The largest is %+g." 

442 % unrolled_absmax(senss_delta.values()))) 

443 return "\n".join(lines) 

444 

445 def save(self, filename="solution.pkl", 

446 *, saveconstraints=True, **pickleargs): 

447 """Pickles the solution and saves it to a file. 

448 

449 Solution can then be loaded with e.g.: 

450 >>> import pickle 

451 >>> pickle.load(open("solution.pkl")) 

452 """ 

453 with SolSavingEnvironment(self, saveconstraints): 

454 pickle.dump(self, open(filename, "wb"), **pickleargs) 

455 

456 def save_compressed(self, filename="solution.pgz", 

457 *, saveconstraints=True, **cpickleargs): 

458 "Pickle a file and then compress it into a file with extension." 

459 with gzip.open(filename, "wb") as f: 

460 with SolSavingEnvironment(self, saveconstraints): 

461 pickled = pickle.dumps(self, **cpickleargs) 

462 f.write(pickletools.optimize(pickled)) 

463 

464 @staticmethod 

465 def decompress_file(file): 

466 "Load a gzip-compressed pickle file" 

467 with gzip.open(file, "rb") as f: 

468 return pickle.Unpickler(f).load() 

469 

470 def varnames(self, showvars, exclude): 

471 "Returns list of variables, optionally with minimal unique names" 

472 if showvars: 

473 showvars = self._parse_showvars(showvars) 

474 for key in self.name_collision_varkeys(): 

475 key.descr["necessarylineage"] = True 

476 names = {} 

477 for key in showvars or self["variables"]: 

478 for k in self["variables"].keymap[key]: 

479 names[k.str_without(exclude)] = k 

480 for key in self.name_collision_varkeys(): 

481 del key.descr["necessarylineage"] 

482 return names 

483 

484 def savemat(self, filename="solution.mat", showvars=None, 

485 excluded=("unnecessary lineage", "vec")): 

486 "Saves primal solution as matlab file" 

487 from scipy.io import savemat 

488 savemat(filename, 

489 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

490 for name, key in self.varnames(showvars, excluded).items()}) 

491 

492 def todataframe(self, showvars=None, 

493 excluded=("unnecessary lineage", "vec")): 

494 "Returns primal solution as pandas dataframe" 

495 import pandas as pd # pylint:disable=import-error 

496 rows = [] 

497 cols = ["Name", "Index", "Value", "Units", "Label", 

498 "Lineage", "Other"] 

499 for _, key in sorted(self.varnames(showvars, excluded).items(), 

500 key=lambda k: k[0]): 

501 value = self["variables"][key] 

502 if key.shape: 

503 idxs = [] 

504 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

505 while not it.finished: 

506 idx = it.multi_index 

507 idxs.append(idx[0] if len(idx) == 1 else idx) 

508 it.iternext() 

509 else: 

510 idxs = [None] 

511 for idx in idxs: 

512 row = [ 

513 key.name, 

514 "" if idx is None else idx, 

515 value if idx is None else value[idx]] 

516 rows.append(row) 

517 row.extend([ 

518 key.unitstr(), 

519 key.label or "", 

520 key.lineage or "", 

521 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

522 if k not in ["name", "units", "unitrepr", 

523 "idx", "shape", "veckey", 

524 "value", "original_fn", 

525 "lineage", "label"])]) 

526 return pd.DataFrame(rows, columns=cols) 

527 

528 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs): 

529 "Saves solution table as a text file" 

530 with open(filename, "w") as f: 

531 if printmodel: 

532 f.write(self.modelstr + "\n") 

533 f.write(self.table(**kwargs)) 

534 

535 def savecsv(self, showvars=None, filename="solution.csv", valcols=5): 

536 "Saves primal solution as a CSV sorted by modelname, like the tables." 

537 data = self["variables"] 

538 if showvars: 

539 showvars = self._parse_showvars(showvars) 

540 data = {k: data[k] for k in showvars if k in data} 

541 # if the columns don't capture any dimensions, skip them 

542 minspan, maxspan = None, 1 

543 for v in data.values(): 

544 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

545 minspan_ = min((di for di in v.shape if di != 1)) 

546 maxspan_ = max((di for di in v.shape if di != 1)) 

547 if minspan is None or minspan_ < minspan: 

548 minspan = minspan_ 

549 if maxspan is None or maxspan_ > maxspan: 

550 maxspan = maxspan_ 

551 if minspan is not None and minspan > valcols: 

552 valcols = 1 

553 if maxspan < valcols: 

554 valcols = maxspan 

555 lines = var_table(data, "", rawlines=True, maxcolumns=valcols) 

556 with open(filename, "w") as f: 

557 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

558 + "Units,Description\n") 

559 for line in lines: 

560 if line[0] == ("modelname",): 

561 f.write(line[1]) 

562 elif not line[1]: # spacer line 

563 f.write("\n") 

564 else: 

565 f.write("," + line[0].replace(" : ", "") + ",") 

566 vals = line[1].replace("[", "").replace("]", "").strip() 

567 for el in vals.split(): 

568 f.write(el + ",") 

569 f.write(","*(valcols - len(vals.split()))) 

570 f.write((line[2].replace("[", "").replace("]", "").strip() 

571 + ",")) 

572 f.write(line[3].strip() + "\n") 

573 

574 def subinto(self, posy): 

575 "Returns NomialArray of each solution substituted into posy." 

576 if posy in self["variables"]: 

577 return self["variables"](posy) 

578 

579 if not hasattr(posy, "sub"): 

580 raise ValueError("no variable '%s' found in the solution" % posy) 

581 

582 if len(self) > 1: 

583 return NomialArray([self.atindex(i).subinto(posy) 

584 for i in range(len(self))]) 

585 

586 return posy.sub(self["variables"]) 

587 

588 def _parse_showvars(self, showvars): 

589 showvars_out = set() 

590 for k in showvars: 

591 k, _ = self["variables"].parse_and_index(k) 

592 keys = self["variables"].keymap[k] 

593 showvars_out.update(keys) 

594 return showvars_out 

595 

596 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

597 "Print summary table, showing top sensitivities and no constants" 

598 showvars = self._parse_showvars(showvars) 

599 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

600 "freevariables"], **kwargs) 

601 constants_in_showvars = showvars.intersection(self["constants"]) 

602 senss_tables = [] 

603 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

604 senss_tables.append("sensitivities") 

605 if len(self["constants"]) >= ntopsenss+2: 

606 senss_tables.append("top sensitivities") 

607 senss_tables.append("tightest constraints") 

608 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

609 **kwargs) 

610 if senss_str: 

611 out += "\n" + senss_str 

612 return out 

613 

614 def table(self, showvars=(), 

615 tables=("cost", "warnings", "sweepvariables", "freevariables", 

616 "constants", "sensitivities", "tightest constraints"), 

617 **kwargs): 

618 """A table representation of this SolutionArray 

619 

620 Arguments 

621 --------- 

622 tables: Iterable 

623 Which to print of ("cost", "sweepvariables", "freevariables", 

624 "constants", "sensitivities") 

625 fixedcols: If true, print vectors in fixed-width format 

626 latex: int 

627 If > 0, return latex format (options 1-3); otherwise plain text 

628 included_models: Iterable of strings 

629 If specified, the models (by name) to include 

630 excluded_models: Iterable of strings 

631 If specified, model names to exclude 

632 

633 Returns 

634 ------- 

635 str 

636 """ 

637 varlist = list(self["variables"]) 

638 has_only_one_model = True 

639 for var in varlist[1:]: 

640 if var.lineage != varlist[0].lineage: 

641 has_only_one_model = False 

642 break 

643 if has_only_one_model: 

644 kwargs["sortbymodel"] = False 

645 for key in self.name_collision_varkeys(): 

646 key.descr["necessarylineage"] = True 

647 showvars = self._parse_showvars(showvars) 

648 strs = [] 

649 for table in tables: 

650 if table == "cost": 

651 cost = self["cost"] # pylint: disable=unsubscriptable-object 

652 if kwargs.get("latex", None): # cost is not printed for latex 

653 continue 

654 strs += ["\n%s\n------------" % "Optimal Cost"] 

655 if len(self) > 1: 

656 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

657 strs += [" [ %s %s ]" % (" ".join(costs), 

658 "..." if len(self) > 4 else "")] 

659 else: 

660 strs += [" %-.4g" % mag(cost)] 

661 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

662 strs += [""] 

663 elif table in TABLEFNS: 

664 strs += TABLEFNS[table](self, showvars, **kwargs) 

665 elif table in self: 

666 data = self[table] 

667 if showvars: 

668 showvars = self._parse_showvars(showvars) 

669 data = {k: data[k] for k in showvars if k in data} 

670 strs += var_table(data, self.table_titles[table], **kwargs) 

671 if kwargs.get("latex", None): 

672 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

673 "% \\usepackage{booktabs}", 

674 "% \\usepackage{longtable}", 

675 "% \\usepackage{amsmath}", 

676 "% \\begin{document}\n")) 

677 strs = [preamble] + strs + ["% \\end{document}"] 

678 for key in self.name_collision_varkeys(): 

679 del key.descr["necessarylineage"] 

680 return "\n".join(strs) 

681 

682 def plot(self, posys=None, axes=None): 

683 "Plots a sweep for each posy" 

684 if len(self["sweepvariables"]) != 1: 

685 print("SolutionArray.plot only supports 1-dimensional sweeps") 

686 if not hasattr(posys, "__len__"): 

687 posys = [posys] 

688 import matplotlib.pyplot as plt 

689 from .interactive.plot_sweep import assign_axes 

690 from . import GPBLU 

691 (swept, x), = self["sweepvariables"].items() 

692 posys, axes = assign_axes(swept, posys, axes) 

693 for posy, ax in zip(posys, axes): 

694 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

695 ax.plot(x, y, color=GPBLU) 

696 if len(axes) == 1: 

697 axes, = axes 

698 return plt.gcf(), axes 

699 

700 

701# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

702def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

703 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

704 minval=0, sortbyvals=False, hidebelowminval=False, 

705 included_models=None, excluded_models=None, sortbymodel=True, 

706 maxcolumns=5, skipifempty=True, **_): 

707 """ 

708 Pretty string representation of a dict of VarKeys 

709 Iterable values are handled specially (partial printing) 

710 

711 Arguments 

712 --------- 

713 data : dict whose keys are VarKey's 

714 data to represent in table 

715 title : string 

716 printunits : bool 

717 latex : int 

718 If > 0, return latex format (options 1-3); otherwise plain text 

719 varfmt : string 

720 format for variable names 

721 valfmt : string 

722 format for scalar values 

723 vecfmt : string 

724 format for vector values 

725 minval : float 

726 skip values with all(abs(value)) < minval 

727 sortbyvals : boolean 

728 If true, rows are sorted by their average value instead of by name. 

729 included_models : Iterable of strings 

730 If specified, the models (by name) to include 

731 excluded_models : Iterable of strings 

732 If specified, model names to exclude 

733 """ 

734 if not data: 

735 return [] 

736 decorated, models = [], set() 

737 for i, (k, v) in enumerate(data.items()): 

738 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

739 continue # no values below minval 

740 if minval and hidebelowminval and getattr(v, "shape", None): 

741 v[np.abs(v) <= minval] = np.nan 

742 model = lineagestr(k.lineage) if sortbymodel else "" 

743 models.add(model) 

744 b = bool(getattr(v, "shape", None)) 

745 s = k.str_without(("lineage", "vec")) 

746 if not sortbyvals: 

747 decorated.append((model, b, (varfmt % s), i, k, v)) 

748 else: # for consistent sorting, add small offset to negative vals 

749 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

750 sort = (float("%.4g" % -val), k.name) 

751 decorated.append((model, sort, b, (varfmt % s), i, k, v)) 

752 if not decorated and skipifempty: 

753 return [] 

754 if included_models: 

755 included_models = set(included_models) 

756 included_models.add("") 

757 models = models.intersection(included_models) 

758 if excluded_models: 

759 models = models.difference(excluded_models) 

760 decorated.sort() 

761 previous_model, lines = None, [] 

762 for varlist in decorated: 

763 if not sortbyvals: 

764 model, isvector, varstr, _, var, val = varlist 

765 else: 

766 model, _, isvector, varstr, _, var, val = varlist 

767 if model not in models: 

768 continue 

769 if model != previous_model: 

770 if lines: 

771 lines.append(["", "", "", ""]) 

772 if model: 

773 if not latex: 

774 lines.append([("modelname",), model, "", ""]) 

775 else: 

776 lines.append( 

777 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

778 previous_model = model 

779 label = var.descr.get("label", "") 

780 units = var.unitstr(" [%s] ") if printunits else "" 

781 if not isvector: 

782 valstr = valfmt % val 

783 else: 

784 last_dim_index = len(val.shape)-1 

785 horiz_dim, ncols = last_dim_index, 1 # starting values 

786 for dim_idx, dim_size in enumerate(val.shape): 

787 if ncols <= dim_size <= maxcolumns: 

788 horiz_dim, ncols = dim_idx, dim_size 

789 # align the array with horiz_dim by making it the last one 

790 dim_order = list(range(last_dim_index)) 

791 dim_order.insert(horiz_dim, last_dim_index) 

792 flatval = val.transpose(dim_order).flatten() 

793 vals = [vecfmt % v for v in flatval[:ncols]] 

794 bracket = " ] " if len(flatval) <= ncols else "" 

795 valstr = "[ %s%s" % (" ".join(vals), bracket) 

796 for before, after in VALSTR_REPLACES: 

797 valstr = valstr.replace(before, after) 

798 if not latex: 

799 lines.append([varstr, valstr, units, label]) 

800 if isvector and len(flatval) > ncols: 

801 values_remaining = len(flatval) - ncols 

802 while values_remaining > 0: 

803 idx = len(flatval)-values_remaining 

804 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

805 values_remaining -= ncols 

806 valstr = " " + " ".join(vals) 

807 for before, after in VALSTR_REPLACES: 

808 valstr = valstr.replace(before, after) 

809 if values_remaining <= 0: 

810 spaces = (-values_remaining 

811 * len(valstr)//(values_remaining + ncols)) 

812 valstr = valstr + " ]" + " "*spaces 

813 lines.append(["", valstr, "", ""]) 

814 else: 

815 varstr = "$%s$" % varstr.replace(" : ", "") 

816 if latex == 1: # normal results table 

817 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

818 label]) 

819 coltitles = [title, "Value", "Units", "Description"] 

820 elif latex == 2: # no values 

821 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

822 coltitles = [title, "Units", "Description"] 

823 elif latex == 3: # no description 

824 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

825 coltitles = [title, "Value", "Units"] 

826 else: 

827 raise ValueError("Unexpected latex option, %s." % latex) 

828 if rawlines: 

829 return lines 

830 if not latex: 

831 if lines: 

832 maxlens = np.max([list(map(len, line)) for line in lines 

833 if line[0] != ("modelname",)], axis=0) 

834 dirs = [">", "<", "<", "<"] 

835 # check lengths before using zip 

836 assert len(list(dirs)) == len(list(maxlens)) 

837 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

838 for i, line in enumerate(lines): 

839 if line[0] == ("modelname",): 

840 line = [fmts[0].format(" | "), line[1]] 

841 else: 

842 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

843 lines[i] = "".join(line).rstrip() 

844 lines = [title] + ["-"*len(title)] + lines + [""] 

845 else: 

846 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

847 lines = (["\n".join(["{\\footnotesize", 

848 "\\begin{longtable}{%s}" % colfmt[latex], 

849 "\\toprule", 

850 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

851 [" & ".join(l) + " \\\\" for l in lines] + 

852 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

853 return lines