Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray, saveconstraints): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 self.saveconstraints = saveconstraints 

36 self.constraintstore = None 

37 

38 

39 def __enter__(self): 

40 if self.saveconstraints: 

41 for constraint_attr in ["bounded", "meq_bounded", "vks", 

42 "v_ss", "unsubbed", "varkeys"]: 

43 store = {} 

44 for constraint in self.solarray["sensitivities"]["constraints"]: 

45 if getattr(constraint, constraint_attr, None): 

46 store[constraint] = getattr(constraint, constraint_attr) 

47 delattr(constraint, constraint_attr) 

48 self.attrstore[constraint_attr] = store 

49 else: 

50 self.constraintstore = \ 

51 self.solarray["sensitivities"].pop("constraints") 

52 

53 def __exit__(self, type_, val, traceback): 

54 if self.saveconstraints: 

55 for constraint_attr, store in self.attrstore.items(): 

56 for constraint, value in store.items(): 

57 setattr(constraint, constraint_attr, value) 

58 else: 

59 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

60 

61def msenss_table(data, _, **kwargs): 

62 "Returns model sensitivity table lines" 

63 if "models" not in data.get("sensitivities", {}): 

64 return "" 

65 data = sorted(data["sensitivities"]["models"].items(), 

66 key=lambda i: (-round(np.mean(i[1]), 1), i[0])) 

67 lines = ["Model Sensitivities", "-------------------"] 

68 if kwargs["sortmodelsbysenss"]: 

69 lines[0] += " (sorts models in sections below)" 

70 previousmsenssstr = "" 

71 for model, msenss in data: 

72 if not model: # for now let's only do named models 

73 continue 

74 if (msenss < 0.1).all(): 

75 msenss = np.max(msenss) 

76 if msenss: 

77 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss)) 

78 else: 

79 msenssstr = " =0 " 

80 elif not msenss.shape: 

81 msenssstr = "%+6.1f" % msenss 

82 else: 

83 meansenss = np.mean(msenss) 

84 msenssstr = "%+6.1f" % meansenss 

85 deltas = msenss - meansenss 

86 if np.max(np.abs(deltas)) > 0.1: 

87 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

88 for d in deltas] 

89 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

90 if msenssstr == previousmsenssstr: 

91 msenssstr = " " 

92 else: 

93 previousmsenssstr = msenssstr 

94 lines.append("%s : %s" % (msenssstr, model)) 

95 return lines + [""] if len(lines) > 3 else [] 

96 

97 

98def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

99 "Returns sensitivity table lines" 

100 if "variables" in data.get("sensitivities", {}): 

101 data = data["sensitivities"]["variables"] 

102 if showvars: 

103 data = {k: data[k] for k in showvars if k in data} 

104 return var_table(data, title, sortbyvals=True, skipifempty=True, 

105 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

106 printunits=False, minval=1e-3, **kwargs) 

107 

108 

109def topsenss_table(data, showvars, nvars=5, **kwargs): 

110 "Returns top sensitivity table lines" 

111 data, filtered = topsenss_filter(data, showvars, nvars) 

112 title = "Most Sensitive Variables" 

113 if filtered: 

114 title = "Next Most Sensitive Variables" 

115 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

116 

117 

118def topsenss_filter(data, showvars, nvars=5): 

119 "Filters sensitivities down to top N vars" 

120 if "variables" in data.get("sensitivities", {}): 

121 data = data["sensitivities"]["variables"] 

122 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

123 if not np.isnan(s).any()} 

124 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

125 filter_already_shown = showvars.intersection(topk) 

126 for k in filter_already_shown: 

127 topk.remove(k) 

128 if nvars > 3: # always show at least 3 

129 nvars -= 1 

130 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

131 

132 

133def insenss_table(data, _, maxval=0.1, **kwargs): 

134 "Returns insensitivity table lines" 

135 if "constants" in data.get("sensitivities", {}): 

136 data = data["sensitivities"]["variables"] 

137 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

138 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

139 

140 

141def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

142 "Return constraint tightness lines" 

143 title = "Most Sensitive Constraints" 

144 if len(self) > 1: 

145 title += " (in last sweep)" 

