Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray, saveconstraints): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 self.saveconstraints = saveconstraints 

36 self.constraintstore = None 

37 

38 

39 def __enter__(self): 

40 if self.saveconstraints: 

41 for constraint_attr in ["bounded", "meq_bounded", "vks", 

42 "v_ss", "unsubbed", "varkeys"]: 

43 store = {} 

44 for constraint in self.solarray["sensitivities"]["constraints"]: 

45 if getattr(constraint, constraint_attr, None): 

46 store[constraint] = getattr(constraint, constraint_attr) 

47 delattr(constraint, constraint_attr) 

48 self.attrstore[constraint_attr] = store 

49 else: 

50 self.constraintstore = \ 

51 self.solarray["sensitivities"].pop("constraints") 

52 

53 def __exit__(self, type_, val, traceback): 

54 if self.saveconstraints: 

55 for constraint_attr, store in self.attrstore.items(): 

56 for constraint, value in store.items(): 

57 setattr(constraint, constraint_attr, value) 

58 else: 

59 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

60 

61def msenss_table(data, _, **kwargs): 

62 "Returns model sensitivity table lines" 

63 if "models" not in data.get("sensitivities", {}): 

64 return "" 

65 data = sorted(data["sensitivities"]["models"].items(), 

66 key=lambda i: -np.mean(i[1])) 

67 lines = ["Model Sensitivities", "-------------------"] 

68 if kwargs["sortmodelsbysenss"]: 

69 lines[0] += " (sorts models in sections below)" 

70 for model, msenss in data: 

71 if not model: # for now let's only do named models 

72 continue 

73 if not msenss.shape: 

74 msenssstr = "%+5.2f" % msenss 

75 else: 

76 meansenss = np.mean(msenss) 

77 deltas = msenss - meansenss 

78 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

79 for d in deltas] 

80 msenssstr = "%+5.2f + [ %s ]" % (meansenss, " ".join(deltastrs)) 

81 

82 lines.append(" %s : %s" % (msenssstr, model)) 

83 return lines + [""] if len(lines) > 3 else [] 

84 

85 

86def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

87 "Returns sensitivity table lines" 

88 if "variables" in data.get("sensitivities", {}): 

89 data = data["sensitivities"]["variables"] 

90 if showvars: 

91 data = {k: data[k] for k in showvars if k in data} 

92 return var_table(data, title, sortbyvals=True, skipifempty=True, 

93 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

94 printunits=False, minval=1e-3, **kwargs) 

95 

96 

97def topsenss_table(data, showvars, nvars=5, **kwargs): 

98 "Returns top sensitivity table lines" 

99 data, filtered = topsenss_filter(data, showvars, nvars) 

100 title = "Most Sensitive Variables" 

101 if filtered: 

102 title = "Next Most Sensitive Variables" 

103 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

104 

105 

106def topsenss_filter(data, showvars, nvars=5): 

107 "Filters sensitivities down to top N vars" 

108 if "variables" in data.get("sensitivities", {}): 

109 data = data["sensitivities"]["variables"] 

110 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

111 if not np.isnan(s).any()} 

112 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

113 filter_already_shown = showvars.intersection(topk) 

114 for k in filter_already_shown: 

115 topk.remove(k) 

116 if nvars > 3: # always show at least 3 

117 nvars -= 1 

118 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

119 

120 

121def insenss_table(data, _, maxval=0.1, **kwargs): 

122 "Returns insensitivity table lines" 

123 if "constants" in data.get("sensitivities", {}): 

124 data = data["sensitivities"]["variables"] 

125 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

126 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

127 

128 

129def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

130 "Return constraint tightness lines" 

131 title = "Most Sensitive Constraints" 

132 if len(self) > 1: 

133 title += " (in last sweep)" 

134 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

135 "%+6.2g" % s[-1], id(c), c) 

136 for c, s in self["sensitivities"]["constraints"].items() 

137 if s[-1] >= tight_senss)[:ntightconstrs] 

138 else: 

139 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

140 for c, s in self["sensitivities"]["constraints"].items() 

