Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray, saveconstraints): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 self.saveconstraints = saveconstraints 

36 self.constraintstore = None 

37 

38 

39 def __enter__(self): 

40 if self.saveconstraints: 

41 for constraint_attr in ["bounded", "meq_bounded", "vks", 

42 "v_ss", "unsubbed", "varkeys"]: 

43 store = {} 

44 for constraint in self.solarray["sensitivities"]["constraints"]: 

45 if getattr(constraint, constraint_attr, None): 

46 store[constraint] = getattr(constraint, constraint_attr) 

47 delattr(constraint, constraint_attr) 

48 self.attrstore[constraint_attr] = store 

49 else: 

50 self.constraintstore = \ 

51 self.solarray["sensitivities"].pop("constraints") 

52 

53 def __exit__(self, type_, val, traceback): 

54 if self.saveconstraints: 

55 for constraint_attr, store in self.attrstore.items(): 

56 for constraint, value in store.items(): 

57 setattr(constraint, constraint_attr, value) 

58 else: 

59 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

60 

61def msenss_table(data, _, **kwargs): 

62 "Returns model sensitivity table lines" 

63 if "models" not in data.get("sensitivities", {}): 

64 return "" 

65 data = sorted(data["sensitivities"]["models"].items(), 

66 key=lambda i: (-round(np.mean(i[1]), 1), i[0])) 

67 lines = ["Model Sensitivities", "-------------------"] 

68 if kwargs["sortmodelsbysenss"]: 

69 lines[0] += " (sorts models in sections below)" 

70 previousmsenssstr = "" 

71 for model, msenss in data: 

72 if not model: # for now let's only do named models 

73 continue 

74 if (msenss < 0.1).all(): 

75 msenss = np.max(msenss) 

76 if msenss: 

77 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss)) 

78 else: 

79 msenssstr = " =0 " 

80 elif not msenss.shape: 

81 msenssstr = "%+6.1f" % msenss 

82 else: 

83 meansenss = np.mean(msenss) 

84 msenssstr = "%+6.1f" % meansenss 

85 deltas = msenss - meansenss 

86 if np.max(np.abs(deltas)) > 0.1: 

87 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

88 for d in deltas] 

89 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

90 if msenssstr == previousmsenssstr: 

91 msenssstr = " " 

92 else: 

93 previousmsenssstr = msenssstr 

94 lines.append("%s : %s" % (msenssstr, model)) 

95 return lines + [""] if len(lines) > 3 else [] 

96 

97 

98def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

99 "Returns sensitivity table lines" 

100 if "variables" in data.get("sensitivities", {}): 

101 data = data["sensitivities"]["variables"] 

102 if showvars: 

103 data = {k: data[k] for k in showvars if k in data} 

104 return var_table(data, title, sortbyvals=True, skipifempty=True, 

105 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

106 printunits=False, minval=1e-3, **kwargs) 

107 

108 

109def topsenss_table(data, showvars, nvars=5, **kwargs): 

110 "Returns top sensitivity table lines" 

111 data, filtered = topsenss_filter(data, showvars, nvars) 

112 title = "Most Sensitive Variables" 

113 if filtered: 

114 title = "Next Most Sensitive Variables" 

115 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

116 

117 

118def topsenss_filter(data, showvars, nvars=5): 

119 "Filters sensitivities down to top N vars" 

120 if "variables" in data.get("sensitivities", {}): 

121 data = data["sensitivities"]["variables"] 

122 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

123 if not np.isnan(s).any()} 

124 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

125 filter_already_shown = showvars.intersection(topk) 

126 for k in filter_already_shown: 

127 topk.remove(k) 

128 if nvars > 3: # always show at least 3 

129 nvars -= 1 

130 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

131 

132 

133def insenss_table(data, _, maxval=0.1, **kwargs): 

134 "Returns insensitivity table lines" 

135 if "constants" in data.get("sensitivities", {}): 

136 data = data["sensitivities"]["variables"] 

137 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

138 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

139 

140 

141def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

142 "Return constraint tightness lines" 

143 title = "Most Sensitive Constraints" 

144 if len(self) > 1: 

145 title += " (in last sweep)" 