146 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)), 

147 "%+6.2g" % abs(s[-1]), id(c), c) 

148 for c, s in self["sensitivities"]["constraints"].items() 

149 if s[-1] >= tight_senss)[:ntightconstrs] 

150 else: 

151 data = sorted(((-float("%+6.2g" % abs(s)), str(c)), 

152 "%+6.2g" % abs(s), id(c), c) 

153 for c, s in self["sensitivities"]["constraints"].items() 

154 if s >= tight_senss)[:ntightconstrs] 

155 return constraint_table(data, title, **kwargs) 

156 

157def loose_table(self, _, min_senss=1e-5, **kwargs): 

158 "Return constraint tightness lines" 

159 title = "Insensitive Constraints |below %+g|" % min_senss 

160 if len(self) > 1: 

161 title += " (in last sweep)" 

162 data = [(0, "", id(c), c) 

163 for c, s in self["sensitivities"]["constraints"].items() 

164 if s[-1] <= min_senss] 

165 else: 

166 data = [(0, "", id(c), c) 

167 for c, s in self["sensitivities"]["constraints"].items() 

168 if s <= min_senss] 

169 return constraint_table(data, title, **kwargs) 

170 

171 

172# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

174 "Creates lines for tables where the right side is a constraint." 

175 # TODO: this should support 1D array inputs from sweeps 

176 excluded = ("units", "unnecessary lineage") 

177 if not showmodels: 

178 excluded = ("units", "lineage") # hide all of it 

179 models, decorated = {}, [] 

180 for sortby, openingstr, _, constraint in sorted(data): 

181 model = lineagestr(constraint) if sortbymodel else "" 

182 if model not in models: 

183 models[model] = len(models) 

184 constrstr = try_str_without(constraint, excluded) 

185 if " at 0x" in constrstr: # don't print memory addresses 

186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

187 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

188 decorated.sort() 

189 previous_model, lines = None, [] 

190 for varlist in decorated: 

191 _, model, _, constrstr, openingstr = varlist 

192 if model != previous_model: 

193 if lines: 

194 lines.append(["", ""]) 

195 if model or lines: 

196 lines.append([("newmodelline",), model]) 

197 previous_model = model 

198 constrstr = constrstr.replace(model, "") 

199 minlen, maxlen = 25, 80 

200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

201 constraintlines = [] 

202 line = "" 

203 next_idx = 0 

204 while next_idx < len(segments): 

205 segment = segments[next_idx] 

206 next_idx += 1 

207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

208 segments[next_idx] = segment[1:] + segments[next_idx] 

209 segment = segment[0] 

210 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

211 constraintlines.append(line) 

212 line = " " # start a new line 

213 line += segment 

214 while len(line) > maxlen: 

215 constraintlines.append(line[:maxlen]) 

216 line = " " + line[maxlen:] 

217 constraintlines.append(line) 

218 lines += [(openingstr + " : ", constraintlines[0])] 

219 lines += [("", l) for l in constraintlines[1:]] 

220 if not lines: 

221 lines = [("", "(none)")] 

222 maxlens = np.max([list(map(len, line)) for line in lines 

223 if line[0] != ("newmodelline",)], axis=0) 

224 dirs = [">", "<"] # we'll check lengths before using zip 

225 assert len(list(dirs)) == len(list(maxlens)) 

226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

227 for i, line in enumerate(lines): 

228 if line[0] == ("newmodelline",): 

229 linelist = [fmts[0].format(" | "), line[1]] 

230 else: 

231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

232 lines[i] = "".join(linelist).rstrip() 

233 return [title] + ["-"*len(title)] + lines + [""] 

234 

235 

236def warnings_table(self, _, **kwargs): 

237 "Makes a table for all warnings in the solution." 

238 title = "WARNINGS" 

239 lines = ["~"*len(title), title, "~"*len(title)] 

240 if "warnings" not in self or not self["warnings"]: 

241 return [] 

242 for wtype in sorted(self["warnings"]): 

243 data_vec = self["warnings"][wtype] 

244 if len(data_vec) == 0: 

245 continue 

246 if not hasattr(data_vec, "shape"): 

247 data_vec = [data_vec] # not a sweep 

