Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray, saveconstraints): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 self.saveconstraints = saveconstraints 

36 self.constraintstore = None 

37 

38 

39 def __enter__(self): 

40 if self.saveconstraints: 

41 for constraint_attr in ["bounded", "meq_bounded", "vks", 

42 "v_ss", "unsubbed", "varkeys"]: 

43 store = {} 

44 for constraint in self.solarray["sensitivities"]["constraints"]: 

45 if getattr(constraint, constraint_attr, None): 

46 store[constraint] = getattr(constraint, constraint_attr) 

47 delattr(constraint, constraint_attr) 

48 self.attrstore[constraint_attr] = store 

49 else: 

50 self.constraintstore = \ 

51 self.solarray["sensitivities"].pop("constraints") 

52 

53 def __exit__(self, type_, val, traceback): 

54 if self.saveconstraints: 

55 for constraint_attr, store in self.attrstore.items(): 

56 for constraint, value in store.items(): 

57 setattr(constraint, constraint_attr, value) 

58 else: 

59 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

60 

61def msenss_table(data, _, **kwargs): 

62 "Returns model sensitivity table lines" 

63 if "models" not in data.get("sensitivities", {}): 

64 return "" 

65 data = sorted(data["sensitivities"]["models"].items(), 

66 key=lambda i: -np.mean(i[1])) 

67 lines = ["Model Sensitivities", "-------------------"] 

68 if kwargs["sortmodelsbysenss"]: 

69 lines[0] += " (sorts models in sections below)" 

70 for model, msenss in data: 

71 if not model: # for now let's only do named models 

72 continue 

73 if (msenss < 0.1).all(): 

74 msenss = np.max(msenss) 

75 if msenss: 

76 msenssstr = "%6s" % ("<1E%i" % np.log10(msenss)) 

77 else: 

78 msenssstr = " =0 " 

79 elif not msenss.shape: 

80 msenssstr = "%+6.1f" % msenss 

81 else: 

82 meansenss = np.mean(msenss) 

83 deltas = msenss - meansenss 

84 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

85 for d in deltas] 

86 msenssstr = "%+6.1f + [ %s ]" % (meansenss, " ".join(deltastrs)) 

87 

88 lines.append("%s : %s" % (msenssstr, model)) 

89 return lines + [""] if len(lines) > 3 else [] 

90 

91 

92def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

93 "Returns sensitivity table lines" 

94 if "variables" in data.get("sensitivities", {}): 

95 data = data["sensitivities"]["variables"] 

96 if showvars: 

97 data = {k: data[k] for k in showvars if k in data} 

98 return var_table(data, title, sortbyvals=True, skipifempty=True, 

99 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

100 printunits=False, minval=1e-3, **kwargs) 

101 

102 

103def topsenss_table(data, showvars, nvars=5, **kwargs): 

104 "Returns top sensitivity table lines" 

105 data, filtered = topsenss_filter(data, showvars, nvars) 

106 title = "Most Sensitive Variables" 

107 if filtered: 

108 title = "Next Most Sensitive Variables" 

109 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

110 

111 

112def topsenss_filter(data, showvars, nvars=5): 

113 "Filters sensitivities down to top N vars" 

114 if "variables" in data.get("sensitivities", {}): 

115 data = data["sensitivities"]["variables"] 

116 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

117 if not np.isnan(s).any()} 

118 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

119 filter_already_shown = showvars.intersection(topk) 

120 for k in filter_already_shown: 

121 topk.remove(k) 

122 if nvars > 3: # always show at least 3 

123 nvars -= 1 

124 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

125 

126 

127def insenss_table(data, _, maxval=0.1, **kwargs): 

128 "Returns insensitivity table lines" 

129 if "constants" in data.get("sensitivities", {}): 

130 data = data["sensitivities"]["variables"] 

131 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

132 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

133 

134 

135def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

136 "Return constraint tightness lines" 

137 title = "Most Sensitive Constraints" 

138 if len(self) > 1: 

139 title += " (in last sweep)" 

140 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

141 "%+6.2g" % s[-1], id(c), c) 