146 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

147 "%+6.2g" % s[-1], id(c), c) 

148 for c, s in self["sensitivities"]["constraints"].items() 

149 if s[-1] >= tight_senss)[:ntightconstrs] 

150 else: 

151 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

152 for c, s in self["sensitivities"]["constraints"].items() 

153 if s >= tight_senss)[:ntightconstrs] 

154 return constraint_table(data, title, **kwargs) 

155 

156def loose_table(self, _, min_senss=1e-5, **kwargs): 

157 "Return constraint tightness lines" 

158 title = "Insensitive Constraints |below %+g|" % min_senss 

159 if len(self) > 1: 

160 title += " (in last sweep)" 

161 data = [(0, "", id(c), c) 

162 for c, s in self["sensitivities"]["constraints"].items() 

163 if s[-1] <= min_senss] 

164 else: 

165 data = [(0, "", id(c), c) 

166 for c, s in self["sensitivities"]["constraints"].items() 

167 if s <= min_senss] 

168 return constraint_table(data, title, **kwargs) 

169 

170 

171# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

172def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

173 "Creates lines for tables where the right side is a constraint." 

174 # TODO: this should support 1D array inputs from sweeps 

175 excluded = ("units", "unnecessary lineage") 

176 if not showmodels: 

177 excluded = ("units", "lineage") # hide all of it 

178 models, decorated = {}, [] 

179 for sortby, openingstr, _, constraint in sorted(data): 

180 model = lineagestr(constraint) if sortbymodel else "" 

181 if model not in models: 

182 models[model] = len(models) 

183 constrstr = try_str_without(constraint, excluded) 

184 if " at 0x" in constrstr: # don't print memory addresses 

185 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

186 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

187 decorated.sort() 

188 previous_model, lines = None, [] 

189 for varlist in decorated: 

190 _, model, _, constrstr, openingstr = varlist 

191 if model != previous_model: 

192 if lines: 

193 lines.append(["", ""]) 

194 if model or lines: 

195 lines.append([("newmodelline",), model]) 

196 previous_model = model 

197 constrstr = constrstr.replace(model, "") 

198 minlen, maxlen = 25, 80 

199 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

200 constraintlines = [] 

201 line = "" 

202 next_idx = 0 

203 while next_idx < len(segments): 

204 segment = segments[next_idx] 

205 next_idx += 1 

206 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

207 segments[next_idx] = segment[1:] + segments[next_idx] 

208 segment = segment[0] 

209 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

210 constraintlines.append(line) 

211 line = " " # start a new line 

212 line += segment 

213 while len(line) > maxlen: 

214 constraintlines.append(line[:maxlen]) 

215 line = " " + line[maxlen:] 

216 constraintlines.append(line) 

217 lines += [(openingstr + " : ", constraintlines[0])] 

218 lines += [("", l) for l in constraintlines[1:]] 

219 if not lines: 

220 lines = [("", "(none)")] 

221 maxlens = np.max([list(map(len, line)) for line in lines 

222 if line[0] != ("newmodelline",)], axis=0) 

223 dirs = [">", "<"] # we'll check lengths before using zip 

224 assert len(list(dirs)) == len(list(maxlens)) 

225 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

226 for i, line in enumerate(lines): 

227 if line[0] == ("newmodelline",): 

228 linelist = [fmts[0].format(" | "), line[1]] 

229 else: 

230 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

231 lines[i] = "".join(linelist).rstrip() 

232 return [title] + ["-"*len(title)] + lines + [""] 

233 

234 

235def warnings_table(self, _, **kwargs): 

236 "Makes a table for all warnings in the solution." 

237 title = "WARNINGS" 

238 lines = ["~"*len(title), title, "~"*len(title)] 

239 if "warnings" not in self or not self["warnings"]: 

240 return [] 

241 for wtype in sorted(self["warnings"]): 

242 data_vec = self["warnings"][wtype] 

243 if len(data_vec) == 0: 

244 continue 

245 if not hasattr(data_vec, "shape"): 

246 data_vec = [data_vec] # not a sweep 

247 if all((data == data_vec[0]).all() for data in data_vec[1:]): 

