Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray, saveconstraints): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 self.saveconstraints = saveconstraints 

36 self.constraintstore = None 

37 

38 

39 def __enter__(self): 

40 if self.saveconstraints: 

41 for constraint_attr in ["bounded", "meq_bounded", "vks", 

42 "v_ss", "unsubbed", "varkeys"]: 

43 store = {} 

44 for constraint in self.solarray["sensitivities"]["constraints"]: 

45 if getattr(constraint, constraint_attr, None): 

46 store[constraint] = getattr(constraint, constraint_attr) 

47 delattr(constraint, constraint_attr) 

48 self.attrstore[constraint_attr] = store 

49 else: 

50 self.constraintstore = \ 

51 self.solarray["sensitivities"].pop("constraints") 

52 

53 def __exit__(self, type_, val, traceback): 

54 if self.saveconstraints: 

55 for constraint_attr, store in self.attrstore.items(): 

56 for constraint, value in store.items(): 

57 setattr(constraint, constraint_attr, value) 

58 else: 

59 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

60 

61def msenss_table(data, _, **kwargs): 

62 "Returns model sensitivity table lines" 

63 if "models" not in data.get("sensitivities", {}): 

64 return "" 

65 data = sorted(data["sensitivities"]["models"].items(), 

66 key=lambda i: -np.mean(i[1])) 

67 lines = ["Model Sensitivities", "-------------------"] 

68 if kwargs["sortmodelsbysenss"]: 

69 lines[0] += " (sorts models in sections below)" 

70 for model, msenss in data: 

71 if not model: # for now let's only do named models 

72 continue 

73 if not msenss.shape: 

74 msenssstr = "%+5.2f" % msenss 

75 else: 

76 meansenss = np.mean(msenss) 

77 deltas = msenss - meansenss 

78 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

79 for d in deltas] 

80 msenssstr = "%+5.2f + [ %s ]" % (meansenss, " ".join(deltastrs)) 

81 

82 lines.append(" %s : %s" % (msenssstr, model)) 

83 return lines + [""] if len(lines) > 3 else [] 

84 

85 

86def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

87 "Returns sensitivity table lines" 

88 if "variables" in data.get("sensitivities", {}): 

89 data = data["sensitivities"]["variables"] 

90 if showvars: 

91 data = {k: data[k] for k in showvars if k in data} 

92 return var_table(data, title, sortbyvals=True, skipifempty=True, 

93 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

94 printunits=False, minval=1e-3, **kwargs) 

95 

96 

97def topsenss_table(data, showvars, nvars=5, **kwargs): 

98 "Returns top sensitivity table lines" 

99 data, filtered = topsenss_filter(data, showvars, nvars) 

100 title = "Most Sensitive Variables" 

101 if filtered: 

102 title = "Next Most Sensitive Variables" 

103 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

104 

105 

106def topsenss_filter(data, showvars, nvars=5): 

107 "Filters sensitivities down to top N vars" 

108 if "variables" in data.get("sensitivities", {}): 

109 data = data["sensitivities"]["variables"] 

110 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

111 if not np.isnan(s).any()} 

112 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

113 filter_already_shown = showvars.intersection(topk) 

114 for k in filter_already_shown: 

115 topk.remove(k) 

116 if nvars > 3: # always show at least 3 

117 nvars -= 1 

118 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

119 

120 

121def insenss_table(data, _, maxval=0.1, **kwargs): 

122 "Returns insensitivity table lines" 

123 if "constants" in data.get("sensitivities", {}): 

124 data = data["sensitivities"]["variables"] 

125 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

126 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

127 

128 

129def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

130 "Return constraint tightness lines" 

131 title = "Most Sensitive Constraints" 

132 if len(self) > 1: 

133 title += " (in last sweep)" 

134 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

135 "%+6.2g" % s[-1], id(c), c) 

136 for c, s in self["sensitivities"]["constraints"].items() 

137 if s[-1] >= tight_senss)[:ntightconstrs] 

138 else: 

139 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

140 for c, s in self["sensitivities"]["constraints"].items() 