142 for c, s in self["sensitivities"]["constraints"].items() 

143 if s[-1] >= tight_senss)[:ntightconstrs] 

144 else: 

145 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

146 for c, s in self["sensitivities"]["constraints"].items() 

147 if s >= tight_senss)[:ntightconstrs] 

148 return constraint_table(data, title, **kwargs) 

149 

150def loose_table(self, _, min_senss=1e-5, **kwargs): 

151 "Return constraint tightness lines" 

152 title = "Insensitive Constraints |below %+g|" % min_senss 

153 if len(self) > 1: 

154 title += " (in last sweep)" 

155 data = [(0, "", id(c), c) 

156 for c, s in self["sensitivities"]["constraints"].items() 

157 if s[-1] <= min_senss] 

158 else: 

159 data = [(0, "", id(c), c) 

160 for c, s in self["sensitivities"]["constraints"].items() 

161 if s <= min_senss] 

162 return constraint_table(data, title, **kwargs) 

163 

164 

165# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

166def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

167 "Creates lines for tables where the right side is a constraint." 

168 # TODO: this should support 1D array inputs from sweeps 

169 excluded = ("units", "unnecessary lineage") 

170 if not showmodels: 

171 excluded = ("units", "lineage") # hide all of it 

172 models, decorated = {}, [] 

173 for sortby, openingstr, _, constraint in sorted(data): 

174 model = lineagestr(constraint) if sortbymodel else "" 

175 if model not in models: 

176 models[model] = len(models) 

177 constrstr = try_str_without(constraint, excluded) 

178 if " at 0x" in constrstr: # don't print memory addresses 

179 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

180 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

181 decorated.sort() 

182 previous_model, lines = None, [] 

183 for varlist in decorated: 

184 _, model, _, constrstr, openingstr = varlist 

185 if model != previous_model: 

186 if lines: 

187 lines.append(["", ""]) 

188 if model or lines: 

189 lines.append([("newmodelline",), model]) 

190 previous_model = model 

191 constrstr = constrstr.replace(model, "") 

192 minlen, maxlen = 25, 80 

193 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

194 constraintlines = [] 

195 line = "" 

196 next_idx = 0 

197 while next_idx < len(segments): 

198 segment = segments[next_idx] 

199 next_idx += 1 

200 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

201 segments[next_idx] = segment[1:] + segments[next_idx] 

202 segment = segment[0] 

203 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

204 constraintlines.append(line) 

205 line = " " # start a new line 

206 line += segment 

207 while len(line) > maxlen: 

208 constraintlines.append(line[:maxlen]) 

209 line = " " + line[maxlen:] 

210 constraintlines.append(line) 

211 lines += [(openingstr + " : ", constraintlines[0])] 

212 lines += [("", l) for l in constraintlines[1:]] 

213 if not lines: 

214 lines = [("", "(none)")] 

215 maxlens = np.max([list(map(len, line)) for line in lines 

216 if line[0] != ("newmodelline",)], axis=0) 

217 dirs = [">", "<"] # we'll check lengths before using zip 

218 assert len(list(dirs)) == len(list(maxlens)) 

219 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

220 for i, line in enumerate(lines): 

221 if line[0] == ("newmodelline",): 

222 linelist = [fmts[0].format(" | "), line[1]] 

223 else: 

224 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

225 lines[i] = "".join(linelist).rstrip() 

226 return [title] + ["-"*len(title)] + lines + [""] 

227 

228 

229def warnings_table(self, _, **kwargs): 

230 "Makes a table for all warnings in the solution." 

231 title = "WARNINGS" 

232 lines = ["~"*len(title), title, "~"*len(title)] 

233 if "warnings" not in self or not self["warnings"]: 

234 return [] 

235 for wtype in sorted(self["warnings"]): 

236 data_vec = self["warnings"][wtype] 

237 if not hasattr(data_vec, "shape"): 

238 data_vec = [data_vec] 

239 for i, data in enumerate(data_vec): 

240 data = sorted(data, key=lambda l: l[0]) # sort by msg 

241 title = wtype 

242 if len(data_vec) > 1: 

