Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray, saveconstraints): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 self.saveconstraints = saveconstraints 

36 self.constraintstore = None 

37 

38 

39 def __enter__(self): 

40 if self.saveconstraints: 

41 for constraint_attr in ["bounded", "meq_bounded", "vks", 

42 "v_ss", "unsubbed", "varkeys"]: 

43 store = {} 

44 for constraint in self.solarray["sensitivities"]["constraints"]: 

45 if getattr(constraint, constraint_attr, None): 

46 store[constraint] = getattr(constraint, constraint_attr) 

47 delattr(constraint, constraint_attr) 

48 self.attrstore[constraint_attr] = store 

49 else: 

50 self.constraintstore = \ 

51 self.solarray["sensitivities"].pop("constraints") 

52 

53 def __exit__(self, type_, val, traceback): 

54 if self.saveconstraints: 

55 for constraint_attr, store in self.attrstore.items(): 

56 for constraint, value in store.items(): 

57 setattr(constraint, constraint_attr, value) 

58 else: 

59 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

60 

61def msenss_table(data, _, **kwargs): 

62 "Returns model sensitivity table lines" 

63 if "models" not in data.get("sensitivities", {}): 

64 return "" 

65 data = sorted(data["sensitivities"]["models"].items(), 

66 key=lambda i: -np.mean(i[1])) 

67 lines = ["Model Sensitivities", "-------------------"] 

68 if kwargs["sortmodelsbysenss"]: 

69 lines[0] += " (sorts models in sections below)" 

70 previousmsenssstr = "" 

71 for model, msenss in data: 

72 if not model: # for now let's only do named models 

73 continue 

74 if (msenss < 0.1).all(): 

75 msenss = np.max(msenss) 

76 if msenss: 

77 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss)) 

78 else: 

79 msenssstr = " =0 " 

80 elif not msenss.shape: 

81 msenssstr = "%+6.1f" % msenss 

82 else: 

83 meansenss = np.mean(msenss) 

84 msenssstr = "%+6.1f" % meansenss 

85 deltas = msenss - meansenss 

86 if np.max(np.abs(deltas)) > 0.1: 

87 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

88 for d in deltas] 

89 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

90 if msenssstr == previousmsenssstr: 

91 msenssstr = " " 

92 else: 

93 previousmsenssstr = msenssstr 

94 lines.append("%s : %s" % (msenssstr, model)) 

95 return lines + [""] if len(lines) > 3 else [] 

96 

97 

98def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

99 "Returns sensitivity table lines" 

100 if "variables" in data.get("sensitivities", {}): 

101 data = data["sensitivities"]["variables"] 

102 if showvars: 

103 data = {k: data[k] for k in showvars if k in data} 

104 return var_table(data, title, sortbyvals=True, skipifempty=True, 

105 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

106 printunits=False, minval=1e-3, **kwargs) 

107 

108 

109def topsenss_table(data, showvars, nvars=5, **kwargs): 

110 "Returns top sensitivity table lines" 

111 data, filtered = topsenss_filter(data, showvars, nvars) 

112 title = "Most Sensitive Variables" 

113 if filtered: 

114 title = "Next Most Sensitive Variables" 

115 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

116 

117 

118def topsenss_filter(data, showvars, nvars=5): 

119 "Filters sensitivities down to top N vars" 

120 if "variables" in data.get("sensitivities", {}): 

121 data = data["sensitivities"]["variables"] 

122 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

123 if not np.isnan(s).any()} 

124 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

125 filter_already_shown = showvars.intersection(topk) 

126 for k in filter_already_shown: 

127 topk.remove(k) 

128 if nvars > 3: # always show at least 3 

129 nvars -= 1 

130 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

131 

132 

133def insenss_table(data, _, maxval=0.1, **kwargs): 

134 "Returns insensitivity table lines" 

135 if "constants" in data.get("sensitivities", {}): 

136 data = data["sensitivities"]["variables"] 

137 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

138 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

139 

140 

141def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

142 "Return constraint tightness lines" 

143 title = "Most Sensitive Constraints" 

144 if len(self) > 1: 

145 title += " (in last sweep)" 