248 if all((data == data_vec[0]).all() for data in data_vec[1:]): 

249 data_vec = [data_vec[0]] # warnings identical across all sweeps 

250 for i, data in enumerate(data_vec): 

251 if len(data) == 0: 

252 continue 

253 data = sorted(data, key=lambda l: l[0]) # sort by msg 

254 title = wtype 

255 if len(data_vec) > 1: 

256 title += " in sweep %i" % i 

257 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

258 data = [(-int(1e5*c.relax_sensitivity), 

259 "%+6.2g" % c.relax_sensitivity, id(c), c) 

260 for _, c in data] 

261 lines += constraint_table(data, title, **kwargs) 

262 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

263 data = [(-int(1e5*c.rel_diff), 

264 "%.4g %s %.4g" % c.tightvalues, id(c), c) 

265 for _, c in data] 

266 lines += constraint_table(data, title, **kwargs) 

267 else: 

268 lines += [title] + ["-"*len(wtype)] 

269 lines += [msg for msg, _ in data] + [""] 

270 lines[-1] = "~~~~~~~~" 

271 return lines + [""] 

272 

273 

274TABLEFNS = {"sensitivities": senss_table, 

275 "top sensitivities": topsenss_table, 

276 "insensitivities": insenss_table, 

277 "model sensitivities": msenss_table, 

278 "tightest constraints": tight_table, 

279 "loose constraints": loose_table, 

280 "warnings": warnings_table, 

281 } 

282 

283def unrolled_absmax(values): 

284 "From an iterable of numbers and arrays, returns the largest magnitude" 

285 finalval, absmaxest = None, 0 

286 for val in values: 

287 absmaxval = np.abs(val).max() 

288 if absmaxval >= absmaxest: 

289 absmaxest, finalval = absmaxval, val 

290 if getattr(finalval, "shape", None): 

291 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

292 finalval.shape)] 

293 return finalval 

294 

295 

296def cast(function, val1, val2): 

297 "Relative difference between val1 and val2 (positive if val2 is larger)" 

298 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

299 pywarnings.simplefilter("ignore") 

300 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

301 if val1.ndim == val2.ndim: 

302 return function(val1, val2) 

303 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

304 dimdelta = dimmest.ndim - lessdim.ndim 

305 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

306 if dimmest is val1: 

307 return function(dimmest, lessdim[add_axes]) 

308 if dimmest is val2: 

309 return function(lessdim[add_axes], dimmest) 

310 return function(val1, val2) 

311 

312 

313class SolutionArray(DictOfLists): 

314 """A dictionary (of dictionaries) of lists, with convenience methods. 

315 

316 Items 

317 ----- 

318 cost : array 

319 variables: dict of arrays 

320 sensitivities: dict containing: 

321 monomials : array 

322 posynomials : array 

323 variables: dict of arrays 

324 localmodels : NomialArray 

325 Local power-law fits (small sensitivities are cut off) 

326 

327 Example 

328 ------- 

329 >>> import gpkit 

330 >>> import numpy as np 

331 >>> x = gpkit.Variable("x") 

332 >>> x_min = gpkit.Variable("x_{min}", 2) 

333 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

334 >>> 

335 >>> # VALUES 

336 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

337 >>> assert all(np.array(values) == 2) 

338 >>> 

339 >>> # SENSITIVITIES 

340 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

341 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

342 >>> assert all(np.array(senss) == 1) 

343 """ 

344 modelstr = "" 

345 _name_collision_varkeys = None 

346 table_titles = {"choicevariables": "Choice Variables", 

347 "sweepvariables": "Swept Variables", 

348 "freevariables": "Free Variables", 

349 "constants": "Fixed Variables", # TODO: change everywhere 

350 "variables": "Variables"} 

351 

352 def name_collision_varkeys(self): 

353 "Returns the set of contained varkeys whose names are not unique" 

354 if self._name_collision_varkeys is None: 

355 self["variables"].update_keymap() 

356 keymap = self["variables"].keymap 

357 self._name_collision_varkeys = set() 

358 for key in list(keymap): 

359 if hasattr(key, "key"): 

360 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

361 self._name_collision_varkeys.add(key) 