243 title += " in sweep %i" % i 

244 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

245 data = [(-int(1e5*c.relax_sensitivity), 

246 "%+6.2g" % c.relax_sensitivity, id(c), c) 

247 for _, c in data] 

248 lines += constraint_table(data, title, **kwargs) 

249 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

250 data = [(-int(1e5*c.rel_diff), 

251 "%.4g %s %.4g" % c.tightvalues, id(c), c) 

252 for _, c in data] 

253 lines += constraint_table(data, title, **kwargs) 

254 else: 

255 lines += [title] + ["-"*len(wtype)] 

256 lines += [msg for msg, _ in data] + [""] 

257 lines[-1] = "~~~~~~~~" 

258 return lines + [""] 

259 

260 

261TABLEFNS = {"sensitivities": senss_table, 

262 "top sensitivities": topsenss_table, 

263 "insensitivities": insenss_table, 

264 "model sensitivities": msenss_table, 

265 "tightest constraints": tight_table, 

266 "loose constraints": loose_table, 

267 "warnings": warnings_table, 

268 } 

269 

270def unrolled_absmax(values): 

271 "From an iterable of numbers and arrays, returns the largest magnitude" 

272 finalval, absmaxest = None, 0 

273 for val in values: 

274 absmaxval = np.abs(val).max() 

275 if absmaxval >= absmaxest: 

276 absmaxest, finalval = absmaxval, val 

277 if getattr(finalval, "shape", None): 

278 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

279 finalval.shape)] 

280 return finalval 

281 

282 

283def cast(function, val1, val2): 

284 "Relative difference between val1 and val2 (positive if val2 is larger)" 

285 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

286 pywarnings.simplefilter("ignore") 

287 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

288 if val1.ndim == val2.ndim: 

289 return function(val1, val2) 

290 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

291 dimdelta = dimmest.ndim - lessdim.ndim 

292 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

293 if dimmest is val1: 

294 return function(dimmest, lessdim[add_axes]) 

295 if dimmest is val2: 

296 return function(lessdim[add_axes], dimmest) 

297 return function(val1, val2) 

298 

299 

300class SolutionArray(DictOfLists): 

301 """A dictionary (of dictionaries) of lists, with convenience methods. 

302 

303 Items 

304 ----- 

305 cost : array 

306 variables: dict of arrays 

307 sensitivities: dict containing: 

308 monomials : array 

309 posynomials : array 

310 variables: dict of arrays 

311 localmodels : NomialArray 

312 Local power-law fits (small sensitivities are cut off) 

313 

314 Example 

315 ------- 

316 >>> import gpkit 

317 >>> import numpy as np 

318 >>> x = gpkit.Variable("x") 

319 >>> x_min = gpkit.Variable("x_{min}", 2) 

320 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

321 >>> 

322 >>> # VALUES 

323 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

324 >>> assert all(np.array(values) == 2) 

325 >>> 

326 >>> # SENSITIVITIES 

327 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

328 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

329 >>> assert all(np.array(senss) == 1) 

330 """ 

331 modelstr = "" 

332 _name_collision_varkeys = None 

333 table_titles = {"sweepvariables": "Swept Variables", 

334 "freevariables": "Free Variables", 

335 "constants": "Fixed Variables", # TODO: change everywhere 

336 "variables": "Variables"} 

337 

338 def name_collision_varkeys(self): 

339 "Returns the set of contained varkeys whose names are not unique" 

340 if self._name_collision_varkeys is None: 

341 self["variables"].update_keymap() 

342 keymap = self["variables"].keymap 

343 self._name_collision_varkeys = set() 

344 for key in list(keymap): 

345 if hasattr(key, "key"): 

346 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

347 self._name_collision_varkeys.add(key) 

348 return self._name_collision_varkeys 

349 

350 def __len__(self): 

351 try: 

352 return len(self["cost"]) 

353 except TypeError: 

354 return 1 

355 except KeyError: 

356 return 0 

357 

358 def __call__(self, posy): 

359 posy_subbed = self.subinto(posy) 

360 return getattr(posy_subbed, "c", posy_subbed) 