141 if s >= tight_senss)[:ntightconstrs] 

142 return constraint_table(data, title, **kwargs) 

143 

144def loose_table(self, _, min_senss=1e-5, **kwargs): 

145 "Return constraint tightness lines" 

146 title = "Insensitive Constraints |below %+g|" % min_senss 

147 if len(self) > 1: 

148 title += " (in last sweep)" 

149 data = [(0, "", id(c), c) 

150 for c, s in self["sensitivities"]["constraints"].items() 

151 if s[-1] <= min_senss] 

152 else: 

153 data = [(0, "", id(c), c) 

154 for c, s in self["sensitivities"]["constraints"].items() 

155 if s <= min_senss] 

156 return constraint_table(data, title, **kwargs) 

157 

158 

159# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

160def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

161 "Creates lines for tables where the right side is a constraint." 

162 # TODO: this should support 1D array inputs from sweeps 

163 excluded = ("units", "unnecessary lineage") 

164 if not showmodels: 

165 excluded = ("units", "lineage") # hide all of it 

166 models, decorated = {}, [] 

167 for sortby, openingstr, _, constraint in sorted(data): 

168 model = lineagestr(constraint) if sortbymodel else "" 

169 if model not in models: 

170 models[model] = len(models) 

171 constrstr = try_str_without(constraint, excluded) 

172 if " at 0x" in constrstr: # don't print memory addresses 

173 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

174 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

175 decorated.sort() 

176 previous_model, lines = None, [] 

177 for varlist in decorated: 

178 _, model, _, constrstr, openingstr = varlist 

179 if model != previous_model: 

180 if lines: 

181 lines.append(["", ""]) 

182 if model or lines: 

183 lines.append([("newmodelline",), model]) 

184 previous_model = model 

185 constrstr = constrstr.replace(model, "") 

186 minlen, maxlen = 25, 80 

187 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

188 constraintlines = [] 

189 line = "" 

190 next_idx = 0 

191 while next_idx < len(segments): 

192 segment = segments[next_idx] 

193 next_idx += 1 

194 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

195 segments[next_idx] = segment[1:] + segments[next_idx] 

196 segment = segment[0] 

197 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

198 constraintlines.append(line) 

199 line = " " # start a new line 

200 line += segment 

201 while len(line) > maxlen: 

202 constraintlines.append(line[:maxlen]) 

203 line = " " + line[maxlen:] 

204 constraintlines.append(line) 

205 lines += [(openingstr + " : ", constraintlines[0])] 

206 lines += [("", l) for l in constraintlines[1:]] 

207 if not lines: 

208 lines = [("", "(none)")] 

209 maxlens = np.max([list(map(len, line)) for line in lines 

210 if line[0] != ("newmodelline",)], axis=0) 

211 dirs = [">", "<"] # we'll check lengths before using zip 

212 assert len(list(dirs)) == len(list(maxlens)) 

213 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

214 for i, line in enumerate(lines): 

215 if line[0] == ("newmodelline",): 

216 linelist = [fmts[0].format(" | "), line[1]] 

217 else: 

218 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

219 lines[i] = "".join(linelist).rstrip() 

220 return [title] + ["-"*len(title)] + lines + [""] 

221 

222 

223def warnings_table(self, _, **kwargs): 

224 "Makes a table for all warnings in the solution." 

225 title = "WARNINGS" 

226 lines = ["~"*len(title), title, "~"*len(title)] 

227 if "warnings" not in self or not self["warnings"]: 

228 return [] 

229 for wtype in sorted(self["warnings"]): 

230 data_vec = self["warnings"][wtype] 

231 if not hasattr(data_vec, "shape"): 

232 data_vec = [data_vec] 

233 for i, data in enumerate(data_vec): 

234 data = sorted(data, key=lambda l: l[0]) # sort by msg 

235 title = wtype 

236 if len(data_vec) > 1: 

237 title += " in sweep %i" % i 

238 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

239 data = [(-int(1e5*c.relax_sensitivity), 

240 "%+6.2g" % c.relax_sensitivity, id(c), c) 

241 for _, c in data] 