362 return self._name_collision_varkeys 

363 

364 def __len__(self): 

365 try: 

366 return len(self["cost"]) 

367 except TypeError: 

368 return 1 

369 except KeyError: 

370 return 0 

371 

372 def __call__(self, posy): 

373 posy_subbed = self.subinto(posy) 

374 return getattr(posy_subbed, "c", posy_subbed) 

375 

376 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

377 "Checks for almost-equality between two solutions" 

378 svars, ovars = self["variables"], other["variables"] 

379 svks, ovks = set(svars), set(ovars) 

380 if svks != ovks: 

381 return False 

382 for key in svks: 

383 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

384 return False 

385 if abs(self["sensitivities"]["variables"][key] 

386 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

387 return False 

388 return True 

389 

390 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

391 def diff(self, other, showvars=None, *, 

392 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

393 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

394 sortmodelsbysenss=True, **tableargs): 

395 """Outputs differences between this solution and another 

396 

397 Arguments 

398 --------- 

399 other : solution or string 

400 strings will be treated as paths to pickled solutions 

401 senssdiff : boolean 

402 if True, show sensitivity differences 

403 sensstol : float 

404 the smallest sensitivity difference worth showing 

405 absdiff : boolean 

406 if True, show absolute differences 

407 abstol : float 

408 the smallest absolute difference worth showing 

409 reldiff : boolean 

410 if True, show relative differences 

411 reltol : float 

412 the smallest relative difference worth showing 

413 

414 Returns 

415 ------- 

416 str 

417 """ 

418 if sortmodelsbysenss: 

419 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

420 else: 

421 tableargs["sortmodelsbysenss"] = False 

422 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

423 "skipifempty": False}) 

424 if isinstance(other, Strings): 

425 if other[-4:] == ".pgz": 

426 other = SolutionArray.decompress_file(other) 

427 else: 

428 other = pickle.load(open(other, "rb")) 

429 svars, ovars = self["variables"], other["variables"] 

430 lines = ["Solution Diff", 

431 "=============", 

432 "(argument is the baseline solution)", ""] 

433 svks, ovks = set(svars), set(ovars) 

434 if showvars: 

435 lines[0] += " (for selected variables)" 

436 lines[1] += "=========================" 

437 showvars = self._parse_showvars(showvars) 

438 svks = {k for k in showvars if k in svars} 

439 ovks = {k for k in showvars if k in ovars} 

440 if constraintsdiff and other.modelstr and self.modelstr: 

441 if self.modelstr == other.modelstr: 

442 lines += ["** no constraint differences **", ""] 

443 else: 

444 cdiff = ["Constraint Differences", 

445 "**********************"] 

446 cdiff.extend(list(difflib.unified_diff( 

447 other.modelstr.split("\n"), self.modelstr.split("\n"), 

448 lineterm="", n=3))[2:]) 

449 cdiff += ["", "**********************", ""] 

450 lines += cdiff 

451 if svks - ovks: 

452 lines.append("Variable(s) of this solution" 

453 " which are not in the argument:") 

454 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

455 lines.append("") 

456 if ovks - svks: 

457 lines.append("Variable(s) of the argument" 

458 " which are not in this solution:") 

459 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

460 lines.append("") 

461 sharedvks = svks.intersection(ovks) 

462 if reldiff: 

463 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

464 for vk in sharedvks} 

465 lines += var_table(rel_diff, 

466 "Relative Differences |above %g%%|" % reltol, 

467 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

468 minval=reltol, printunits=False, **tableargs) 

469 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

470 lines.insert(-1, ("The largest is %+g%%." 

471 % unrolled_absmax(rel_diff.values()))) 

472 if absdiff: 

473 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

474 lines += var_table(abs_diff, 

475 "Absolute Differences |above %g|" % abstol, 

476 valfmt="%+.2g", vecfmt="%+8.2g", 

477 minval=abstol, **tableargs) 

478 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

479 lines.insert(-1, ("The largest is %+g." 

480 % unrolled_absmax(abs_diff.values()))) 

481 if senssdiff: 

482 ssenss = self["sensitivities"]["variables"] 

483 osenss = other["sensitivities"]["variables"] 