361 

362 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

363 "Checks for almost-equality between two solutions" 

364 svars, ovars = self["variables"], other["variables"] 

365 svks, ovks = set(svars), set(ovars) 

366 if svks != ovks: 

367 return False 

368 for key in svks: 

369 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

370 return False 

371 if abs(self["sensitivities"]["variables"][key] 

372 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

373 return False 

374 return True 

375 

376 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

377 def diff(self, other, showvars=None, *, 

378 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

379 absdiff=False, abstol=0, reldiff=True, reltol=1.0, 

380 sortmodelsbysenss=True, **tableargs): 

381 """Outputs differences between this solution and another 

382 

383 Arguments 

384 --------- 

385 other : solution or string 

386 strings will be treated as paths to pickled solutions 

387 senssdiff : boolean 

388 if True, show sensitivity differences 

389 sensstol : float 

390 the smallest sensitivity difference worth showing 

391 abssdiff : boolean 

392 if True, show absolute differences 

393 absstol : float 

394 the smallest absolute difference worth showing 

395 reldiff : boolean 

396 if True, show relative differences 

397 reltol : float 

398 the smallest relative difference worth showing 

399 

400 Returns 

401 ------- 

402 str 

403 """ 

404 if sortmodelsbysenss: 

405 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

406 else: 

407 tableargs["sortmodelsbysenss"] = False 

408 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

409 "skipifempty": False}) 

410 if isinstance(other, Strings): 

411 if other[-4:] == ".pgz": 

412 other = SolutionArray.decompress_file(other) 

413 else: 

414 other = pickle.load(open(other, "rb")) 

415 svars, ovars = self["variables"], other["variables"] 

416 lines = ["Solution Diff", 

417 "=============", 

418 "(argument is the baseline solution)", ""] 

419 svks, ovks = set(svars), set(ovars) 

420 if showvars: 

421 lines[0] += " (for selected variables)" 

422 lines[1] += "=========================" 

423 showvars = self._parse_showvars(showvars) 

424 svks = {k for k in showvars if k in svars} 

425 ovks = {k for k in showvars if k in ovars} 

426 if constraintsdiff and other.modelstr and self.modelstr: 

427 if self.modelstr == other.modelstr: 

428 lines += ["** no constraint differences **", ""] 

429 else: 

430 cdiff = ["Constraint Differences", 

431 "**********************"] 

432 cdiff.extend(list(difflib.unified_diff( 

433 other.modelstr.split("\n"), self.modelstr.split("\n"), 

434 lineterm="", n=3))[2:]) 

435 cdiff += ["", "**********************", ""] 

436 lines += cdiff 

437 if svks - ovks: 

438 lines.append("Variable(s) of this solution" 

439 " which are not in the argument:") 

440 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

441 lines.append("") 

442 if ovks - svks: 

443 lines.append("Variable(s) of the argument" 

444 " which are not in this solution:") 

445 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

446 lines.append("") 

447 sharedvks = svks.intersection(ovks) 

448 if reldiff: 

449 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

450 for vk in sharedvks} 

451 lines += var_table(rel_diff, 

452 "Relative Differences |above %g%%|" % reltol, 

453 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

454 minval=reltol, printunits=False, **tableargs) 

455 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

456 lines.insert(-1, ("The largest is %+g%%." 

457 % unrolled_absmax(rel_diff.values()))) 

458 if absdiff: 

459 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

460 lines += var_table(abs_diff, 

461 "Absolute Differences |above %g|" % abstol, 

462 valfmt="%+.2g", vecfmt="%+8.2g", 

463 minval=abstol, **tableargs) 

464 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

465 lines.insert(-1, ("The largest is %+g." 

466 % unrolled_absmax(abs_diff.values()))) 

467 if senssdiff: 

468 ssenss = self["sensitivities"]["variables"] 

469 osenss = other["sensitivities"]["variables"] 

470 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

471 for vk in svks.intersection(ovks)} 

472 lines += var_table(senss_delta, 

473 "Sensitivity Differences |above %g|" % sensstol, 

474 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

475 minval=sensstol, printunits=False, **tableargs) 