141 if s >= tight_senss)[:ntightconstrs] 

142 return constraint_table(data, title, **kwargs) 

143 

144def loose_table(self, _, min_senss=1e-5, **kwargs): 

145 "Return constraint tightness lines" 

146 title = "Insensitive Constraints |below %+g|" % min_senss 

147 if len(self) > 1: 

148 title += " (in last sweep)" 

149 data = [(0, "", id(c), c) 

150 for c, s in self["sensitivities"]["constraints"].items() 

151 if s[-1] <= min_senss] 

152 else: 

153 data = [(0, "", id(c), c) 

154 for c, s in self["sensitivities"]["constraints"].items() 

155 if s <= min_senss] 

156 return constraint_table(data, title, **kwargs) 

157 

158 

159# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

160def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

161 "Creates lines for tables where the right side is a constraint." 

162 # TODO: this should support 1D array inputs from sweeps 

163 excluded = ("units", "unnecessary lineage") 

164 if not showmodels: 

165 excluded = ("units", "lineage") # hide all of it 

166 models, decorated = {}, [] 

167 for sortby, openingstr, _, constraint in sorted(data): 

168 model = lineagestr(constraint) if sortbymodel else "" 

169 if model not in models: 

170 models[model] = len(models) 

171 constrstr = try_str_without(constraint, excluded) 

172 if " at 0x" in constrstr: # don't print memory addresses 

173 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

174 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

175 decorated.sort() 

176 previous_model, lines = None, [] 

177 for varlist in decorated: 

178 _, model, _, constrstr, openingstr = varlist 

179 if model != previous_model: 

180 if lines: 

181 lines.append(["", ""]) 

182 if model or lines: 

183 lines.append([("newmodelline",), model]) 

184 previous_model = model 

185 constrstr = constrstr.replace(model, "") 

186 minlen, maxlen = 25, 80 

187 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

188 constraintlines = [] 

189 line = "" 

190 next_idx = 0 

191 while next_idx < len(segments): 

192 segment = segments[next_idx] 

193 next_idx += 1 

194 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

195 segments[next_idx] = segment[1:] + segments[next_idx] 

196 segment = segment[0] 

197 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

198 constraintlines.append(line) 

199 line = " " # start a new line 

200 line += segment 

201 while len(line) > maxlen: 

202 constraintlines.append(line[:maxlen]) 

203 line = " " + line[maxlen:] 

204 constraintlines.append(line) 

205 lines += [(openingstr + " : ", constraintlines[0])] 

206 lines += [("", l) for l in constraintlines[1:]] 

207 if not lines: 

208 lines = [("", "(none)")] 

209 maxlens = np.max([list(map(len, line)) for line in lines 

210 if line[0] != ("newmodelline",)], axis=0) 

211 dirs = [">", "<"] # we'll check lengths before using zip 

212 assert len(list(dirs)) == len(list(maxlens)) 

213 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

214 for i, line in enumerate(lines): 

215 if line[0] == ("newmodelline",): 

216 linelist = [fmts[0].format(" | "), line[1]] 

217 else: 

218 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

219 lines[i] = "".join(linelist).rstrip() 

220 return [title] + ["-"*len(title)] + lines + [""] 

221 

222 

223def warnings_table(self, _, **kwargs): 

224 "Makes a table for all warnings in the solution." 

225 title = "WARNINGS" 

226 lines = ["~"*len(title), title, "~"*len(title)] 

227 if "warnings" not in self or not self["warnings"]: 

228 return [] 

229 for wtype in sorted(self["warnings"]): 

230 data_vec = self["warnings"][wtype] 

231 if not hasattr(data_vec, "shape"): 

232 data_vec = [data_vec] 

233 for i, data in enumerate(data_vec): 

234 data = sorted(data, key=lambda l: l[0]) # sort by msg 

235 title = wtype 

236 if len(data_vec) > 1: 

237 title += " in sweep %i" % i 

238 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

239 data = [(-int(1e5*c.relax_sensitivity), 

240 "%+6.2g" % c.relax_sensitivity, id(c), c) 

241 for _, c in data] 