146 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

147 "%+6.2g" % s[-1], id(c), c) 

148 for c, s in self["sensitivities"]["constraints"].items() 

149 if s[-1] >= tight_senss)[:ntightconstrs] 

150 else: 

151 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

152 for c, s in self["sensitivities"]["constraints"].items() 

153 if s >= tight_senss)[:ntightconstrs] 

154 return constraint_table(data, title, **kwargs) 

155 

156def loose_table(self, _, min_senss=1e-5, **kwargs): 

157 "Return constraint tightness lines" 

158 title = "Insensitive Constraints |below %+g|" % min_senss 

159 if len(self) > 1: 

160 title += " (in last sweep)" 

161 data = [(0, "", id(c), c) 

162 for c, s in self["sensitivities"]["constraints"].items() 

163 if s[-1] <= min_senss] 

164 else: 

165 data = [(0, "", id(c), c) 

166 for c, s in self["sensitivities"]["constraints"].items() 

167 if s <= min_senss] 

168 return constraint_table(data, title, **kwargs) 

169 

170 

171# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

172def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

173 "Creates lines for tables where the right side is a constraint." 

174 # TODO: this should support 1D array inputs from sweeps 

175 excluded = ("units", "unnecessary lineage") 

176 if not showmodels: 

177 excluded = ("units", "lineage") # hide all of it 

178 models, decorated = {}, [] 

179 for sortby, openingstr, _, constraint in sorted(data): 

180 model = lineagestr(constraint) if sortbymodel else "" 

181 if model not in models: 

182 models[model] = len(models) 

183 constrstr = try_str_without(constraint, excluded) 

184 if " at 0x" in constrstr: # don't print memory addresses 

185 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

186 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

187 decorated.sort() 

188 previous_model, lines = None, [] 

189 for varlist in decorated: 

190 _, model, _, constrstr, openingstr = varlist 

191 if model != previous_model: 

192 if lines: 

193 lines.append(["", ""]) 

194 if model or lines: 

195 lines.append([("newmodelline",), model]) 

196 previous_model = model 

197 constrstr = constrstr.replace(model, "") 

198 minlen, maxlen = 25, 80 

199 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

200 constraintlines = [] 

201 line = "" 

202 next_idx = 0 

203 while next_idx < len(segments): 

204 segment = segments[next_idx] 

205 next_idx += 1 

206 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

207 segments[next_idx] = segment[1:] + segments[next_idx] 

208 segment = segment[0] 

209 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

210 constraintlines.append(line) 

211 line = " " # start a new line 

212 line += segment 

213 while len(line) > maxlen: 

214 constraintlines.append(line[:maxlen]) 

215 line = " " + line[maxlen:] 

216 constraintlines.append(line) 

217 lines += [(openingstr + " : ", constraintlines[0])] 

218 lines += [("", l) for l in constraintlines[1:]] 

219 if not lines: 

220 lines = [("", "(none)")] 

221 maxlens = np.max([list(map(len, line)) for line in lines 

222 if line[0] != ("newmodelline",)], axis=0) 

223 dirs = [">", "<"] # we'll check lengths before using zip 

224 assert len(list(dirs)) == len(list(maxlens)) 

225 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

226 for i, line in enumerate(lines): 

227 if line[0] == ("newmodelline",): 

228 linelist = [fmts[0].format(" | "), line[1]] 

229 else: 

230 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

231 lines[i] = "".join(linelist).rstrip() 

232 return [title] + ["-"*len(title)] + lines + [""] 

233 

234 

235def warnings_table(self, _, **kwargs): 

236 "Makes a table for all warnings in the solution." 

237 title = "WARNINGS" 

238 lines = ["~"*len(title), title, "~"*len(title)] 

239 if "warnings" not in self or not self["warnings"]: 

240 return [] 

241 for wtype in sorted(self["warnings"]): 

242 data_vec = self["warnings"][wtype] 

243 if len(data_vec) == 0: 

244 continue 

245 if not hasattr(data_vec, "shape"): 

246 data_vec = [data_vec] 

247 for i, data in enumerate(data_vec): 

248 if len(data) == 0: 