248 data_vec = [data_vec[0]] # warnings identical across all sweeps 

249 for i, data in enumerate(data_vec): 

250 if len(data) == 0: 

251 continue 

252 data = sorted(data, key=lambda l: l[0]) # sort by msg 

253 title = wtype 

254 if len(data_vec) > 1: 

255 title += " in sweep %i" % i 

256 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

257 data = [(-int(1e5*c.relax_sensitivity), 

258 "%+6.2g" % c.relax_sensitivity, id(c), c) 

259 for _, c in data] 

260 lines += constraint_table(data, title, **kwargs) 

261 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

262 data = [(-int(1e5*c.rel_diff), 

263 "%.4g %s %.4g" % c.tightvalues, id(c), c) 

264 for _, c in data] 

265 lines += constraint_table(data, title, **kwargs) 

266 else: 

267 lines += [title] + ["-"*len(wtype)] 

268 lines += [msg for msg, _ in data] + [""] 

269 lines[-1] = "~~~~~~~~" 

270 return lines + [""] 

271 

272 

273TABLEFNS = {"sensitivities": senss_table, 

274 "top sensitivities": topsenss_table, 

275 "insensitivities": insenss_table, 

276 "model sensitivities": msenss_table, 

277 "tightest constraints": tight_table, 

278 "loose constraints": loose_table, 

279 "warnings": warnings_table, 

280 } 

281 

282def unrolled_absmax(values): 

283 "From an iterable of numbers and arrays, returns the largest magnitude" 

284 finalval, absmaxest = None, 0 

285 for val in values: 

286 absmaxval = np.abs(val).max() 

287 if absmaxval >= absmaxest: 

288 absmaxest, finalval = absmaxval, val 

289 if getattr(finalval, "shape", None): 

290 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

291 finalval.shape)] 

292 return finalval 

293 

294 

295def cast(function, val1, val2): 

296 "Relative difference between val1 and val2 (positive if val2 is larger)" 

297 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

298 pywarnings.simplefilter("ignore") 

299 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

300 if val1.ndim == val2.ndim: 

301 return function(val1, val2) 

302 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

303 dimdelta = dimmest.ndim - lessdim.ndim 

304 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

305 if dimmest is val1: 

306 return function(dimmest, lessdim[add_axes]) 

307 if dimmest is val2: 

308 return function(lessdim[add_axes], dimmest) 

309 return function(val1, val2) 

310 

311 

312class SolutionArray(DictOfLists): 

313 """A dictionary (of dictionaries) of lists, with convenience methods. 

314 

315 Items 

316 ----- 

317 cost : array 

318 variables: dict of arrays 

319 sensitivities: dict containing: 

320 monomials : array 

321 posynomials : array 

322 variables: dict of arrays 

323 localmodels : NomialArray 

324 Local power-law fits (small sensitivities are cut off) 

325 

326 Example 

327 ------- 

328 >>> import gpkit 

329 >>> import numpy as np 

330 >>> x = gpkit.Variable("x") 

331 >>> x_min = gpkit.Variable("x_{min}", 2) 

332 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

333 >>> 

334 >>> # VALUES 

335 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

336 >>> assert all(np.array(values) == 2) 

337 >>> 

338 >>> # SENSITIVITIES 

339 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

340 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

341 >>> assert all(np.array(senss) == 1) 

342 """ 

343 modelstr = "" 

344 _name_collision_varkeys = None 

345 table_titles = {"choicevariables": "Choice Variables", 

346 "sweepvariables": "Swept Variables", 

347 "freevariables": "Free Variables", 

348 "constants": "Fixed Variables", # TODO: change everywhere 

349 "variables": "Variables"} 

350 

351 def name_collision_varkeys(self): 

352 "Returns the set of contained varkeys whose names are not unique" 

353 if self._name_collision_varkeys is None: 

354 self["variables"].update_keymap() 

355 keymap = self["variables"].keymap 

356 self._name_collision_varkeys = set() 

357 for key in list(keymap): 

358 if hasattr(key, "key"): 

359 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

360 self._name_collision_varkeys.add(key) 

361 return self._name_collision_varkeys 

362 