242 lines += constraint_table(data, title, **kwargs) 

243 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

244 data = [(-int(1e5*c.rel_diff), 

245 "%.4g %s %.4g" % c.tightvalues, id(c), c) 

246 for _, c in data] 

247 lines += constraint_table(data, title, **kwargs) 

248 else: 

249 lines += [title] + ["-"*len(wtype)] 

250 lines += [msg for msg, _ in data] + [""] 

251 lines[-1] = "~~~~~~~~" 

252 return lines + [""] 

253 

254 

255TABLEFNS = {"sensitivities": senss_table, 

256 "top sensitivities": topsenss_table, 

257 "insensitivities": insenss_table, 

258 "model sensitivities": msenss_table, 

259 "tightest constraints": tight_table, 

260 "loose constraints": loose_table, 

261 "warnings": warnings_table, 

262 } 

263 

264def unrolled_absmax(values): 

265 "From an iterable of numbers and arrays, returns the largest magnitude" 

266 finalval, absmaxest = None, 0 

267 for val in values: 

268 absmaxval = np.abs(val).max() 

269 if absmaxval >= absmaxest: 

270 absmaxest, finalval = absmaxval, val 

271 if getattr(finalval, "shape", None): 

272 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

273 finalval.shape)] 

274 return finalval 

275 

276 

277def cast(function, val1, val2): 

278 "Relative difference between val1 and val2 (positive if val2 is larger)" 

279 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

280 pywarnings.simplefilter("ignore") 

281 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

282 if val1.ndim == val2.ndim: 

283 return function(val1, val2) 

284 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

285 dimdelta = dimmest.ndim - lessdim.ndim 

286 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

287 if dimmest is val1: 

288 return function(dimmest, lessdim[add_axes]) 

289 if dimmest is val2: 

290 return function(lessdim[add_axes], dimmest) 

291 return function(val1, val2) 

292 

293 

294class SolutionArray(DictOfLists): 

295 """A dictionary (of dictionaries) of lists, with convenience methods. 

296 

297 Items 

298 ----- 

299 cost : array 

300 variables: dict of arrays 

301 sensitivities: dict containing: 

302 monomials : array 

303 posynomials : array 

304 variables: dict of arrays 

305 localmodels : NomialArray 

306 Local power-law fits (small sensitivities are cut off) 

307 

308 Example 

309 ------- 

310 >>> import gpkit 

311 >>> import numpy as np 

312 >>> x = gpkit.Variable("x") 

313 >>> x_min = gpkit.Variable("x_{min}", 2) 

314 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

315 >>> 

316 >>> # VALUES 

317 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

318 >>> assert all(np.array(values) == 2) 

319 >>> 

320 >>> # SENSITIVITIES 

321 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

322 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

323 >>> assert all(np.array(senss) == 1) 

324 """ 

325 modelstr = "" 

326 _name_collision_varkeys = None 

327 table_titles = {"sweepvariables": "Swept Variables", 

328 "freevariables": "Free Variables", 

329 "constants": "Fixed Variables", # TODO: change everywhere 

330 "variables": "Variables"} 

331 

332 def name_collision_varkeys(self): 

333 "Returns the set of contained varkeys whose names are not unique" 

334 if self._name_collision_varkeys is None: 

335 self["variables"].update_keymap() 

336 keymap = self["variables"].keymap 

337 self._name_collision_varkeys = set() 

338 for key in list(keymap): 

339 if hasattr(key, "key"): 

340 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

341 self._name_collision_varkeys.add(key) 

342 return self._name_collision_varkeys 

343 

344 def __len__(self): 

345 try: 

346 return len(self["cost"]) 

347 except TypeError: 

348 return 1 

349 except KeyError: 

350 return 0 

351 

352 def __call__(self, posy): 

353 posy_subbed = self.subinto(posy) 

354 return getattr(posy_subbed, "c", posy_subbed) 

355 

356 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

357 "Checks for almost-equality between two solutions" 

358 svars, ovars = self["variables"], other["variables"] 