242 lines += constraint_table(data, title, **kwargs) 

243 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

244 data = [(-int(1e5*c.rel_diff), 

245 "%.4g %s %.4g" % c.tightvalues, id(c), c) 

246 for _, c in data] 

247 lines += constraint_table(data, title, **kwargs) 

248 else: 

249 lines += [title] + ["-"*len(wtype)] 

250 lines += [msg for msg, _ in data] + [""] 

251 lines[-1] = "~~~~~~~~" 

252 return lines + [""] 

253 

254 

255TABLEFNS = {"sensitivities": senss_table, 

256 "top sensitivities": topsenss_table, 

257 "insensitivities": insenss_table, 

258 "model sensitivities": msenss_table, 

259 "tightest constraints": tight_table, 

260 "loose constraints": loose_table, 

261 "warnings": warnings_table, 

262 } 

263 

264def unrolled_absmax(values): 

265 "From an iterable of numbers and arrays, returns the largest magnitude" 

266 finalval, absmaxest = None, 0 

267 for val in values: 

268 absmaxval = np.abs(val).max() 

269 if absmaxval >= absmaxest: 

270 absmaxest, finalval = absmaxval, val 

271 if getattr(finalval, "shape", None): 

272 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

273 finalval.shape)] 

274 return finalval 

275 

276 

277def cast(function, val1, val2): 

278 "Relative difference between val1 and val2 (positive if val2 is larger)" 

279 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

280 pywarnings.simplefilter("ignore") 

281 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

282 if val1.ndim == val2.ndim: 

283 return function(val1, val2) 

284 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

285 dimdelta = dimmest.ndim - lessdim.ndim 

286 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

287 if dimmest is val1: 

288 return function(dimmest, lessdim[add_axes]) 

289 if dimmest is val2: 

290 return function(lessdim[add_axes], dimmest) 

291 return function(val1, val2) 

292 

293 

294class SolutionArray(DictOfLists): 

295 """A dictionary (of dictionaries) of lists, with convenience methods. 

296 

297 Items 

298 ----- 

299 cost : array 

300 variables: dict of arrays 

301 sensitivities: dict containing: 

302 monomials : array 

303 posynomials : array 

304 variables: dict of arrays 

305 localmodels : NomialArray 

306 Local power-law fits (small sensitivities are cut off) 

307 

308 Example 

309 ------- 

310 >>> import gpkit 

311 >>> import numpy as np 

312 >>> x = gpkit.Variable("x") 

313 >>> x_min = gpkit.Variable("x_{min}", 2) 

314 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

315 >>> 

316 >>> # VALUES 

317 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

318 >>> assert all(np.array(values) == 2) 

319 >>> 

320 >>> # SENSITIVITIES 

321 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

322 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

323 >>> assert all(np.array(senss) == 1) 

324 """ 

325 modelstr = "" 

326 _name_collision_varkeys = None 

327 table_titles = {"sweepvariables": "Swept Variables", 

328 "freevariables": "Free Variables", 

329 "constants": "Fixed Variables", # TODO: change everywhere 

330 "variables": "Variables"} 

331 

332 def name_collision_varkeys(self): 

333 "Returns the set of contained varkeys whose names are not unique" 

334 if self._name_collision_varkeys is None: 

335 self["variables"].update_keymap() 

336 keymap = self["variables"].keymap 

337 self._name_collision_varkeys = set() 

338 for key in list(keymap): 

339 if hasattr(key, "key"): 

340 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

341 self._name_collision_varkeys.add(key) 

342 return self._name_collision_varkeys 

343 

344 def __len__(self): 

345 try: 

346 return len(self["cost"]) 

347 except TypeError: 

348 return 1 

349 except KeyError: 

350 return 0 

351 

352 def __call__(self, posy): 

353 posy_subbed = self.subinto(posy) 

354 return getattr(posy_subbed, "c", posy_subbed) 

355 

356 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

357 "Checks for almost-equality between two solutions" 