476 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

477 lines.insert(-1, ("The largest is %+g." 

478 % unrolled_absmax(senss_delta.values()))) 

479 return "\n".join(lines) 

480 

481 def save(self, filename="solution.pkl", 

482 *, saveconstraints=True, **pickleargs): 

483 """Pickles the solution and saves it to a file. 

484 

485 Solution can then be loaded with e.g.: 

486 >>> import pickle 

487 >>> pickle.load(open("solution.pkl")) 

488 """ 

489 with SolSavingEnvironment(self, saveconstraints): 

490 pickle.dump(self, open(filename, "wb"), **pickleargs) 

491 

492 def save_compressed(self, filename="solution.pgz", 

493 *, saveconstraints=True, **cpickleargs): 

494 "Pickle a file and then compress it into a file with extension." 

495 with gzip.open(filename, "wb") as f: 

496 with SolSavingEnvironment(self, saveconstraints): 

497 pickled = pickle.dumps(self, **cpickleargs) 

498 f.write(pickletools.optimize(pickled)) 

499 

500 @staticmethod 

501 def decompress_file(file): 

502 "Load a gzip-compressed pickle file" 

503 with gzip.open(file, "rb") as f: 

504 return pickle.Unpickler(f).load() 

505 

506 def varnames(self, showvars, exclude): 

507 "Returns list of variables, optionally with minimal unique names" 

508 if showvars: 

509 showvars = self._parse_showvars(showvars) 

510 for key in self.name_collision_varkeys(): 

511 key.descr["necessarylineage"] = True 

512 names = {} 

513 for key in showvars or self["variables"]: 

514 for k in self["variables"].keymap[key]: 

515 names[k.str_without(exclude)] = k 

516 for key in self.name_collision_varkeys(): 

517 del key.descr["necessarylineage"] 

518 return names 

519 

520 def savemat(self, filename="solution.mat", showvars=None, 

521 excluded=("unnecessary lineage", "vec")): 

522 "Saves primal solution as matlab file" 

523 from scipy.io import savemat 

524 savemat(filename, 

525 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

526 for name, key in self.varnames(showvars, excluded).items()}) 

527 

528 def todataframe(self, showvars=None, 

529 excluded=("unnecessary lineage", "vec")): 

530 "Returns primal solution as pandas dataframe" 

531 import pandas as pd # pylint:disable=import-error 

532 rows = [] 

533 cols = ["Name", "Index", "Value", "Units", "Label", 

534 "Lineage", "Other"] 

535 for _, key in sorted(self.varnames(showvars, excluded).items(), 

536 key=lambda k: k[0]): 

537 value = self["variables"][key] 

538 if key.shape: 

539 idxs = [] 

540 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

541 while not it.finished: 

542 idx = it.multi_index 

543 idxs.append(idx[0] if len(idx) == 1 else idx) 

544 it.iternext() 

545 else: 

546 idxs = [None] 

547 for idx in idxs: 

548 row = [ 

549 key.name, 

550 "" if idx is None else idx, 

551 value if idx is None else value[idx]] 

552 rows.append(row) 

553 row.extend([ 

554 key.unitstr(), 

555 key.label or "", 

556 key.lineage or "", 

557 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

558 if k not in ["name", "units", "unitrepr", 

559 "idx", "shape", "veckey", 

560 "value", "original_fn", 

561 "lineage", "label"])]) 

562 return pd.DataFrame(rows, columns=cols) 

563 

564 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs): 

565 "Saves solution table as a text file" 

566 with open(filename, "w") as f: 

567 if printmodel: 

568 f.write(self.modelstr + "\n") 

569 f.write(self.table(**kwargs)) 

570 

571 def savecsv(self, showvars=None, filename="solution.csv", valcols=5): 

572 "Saves primal solution as a CSV sorted by modelname, like the tables." 

573 data = self["variables"] 

574 if showvars: 

575 showvars = self._parse_showvars(showvars) 

576 data = {k: data[k] for k in showvars if k in data} 

577 # if the columns don't capture any dimensions, skip them 