484 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

485 for vk in svks.intersection(ovks)} 

486 lines += var_table(senss_delta, 

487 "Sensitivity Differences |above %g|" % sensstol, 

488 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

489 minval=sensstol, printunits=False, **tableargs) 

490 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

491 lines.insert(-1, ("The largest is %+g." 

492 % unrolled_absmax(senss_delta.values()))) 

493 return "\n".join(lines) 

494 

495 def save(self, filename="solution.pkl", 

496 *, saveconstraints=True, **pickleargs): 

497 """Pickles the solution and saves it to a file. 

498 

499 Solution can then be loaded with e.g.: 

500 >>> import pickle 

501 >>> pickle.load(open("solution.pkl")) 

502 """ 

503 with SolSavingEnvironment(self, saveconstraints): 

504 pickle.dump(self, open(filename, "wb"), **pickleargs) 

505 

506 def save_compressed(self, filename="solution.pgz", 

507 *, saveconstraints=True, **cpickleargs): 

508 "Pickle a file and then compress it into a file with extension." 

509 with gzip.open(filename, "wb") as f: 

510 with SolSavingEnvironment(self, saveconstraints): 

511 pickled = pickle.dumps(self, **cpickleargs) 

512 f.write(pickletools.optimize(pickled)) 

513 

514 @staticmethod 

515 def decompress_file(file): 

516 "Load a gzip-compressed pickle file" 

517 with gzip.open(file, "rb") as f: 

518 return pickle.Unpickler(f).load() 

519 

520 def varnames(self, showvars, exclude): 

521 "Returns list of variables, optionally with minimal unique names" 

522 if showvars: 

523 showvars = self._parse_showvars(showvars) 

524 for key in self.name_collision_varkeys(): 

525 key.descr["necessarylineage"] = True 

526 names = {} 

527 for key in showvars or self["variables"]: 

528 for k in self["variables"].keymap[key]: 

529 names[k.str_without(exclude)] = k 

530 for key in self.name_collision_varkeys(): 

531 del key.descr["necessarylineage"] 

532 return names 

533 

534 def savemat(self, filename="solution.mat", *, showvars=None, 

535 excluded=("unnecessary lineage", "vec")): 

536 "Saves primal solution as matlab file" 

537 from scipy.io import savemat 

538 savemat(filename, 

539 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

540 for name, key in self.varnames(showvars, excluded).items()}) 

541 

542 def todataframe(self, showvars=None, 

543 excluded=("unnecessary lineage", "vec")): 

544 "Returns primal solution as pandas dataframe" 

545 import pandas as pd # pylint:disable=import-error 

546 rows = [] 

547 cols = ["Name", "Index", "Value", "Units", "Label", 

548 "Lineage", "Other"] 

549 for _, key in sorted(self.varnames(showvars, excluded).items(), 

550 key=lambda k: k[0]): 

551 value = self["variables"][key] 

552 if key.shape: 

553 idxs = [] 

554 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

555 while not it.finished: 

556 idx = it.multi_index 

557 idxs.append(idx[0] if len(idx) == 1 else idx) 

558 it.iternext() 

559 else: 

560 idxs = [None] 

561 for idx in idxs: 

562 row = [ 

563 key.name, 

564 "" if idx is None else idx, 

565 value if idx is None else value[idx]] 

566 rows.append(row) 

567 row.extend([ 

568 key.unitstr(), 

569 key.label or "", 

570 key.lineage or "", 

571 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

572 if k not in ["name", "units", "unitrepr", 

573 "idx", "shape", "veckey", 

574 "value", "vecfn", 

575 "lineage", "label"])]) 

576 return pd.DataFrame(rows, columns=cols) 

577 

578 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

579 "Saves solution table as a text file" 

580 with open(filename, "w") as f: 

581 if printmodel: 

582 f.write(self.modelstr + "\n") 

583 f.write(self.table(**kwargs)) 

584 

585 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

586 "Saves primal solution as a CSV sorted by modelname, like the tables." 

587 data = self["variables"] 

588 if showvars: 

589 showvars = self._parse_showvars(showvars) 

590 data = {k: data[k] for k in showvars if k in data} 