363 def __len__(self): 

364 try: 

365 return len(self["cost"]) 

366 except TypeError: 

367 return 1 

368 except KeyError: 

369 return 0 

370 

371 def __call__(self, posy): 

372 posy_subbed = self.subinto(posy) 

373 return getattr(posy_subbed, "c", posy_subbed) 

374 

375 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

376 "Checks for almost-equality between two solutions" 

377 svars, ovars = self["variables"], other["variables"] 

378 svks, ovks = set(svars), set(ovars) 

379 if svks != ovks: 

380 return False 

381 for key in svks: 

382 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

383 return False 

384 if abs(self["sensitivities"]["variables"][key] 

385 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

386 return False 

387 return True 

388 

389 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

390 def diff(self, other, showvars=None, *, 

391 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

392 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

393 sortmodelsbysenss=True, **tableargs): 

394 """Outputs differences between this solution and another 

395 

396 Arguments 

397 --------- 

398 other : solution or string 

399 strings will be treated as paths to pickled solutions 

400 senssdiff : boolean 

401 if True, show sensitivity differences 

402 sensstol : float 

403 the smallest sensitivity difference worth showing 

404 absdiff : boolean 

405 if True, show absolute differences 

406 abstol : float 

407 the smallest absolute difference worth showing 

408 reldiff : boolean 

409 if True, show relative differences 

410 reltol : float 

411 the smallest relative difference worth showing 

412 

413 Returns 

414 ------- 

415 str 

416 """ 

417 if sortmodelsbysenss: 

418 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

419 else: 

420 tableargs["sortmodelsbysenss"] = False 

421 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

422 "skipifempty": False}) 

423 if isinstance(other, Strings): 

424 if other[-4:] == ".pgz": 

425 other = SolutionArray.decompress_file(other) 

426 else: 

427 other = pickle.load(open(other, "rb")) 

428 svars, ovars = self["variables"], other["variables"] 

429 lines = ["Solution Diff", 

430 "=============", 

431 "(argument is the baseline solution)", ""] 

432 svks, ovks = set(svars), set(ovars) 

433 if showvars: 

434 lines[0] += " (for selected variables)" 

435 lines[1] += "=========================" 

436 showvars = self._parse_showvars(showvars) 

437 svks = {k for k in showvars if k in svars} 

438 ovks = {k for k in showvars if k in ovars} 

439 if constraintsdiff and other.modelstr and self.modelstr: 

440 if self.modelstr == other.modelstr: 

441 lines += ["** no constraint differences **", ""] 

442 else: 

443 cdiff = ["Constraint Differences", 

444 "**********************"] 

445 cdiff.extend(list(difflib.unified_diff( 

446 other.modelstr.split("\n"), self.modelstr.split("\n"), 

447 lineterm="", n=3))[2:]) 

448 cdiff += ["", "**********************", ""] 

449 lines += cdiff 

450 if svks - ovks: 

451 lines.append("Variable(s) of this solution" 

452 " which are not in the argument:") 

453 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

454 lines.append("") 

455 if ovks - svks: 

456 lines.append("Variable(s) of the argument" 

457 " which are not in this solution:") 

458 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

459 lines.append("") 

460 sharedvks = svks.intersection(ovks) 

461 if reldiff: 

462 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

463 for vk in sharedvks} 

464 lines += var_table(rel_diff, 

465 "Relative Differences |above %g%%|" % reltol, 

466 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

467 minval=reltol, printunits=False, **tableargs) 

468 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

469 lines.insert(-1, ("The largest is %+g%%." 

470 % unrolled_absmax(rel_diff.values()))) 

471 if absdiff: 

472 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

473 lines += var_table(abs_diff, 

474 "Absolute Differences |above %g|" % abstol, 

475 valfmt="%+.2g", vecfmt="%+8.2g", 

476 minval=abstol, **tableargs) 

477 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

478 lines.insert(-1, ("The largest is %+g." 

479 % unrolled_absmax(abs_diff.values()))) 

480 if senssdiff: 

481 ssenss = self["sensitivities"]["variables"] 

482 osenss = other["sensitivities"]["variables"] 

483 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