578 minspan, maxspan = None, 1 

579 for v in data.values(): 

580 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

581 minspan_ = min((di for di in v.shape if di != 1)) 

582 maxspan_ = max((di for di in v.shape if di != 1)) 

583 if minspan is None or minspan_ < minspan: 

584 minspan = minspan_ 

585 if maxspan is None or maxspan_ > maxspan: 

586 maxspan = maxspan_ 

587 if minspan is not None and minspan > valcols: 

588 valcols = 1 

589 if maxspan < valcols: 

590 valcols = maxspan 

591 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

592 tables=("cost", "sweepvariables", "freevariables", 

593 "constants", "sensitivities")) 

594 with open(filename, "w") as f: 

595 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

596 + "Units,Description\n") 

597 for line in lines: 

598 if line[0] == ("newmodelline",): 

599 f.write(line[1]) 

600 elif not line[1]: # spacer line 

601 f.write("\n") 

602 else: 

603 f.write("," + line[0].replace(" : ", "") + ",") 

604 vals = line[1].replace("[", "").replace("]", "").strip() 

605 for el in vals.split(): 

606 f.write(el + ",") 

607 f.write(","*(valcols - len(vals.split()))) 

608 f.write((line[2].replace("[", "").replace("]", "").strip() 

609 + ",")) 

610 f.write(line[3].strip() + "\n") 

611 

612 def subinto(self, posy): 

613 "Returns NomialArray of each solution substituted into posy." 

614 if posy in self["variables"]: 

615 return self["variables"](posy) 

616 

617 if not hasattr(posy, "sub"): 

618 raise ValueError("no variable '%s' found in the solution" % posy) 

619 

620 if len(self) > 1: 

621 return NomialArray([self.atindex(i).subinto(posy) 

622 for i in range(len(self))]) 

623 

624 return posy.sub(self["variables"]) 

625 

626 def _parse_showvars(self, showvars): 

627 showvars_out = set() 

628 for k in showvars: 

629 k, _ = self["variables"].parse_and_index(k) 

630 keys = self["variables"].keymap[k] 

631 showvars_out.update(keys) 

632 return showvars_out 

633 

634 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

635 "Print summary table, showing top sensitivities and no constants" 

636 showvars = self._parse_showvars(showvars) 

637 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

638 "freevariables"], **kwargs) 

639 constants_in_showvars = showvars.intersection(self["constants"]) 

640 senss_tables = [] 

641 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

642 senss_tables.append("sensitivities") 

643 if len(self["constants"]) >= ntopsenss+2: 

644 senss_tables.append("top sensitivities") 

645 senss_tables.append("tightest constraints") 

646 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

647 **kwargs) 

648 if senss_str: 

649 out += "\n" + senss_str 

650 return out 

651 

652 def table(self, showvars=(), 

653 tables=("cost", "warnings", "sweepvariables", 

654 "model sensitivities", "freevariables", 

655 "constants", "sensitivities", "tightest constraints"), 

656 sortmodelsbysenss=True, **kwargs): 

657 """A table representation of this SolutionArray 

658 

659 Arguments 

660 --------- 

661 tables: Iterable 

662 Which to print of ("cost", "sweepvariables", "freevariables", 

663 "constants", "sensitivities") 

664 fixedcols: If true, print vectors in fixed-width format 

665 latex: int 

666 If > 0, return latex format (options 1-3); otherwise plain text 

667 included_models: Iterable of strings 

668 If specified, the models (by name) to include 

669 excluded_models: Iterable of strings 

670 If specified, model names to exclude 

671 

672 Returns 

673 ------- 

674 str 

675 """ 

676 if sortmodelsbysenss: 

677 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

678 else: 

679 kwargs["sortmodelsbysenss"] = False 

680 varlist = list(self["variables"]) 

681 has_only_one_model = True 

682 for var in varlist[1:]: 

683 if var.lineage != varlist[0].lineage: 

684 has_only_one_model = False 

685 break 

686 if has_only_one_model: 

687 kwargs["sortbymodel"] = False 

688 for key in self.name_collision_varkeys(): 