249 continue 

250 data = sorted(data, key=lambda l: l[0]) # sort by msg 

251 title = wtype 

252 if len(data_vec) > 1: 

253 title += " in sweep %i" % i 

254 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

255 data = [(-int(1e5*c.relax_sensitivity), 

256 "%+6.2g" % c.relax_sensitivity, id(c), c) 

257 for _, c in data] 

258 lines += constraint_table(data, title, **kwargs) 

259 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

260 data = [(-int(1e5*c.rel_diff), 

261 "%.4g %s %.4g" % c.tightvalues, id(c), c) 

262 for _, c in data] 

263 lines += constraint_table(data, title, **kwargs) 

264 else: 

265 lines += [title] + ["-"*len(wtype)] 

266 lines += [msg for msg, _ in data] + [""] 

267 lines[-1] = "~~~~~~~~" 

268 return lines + [""] 

269 

270 

271TABLEFNS = {"sensitivities": senss_table, 

272 "top sensitivities": topsenss_table, 

273 "insensitivities": insenss_table, 

274 "model sensitivities": msenss_table, 

275 "tightest constraints": tight_table, 

276 "loose constraints": loose_table, 

277 "warnings": warnings_table, 

278 } 

279 

280def unrolled_absmax(values): 

281 "From an iterable of numbers and arrays, returns the largest magnitude" 

282 finalval, absmaxest = None, 0 

283 for val in values: 

284 absmaxval = np.abs(val).max() 

285 if absmaxval >= absmaxest: 

286 absmaxest, finalval = absmaxval, val 

287 if getattr(finalval, "shape", None): 

288 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

289 finalval.shape)] 

290 return finalval 

291 

292 

293def cast(function, val1, val2): 

294 "Relative difference between val1 and val2 (positive if val2 is larger)" 

295 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

296 pywarnings.simplefilter("ignore") 

297 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

298 if val1.ndim == val2.ndim: 

299 return function(val1, val2) 

300 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

301 dimdelta = dimmest.ndim - lessdim.ndim 

302 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

303 if dimmest is val1: 

304 return function(dimmest, lessdim[add_axes]) 

305 if dimmest is val2: 

306 return function(lessdim[add_axes], dimmest) 

307 return function(val1, val2) 

308 

309 

310class SolutionArray(DictOfLists): 

311 """A dictionary (of dictionaries) of lists, with convenience methods. 

312 

313 Items 

314 ----- 

315 cost : array 

316 variables: dict of arrays 

317 sensitivities: dict containing: 

318 monomials : array 

319 posynomials : array 

320 variables: dict of arrays 

321 localmodels : NomialArray 

322 Local power-law fits (small sensitivities are cut off) 

323 

324 Example 

325 ------- 

326 >>> import gpkit 

327 >>> import numpy as np 

328 >>> x = gpkit.Variable("x") 

329 >>> x_min = gpkit.Variable("x_{min}", 2) 

330 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

331 >>> 

332 >>> # VALUES 

333 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

334 >>> assert all(np.array(values) == 2) 

335 >>> 

336 >>> # SENSITIVITIES 

337 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

338 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

339 >>> assert all(np.array(senss) == 1) 

340 """ 

341 modelstr = "" 

342 _name_collision_varkeys = None 

343 table_titles = {"sweepvariables": "Swept Variables", 

344 "freevariables": "Free Variables", 

345 "constants": "Fixed Variables", # TODO: change everywhere 

346 "variables": "Variables"} 

347 

348 def name_collision_varkeys(self): 

349 "Returns the set of contained varkeys whose names are not unique" 

350 if self._name_collision_varkeys is None: 

351 self["variables"].update_keymap() 

352 keymap = self["variables"].keymap 

353 self._name_collision_varkeys = set() 

354 for key in list(keymap): 

355 if hasattr(key, "key"): 

356 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

357 self._name_collision_varkeys.add(key) 

358 return self._name_collision_varkeys 

359 

360 def __len__(self): 

361 try: 

362 return len(self["cost"]) 

363 except TypeError: 

364 return 1 

365 except KeyError: 

366 return 0 

367 