484 for vk in svks.intersection(ovks)} 

485 lines += var_table(senss_delta, 

486 "Sensitivity Differences |above %g|" % sensstol, 

487 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

488 minval=sensstol, printunits=False, **tableargs) 

489 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

490 lines.insert(-1, ("The largest is %+g." 

491 % unrolled_absmax(senss_delta.values()))) 

492 return "\n".join(lines) 

493 

494 def save(self, filename="solution.pkl", 

495 *, saveconstraints=True, **pickleargs): 

496 """Pickles the solution and saves it to a file. 

497 

498 Solution can then be loaded with e.g.: 

499 >>> import pickle 

500 >>> pickle.load(open("solution.pkl")) 

501 """ 

502 with SolSavingEnvironment(self, saveconstraints): 

503 pickle.dump(self, open(filename, "wb"), **pickleargs) 

504 

505 def save_compressed(self, filename="solution.pgz", 

506 *, saveconstraints=True, **cpickleargs): 

507 "Pickle a file and then compress it into a file with extension." 

508 with gzip.open(filename, "wb") as f: 

509 with SolSavingEnvironment(self, saveconstraints): 

510 pickled = pickle.dumps(self, **cpickleargs) 

511 f.write(pickletools.optimize(pickled)) 

512 

513 @staticmethod 

514 def decompress_file(file): 

515 "Load a gzip-compressed pickle file" 

516 with gzip.open(file, "rb") as f: 

517 return pickle.Unpickler(f).load() 

518 

519 def varnames(self, showvars, exclude): 

520 "Returns list of variables, optionally with minimal unique names" 

521 if showvars: 

522 showvars = self._parse_showvars(showvars) 

523 for key in self.name_collision_varkeys(): 

524 key.descr["necessarylineage"] = True 

525 names = {} 

526 for key in showvars or self["variables"]: 

527 for k in self["variables"].keymap[key]: 

528 names[k.str_without(exclude)] = k 

529 for key in self.name_collision_varkeys(): 

530 del key.descr["necessarylineage"] 

531 return names 

532 

533 def savemat(self, filename="solution.mat", *, showvars=None, 

534 excluded=("unnecessary lineage", "vec")): 

535 "Saves primal solution as matlab file" 

536 from scipy.io import savemat 

537 savemat(filename, 

538 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

539 for name, key in self.varnames(showvars, excluded).items()}) 

540 

541 def todataframe(self, showvars=None, 

542 excluded=("unnecessary lineage", "vec")): 

543 "Returns primal solution as pandas dataframe" 

544 import pandas as pd # pylint:disable=import-error 

545 rows = [] 

546 cols = ["Name", "Index", "Value", "Units", "Label", 

547 "Lineage", "Other"] 

548 for _, key in sorted(self.varnames(showvars, excluded).items(), 

549 key=lambda k: k[0]): 

550 value = self["variables"][key] 

551 if key.shape: 

552 idxs = [] 

553 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

554 while not it.finished: 

555 idx = it.multi_index 

556 idxs.append(idx[0] if len(idx) == 1 else idx) 

557 it.iternext() 

558 else: 

559 idxs = [None] 

560 for idx in idxs: 

561 row = [ 

562 key.name, 

563 "" if idx is None else idx, 

564 value if idx is None else value[idx]] 

565 rows.append(row) 

566 row.extend([ 

567 key.unitstr(), 

568 key.label or "", 

569 key.lineage or "", 

570 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

571 if k not in ["name", "units", "unitrepr", 

572 "idx", "shape", "veckey", 

573 "value", "vecfn", 

574 "lineage", "label"])]) 

575 return pd.DataFrame(rows, columns=cols) 

576 

577 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

578 "Saves solution table as a text file" 

579 with open(filename, "w") as f: 

580 if printmodel: 

581 f.write(self.modelstr + "\n") 

582 f.write(self.table(**kwargs)) 

583 

584 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

585 "Saves primal solution as a CSV sorted by modelname, like the tables." 

586 data = self["variables"] 

587 if showvars: 

588 showvars = self._parse_showvars(showvars) 

589 data = {k: data[k] for k in showvars if k in data} 