591 # if the columns don't capture any dimensions, skip them 

592 minspan, maxspan = None, 1 

593 for v in data.values(): 

594 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

595 minspan_ = min((di for di in v.shape if di != 1)) 

596 maxspan_ = max((di for di in v.shape if di != 1)) 

597 if minspan is None or minspan_ < minspan: 

598 minspan = minspan_ 

599 if maxspan is None or maxspan_ > maxspan: 

600 maxspan = maxspan_ 

601 if minspan is not None and minspan > valcols: 

602 valcols = 1 

603 if maxspan < valcols: 

604 valcols = maxspan 

605 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

606 tables=("cost", "sweepvariables", "freevariables", 

607 "constants", "sensitivities")) 

608 with open(filename, "w") as f: 

609 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

610 + "Units,Description\n") 

611 for line in lines: 

612 if line[0] == ("newmodelline",): 

613 f.write(line[1]) 

614 elif not line[1]: # spacer line 

615 f.write("\n") 

616 else: 

617 f.write("," + line[0].replace(" : ", "") + ",") 

618 vals = line[1].replace("[", "").replace("]", "").strip() 

619 for el in vals.split(): 

620 f.write(el + ",") 

621 f.write(","*(valcols - len(vals.split()))) 

622 f.write((line[2].replace("[", "").replace("]", "").strip() 

623 + ",")) 

624 f.write(line[3].strip() + "\n") 

625 

626 def subinto(self, posy): 

627 "Returns NomialArray of each solution substituted into posy." 

628 if posy in self["variables"]: 

629 return self["variables"](posy) 

630 

631 if not hasattr(posy, "sub"): 

632 raise ValueError("no variable '%s' found in the solution" % posy) 

633 

634 if len(self) > 1: 

635 return NomialArray([self.atindex(i).subinto(posy) 

636 for i in range(len(self))]) 

637 

638 return posy.sub(self["variables"], require_positive=False) 

639 

640 def _parse_showvars(self, showvars): 

641 showvars_out = set() 

642 for k in showvars: 

643 k, _ = self["variables"].parse_and_index(k) 

644 keys = self["variables"].keymap[k] 

645 showvars_out.update(keys) 

646 return showvars_out 

647 

648 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

649 "Print summary table, showing top sensitivities and no constants" 

650 showvars = self._parse_showvars(showvars) 

651 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

652 "freevariables"], **kwargs) 

653 constants_in_showvars = showvars.intersection(self["constants"]) 

654 senss_tables = [] 

655 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

656 senss_tables.append("sensitivities") 

657 if len(self["constants"]) >= ntopsenss+2: 

658 senss_tables.append("top sensitivities") 

659 senss_tables.append("tightest constraints") 

660 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

661 **kwargs) 

662 if senss_str: 

663 out += "\n" + senss_str 

664 return out 

665 

666 def table(self, showvars=(), 

667 tables=("cost", "warnings", "model sensitivities", 

668 "sweepvariables", "freevariables", 

669 "constants", "sensitivities", "tightest constraints"), 

670 sortmodelsbysenss=True, **kwargs): 

671 """A table representation of this SolutionArray 

672 

673 Arguments 

674 --------- 

675 tables: Iterable 

676 Which to print of ("cost", "sweepvariables", "freevariables", 

677 "constants", "sensitivities") 

678 fixedcols: If true, print vectors in fixed-width format 

679 latex: int 

680 If > 0, return latex format (options 1-3); otherwise plain text 

681 included_models: Iterable of strings 

682 If specified, the models (by name) to include 

683 excluded_models: Iterable of strings 

684 If specified, model names to exclude 

685 

686 Returns 

687 ------- 

688 str 

689 """ 

690 if sortmodelsbysenss and "sensitivities" in self: 

691 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

692 else: 

693 kwargs["sortmodelsbysenss"] = False 

694 varlist = list(self["variables"]) 

695 has_only_one_model = True 

696 for var in varlist[1:]: 

697 if var.lineage != varlist[0].lineage: 

698 has_only_one_model = False 

699 break 

700 if has_only_one_model: 

701 kwargs["sortbymodel"] = False 

702 for key in self.name_collision_varkeys(): 