368 def __call__(self, posy): 

369 posy_subbed = self.subinto(posy) 

370 return getattr(posy_subbed, "c", posy_subbed) 

371 

372 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

373 "Checks for almost-equality between two solutions" 

374 svars, ovars = self["variables"], other["variables"] 

375 svks, ovks = set(svars), set(ovars) 

376 if svks != ovks: 

377 return False 

378 for key in svks: 

379 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

380 return False 

381 if abs(self["sensitivities"]["variables"][key] 

382 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

383 return False 

384 return True 

385 

386 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

387 def diff(self, other, showvars=None, *, 

388 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

389 absdiff=False, abstol=0, reldiff=True, reltol=1.0, 

390 sortmodelsbysenss=True, **tableargs): 

391 """Outputs differences between this solution and another 

392 

393 Arguments 

394 --------- 

395 other : solution or string 

396 strings will be treated as paths to pickled solutions 

397 senssdiff : boolean 

398 if True, show sensitivity differences 

399 sensstol : float 

400 the smallest sensitivity difference worth showing 

401 abssdiff : boolean 

402 if True, show absolute differences 

403 absstol : float 

404 the smallest absolute difference worth showing 

405 reldiff : boolean 

406 if True, show relative differences 

407 reltol : float 

408 the smallest relative difference worth showing 

409 

410 Returns 

411 ------- 

412 str 

413 """ 

414 if sortmodelsbysenss: 

415 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

416 else: 

417 tableargs["sortmodelsbysenss"] = False 

418 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

419 "skipifempty": False}) 

420 if isinstance(other, Strings): 

421 if other[-4:] == ".pgz": 

422 other = SolutionArray.decompress_file(other) 

423 else: 

424 other = pickle.load(open(other, "rb")) 

425 svars, ovars = self["variables"], other["variables"] 

426 lines = ["Solution Diff", 

427 "=============", 

428 "(argument is the baseline solution)", ""] 

429 svks, ovks = set(svars), set(ovars) 

430 if showvars: 

431 lines[0] += " (for selected variables)" 

432 lines[1] += "=========================" 

433 showvars = self._parse_showvars(showvars) 

434 svks = {k for k in showvars if k in svars} 

435 ovks = {k for k in showvars if k in ovars} 

436 if constraintsdiff and other.modelstr and self.modelstr: 

437 if self.modelstr == other.modelstr: 

438 lines += ["** no constraint differences **", ""] 

439 else: 

440 cdiff = ["Constraint Differences", 

441 "**********************"] 

442 cdiff.extend(list(difflib.unified_diff( 

443 other.modelstr.split("\n"), self.modelstr.split("\n"), 

444 lineterm="", n=3))[2:]) 

445 cdiff += ["", "**********************", ""] 

446 lines += cdiff 

447 if svks - ovks: 

448 lines.append("Variable(s) of this solution" 

449 " which are not in the argument:") 

450 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

451 lines.append("") 

452 if ovks - svks: 

453 lines.append("Variable(s) of the argument" 

454 " which are not in this solution:") 

455 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

456 lines.append("") 

457 sharedvks = svks.intersection(ovks) 

458 if reldiff: 

459 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

460 for vk in sharedvks} 

461 lines += var_table(rel_diff, 

462 "Relative Differences |above %g%%|" % reltol, 

463 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

464 minval=reltol, printunits=False, **tableargs) 

465 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

466 lines.insert(-1, ("The largest is %+g%%." 

467 % unrolled_absmax(rel_diff.values()))) 

468 if absdiff: 

469 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

470 lines += var_table(abs_diff, 

471 "Absolute Differences |above %g|" % abstol, 

472 valfmt="%+.2g", vecfmt="%+8.2g", 

473 minval=abstol, **tableargs) 

474 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

475 lines.insert(-1, ("The largest is %+g." 

476 % unrolled_absmax(abs_diff.values()))) 

477 if senssdiff: 

478 ssenss = self["sensitivities"]["variables"] 

479 osenss = other["sensitivities"]["variables"] 

480 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

481 for vk in svks.intersection(ovks)} 