689 key.descr["necessarylineage"] = True 

690 showvars = self._parse_showvars(showvars) 

691 strs = [] 

692 for table in tables: 

693 if table == "cost": 

694 cost = self["cost"] # pylint: disable=unsubscriptable-object 

695 if kwargs.get("latex", None): # cost is not printed for latex 

696 continue 

697 strs += ["\n%s\n------------" % "Optimal Cost"] 

698 if len(self) > 1: 

699 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

700 strs += [" [ %s %s ]" % (" ".join(costs), 

701 "..." if len(self) > 4 else "")] 

702 else: 

703 strs += [" %-.4g" % mag(cost)] 

704 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

705 strs += [""] 

706 elif table in TABLEFNS: 

707 strs += TABLEFNS[table](self, showvars, **kwargs) 

708 elif table in self: 

709 data = self[table] 

710 if showvars: 

711 showvars = self._parse_showvars(showvars) 

712 data = {k: data[k] for k in showvars if k in data} 

713 strs += var_table(data, self.table_titles[table], **kwargs) 

714 if kwargs.get("latex", None): 

715 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

716 "% \\usepackage{booktabs}", 

717 "% \\usepackage{longtable}", 

718 "% \\usepackage{amsmath}", 

719 "% \\begin{document}\n")) 

720 strs = [preamble] + strs + ["% \\end{document}"] 

721 for key in self.name_collision_varkeys(): 

722 del key.descr["necessarylineage"] 

723 return "\n".join(strs) 

724 

725 def plot(self, posys=None, axes=None): 

726 "Plots a sweep for each posy" 

727 if len(self["sweepvariables"]) != 1: 

728 print("SolutionArray.plot only supports 1-dimensional sweeps") 

729 if not hasattr(posys, "__len__"): 

730 posys = [posys] 

731 import matplotlib.pyplot as plt 

732 from .interactive.plot_sweep import assign_axes 

733 from . import GPBLU 

734 (swept, x), = self["sweepvariables"].items() 

735 posys, axes = assign_axes(swept, posys, axes) 

736 for posy, ax in zip(posys, axes): 

737 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

738 ax.plot(x, y, color=GPBLU) 

739 if len(axes) == 1: 

740 axes, = axes 

741 return plt.gcf(), axes 

742 

743 

744# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

745def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

746 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

747 minval=0, sortbyvals=False, hidebelowminval=False, 

748 included_models=None, excluded_models=None, sortbymodel=True, 

749 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

750 """ 

751 Pretty string representation of a dict of VarKeys 

752 Iterable values are handled specially (partial printing) 

753 

754 Arguments 

755 --------- 

756 data : dict whose keys are VarKey's 

757 data to represent in table 

758 title : string 

759 printunits : bool 

760 latex : int 

761 If > 0, return latex format (options 1-3); otherwise plain text 

762 varfmt : string 

763 format for variable names 

764 valfmt : string 

765 format for scalar values 

766 vecfmt : string 

767 format for vector values 

768 minval : float 

769 skip values with all(abs(value)) < minval 

770 sortbyvals : boolean 

771 If true, rows are sorted by their average value instead of by name. 

772 included_models : Iterable of strings 

773 If specified, the models (by name) to include 

774 excluded_models : Iterable of strings 

775 If specified, model names to exclude 

776 """ 

777 if not data: 

778 return [] 

779 decorated, models = [], set() 

780 for i, (k, v) in enumerate(data.items()): 

781 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

782 continue # no values below minval 

783 if minval and hidebelowminval and getattr(v, "shape", None): 

784 v[np.abs(v) <= minval] = np.nan 

785 model = lineagestr(k.lineage) if sortbymodel else "" 

786 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

787 if hasattr(msenss, "shape"): 

788 msenss = np.mean(msenss) 

789 models.add(model) 

790 b = bool(getattr(v, "shape", None)) 

791 s = k.str_without(("lineage", "vec")) 

792 if not sortbyvals: 

793 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

794 else: # for consistent sorting, add small offset to negative vals 

795 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

796 sort = (float("%.4g" % -val), k.name) 