590 # if the columns don't capture any dimensions, skip them 

591 minspan, maxspan = None, 1 

592 for v in data.values(): 

593 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

594 minspan_ = min((di for di in v.shape if di != 1)) 

595 maxspan_ = max((di for di in v.shape if di != 1)) 

596 if minspan is None or minspan_ < minspan: 

597 minspan = minspan_ 

598 if maxspan is None or maxspan_ > maxspan: 

599 maxspan = maxspan_ 

600 if minspan is not None and minspan > valcols: 

601 valcols = 1 

602 if maxspan < valcols: 

603 valcols = maxspan 

604 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

605 tables=("cost", "sweepvariables", "freevariables", 

606 "constants", "sensitivities")) 

607 with open(filename, "w") as f: 

608 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

609 + "Units,Description\n") 

610 for line in lines: 

611 if line[0] == ("newmodelline",): 

612 f.write(line[1]) 

613 elif not line[1]: # spacer line 

614 f.write("\n") 

615 else: 

616 f.write("," + line[0].replace(" : ", "") + ",") 

617 vals = line[1].replace("[", "").replace("]", "").strip() 

618 for el in vals.split(): 

619 f.write(el + ",") 

620 f.write(","*(valcols - len(vals.split()))) 

621 f.write((line[2].replace("[", "").replace("]", "").strip() 

622 + ",")) 

623 f.write(line[3].strip() + "\n") 

624 

625 def subinto(self, posy): 

626 "Returns NomialArray of each solution substituted into posy." 

627 if posy in self["variables"]: 

628 return self["variables"](posy) 

629 

630 if not hasattr(posy, "sub"): 

631 raise ValueError("no variable '%s' found in the solution" % posy) 

632 

633 if len(self) > 1: 

634 return NomialArray([self.atindex(i).subinto(posy) 

635 for i in range(len(self))]) 

636 

637 return posy.sub(self["variables"]) 

638 

639 def _parse_showvars(self, showvars): 

640 showvars_out = set() 

641 for k in showvars: 

642 k, _ = self["variables"].parse_and_index(k) 

643 keys = self["variables"].keymap[k] 

644 showvars_out.update(keys) 

645 return showvars_out 

646 

647 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

648 "Print summary table, showing top sensitivities and no constants" 

649 showvars = self._parse_showvars(showvars) 

650 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

651 "freevariables"], **kwargs) 

652 constants_in_showvars = showvars.intersection(self["constants"]) 

653 senss_tables = [] 

654 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

655 senss_tables.append("sensitivities") 

656 if len(self["constants"]) >= ntopsenss+2: 

657 senss_tables.append("top sensitivities") 

658 senss_tables.append("tightest constraints") 

659 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

660 **kwargs) 

661 if senss_str: 

662 out += "\n" + senss_str 

663 return out 

664 

665 def table(self, showvars=(), 

666 tables=("cost", "warnings", "model sensitivities", 

667 "sweepvariables", "freevariables", 

668 "constants", "sensitivities", "tightest constraints"), 

669 sortmodelsbysenss=True, **kwargs): 

670 """A table representation of this SolutionArray 

671 

672 Arguments 

673 --------- 

674 tables: Iterable 

675 Which to print of ("cost", "sweepvariables", "freevariables", 

676 "constants", "sensitivities") 

677 fixedcols: If true, print vectors in fixed-width format 

678 latex: int 

679 If > 0, return latex format (options 1-3); otherwise plain text 

680 included_models: Iterable of strings 

681 If specified, the models (by name) to include 

682 excluded_models: Iterable of strings 

683 If specified, model names to exclude 

684 

685 Returns 

686 ------- 

687 str 

688 """ 

689 if sortmodelsbysenss and "sensitivities" in self: 

690 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

691 else: 

692 kwargs["sortmodelsbysenss"] = False 

693 varlist = list(self["variables"]) 

694 has_only_one_model = True 

695 for var in varlist[1:]: 

696 if var.lineage != varlist[0].lineage: 

697 has_only_one_model = False 

698 break 

699 if has_only_one_model: 

700 kwargs["sortbymodel"] = False 

701 for key in self.name_collision_varkeys(): 