482 lines += var_table(senss_delta, 

483 "Sensitivity Differences |above %g|" % sensstol, 

484 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

485 minval=sensstol, printunits=False, **tableargs) 

486 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

487 lines.insert(-1, ("The largest is %+g." 

488 % unrolled_absmax(senss_delta.values()))) 

489 return "\n".join(lines) 

490 

491 def save(self, filename="solution.pkl", 

492 *, saveconstraints=True, **pickleargs): 

493 """Pickles the solution and saves it to a file. 

494 

495 Solution can then be loaded with e.g.: 

496 >>> import pickle 

497 >>> pickle.load(open("solution.pkl")) 

498 """ 

499 with SolSavingEnvironment(self, saveconstraints): 

500 pickle.dump(self, open(filename, "wb"), **pickleargs) 

501 

502 def save_compressed(self, filename="solution.pgz", 

503 *, saveconstraints=True, **cpickleargs): 

504 "Pickle a file and then compress it into a file with extension." 

505 with gzip.open(filename, "wb") as f: 

506 with SolSavingEnvironment(self, saveconstraints): 

507 pickled = pickle.dumps(self, **cpickleargs) 

508 f.write(pickletools.optimize(pickled)) 

509 

510 @staticmethod 

511 def decompress_file(file): 

512 "Load a gzip-compressed pickle file" 

513 with gzip.open(file, "rb") as f: 

514 return pickle.Unpickler(f).load() 

515 

516 def varnames(self, showvars, exclude): 

517 "Returns list of variables, optionally with minimal unique names" 

518 if showvars: 

519 showvars = self._parse_showvars(showvars) 

520 for key in self.name_collision_varkeys(): 

521 key.descr["necessarylineage"] = True 

522 names = {} 

523 for key in showvars or self["variables"]: 

524 for k in self["variables"].keymap[key]: 

525 names[k.str_without(exclude)] = k 

526 for key in self.name_collision_varkeys(): 

527 del key.descr["necessarylineage"] 

528 return names 

529 

530 def savemat(self, filename="solution.mat", showvars=None, 

531 excluded=("unnecessary lineage", "vec")): 

532 "Saves primal solution as matlab file" 

533 from scipy.io import savemat 

534 savemat(filename, 

535 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

536 for name, key in self.varnames(showvars, excluded).items()}) 

537 

538 def todataframe(self, showvars=None, 

539 excluded=("unnecessary lineage", "vec")): 

540 "Returns primal solution as pandas dataframe" 

541 import pandas as pd # pylint:disable=import-error 

542 rows = [] 

543 cols = ["Name", "Index", "Value", "Units", "Label", 

544 "Lineage", "Other"] 

545 for _, key in sorted(self.varnames(showvars, excluded).items(), 

546 key=lambda k: k[0]): 

547 value = self["variables"][key] 

548 if key.shape: 

549 idxs = [] 

550 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

551 while not it.finished: 

552 idx = it.multi_index 

553 idxs.append(idx[0] if len(idx) == 1 else idx) 

554 it.iternext() 

555 else: 

556 idxs = [None] 

557 for idx in idxs: 

558 row = [ 

559 key.name, 

560 "" if idx is None else idx, 

561 value if idx is None else value[idx]] 

562 rows.append(row) 

563 row.extend([ 

564 key.unitstr(), 

565 key.label or "", 

566 key.lineage or "", 

567 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

568 if k not in ["name", "units", "unitrepr", 

569 "idx", "shape", "veckey", 

570 "value", "original_fn", 

571 "lineage", "label"])]) 

572 return pd.DataFrame(rows, columns=cols) 

573 

574 def savetxt(self, filename="solution.txt", printmodel=True, **kwargs): 

575 "Saves solution table as a text file" 

576 with open(filename, "w") as f: 

577 if printmodel: 

578 f.write(self.modelstr + "\n") 

579 f.write(self.table(**kwargs)) 

580 

581 def savecsv(self, showvars=None, filename="solution.csv", valcols=5): 

582 "Saves primal solution as a CSV sorted by modelname, like the tables." 

583 data = self["variables"] 

584 if showvars: 

585 showvars = self._parse_showvars(showvars) 