358 svars, ovars = self["variables"], other["variables"] 

359 svks, ovks = set(svars), set(ovars) 

360 if svks != ovks: 

361 return False 

362 for key in svks: 

363 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

364 return False 

365 if abs(self["sensitivities"]["variables"][key] 

366 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

367 return False 

368 return True 

369 

370 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

371 def diff(self, other, showvars=None, *, 

372 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

373 absdiff=False, abstol=0, reldiff=True, reltol=1.0, **tableargs): 

374 """Outputs differences between this solution and another 

375 

376 Arguments 

377 --------- 

378 other : solution or string 

379 strings will be treated as paths to pickled solutions 

380 senssdiff : boolean 

381 if True, show sensitivity differences 

382 sensstol : float 

383 the smallest sensitivity difference worth showing 

384 abssdiff : boolean 

385 if True, show absolute differences 

386 absstol : float 

387 the smallest absolute difference worth showing 

388 reldiff : boolean 

389 if True, show relative differences 

390 reltol : float 

391 the smallest relative difference worth showing 

392 

393 Returns 

394 ------- 

395 str 

396 """ 

397 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

398 "skipifempty": False}) 

399 if isinstance(other, Strings): 

400 if other[-4:] == ".pgz": 

401 other = SolutionArray.decompress_file(other) 

402 else: 

403 other = pickle.load(open(other, "rb")) 

404 svars, ovars = self["variables"], other["variables"] 

405 lines = ["Solution Diff", 

406 "=============", 

407 "(argument is the baseline solution)", ""] 

408 svks, ovks = set(svars), set(ovars) 

409 if showvars: 

410 lines[0] += " (for selected variables)" 

411 lines[1] += "=========================" 

412 showvars = self._parse_showvars(showvars) 

413 svks = {k for k in showvars if k in svars} 

414 ovks = {k for k in showvars if k in ovars} 

415 if constraintsdiff and other.modelstr and self.modelstr: 

416 if self.modelstr == other.modelstr: 

417 lines += ["** no constraint differences **", ""] 

418 else: 

419 cdiff = ["Constraint Differences", 

420 "**********************"] 

421 cdiff.extend(list(difflib.unified_diff( 

422 other.modelstr.split("\n"), self.modelstr.split("\n"), 

423 lineterm="", n=3))[2:]) 

424 cdiff += ["", "**********************", ""] 

425 lines += cdiff 

426 if svks - ovks: 

427 lines.append("Variable(s) of this solution" 

428 " which are not in the argument:") 

429 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

430 lines.append("") 

431 if ovks - svks: 

432 lines.append("Variable(s) of the argument" 

433 " which are not in this solution:") 

434 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

435 lines.append("") 

436 sharedvks = svks.intersection(ovks) 

437 if reldiff: 

438 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

439 for vk in sharedvks} 

440 lines += var_table(rel_diff, 

441 "Relative Differences |above %g%%|" % reltol, 

442 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

443 minval=reltol, printunits=False, **tableargs) 

444 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

445 lines.insert(-1, ("The largest is %+g%%." 

446 % unrolled_absmax(rel_diff.values()))) 

447 if absdiff: 

448 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

449 lines += var_table(abs_diff, 

450 "Absolute Differences |above %g|" % abstol, 

451 valfmt="%+.2g", vecfmt="%+8.2g", 

452 minval=abstol, **tableargs) 

453 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

454 lines.insert(-1, ("The largest is %+g." 

455 % unrolled_absmax(abs_diff.values()))) 

456 if senssdiff: 

457 ssenss = self["sensitivities"]["variables"] 

458 osenss = other["sensitivities"]["variables"] 

459 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

460 for vk in svks.intersection(ovks)} 

461 lines += var_table(senss_delta, 

462 "Sensitivity Differences |above %g|" % sensstol, 

463 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

464 minval=sensstol, printunits=False, **tableargs) 

465 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

466 lines.insert(-1, ("The largest is %+g." 

467 % unrolled_absmax(senss_delta.values()))) 

468 return "\n".join(lines) 