359 svks, ovks = set(svars), set(ovars) 

360 if svks != ovks: 

361 return False 

362 for key in svks: 

363 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

364 return False 

365 if abs(self["sensitivities"]["variables"][key] 

366 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

367 return False 

368 return True 

369 

370 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

371 def diff(self, other, showvars=None, *, 

372 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

373 absdiff=False, abstol=0, reldiff=True, reltol=1.0, 

374 sortmodelsbysenss= True, **tableargs): 

375 """Outputs differences between this solution and another 

376 

377 Arguments 

378 --------- 

379 other : solution or string 

380 strings will be treated as paths to pickled solutions 

381 senssdiff : boolean 

382 if True, show sensitivity differences 

383 sensstol : float 

384 the smallest sensitivity difference worth showing 

385 abssdiff : boolean 

386 if True, show absolute differences 

387 absstol : float 

388 the smallest absolute difference worth showing 

389 reldiff : boolean 

390 if True, show relative differences 

391 reltol : float 

392 the smallest relative difference worth showing 

393 

394 Returns 

395 ------- 

396 str 

397 """ 

398 if sortmodelsbysenss: 

399 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

400 else: 

401 tableargs["sortmodelsbysenss"] = False 

402 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

403 "skipifempty": False}) 

404 if isinstance(other, Strings): 

405 if other[-4:] == ".pgz": 

406 other = SolutionArray.decompress_file(other) 

407 else: 

408 other = pickle.load(open(other, "rb")) 

409 svars, ovars = self["variables"], other["variables"] 

410 lines = ["Solution Diff", 

411 "=============", 

412 "(argument is the baseline solution)", ""] 

413 svks, ovks = set(svars), set(ovars) 

414 if showvars: 

415 lines[0] += " (for selected variables)" 

416 lines[1] += "=========================" 

417 showvars = self._parse_showvars(showvars) 

418 svks = {k for k in showvars if k in svars} 

419 ovks = {k for k in showvars if k in ovars} 

420 if constraintsdiff and other.modelstr and self.modelstr: 

421 if self.modelstr == other.modelstr: 

422 lines += ["** no constraint differences **", ""] 

423 else: 

424 cdiff = ["Constraint Differences", 

425 "**********************"] 

426 cdiff.extend(list(difflib.unified_diff( 

427 other.modelstr.split("\n"), self.modelstr.split("\n"), 

428 lineterm="", n=3))[2:]) 

429 cdiff += ["", "**********************", ""] 

430 lines += cdiff 

431 if svks - ovks: 

432 lines.append("Variable(s) of this solution" 

433 " which are not in the argument:") 

434 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

435 lines.append("") 

436 if ovks - svks: 

437 lines.append("Variable(s) of the argument" 

438 " which are not in this solution:") 

439 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

440 lines.append("") 

441 sharedvks = svks.intersection(ovks) 

442 if reldiff: 

443 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

444 for vk in sharedvks} 

445 lines += var_table(rel_diff, 

446 "Relative Differences |above %g%%|" % reltol, 

447 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

448 minval=reltol, printunits=False, **tableargs) 

449 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

450 lines.insert(-1, ("The largest is %+g%%." 

451 % unrolled_absmax(rel_diff.values()))) 

452 if absdiff: 

453 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

454 lines += var_table(abs_diff, 

455 "Absolute Differences |above %g|" % abstol, 

456 valfmt="%+.2g", vecfmt="%+8.2g", 

457 minval=abstol, **tableargs) 

458 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

459 lines.insert(-1, ("The largest is %+g." 

460 % unrolled_absmax(abs_diff.values()))) 

461 if senssdiff: 

462 ssenss = self["sensitivities"]["variables"] 

463 osenss = other["sensitivities"]["variables"] 

464 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

465 for vk in svks.intersection(ovks)} 

466 lines += var_table(senss_delta, 

467 "Sensitivity Differences |above %g|" % sensstol, 

468 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

469 minval=sensstol, printunits=False, **tableargs) 

470 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

471 lines.insert(-1, ("The largest is %+g." 

472 % unrolled_absmax(senss_delta.values()))) 