586 data = {k: data[k] for k in showvars if k in data} 

587 # if the columns don't capture any dimensions, skip them 

588 minspan, maxspan = None, 1 

589 for v in data.values(): 

590 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

591 minspan_ = min((di for di in v.shape if di != 1)) 

592 maxspan_ = max((di for di in v.shape if di != 1)) 

593 if minspan is None or minspan_ < minspan: 

594 minspan = minspan_ 

595 if maxspan is None or maxspan_ > maxspan: 

596 maxspan = maxspan_ 

597 if minspan is not None and minspan > valcols: 

598 valcols = 1 

599 if maxspan < valcols: 

600 valcols = maxspan 

601 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

602 tables=("cost", "sweepvariables", "freevariables", 

603 "constants", "sensitivities")) 

604 with open(filename, "w") as f: 

605 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

606 + "Units,Description\n") 

607 for line in lines: 

608 if line[0] == ("newmodelline",): 

609 f.write(line[1]) 

610 elif not line[1]: # spacer line 

611 f.write("\n") 

612 else: 

613 f.write("," + line[0].replace(" : ", "") + ",") 

614 vals = line[1].replace("[", "").replace("]", "").strip() 

615 for el in vals.split(): 

616 f.write(el + ",") 

617 f.write(","*(valcols - len(vals.split()))) 

618 f.write((line[2].replace("[", "").replace("]", "").strip() 

619 + ",")) 

620 f.write(line[3].strip() + "\n") 

621 

622 def subinto(self, posy): 

623 "Returns NomialArray of each solution substituted into posy." 

624 if posy in self["variables"]: 

625 return self["variables"](posy) 

626 

627 if not hasattr(posy, "sub"): 

628 raise ValueError("no variable '%s' found in the solution" % posy) 

629 

630 if len(self) > 1: 

631 return NomialArray([self.atindex(i).subinto(posy) 

632 for i in range(len(self))]) 

633 

634 return posy.sub(self["variables"]) 

635 

636 def _parse_showvars(self, showvars): 

637 showvars_out = set() 

638 for k in showvars: 

639 k, _ = self["variables"].parse_and_index(k) 

640 keys = self["variables"].keymap[k] 

641 showvars_out.update(keys) 

642 return showvars_out 

643 

644 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

645 "Print summary table, showing top sensitivities and no constants" 

646 showvars = self._parse_showvars(showvars) 

647 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

648 "freevariables"], **kwargs) 

649 constants_in_showvars = showvars.intersection(self["constants"]) 

650 senss_tables = [] 

651 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

652 senss_tables.append("sensitivities") 

653 if len(self["constants"]) >= ntopsenss+2: 

654 senss_tables.append("top sensitivities") 

655 senss_tables.append("tightest constraints") 

656 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

657 **kwargs) 

658 if senss_str: 

659 out += "\n" + senss_str 

660 return out 

661 

662 def table(self, showvars=(), 

663 tables=("cost", "warnings", "model sensitivities", 

664 "sweepvariables", "freevariables", 

665 "constants", "sensitivities", "tightest constraints"), 

666 sortmodelsbysenss=True, **kwargs): 

667 """A table representation of this SolutionArray 

668 

669 Arguments 

670 --------- 

671 tables: Iterable 

672 Which to print of ("cost", "sweepvariables", "freevariables", 

673 "constants", "sensitivities") 

674 fixedcols: If true, print vectors in fixed-width format 

675 latex: int 

676 If > 0, return latex format (options 1-3); otherwise plain text 

677 included_models: Iterable of strings 

678 If specified, the models (by name) to include 

679 excluded_models: Iterable of strings 

680 If specified, model names to exclude 

681 

682 Returns 

683 ------- 

684 str 

685 """ 

686 if sortmodelsbysenss: 

687 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

688 else: 

689 kwargs["sortmodelsbysenss"] = False 

690 varlist = list(self["variables"]) 

691 has_only_one_model = True 

692 for var in varlist[1:]: 

693 if var.lineage != varlist[0].lineage: 

694 has_only_one_model = False 

695 break 

696 if has_only_one_model: 