702 key.descr["necessarylineage"] = True 

703 showvars = self._parse_showvars(showvars) 

704 strs = [] 

705 for table in tables: 

706 if "sensitivities" not in self and ("sensitivities" in table or 

707 "constraints" in table): 

708 continue 

709 if table == "cost": 

710 cost = self["cost"] # pylint: disable=unsubscriptable-object 

711 if kwargs.get("latex", None): # cost is not printed for latex 

712 continue 

713 strs += ["\n%s\n------------" % "Optimal Cost"] 

714 if len(self) > 1: 

715 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

716 strs += [" [ %s %s ]" % (" ".join(costs), 

717 "..." if len(self) > 4 else "")] 

718 else: 

719 strs += [" %-.4g" % mag(cost)] 

720 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

721 strs += [""] 

722 elif table in TABLEFNS: 

723 strs += TABLEFNS[table](self, showvars, **kwargs) 

724 elif table in self: 

725 data = self[table] 

726 if showvars: 

727 showvars = self._parse_showvars(showvars) 

728 data = {k: data[k] for k in showvars if k in data} 

729 strs += var_table(data, self.table_titles[table], **kwargs) 

730 if kwargs.get("latex", None): 

731 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

732 "% \\usepackage{booktabs}", 

733 "% \\usepackage{longtable}", 

734 "% \\usepackage{amsmath}", 

735 "% \\begin{document}\n")) 

736 strs = [preamble] + strs + ["% \\end{document}"] 

737 for key in self.name_collision_varkeys(): 

738 del key.descr["necessarylineage"] 

739 return "\n".join(strs) 

740 

741 def plot(self, posys=None, axes=None): 

742 "Plots a sweep for each posy" 

743 if len(self["sweepvariables"]) != 1: 

744 print("SolutionArray.plot only supports 1-dimensional sweeps") 

745 if not hasattr(posys, "__len__"): 

746 posys = [posys] 

747 import matplotlib.pyplot as plt 

748 from .interactive.plot_sweep import assign_axes 

749 from . import GPBLU 

750 (swept, x), = self["sweepvariables"].items() 

751 posys, axes = assign_axes(swept, posys, axes) 

752 for posy, ax in zip(posys, axes): 

753 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

754 ax.plot(x, y, color=GPBLU) 

755 if len(axes) == 1: 

756 axes, = axes 

757 return plt.gcf(), axes 

758 

759 

760# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

761def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

762 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

763 minval=0, sortbyvals=False, hidebelowminval=False, 

764 included_models=None, excluded_models=None, sortbymodel=True, 

765 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

766 """ 

767 Pretty string representation of a dict of VarKeys 

768 Iterable values are handled specially (partial printing) 

769 

770 Arguments 

771 --------- 

772 data : dict whose keys are VarKey's 

773 data to represent in table 

774 title : string 

775 printunits : bool 

776 latex : int 

777 If > 0, return latex format (options 1-3); otherwise plain text 

778 varfmt : string 

779 format for variable names 

780 valfmt : string 

781 format for scalar values 

782 vecfmt : string 

783 format for vector values 

784 minval : float 

785 skip values with all(abs(value)) < minval 

786 sortbyvals : boolean 

787 If true, rows are sorted by their average value instead of by name. 

788 included_models : Iterable of strings 

789 If specified, the models (by name) to include 

790 excluded_models : Iterable of strings 

791 If specified, model names to exclude 

792 """ 

793 if not data: 

794 return [] 

795 decorated, models = [], set() 

796 for i, (k, v) in enumerate(data.items()): 

797 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

798 continue # no values below minval 

799 if minval and hidebelowminval and getattr(v, "shape", None): 

800 v[np.abs(v) <= minval] = np.nan 

801 model = lineagestr(k.lineage) if sortbymodel else "" 

802 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

803 if hasattr(msenss, "shape"): 

804 msenss = np.mean(msenss) 

805 models.add(model) 

806 b = bool(getattr(v, "shape", None)) 

807 s = k.str_without(("lineage", "vec")) 

808 if not sortbyvals: 

809 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

810 else: # for consistent sorting, add small offset to negative vals 