469 

470 def save(self, filename="solution.pkl", 

471 *, saveconstraints=True, **pickleargs): 

472 """Pickles the solution and saves it to a file. 

473 

474 Solution can then be loaded with e.g.: 

475 >>> import pickle 

476 >>> pickle.load(open("solution.pkl")) 

477 """ 

478 with SolSavingEnvironment(self, saveconstraints): 

479 pickle.dump(self, open(filename, "wb"), **pickleargs) 

480 

481 def save_compressed(self, filename="solution.pgz", 

482 *, saveconstraints=True, **cpickleargs): 

483 "Pickle a file and then compress it into a file with extension." 

484 with gzip.open(filename, "wb") as f: 

485 with SolSavingEnvironment(self, saveconstraints): 

486 pickled = pickle.dumps(self, **cpickleargs) 

487 f.write(pickletools.optimize(pickled)) 

488 

489 @staticmethod 

490 def decompress_file(file): 

491 "Load a gzip-compressed pickle file" 

492 with gzip.open(file, "rb") as f: 

493 return pickle.Unpickler(f).load() 

494 

495 def varnames(self, showvars, exclude): 

496 "Returns list of variables, optionally with minimal unique names" 

497 if showvars: 

498 showvars = self._parse_showvars(showvars) 

499 for key in self.name_collision_varkeys(): 

500 key.descr["necessarylineage"] = True 

501 names = {} 

502 for key in showvars or self["variables"]: 

503 for k in self["variables"].keymap[key]: 

504 names[k.str_without(exclude)] = k 

505 for key in self.name_collision_varkeys(): 

506 del key.descr["necessarylineage"] 

507 return names 

508 

509 def savemat(self, filename="solution.mat", showvars=None, 

510 excluded=("unnecessary lineage", "vec")): 

511 "Saves primal solution as matlab file" 

512 from scipy.io import savemat 

513 savemat(filename, 

514 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

515 for name, key in self.varnames(showvars, excluded).items()}) 

516 

517 def todataframe(self, showvars=None, 

518 excluded=("unnecessary lineage", "vec")): 

519 "Returns primal solution as pandas dataframe" 

520 import pandas as pd # pylint:disable=import-error 

521 rows = [] 

522 cols = ["Name", "Index", "Value", "Units", "Label", 

523 "Lineage", "Other"] 

524 for _, key in sorted(self.varnames(showvars, excluded).items(), 

525 key=lambda k: k[0]): 

526 value = self["variables"][key] 

527 if key.shape: 

528 idxs = [] 

529 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

530 while not it.finished: 

531 idx = it.multi_index 

532 idxs.append(idx[0] if len(idx) == 1 else idx) 

533 it.iternext() 

534 else: 

535 idxs = [None] 

536 for idx in idxs: 

537 row = [ 

538 key.name, 

539 "" if idx is None else idx, 

540 value if idx is None else value[idx]] 

541 rows.append(row) 

542 row.extend([ 

543 key.unitstr(), 

544 key.label or "", 

545 key.lineage or "", 

546 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

547 if k not in ["name", "units", "unitrepr", 

548 "idx", "shape", "veckey", 

549 "value", "original_fn", 

550 "lineage", "label"])]) 

551 return pd.DataFrame(rows, columns=cols) 

552 

553 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs): 

554 "Saves solution table as a text file" 

555 with open(filename, "w") as f: 

556 if printmodel: 

557 f.write(self.modelstr + "\n") 

558 f.write(self.table(**kwargs)) 

559 

560 def savecsv(self, showvars=None, filename="solution.csv", valcols=5): 

561 "Saves primal solution as a CSV sorted by modelname, like the tables." 

562 data = self["variables"] 

563 if showvars: 

564 showvars = self._parse_showvars(showvars) 

565 data = {k: data[k] for k in showvars if k in data} 

566 # if the columns don't capture any dimensions, skip them 

567 minspan, maxspan = None, 1 

568 for v in data.values(): 

569 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

570 minspan_ = min((di for di in v.shape if di != 1)) 