697 kwargs["sortbymodel"] = False 

698 for key in self.name_collision_varkeys(): 

699 key.descr["necessarylineage"] = True 

700 showvars = self._parse_showvars(showvars) 

701 strs = [] 

702 for table in tables: 

703 if table == "cost": 

704 cost = self["cost"] # pylint: disable=unsubscriptable-object 

705 if kwargs.get("latex", None): # cost is not printed for latex 

706 continue 

707 strs += ["\n%s\n------------" % "Optimal Cost"] 

708 if len(self) > 1: 

709 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

710 strs += [" [ %s %s ]" % (" ".join(costs), 

711 "..." if len(self) > 4 else "")] 

712 else: 

713 strs += [" %-.4g" % mag(cost)] 

714 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

715 strs += [""] 

716 elif table in TABLEFNS: 

717 strs += TABLEFNS[table](self, showvars, **kwargs) 

718 elif table in self: 

719 data = self[table] 

720 if showvars: 

721 showvars = self._parse_showvars(showvars) 

722 data = {k: data[k] for k in showvars if k in data} 

723 strs += var_table(data, self.table_titles[table], **kwargs) 

724 if kwargs.get("latex", None): 

725 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

726 "% \\usepackage{booktabs}", 

727 "% \\usepackage{longtable}", 

728 "% \\usepackage{amsmath}", 

729 "% \\begin{document}\n")) 

730 strs = [preamble] + strs + ["% \\end{document}"] 

731 for key in self.name_collision_varkeys(): 

732 del key.descr["necessarylineage"] 

733 return "\n".join(strs) 

734 

735 def plot(self, posys=None, axes=None): 

736 "Plots a sweep for each posy" 

737 if len(self["sweepvariables"]) != 1: 

738 print("SolutionArray.plot only supports 1-dimensional sweeps") 

739 if not hasattr(posys, "__len__"): 

740 posys = [posys] 

741 import matplotlib.pyplot as plt 

742 from .interactive.plot_sweep import assign_axes 

743 from . import GPBLU 

744 (swept, x), = self["sweepvariables"].items() 

745 posys, axes = assign_axes(swept, posys, axes) 

746 for posy, ax in zip(posys, axes): 

747 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

748 ax.plot(x, y, color=GPBLU) 

749 if len(axes) == 1: 

750 axes, = axes 

751 return plt.gcf(), axes 

752 

753 

754# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

755def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

756 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

757 minval=0, sortbyvals=False, hidebelowminval=False, 

758 included_models=None, excluded_models=None, sortbymodel=True, 

759 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

760 """ 

761 Pretty string representation of a dict of VarKeys 

762 Iterable values are handled specially (partial printing) 

763 

764 Arguments 

765 --------- 

766 data : dict whose keys are VarKey's 

767 data to represent in table 

768 title : string 

769 printunits : bool 

770 latex : int 

771 If > 0, return latex format (options 1-3); otherwise plain text 

772 varfmt : string 

773 format for variable names 

774 valfmt : string 

775 format for scalar values 

776 vecfmt : string 

777 format for vector values 

778 minval : float 

779 skip values with all(abs(value)) < minval 

780 sortbyvals : boolean 

781 If true, rows are sorted by their average value instead of by name. 

782 included_models : Iterable of strings 

783 If specified, the models (by name) to include 

784 excluded_models : Iterable of strings 

785 If specified, model names to exclude 

786 """ 

787 if not data: 

788 return [] 

789 decorated, models = [], set() 

790 for i, (k, v) in enumerate(data.items()): 

791 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

792 continue # no values below minval 

793 if minval and hidebelowminval and getattr(v, "shape", None): 

794 v[np.abs(v) <= minval] = np.nan 

795 model = lineagestr(k.lineage) if sortbymodel else "" 

796 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

797 if hasattr(msenss, "shape"): 

798 msenss = np.mean(msenss) 

799 models.add(model) 

800 b = bool(getattr(v, "shape", None)) 

801 s = k.str_without(("lineage", "vec")) 

802 if not sortbyvals: 

803 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

804 else: # for consistent sorting, add small offset to negative vals 

805 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