811 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

812 sort = (float("%.4g" % -val), k.name) 

813 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

814 if not decorated and skipifempty: 

815 return [] 

816 if included_models: 

817 included_models = set(included_models) 

818 included_models.add("") 

819 models = models.intersection(included_models) 

820 if excluded_models: 

821 models = models.difference(excluded_models) 

822 decorated.sort() 

823 previous_model, lines = None, [] 

824 for varlist in decorated: 

825 if sortbyvals: 

826 model, _, msenss, isvector, varstr, _, var, val = varlist 

827 else: 

828 msenss, model, isvector, varstr, _, var, val = varlist 

829 if model not in models: 

830 continue 

831 if model != previous_model: 

832 if lines: 

833 lines.append(["", "", "", ""]) 

834 if model: 

835 if not latex: 

836 lines.append([("newmodelline",), model, "", ""]) 

837 else: 

838 lines.append( 

839 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

840 previous_model = model 

841 label = var.descr.get("label", "") 

842 units = var.unitstr(" [%s] ") if printunits else "" 

843 if not isvector: 

844 valstr = valfmt % val 

845 else: 

846 last_dim_index = len(val.shape)-1 

847 horiz_dim, ncols = last_dim_index, 1 # starting values 

848 for dim_idx, dim_size in enumerate(val.shape): 

849 if ncols <= dim_size <= maxcolumns: 

850 horiz_dim, ncols = dim_idx, dim_size 

851 # align the array with horiz_dim by making it the last one 

852 dim_order = list(range(last_dim_index)) 

853 dim_order.insert(horiz_dim, last_dim_index) 

854 flatval = val.transpose(dim_order).flatten() 

855 vals = [vecfmt % v for v in flatval[:ncols]] 

856 bracket = " ] " if len(flatval) <= ncols else "" 

857 valstr = "[ %s%s" % (" ".join(vals), bracket) 

858 for before, after in VALSTR_REPLACES: 

859 valstr = valstr.replace(before, after) 

860 if not latex: 

861 lines.append([varstr, valstr, units, label]) 

862 if isvector and len(flatval) > ncols: 

863 values_remaining = len(flatval) - ncols 

864 while values_remaining > 0: 

865 idx = len(flatval)-values_remaining 

866 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

867 values_remaining -= ncols 

868 valstr = " " + " ".join(vals) 

869 for before, after in VALSTR_REPLACES: 

870 valstr = valstr.replace(before, after) 

871 if values_remaining <= 0: 

872 spaces = (-values_remaining 

873 * len(valstr)//(values_remaining + ncols)) 

874 valstr = valstr + " ]" + " "*spaces 

875 lines.append(["", valstr, "", ""]) 

876 else: 

877 varstr = "$%s$" % varstr.replace(" : ", "") 

878 if latex == 1: # normal results table 

879 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

880 label]) 

881 coltitles = [title, "Value", "Units", "Description"] 

882 elif latex == 2: # no values 

883 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

884 coltitles = [title, "Units", "Description"] 

885 elif latex == 3: # no description 

886 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

887 coltitles = [title, "Value", "Units"] 

888 else: 

889 raise ValueError("Unexpected latex option, %s." % latex) 

890 if rawlines: 

891 return lines 

892 if not latex: 

893 if lines: 

894 maxlens = np.max([list(map(len, line)) for line in lines 

895 if line[0] != ("newmodelline",)], axis=0) 

896 dirs = [">", "<", "<", "<"] 

897 # check lengths before using zip 

898 assert len(list(dirs)) == len(list(maxlens)) 

899 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

900 for i, line in enumerate(lines): 

901 if line[0] == ("newmodelline",): 

902 line = [fmts[0].format(" | "), line[1]] 

903 else: 

904 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

905 lines[i] = "".join(line).rstrip() 

906 lines = [title] + ["-"*len(title)] + lines + [""] 

907 else: 

908 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

909 lines = (["\n".join(["{\\footnotesize", 

910 "\\begin{longtable}{%s}" % colfmt[latex], 

911 "\\toprule", 

912 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

913 [" & ".join(l) + " \\\\" for l in lines] + 

914 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

915 return lines