703 key.descr["necessarylineage"] = True 

704 showvars = self._parse_showvars(showvars) 

705 strs = [] 

706 for table in tables: 

707 if "sensitivities" not in self and ("sensitivities" in table or 

708 "constraints" in table): 

709 continue 

710 if table == "cost": 

711 cost = self["cost"] # pylint: disable=unsubscriptable-object 

712 if kwargs.get("latex", None): # cost is not printed for latex 

713 continue 

714 strs += ["\n%s\n------------" % "Optimal Cost"] 

715 if len(self) > 1: 

716 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

717 strs += [" [ %s %s ]" % (" ".join(costs), 

718 "..." if len(self) > 4 else "")] 

719 else: 

720 strs += [" %-.4g" % mag(cost)] 

721 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

722 strs += [""] 

723 elif table in TABLEFNS: 

724 strs += TABLEFNS[table](self, showvars, **kwargs) 

725 elif table in self: 

726 data = self[table] 

727 if showvars: 

728 showvars = self._parse_showvars(showvars) 

729 data = {k: data[k] for k in showvars if k in data} 

730 strs += var_table(data, self.table_titles[table], **kwargs) 

731 if kwargs.get("latex", None): 

732 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

733 "% \\usepackage{booktabs}", 

734 "% \\usepackage{longtable}", 

735 "% \\usepackage{amsmath}", 

736 "% \\begin{document}\n")) 

737 strs = [preamble] + strs + ["% \\end{document}"] 

738 for key in self.name_collision_varkeys(): 

739 del key.descr["necessarylineage"] 

740 return "\n".join(strs) 

741 

742 def plot(self, posys=None, axes=None): 

743 "Plots a sweep for each posy" 

744 if len(self["sweepvariables"]) != 1: 

745 print("SolutionArray.plot only supports 1-dimensional sweeps") 

746 if not hasattr(posys, "__len__"): 

747 posys = [posys] 

748 import matplotlib.pyplot as plt 

749 from .interactive.plot_sweep import assign_axes 

750 from . import GPBLU 

751 (swept, x), = self["sweepvariables"].items() 

752 posys, axes = assign_axes(swept, posys, axes) 

753 for posy, ax in zip(posys, axes): 

754 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

755 ax.plot(x, y, color=GPBLU) 

756 if len(axes) == 1: 

757 axes, = axes 

758 return plt.gcf(), axes 

759 

760 

761# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

762def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

763 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

764 minval=0, sortbyvals=False, hidebelowminval=False, 

765 included_models=None, excluded_models=None, sortbymodel=True, 

766 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

767 """ 

768 Pretty string representation of a dict of VarKeys 

769 Iterable values are handled specially (partial printing) 

770 

771 Arguments 

772 --------- 

773 data : dict whose keys are VarKey's 

774 data to represent in table 

775 title : string 

776 printunits : bool 

777 latex : int 

778 If > 0, return latex format (options 1-3); otherwise plain text 

779 varfmt : string 

780 format for variable names 

781 valfmt : string 

782 format for scalar values 

783 vecfmt : string 

784 format for vector values 

785 minval : float 

786 skip values with all(abs(value)) < minval 

787 sortbyvals : boolean 

788 If true, rows are sorted by their average value instead of by name. 

789 included_models : Iterable of strings 

790 If specified, the models (by name) to include 

791 excluded_models : Iterable of strings 

792 If specified, model names to exclude 

793 """ 

794 if not data: 

795 return [] 

796 decorated, models = [], set() 

797 for i, (k, v) in enumerate(data.items()): 

798 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

799 continue # no values below minval 

800 if minval and hidebelowminval and getattr(v, "shape", None): 

801 v[np.abs(v) <= minval] = np.nan 

802 model = lineagestr(k.lineage) if sortbymodel else "" 

803 if not sortmodelsbysenss: 

804 msenss = 0 

805 else: # sort should match that in msenss_table above 

806 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 1) 

807 models.add(model) 

808 b = bool(getattr(v, "shape", None)) 

809 s = k.str_without(("lineage", "vec")) 

810 if not sortbyvals: 

811 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

812 else: # for consistent sorting, add small offset to negative vals 