797 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

798 if not decorated and skipifempty: 

799 return [] 

800 if included_models: 

801 included_models = set(included_models) 

802 included_models.add("") 

803 models = models.intersection(included_models) 

804 if excluded_models: 

805 models = models.difference(excluded_models) 

806 decorated.sort() 

807 previous_model, lines = None, [] 

808 for varlist in decorated: 

809 if sortbyvals: 

810 model, _, msenss, isvector, varstr, _, var, val = varlist 

811 else: 

812 msenss, model, isvector, varstr, _, var, val = varlist 

813 if model not in models: 

814 continue 

815 if model != previous_model: 

816 if lines: 

817 lines.append(["", "", "", ""]) 

818 if model: 

819 if not latex: 

820 lines.append([("newmodelline",), model, "", ""]) 

821 else: 

822 lines.append( 

823 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

824 previous_model = model 

825 label = var.descr.get("label", "") 

826 units = var.unitstr(" [%s] ") if printunits else "" 

827 if not isvector: 

828 valstr = valfmt % val 

829 else: 

830 last_dim_index = len(val.shape)-1 

831 horiz_dim, ncols = last_dim_index, 1 # starting values 

832 for dim_idx, dim_size in enumerate(val.shape): 

833 if ncols <= dim_size <= maxcolumns: 

834 horiz_dim, ncols = dim_idx, dim_size 

835 # align the array with horiz_dim by making it the last one 

836 dim_order = list(range(last_dim_index)) 

837 dim_order.insert(horiz_dim, last_dim_index) 

838 flatval = val.transpose(dim_order).flatten() 

839 vals = [vecfmt % v for v in flatval[:ncols]] 

840 bracket = " ] " if len(flatval) <= ncols else "" 

841 valstr = "[ %s%s" % (" ".join(vals), bracket) 

842 for before, after in VALSTR_REPLACES: 

843 valstr = valstr.replace(before, after) 

844 if not latex: 

845 lines.append([varstr, valstr, units, label]) 

846 if isvector and len(flatval) > ncols: 

847 values_remaining = len(flatval) - ncols 

848 while values_remaining > 0: 

849 idx = len(flatval)-values_remaining 

850 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

851 values_remaining -= ncols 

852 valstr = " " + " ".join(vals) 

853 for before, after in VALSTR_REPLACES: 

854 valstr = valstr.replace(before, after) 

855 if values_remaining <= 0: 

856 spaces = (-values_remaining 

857 * len(valstr)//(values_remaining + ncols)) 

858 valstr = valstr + " ]" + " "*spaces 

859 lines.append(["", valstr, "", ""]) 

860 else: 

861 varstr = "$%s$" % varstr.replace(" : ", "") 

862 if latex == 1: # normal results table 

863 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

864 label]) 

865 coltitles = [title, "Value", "Units", "Description"] 

866 elif latex == 2: # no values 

867 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

868 coltitles = [title, "Units", "Description"] 

869 elif latex == 3: # no description 

870 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

871 coltitles = [title, "Value", "Units"] 

872 else: 

873 raise ValueError("Unexpected latex option, %s." % latex) 

874 if rawlines: 

875 return lines 

876 if not latex: 

877 if lines: 

878 maxlens = np.max([list(map(len, line)) for line in lines 

879 if line[0] != ("newmodelline",)], axis=0) 

880 dirs = [">", "<", "<", "<"] 

881 # check lengths before using zip 

882 assert len(list(dirs)) == len(list(maxlens)) 

883 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

884 for i, line in enumerate(lines): 

885 if line[0] == ("newmodelline",): 

886 line = [fmts[0].format(" | "), line[1]] 

887 else: 

888 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

889 lines[i] = "".join(line).rstrip() 

890 lines = [title] + ["-"*len(title)] + lines + [""] 

891 else: 

892 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

893 lines = (["\n".join(["{\\footnotesize", 

894 "\\begin{longtable}{%s}" % colfmt[latex], 

895 "\\toprule", 

896 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

897 [" & ".join(l) + " \\\\" for l in lines] + 

898 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

899 return lines