806 sort = (float("%.4g" % -val), k.name) 

807 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

808 if not decorated and skipifempty: 

809 return [] 

810 if included_models: 

811 included_models = set(included_models) 

812 included_models.add("") 

813 models = models.intersection(included_models) 

814 if excluded_models: 

815 models = models.difference(excluded_models) 

816 decorated.sort() 

817 previous_model, lines = None, [] 

818 for varlist in decorated: 

819 if sortbyvals: 

820 model, _, msenss, isvector, varstr, _, var, val = varlist 

821 else: 

822 msenss, model, isvector, varstr, _, var, val = varlist 

823 if model not in models: 

824 continue 

825 if model != previous_model: 

826 if lines: 

827 lines.append(["", "", "", ""]) 

828 if model: 

829 if not latex: 

830 lines.append([("newmodelline",), model, "", ""]) 

831 else: 

832 lines.append( 

833 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

834 previous_model = model 

835 label = var.descr.get("label", "") 

836 units = var.unitstr(" [%s] ") if printunits else "" 

837 if not isvector: 

838 valstr = valfmt % val 

839 else: 

840 last_dim_index = len(val.shape)-1 

841 horiz_dim, ncols = last_dim_index, 1 # starting values 

842 for dim_idx, dim_size in enumerate(val.shape): 

843 if ncols <= dim_size <= maxcolumns: 

844 horiz_dim, ncols = dim_idx, dim_size 

845 # align the array with horiz_dim by making it the last one 

846 dim_order = list(range(last_dim_index)) 

847 dim_order.insert(horiz_dim, last_dim_index) 

848 flatval = val.transpose(dim_order).flatten() 

849 vals = [vecfmt % v for v in flatval[:ncols]] 

850 bracket = " ] " if len(flatval) <= ncols else "" 

851 valstr = "[ %s%s" % (" ".join(vals), bracket) 

852 for before, after in VALSTR_REPLACES: 

853 valstr = valstr.replace(before, after) 

854 if not latex: 

855 lines.append([varstr, valstr, units, label]) 

856 if isvector and len(flatval) > ncols: 

857 values_remaining = len(flatval) - ncols 

858 while values_remaining > 0: 

859 idx = len(flatval)-values_remaining 

860 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

861 values_remaining -= ncols 

862 valstr = " " + " ".join(vals) 

863 for before, after in VALSTR_REPLACES: 

864 valstr = valstr.replace(before, after) 

865 if values_remaining <= 0: 

866 spaces = (-values_remaining 

867 * len(valstr)//(values_remaining + ncols)) 

868 valstr = valstr + " ]" + " "*spaces 

869 lines.append(["", valstr, "", ""]) 

870 else: 

871 varstr = "$%s$" % varstr.replace(" : ", "") 

872 if latex == 1: # normal results table 

873 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

874 label]) 

875 coltitles = [title, "Value", "Units", "Description"] 

876 elif latex == 2: # no values 

877 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

878 coltitles = [title, "Units", "Description"] 

879 elif latex == 3: # no description 

880 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

881 coltitles = [title, "Value", "Units"] 

882 else: 

883 raise ValueError("Unexpected latex option, %s." % latex) 

884 if rawlines: 

885 return lines 

886 if not latex: 

887 if lines: 

888 maxlens = np.max([list(map(len, line)) for line in lines 

889 if line[0] != ("newmodelline",)], axis=0) 

890 dirs = [">", "<", "<", "<"] 

891 # check lengths before using zip 

892 assert len(list(dirs)) == len(list(maxlens)) 

893 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

894 for i, line in enumerate(lines): 

895 if line[0] == ("newmodelline",): 

896 line = [fmts[0].format(" | "), line[1]] 

897 else: 

898 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

899 lines[i] = "".join(line).rstrip() 

900 lines = [title] + ["-"*len(title)] + lines + [""] 

901 else: 

902 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

903 lines = (["\n".join(["{\\footnotesize", 

904 "\\begin{longtable}{%s}" % colfmt[latex], 

905 "\\toprule", 

906 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

907 [" & ".join(l) + " \\\\" for l in lines] + 

908 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

909 return lines