813 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

814 sort = (float("%.4g" % -val), k.name) 

815 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

816 if not decorated and skipifempty: 

817 return [] 

818 if included_models: 

819 included_models = set(included_models) 

820 included_models.add("") 

821 models = models.intersection(included_models) 

822 if excluded_models: 

823 models = models.difference(excluded_models) 

824 decorated.sort() 

825 previous_model, lines = None, [] 

826 for varlist in decorated: 

827 if sortbyvals: 

828 model, _, msenss, isvector, varstr, _, var, val = varlist 

829 else: 

830 msenss, model, isvector, varstr, _, var, val = varlist 

831 if model not in models: 

832 continue 

833 if model != previous_model: 

834 if lines: 

835 lines.append(["", "", "", ""]) 

836 if model: 

837 if not latex: 

838 lines.append([("newmodelline",), model, "", ""]) 

839 else: 

840 lines.append( 

841 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

842 previous_model = model 

843 label = var.descr.get("label", "") 

844 units = var.unitstr(" [%s] ") if printunits else "" 

845 if not isvector: 

846 valstr = valfmt % val 

847 else: 

848 last_dim_index = len(val.shape)-1 

849 horiz_dim, ncols = last_dim_index, 1 # starting values 

850 for dim_idx, dim_size in enumerate(val.shape): 

851 if ncols <= dim_size <= maxcolumns: 

852 horiz_dim, ncols = dim_idx, dim_size 

853 # align the array with horiz_dim by making it the last one 

854 dim_order = list(range(last_dim_index)) 

855 dim_order.insert(horiz_dim, last_dim_index) 

856 flatval = val.transpose(dim_order).flatten() 

857 vals = [vecfmt % v for v in flatval[:ncols]] 

858 bracket = " ] " if len(flatval) <= ncols else "" 

859 valstr = "[ %s%s" % (" ".join(vals), bracket) 

860 for before, after in VALSTR_REPLACES: 

861 valstr = valstr.replace(before, after) 

862 if not latex: 

863 lines.append([varstr, valstr, units, label]) 

864 if isvector and len(flatval) > ncols: 

865 values_remaining = len(flatval) - ncols 

866 while values_remaining > 0: 

867 idx = len(flatval)-values_remaining 

868 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

869 values_remaining -= ncols 

870 valstr = " " + " ".join(vals) 

871 for before, after in VALSTR_REPLACES: 

872 valstr = valstr.replace(before, after) 

873 if values_remaining <= 0: 

874 spaces = (-values_remaining 

875 * len(valstr)//(values_remaining + ncols)) 

876 valstr = valstr + " ]" + " "*spaces 

877 lines.append(["", valstr, "", ""]) 

878 else: 

879 varstr = "$%s$" % varstr.replace(" : ", "") 

880 if latex == 1: # normal results table 

881 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

882 label]) 

883 coltitles = [title, "Value", "Units", "Description"] 

884 elif latex == 2: # no values 

885 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

886 coltitles = [title, "Units", "Description"] 

887 elif latex == 3: # no description 

888 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

889 coltitles = [title, "Value", "Units"] 

890 else: 

891 raise ValueError("Unexpected latex option, %s." % latex) 

892 if rawlines: 

893 return lines 

894 if not latex: 

895 if lines: 

896 maxlens = np.max([list(map(len, line)) for line in lines 

897 if line[0] != ("newmodelline",)], axis=0) 

898 dirs = [">", "<", "<", "<"] 

899 # check lengths before using zip 

900 assert len(list(dirs)) == len(list(maxlens)) 

901 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

902 for i, line in enumerate(lines): 

903 if line[0] == ("newmodelline",): 

904 line = [fmts[0].format(" | "), line[1]] 

905 else: 

906 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

907 lines[i] = "".join(line).rstrip() 

908 lines = [title] + ["-"*len(title)] + lines + [""] 

909 else: 

910 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

911 lines = (["\n".join(["{\\footnotesize", 

912 "\\begin{longtable}{%s}" % colfmt[latex], 

913 "\\toprule", 

914 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

915 [" & ".join(l) + " \\\\" for l in lines] + 

916 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

917 return lines