571 maxspan_ = max((di for di in v.shape if di != 1)) 

572 if minspan is None or minspan_ < minspan: 

573 minspan = minspan_ 

574 if maxspan is None or maxspan_ > maxspan: 

575 maxspan = maxspan_ 

576 if minspan is not None and minspan > valcols: 

577 valcols = 1 

578 if maxspan < valcols: 

579 valcols = maxspan 

580 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

581 tables=("cost", "sweepvariables", "freevariables", 

582 "constants", "sensitivities")) 

583 with open(filename, "w") as f: 

584 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

585 + "Units,Description\n") 

586 for line in lines: 

587 if line[0] == ("newmodelline",): 

588 f.write(line[1]) 

589 elif not line[1]: # spacer line 

590 f.write("\n") 

591 else: 

592 f.write("," + line[0].replace(" : ", "") + ",") 

593 vals = line[1].replace("[", "").replace("]", "").strip() 

594 for el in vals.split(): 

595 f.write(el + ",") 

596 f.write(","*(valcols - len(vals.split()))) 

597 f.write((line[2].replace("[", "").replace("]", "").strip() 

598 + ",")) 

599 f.write(line[3].strip() + "\n") 

600 

601 def subinto(self, posy): 

602 "Returns NomialArray of each solution substituted into posy." 

603 if posy in self["variables"]: 

604 return self["variables"](posy) 

605 

606 if not hasattr(posy, "sub"): 

607 raise ValueError("no variable '%s' found in the solution" % posy) 

608 

609 if len(self) > 1: 

610 return NomialArray([self.atindex(i).subinto(posy) 

611 for i in range(len(self))]) 

612 

613 return posy.sub(self["variables"]) 

614 

615 def _parse_showvars(self, showvars): 

616 showvars_out = set() 

617 for k in showvars: 

618 k, _ = self["variables"].parse_and_index(k) 

619 keys = self["variables"].keymap[k] 

620 showvars_out.update(keys) 

621 return showvars_out 

622 

623 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

624 "Print summary table, showing top sensitivities and no constants" 

625 showvars = self._parse_showvars(showvars) 

626 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

627 "freevariables"], **kwargs) 

628 constants_in_showvars = showvars.intersection(self["constants"]) 

629 senss_tables = [] 

630 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

631 senss_tables.append("sensitivities") 

632 if len(self["constants"]) >= ntopsenss+2: 

633 senss_tables.append("top sensitivities") 

634 senss_tables.append("tightest constraints") 

635 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

636 **kwargs) 

637 if senss_str: 

638 out += "\n" + senss_str 

639 return out 

640 

641 def table(self, showvars=(), 

642 tables=("cost", "warnings", "sweepvariables", 

643 "model sensitivities", "freevariables", 

644 "constants", "sensitivities", "tightest constraints"), 

645 sortmodelsbysenss=True, **kwargs): 

646 """A table representation of this SolutionArray 

647 

648 Arguments 

649 --------- 

650 tables: Iterable 

651 Which to print of ("cost", "sweepvariables", "freevariables", 

652 "constants", "sensitivities") 

653 fixedcols: If true, print vectors in fixed-width format 

654 latex: int 

655 If > 0, return latex format (options 1-3); otherwise plain text 

656 included_models: Iterable of strings 

657 If specified, the models (by name) to include 

658 excluded_models: Iterable of strings 

659 If specified, model names to exclude 

660 

661 Returns 

662 ------- 

663 str 

664 """ 

665 if sortmodelsbysenss: 

666 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

667 else: 

668 kwargs["sortmodelsbysenss"] = False 

669 varlist = list(self["variables"]) 

670 has_only_one_model = True 

671 for var in varlist[1:]: 

672 if var.lineage != varlist[0].lineage: 

673 has_only_one_model = False 

674 break 

675 if has_only_one_model: 

676 kwargs["sortbymodel"] = False 

677 for key in self.name_collision_varkeys(): 

678 key.descr["necessarylineage"] = True 

679 showvars = self._parse_showvars(showvars) 