473 return "\n".join(lines) 

474 

475 def save(self, filename="solution.pkl", 

476 *, saveconstraints=True, **pickleargs): 

477 """Pickles the solution and saves it to a file. 

478 

479 Solution can then be loaded with e.g.: 

480 >>> import pickle 

481 >>> pickle.load(open("solution.pkl")) 

482 """ 

483 with SolSavingEnvironment(self, saveconstraints): 

484 pickle.dump(self, open(filename, "wb"), **pickleargs) 

485 

486 def save_compressed(self, filename="solution.pgz", 

487 *, saveconstraints=True, **cpickleargs): 

488 "Pickle a file and then compress it into a file with extension." 

489 with gzip.open(filename, "wb") as f: 

490 with SolSavingEnvironment(self, saveconstraints): 

491 pickled = pickle.dumps(self, **cpickleargs) 

492 f.write(pickletools.optimize(pickled)) 

493 

494 @staticmethod 

495 def decompress_file(file): 

496 "Load a gzip-compressed pickle file" 

497 with gzip.open(file, "rb") as f: 

498 return pickle.Unpickler(f).load() 

499 

500 def varnames(self, showvars, exclude): 

501 "Returns list of variables, optionally with minimal unique names" 

502 if showvars: 

503 showvars = self._parse_showvars(showvars) 

504 for key in self.name_collision_varkeys(): 

505 key.descr["necessarylineage"] = True 

506 names = {} 

507 for key in showvars or self["variables"]: 

508 for k in self["variables"].keymap[key]: 

509 names[k.str_without(exclude)] = k 

510 for key in self.name_collision_varkeys(): 

511 del key.descr["necessarylineage"] 

512 return names 

513 

514 def savemat(self, filename="solution.mat", showvars=None, 

515 excluded=("unnecessary lineage", "vec")): 

516 "Saves primal solution as matlab file" 

517 from scipy.io import savemat 

518 savemat(filename, 

519 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

520 for name, key in self.varnames(showvars, excluded).items()}) 

521 

522 def todataframe(self, showvars=None, 

523 excluded=("unnecessary lineage", "vec")): 

524 "Returns primal solution as pandas dataframe" 

525 import pandas as pd # pylint:disable=import-error 

526 rows = [] 

527 cols = ["Name", "Index", "Value", "Units", "Label", 

528 "Lineage", "Other"] 

529 for _, key in sorted(self.varnames(showvars, excluded).items(), 

530 key=lambda k: k[0]): 

531 value = self["variables"][key] 

532 if key.shape: 

533 idxs = [] 

534 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

535 while not it.finished: 

536 idx = it.multi_index 

537 idxs.append(idx[0] if len(idx) == 1 else idx) 

538 it.iternext() 

539 else: 

540 idxs = [None] 

541 for idx in idxs: 

542 row = [ 

543 key.name, 

544 "" if idx is None else idx, 

545 value if idx is None else value[idx]] 

546 rows.append(row) 

547 row.extend([ 

548 key.unitstr(), 

549 key.label or "", 

550 key.lineage or "", 

551 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

552 if k not in ["name", "units", "unitrepr", 

553 "idx", "shape", "veckey", 

554 "value", "original_fn", 

555 "lineage", "label"])]) 

556 return pd.DataFrame(rows, columns=cols) 

557 

558 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs): 

559 "Saves solution table as a text file" 

560 with open(filename, "w") as f: 

561 if printmodel: 

562 f.write(self.modelstr + "\n") 

563 f.write(self.table(**kwargs)) 

564 

565 def savecsv(self, showvars=None, filename="solution.csv", valcols=5): 

566 "Saves primal solution as a CSV sorted by modelname, like the tables." 

567 data = self["variables"] 

568 if showvars: 

569 showvars = self._parse_showvars(showvars) 

570 data = {k: data[k] for k in showvars if k in data} 

571 # if the columns don't capture any dimensions, skip them 

572 minspan, maxspan = None, 1 

573 for v in data.values(): 

574 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

575 minspan_ = min((di for di in v.shape if di != 1)) 

576 maxspan_ = max((di for di in v.shape if di != 1)) 

577 if minspan is None or minspan_ < minspan: 

578 minspan = minspan_ 

579 if maxspan is None or maxspan_ > maxspan: 

580 maxspan = maxspan_ 

581 if minspan is not None and minspan > valcols: 

582 valcols = 1 

583 if maxspan < valcols: 

584 valcols = maxspan 

585 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

586 tables=("cost", "sweepvariables", "freevariables", 

587 "constants", "sensitivities")) 

588 with open(filename, "w") as f: 

589 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

590 + "Units,Description\n") 

591 for line in lines: 

592 if line[0] == ("newmodelline",): 

593 f.write(line[1]) 

594 elif not line[1]: # spacer line 

595 f.write("\n") 

596 else: 

597 f.write("," + line[0].replace(" : ", "") + ",") 

598 vals = line[1].replace("[", "").replace("]", "").strip() 

599 for el in vals.split(): 

600 f.write(el + ",") 

601 f.write(","*(valcols - len(vals.split()))) 

602 f.write((line[2].replace("[", "").replace("]", "").strip() 

603 + ",")) 

604 f.write(line[3].strip() + "\n") 

605 

606 def subinto(self, posy): 

607 "Returns NomialArray of each solution substituted into posy." 

608 if posy in self["variables"]: 

609 return self["variables"](posy) 

610 

611 if not hasattr(posy, "sub"): 

612 raise ValueError("no variable '%s' found in the solution" % posy) 

613 

614 if len(self) > 1: 

615 return NomialArray([self.atindex(i).subinto(posy) 

616 for i in range(len(self))]) 

617 

618 return posy.sub(self["variables"]) 

619 

620 def _parse_showvars(self, showvars): 

621 showvars_out = set() 

622 for k in showvars: 

623 k, _ = self["variables"].parse_and_index(k) 

624 keys = self["variables"].keymap[k] 

625 showvars_out.update(keys) 

626 return showvars_out 

627 

628 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

629 "Print summary table, showing top sensitivities and no constants" 

630 showvars = self._parse_showvars(showvars) 

631 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

632 "freevariables"], **kwargs) 

633 constants_in_showvars = showvars.intersection(self["constants"]) 

634 senss_tables = [] 

635 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

636 senss_tables.append("sensitivities") 

637 if len(self["constants"]) >= ntopsenss+2: 

638 senss_tables.append("top sensitivities") 

639 senss_tables.append("tightest constraints") 

640 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

641 **kwargs) 

642 if senss_str: 

643 out += "\n" + senss_str 

644 return out 

645 

646 def table(self, showvars=(), 

647 tables=("cost", "warnings", "sweepvariables", 

648 "model sensitivities", "freevariables", 

649 "constants", "sensitivities", "tightest constraints"), 

650 sortmodelsbysenss=True, **kwargs): 

651 """A table representation of this SolutionArray 

652 

653 Arguments 

654 --------- 

655 tables: Iterable 

656 Which to print of ("cost", "sweepvariables", "freevariables", 

657 "constants", "sensitivities") 

658 fixedcols: If true, print vectors in fixed-width format 

659 latex: int 

660 If > 0, return latex format (options 1-3); otherwise plain text 

661 included_models: Iterable of strings 

662 If specified, the models (by name) to include 

663 excluded_models: Iterable of strings 

664 If specified, model names to exclude 

665 

666 Returns 

667 ------- 

668 str 

669 """ 

670 if sortmodelsbysenss: 

671 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

672 else: 

673 kwargs["sortmodelsbysenss"] = False 

674 varlist = list(self["variables"]) 

675 has_only_one_model = True 

676 for var in varlist[1:]: 

677 if var.lineage != varlist[0].lineage: 

678 has_only_one_model = False 

679 break 

680 if has_only_one_model: 

681 kwargs["sortbymodel"] = False 

682 for key in self.name_collision_varkeys(): 

683 key.descr["necessarylineage"] = True 