680 strs = [] 

681 for table in tables: 

682 if table == "cost": 

683 cost = self["cost"] # pylint: disable=unsubscriptable-object 

684 if kwargs.get("latex", None): # cost is not printed for latex 

685 continue 

686 strs += ["\n%s\n------------" % "Optimal Cost"] 

687 if len(self) > 1: 

688 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

689 strs += [" [ %s %s ]" % (" ".join(costs), 

690 "..." if len(self) > 4 else "")] 

691 else: 

692 strs += [" %-.4g" % mag(cost)] 

693 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

694 strs += [""] 

695 elif table in TABLEFNS: 

696 strs += TABLEFNS[table](self, showvars, **kwargs) 

697 elif table in self: 

698 data = self[table] 

699 if showvars: 

700 showvars = self._parse_showvars(showvars) 

701 data = {k: data[k] for k in showvars if k in data} 

702 strs += var_table(data, self.table_titles[table], **kwargs) 

703 if kwargs.get("latex", None): 

704 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

705 "% \\usepackage{booktabs}", 

706 "% \\usepackage{longtable}", 

707 "% \\usepackage{amsmath}", 

708 "% \\begin{document}\n")) 

709 strs = [preamble] + strs + ["% \\end{document}"] 

710 for key in self.name_collision_varkeys(): 

711 del key.descr["necessarylineage"] 

712 return "\n".join(strs) 

713 

714 def plot(self, posys=None, axes=None): 

715 "Plots a sweep for each posy" 

716 if len(self["sweepvariables"]) != 1: 

717 print("SolutionArray.plot only supports 1-dimensional sweeps") 

718 if not hasattr(posys, "__len__"): 

719 posys = [posys] 

720 import matplotlib.pyplot as plt 

721 from .interactive.plot_sweep import assign_axes 

722 from . import GPBLU 

723 (swept, x), = self["sweepvariables"].items() 

724 posys, axes = assign_axes(swept, posys, axes) 

725 for posy, ax in zip(posys, axes): 

726 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

727 ax.plot(x, y, color=GPBLU) 

728 if len(axes) == 1: 

729 axes, = axes 

730 return plt.gcf(), axes 

731 

732 

733# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

734def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

735 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

736 minval=0, sortbyvals=False, hidebelowminval=False, 

737 included_models=None, excluded_models=None, sortbymodel=True, 

738 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

739 """ 

740 Pretty string representation of a dict of VarKeys 

741 Iterable values are handled specially (partial printing) 

742 

743 Arguments 

744 --------- 

745 data : dict whose keys are VarKey's 

746 data to represent in table 

747 title : string 

748 printunits : bool 

749 latex : int 

750 If > 0, return latex format (options 1-3); otherwise plain text 

751 varfmt : string 

752 format for variable names 

753 valfmt : string 

754 format for scalar values 

755 vecfmt : string 

756 format for vector values 

757 minval : float 

758 skip values with all(abs(value)) < minval 

759 sortbyvals : boolean 

760 If true, rows are sorted by their average value instead of by name. 

761 included_models : Iterable of strings 

762 If specified, the models (by name) to include 

763 excluded_models : Iterable of strings 

764 If specified, model names to exclude 

765 """ 

766 if not data: 

767 return [] 

768 decorated, models = [], set() 

769 for i, (k, v) in enumerate(data.items()): 

770 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

771 continue # no values below minval 

772 if minval and hidebelowminval and getattr(v, "shape", None): 

773 v[np.abs(v) <= minval] = np.nan 

774 model = lineagestr(k.lineage) if sortbymodel else "" 

775 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

776 if hasattr(msenss, "shape"): 

777 msenss = np.mean(msenss) 

778 models.add(model) 

779 b = bool(getattr(v, "shape", None)) 

780 s = k.str_without(("lineage", "vec")) 

781 if not sortbyvals: 

782 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

783 else: # for consistent sorting, add small offset to negative vals 

784 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

785 sort = (float("%.4g" % -val), k.name) 

786 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

787 if not decorated and skipifempty: 