684 showvars = self._parse_showvars(showvars) 

685 strs = [] 

686 for table in tables: 

687 if table == "cost": 

688 cost = self["cost"] # pylint: disable=unsubscriptable-object 

689 if kwargs.get("latex", None): # cost is not printed for latex 

690 continue 

691 strs += ["\n%s\n------------" % "Optimal Cost"] 

692 if len(self) > 1: 

693 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

694 strs += [" [ %s %s ]" % (" ".join(costs), 

695 "..." if len(self) > 4 else "")] 

696 else: 

697 strs += [" %-.4g" % mag(cost)] 

698 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

699 strs += [""] 

700 elif table in TABLEFNS: 

701 strs += TABLEFNS[table](self, showvars, **kwargs) 

702 elif table in self: 

703 data = self[table] 

704 if showvars: 

705 showvars = self._parse_showvars(showvars) 

706 data = {k: data[k] for k in showvars if k in data} 

707 strs += var_table(data, self.table_titles[table], **kwargs) 

708 if kwargs.get("latex", None): 

709 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

710 "% \\usepackage{booktabs}", 

711 "% \\usepackage{longtable}", 

712 "% \\usepackage{amsmath}", 

713 "% \\begin{document}\n")) 

714 strs = [preamble] + strs + ["% \\end{document}"] 

715 for key in self.name_collision_varkeys(): 

716 del key.descr["necessarylineage"] 

717 return "\n".join(strs) 

718 

719 def plot(self, posys=None, axes=None): 

720 "Plots a sweep for each posy" 

721 if len(self["sweepvariables"]) != 1: 

722 print("SolutionArray.plot only supports 1-dimensional sweeps") 

723 if not hasattr(posys, "__len__"): 

724 posys = [posys] 

725 import matplotlib.pyplot as plt 

726 from .interactive.plot_sweep import assign_axes 

727 from . import GPBLU 

728 (swept, x), = self["sweepvariables"].items() 

729 posys, axes = assign_axes(swept, posys, axes) 

730 for posy, ax in zip(posys, axes): 

731 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

732 ax.plot(x, y, color=GPBLU) 

733 if len(axes) == 1: 

734 axes, = axes 

735 return plt.gcf(), axes 

736 

737 

738# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

739def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

740 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

741 minval=0, sortbyvals=False, hidebelowminval=False, 

742 included_models=None, excluded_models=None, sortbymodel=True, 

743 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

744 """ 

745 Pretty string representation of a dict of VarKeys 

746 Iterable values are handled specially (partial printing) 

747 

748 Arguments 

749 --------- 

750 data : dict whose keys are VarKey's 

751 data to represent in table 

752 title : string 

753 printunits : bool 

754 latex : int 

755 If > 0, return latex format (options 1-3); otherwise plain text 

756 varfmt : string 

757 format for variable names 

758 valfmt : string 

759 format for scalar values 

760 vecfmt : string 

761 format for vector values 

762 minval : float 

763 skip values with all(abs(value)) < minval 

764 sortbyvals : boolean 

765 If true, rows are sorted by their average value instead of by name. 

766 included_models : Iterable of strings 

767 If specified, the models (by name) to include 

768 excluded_models : Iterable of strings 

769 If specified, model names to exclude 

770 """ 

771 if not data: 

772 return [] 

773 decorated, models = [], set() 

774 for i, (k, v) in enumerate(data.items()): 

775 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

776 continue # no values below minval 

777 if minval and hidebelowminval and getattr(v, "shape", None): 

778 v[np.abs(v) <= minval] = np.nan 

779 model = lineagestr(k.lineage) if sortbymodel else "" 

780 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

781 if hasattr(msenss, "shape"): 

782 msenss = np.mean(msenss) 

783 models.add(model) 

784 b = bool(getattr(v, "shape", None)) 

785 s = k.str_without(("lineage", "vec")) 

786 if not sortbyvals: 

787 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

788 else: # for consistent sorting, add small offset to negative vals 

789 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

790 sort = (float("%.4g" % -val), k.name) 

791 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

792 if not decorated and skipifempty: 

793 return [] 

794 if included_models: 

795 included_models = set(included_models) 

796 included_models.add("") 

797 models = models.intersection(included_models) 

798 if excluded_models: 

799 models = models.difference(excluded_models) 

800 decorated.sort() 

801 previous_model, lines = None, [] 

802 for varlist in decorated: 

803 if sortbyvals: 

804 model, _, msenss, isvector, varstr, _, var, val = varlist 

805 else: 

806 msenss, model, isvector, varstr, _, var, val = varlist 

807 if model not in models: 

808 continue 

809 if model != previous_model: 

810 if lines: 

811 lines.append(["", "", "", ""]) 

812 if model: 

813 if not latex: 

814 lines.append([("newmodelline",), model, "", ""]) 

815 else: 

816 lines.append( 

817 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

818 previous_model = model 

819 label = var.descr.get("label", "") 

820 units = var.unitstr(" [%s] ") if printunits else "" 

821 if not isvector: 

822 valstr = valfmt % val 

823 else: 

824 last_dim_index = len(val.shape)-1 

825 horiz_dim, ncols = last_dim_index, 1 # starting values 

826 for dim_idx, dim_size in enumerate(val.shape): 

827 if ncols <= dim_size <= maxcolumns: 

828 horiz_dim, ncols = dim_idx, dim_size 

829 # align the array with horiz_dim by making it the last one 

830 dim_order = list(range(last_dim_index)) 

831 dim_order.insert(horiz_dim, last_dim_index) 

832 flatval = val.transpose(dim_order).flatten() 

833 vals = [vecfmt % v for v in flatval[:ncols]] 

834 bracket = " ] " if len(flatval) <= ncols else "" 

835 valstr = "[ %s%s" % (" ".join(vals), bracket) 

836 for before, after in VALSTR_REPLACES: 

837 valstr = valstr.replace(before, after) 

838 if not latex: 

839 lines.append([varstr, valstr, units, label]) 

840 if isvector and len(flatval) > ncols: 

841 values_remaining = len(flatval) - ncols 

842 while values_remaining > 0: 

843 idx = len(flatval)-values_remaining 

844 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

845 values_remaining -= ncols 

846 valstr = " " + " ".join(vals) 

847 for before, after in VALSTR_REPLACES: 

848 valstr = valstr.replace(before, after) 

849 if values_remaining <= 0: 

850 spaces = (-values_remaining 

851 * len(valstr)//(values_remaining + ncols)) 

852 valstr = valstr + " ]" + " "*spaces 

853 lines.append(["", valstr, "", ""]) 

854 else: 

855 varstr = "$%s$" % varstr.replace(" : ", "") 

856 if latex == 1: # normal results table 

857 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

858 label]) 

859 coltitles = [title, "Value", "Units", "Description"] 

860 elif latex == 2: # no values 

861 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

862 coltitles = [title, "Units", "Description"] 

863 elif latex == 3: # no description 

864 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

865 coltitles = [title, "Value", "Units"] 

866 else: 

867 raise ValueError("Unexpected latex option, %s." % latex) 

868 if rawlines: 

869 return lines 

870 if not latex: 

871 if lines: 

872 maxlens = np.max([list(map(len, line)) for line in lines 

873 if line[0] != ("newmodelline",)], axis=0) 

874 dirs = [">", "<", "<", "<"] 

875 # check lengths before using zip 

876 assert len(list(dirs)) == len(list(maxlens)) 

877 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

878 for i, line in enumerate(lines): 

879 if line[0] == ("newmodelline",): 

880 line = [fmts[0].format(" | "), line[1]] 

881 else: 

882 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

883 lines[i] = "".join(line).rstrip() 

884 lines = [title] + ["-"*len(title)] + lines + [""] 

885 else: 

886 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

887 lines = (["\n".join(["{\\footnotesize", 

888 "\\begin{longtable}{%s}" % colfmt[latex], 

889 "\\toprule", 

890 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

891 [" & ".join(l) + " \\\\" for l in lines] + 

892 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

893 return lines