788 return [] 

789 if included_models: 

790 included_models = set(included_models) 

791 included_models.add("") 

792 models = models.intersection(included_models) 

793 if excluded_models: 

794 models = models.difference(excluded_models) 

795 decorated.sort() 

796 previous_model, lines = None, [] 

797 for varlist in decorated: 

798 if sortbyvals: 

799 model, _, msenss, isvector, varstr, _, var, val = varlist 

800 else: 

801 msenss, model, isvector, varstr, _, var, val = varlist 

802 if model not in models: 

803 continue 

804 if model != previous_model: 

805 if lines: 

806 lines.append(["", "", "", ""]) 

807 if model: 

808 if not latex: 

809 lines.append([("newmodelline",), model, "", ""]) 

810 else: 

811 lines.append( 

812 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

813 previous_model = model 

814 label = var.descr.get("label", "") 

815 units = var.unitstr(" [%s] ") if printunits else "" 

816 if not isvector: 

817 valstr = valfmt % val 

818 else: 

819 last_dim_index = len(val.shape)-1 

820 horiz_dim, ncols = last_dim_index, 1 # starting values 

821 for dim_idx, dim_size in enumerate(val.shape): 

822 if ncols <= dim_size <= maxcolumns: 

823 horiz_dim, ncols = dim_idx, dim_size 

824 # align the array with horiz_dim by making it the last one 

825 dim_order = list(range(last_dim_index)) 

826 dim_order.insert(horiz_dim, last_dim_index) 

827 flatval = val.transpose(dim_order).flatten() 

828 vals = [vecfmt % v for v in flatval[:ncols]] 

829 bracket = " ] " if len(flatval) <= ncols else "" 

830 valstr = "[ %s%s" % (" ".join(vals), bracket) 

831 for before, after in VALSTR_REPLACES: 

832 valstr = valstr.replace(before, after) 

833 if not latex: 

834 lines.append([varstr, valstr, units, label]) 

835 if isvector and len(flatval) > ncols: 

836 values_remaining = len(flatval) - ncols 

837 while values_remaining > 0: 

838 idx = len(flatval)-values_remaining 

839 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

840 values_remaining -= ncols 

841 valstr = " " + " ".join(vals) 

842 for before, after in VALSTR_REPLACES: 

843 valstr = valstr.replace(before, after) 

844 if values_remaining <= 0: 

845 spaces = (-values_remaining 

846 * len(valstr)//(values_remaining + ncols)) 

847 valstr = valstr + " ]" + " "*spaces 

848 lines.append(["", valstr, "", ""]) 

849 else: 

850 varstr = "$%s$" % varstr.replace(" : ", "") 

851 if latex == 1: # normal results table 

852 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

853 label]) 

854 coltitles = [title, "Value", "Units", "Description"] 

855 elif latex == 2: # no values 

856 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

857 coltitles = [title, "Units", "Description"] 

858 elif latex == 3: # no description 

859 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

860 coltitles = [title, "Value", "Units"] 

861 else: 

862 raise ValueError("Unexpected latex option, %s." % latex) 

863 if rawlines: 

864 return lines 

865 if not latex: 

866 if lines: 

867 maxlens = np.max([list(map(len, line)) for line in lines 

868 if line[0] != ("newmodelline",)], axis=0) 

869 dirs = [">", "<", "<", "<"] 

870 # check lengths before using zip 

871 assert len(list(dirs)) == len(list(maxlens)) 

872 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

873 for i, line in enumerate(lines): 

874 if line[0] == ("newmodelline",): 

875 line = [fmts[0].format(" | "), line[1]] 

876 else: 

877 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

878 lines[i] = "".join(line).rstrip() 

879 lines = [title] + ["-"*len(title)] + lines + [""] 

880 else: 

881 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

882 lines = (["\n".join(["{\\footnotesize", 

883 "\\begin{longtable}{%s}" % colfmt[latex], 

884 "\\toprule", 

885 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

886 [" & ".join(l) + " \\\\" for l in lines] + 

887 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

888 return lines