Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import json 

4import difflib 

5from operator import sub 

6import warnings as pywarnings 

7import pickle 

8import gzip 

9import pickletools 

10import numpy as np 

11from .nomials import NomialArray 

12from .small_classes import DictOfLists, Strings 

13from .small_scripts import mag, try_str_without 

14from .repr_conventions import unitstr, lineagestr 

15 

16 

17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

18 

19VALSTR_REPLACES = [ 

20 ("+nan", " nan"), 

21 ("-nan", " nan"), 

22 ("nan%", "nan "), 

23 ("nan", " - "), 

24] 

25 

26 

27class SolSavingEnvironment: 

28 """Temporarily removes construction/solve attributes from constraints. 

29 

30 This approximately halves the size of the pickled solution. 

31 """ 

32 

33 def __init__(self, solarray, saveconstraints): 

34 self.solarray = solarray 

35 self.attrstore = {} 

36 self.saveconstraints = saveconstraints 

37 self.constraintstore = None 

38 

39 

40 def __enter__(self): 

41 if self.saveconstraints: 

42 for constraint_attr in ["bounded", "meq_bounded", "vks", 

43 "v_ss", "unsubbed", "varkeys"]: 

44 store = {} 

45 for constraint in self.solarray["sensitivities"]["constraints"]: 

46 if getattr(constraint, constraint_attr, None): 

47 store[constraint] = getattr(constraint, constraint_attr) 

48 delattr(constraint, constraint_attr) 

49 self.attrstore[constraint_attr] = store 

50 else: 

51 self.constraintstore = \ 

52 self.solarray["sensitivities"].pop("constraints") 

53 

54 def __exit__(self, type_, val, traceback): 

55 if self.saveconstraints: 

56 for constraint_attr, store in self.attrstore.items(): 

57 for constraint, value in store.items(): 

58 setattr(constraint, constraint_attr, value) 

59 else: 

60 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

61 

62def msenss_table(data, _, **kwargs): 

63 "Returns model sensitivity table lines" 

64 if "models" not in data.get("sensitivities", {}): 

65 return "" 

66 data = sorted(data["sensitivities"]["models"].items(), 

67 key=lambda i: ((i[1] < 0.1).all(), 

68 -np.max(i[1]) if (i[1] < 0.1).all() 

69 else -round(np.mean(i[1]), 1), i[0])) 

70 lines = ["Model Sensitivities", "-------------------"] 

71 if kwargs["sortmodelsbysenss"]: 

72 lines[0] += " (sorts models in sections below)" 

73 previousmsenssstr = "" 

74 for model, msenss in data: 

75 if not model: # for now let's only do named models 

76 continue 

77 if (msenss < 0.1).all(): 

78 msenss = np.max(msenss) 

79 if msenss: 

80 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss)) 

81 else: 

82 msenssstr = " =0 " 

83 else: 

84 meansenss = round(np.mean(msenss), 1) 

85 msenssstr = "%+6.1f" % meansenss 

86 deltas = msenss - meansenss 

87 if np.max(np.abs(deltas)) > 0.1: 

88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

89 for d in deltas] 

90 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

91 if msenssstr == previousmsenssstr: 

92 msenssstr = " "*len(msenssstr) 

93 else: 

94 previousmsenssstr = msenssstr 

95 lines.append("%s : %s" % (msenssstr, model)) 

96 return lines + [""] if len(lines) > 3 else [] 

97 

98 

99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

100 "Returns sensitivity table lines" 

101 if "variables" in data.get("sensitivities", {}): 

102 data = data["sensitivities"]["variables"] 

103 if showvars: 

104 data = {k: data[k] for k in showvars if k in data} 

105 return var_table(data, title, sortbyvals=True, skipifempty=True, 

106 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

107 printunits=False, minval=1e-3, **kwargs) 

108 

109 

110def topsenss_table(data, showvars, nvars=5, **kwargs): 

111 "Returns top sensitivity table lines" 

112 data, filtered = topsenss_filter(data, showvars, nvars) 

113 title = "Most Sensitive Variables" 

114 if filtered: 

115 title = "Next Most Sensitive Variables" 

116 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

117 

118 

119def topsenss_filter(data, showvars, nvars=5): 

120 "Filters sensitivities down to top N vars" 

121 if "variables" in data.get("sensitivities", {}): 

122 data = data["sensitivities"]["variables"] 

123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

124 if not np.isnan(s).any()} 

125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

126 filter_already_shown = showvars.intersection(topk) 

127 for k in filter_already_shown: 

128 topk.remove(k) 

129 if nvars > 3: # always show at least 3 

130 nvars -= 1 

131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

132 

133 

134def insenss_table(data, _, maxval=0.1, **kwargs): 

135 "Returns insensitivity table lines" 

136 if "constants" in data.get("sensitivities", {}): 

137 data = data["sensitivities"]["variables"] 

138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

140 

141 

142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

143 "Return constraint tightness lines" 

144 title = "Most Sensitive Constraints" 

145 if len(self) > 1: 

146 title += " (in last sweep)" 

147 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

148 "%+6.2g" % s[-1], id(c), c) 

149 for c, s in self["sensitivities"]["constraints"].items() 

150 if s[-1] >= tight_senss)[:ntightconstrs] 

151 else: 

152 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

153 for c, s in self["sensitivities"]["constraints"].items() 

154 if s >= tight_senss)[:ntightconstrs] 

155 return constraint_table(data, title, **kwargs) 

156 

157def loose_table(self, _, min_senss=1e-5, **kwargs): 

158 "Return constraint tightness lines" 

159 title = "Insensitive Constraints |below %+g|" % min_senss 

160 if len(self) > 1: 

161 title += " (in last sweep)" 

162 data = [(0, "", id(c), c) 

163 for c, s in self["sensitivities"]["constraints"].items() 

164 if s[-1] <= min_senss] 

165 else: 

166 data = [(0, "", id(c), c) 

167 for c, s in self["sensitivities"]["constraints"].items() 

168 if s <= min_senss] 

169 return constraint_table(data, title, **kwargs) 

170 

171 

172# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

174 "Creates lines for tables where the right side is a constraint." 

175 # TODO: this should support 1D array inputs from sweeps 

176 excluded = ("units", "unnecessary lineage") 

177 if not showmodels: 

178 excluded = ("units", "lineage") # hide all of it 

179 models, decorated = {}, [] 

180 for sortby, openingstr, _, constraint in sorted(data): 

181 model = lineagestr(constraint) if sortbymodel else "" 

182 if model not in models: 

183 models[model] = len(models) 

184 constrstr = try_str_without(constraint, excluded) 

185 if " at 0x" in constrstr: # don't print memory addresses 

186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

187 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

188 decorated.sort() 

189 previous_model, lines = None, [] 

190 for varlist in decorated: 

191 _, model, _, constrstr, openingstr = varlist 

192 if model != previous_model: 

193 if lines: 

194 lines.append(["", ""]) 

195 if model or lines: 

196 lines.append([("newmodelline",), model]) 

197 previous_model = model 

198 constrstr = constrstr.replace(model, "") 

199 minlen, maxlen = 25, 80 

200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

201 constraintlines = [] 

202 line = "" 

203 next_idx = 0 

204 while next_idx < len(segments): 

205 segment = segments[next_idx] 

206 next_idx += 1 

207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

208 segments[next_idx] = segment[1:] + segments[next_idx] 

209 segment = segment[0] 

210 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

211 constraintlines.append(line) 

212 line = " " # start a new line 

213 line += segment 

214 while len(line) > maxlen: 

215 constraintlines.append(line[:maxlen]) 

216 line = " " + line[maxlen:] 

217 constraintlines.append(line) 

218 lines += [(openingstr + " : ", constraintlines[0])] 

219 lines += [("", l) for l in constraintlines[1:]] 

220 if not lines: 

221 lines = [("", "(none)")] 

222 maxlens = np.max([list(map(len, line)) for line in lines 

223 if line[0] != ("newmodelline",)], axis=0) 

224 dirs = [">", "<"] # we'll check lengths before using zip 

225 assert len(list(dirs)) == len(list(maxlens)) 

226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

227 for i, line in enumerate(lines): 

228 if line[0] == ("newmodelline",): 

229 linelist = [fmts[0].format(" | "), line[1]] 

230 else: 

231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

232 lines[i] = "".join(linelist).rstrip() 

233 return [title] + ["-"*len(title)] + lines + [""] 

234 

235 

236def warnings_table(self, _, **kwargs): 

237 "Makes a table for all warnings in the solution." 

238 title = "WARNINGS" 

239 lines = ["~"*len(title), title, "~"*len(title)] 

240 if "warnings" not in self or not self["warnings"]: 

241 return [] 

242 for wtype in sorted(self["warnings"]): 

243 data_vec = self["warnings"][wtype] 

244 if len(data_vec) == 0: 

245 continue 

246 if not hasattr(data_vec, "shape"): 

247 data_vec = [data_vec] # not a sweep 

248 else: 

249 all_equal = True 

250 for data in data_vec[1:]: 

251 eq_i = (data == data_vec[0]) 

252 if hasattr(eq_i, "all"): 

253 eq_i = eq_i.all() 

254 if not eq_i: 

255 all_equal = False 

256 break 

257 if all_equal: 

258 data_vec = [data_vec[0]] # warnings identical across sweeps 

259 for i, data in enumerate(data_vec): 

260 if len(data) == 0: 

261 continue 

262 data = sorted(data, key=lambda l: l[0]) # sort by msg 

263 title = wtype 

264 if len(data_vec) > 1: 

265 title += " in sweep %i" % i 

266 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

267 data = [(-int(1e5*relax_sensitivity), 

268 "%+6.2g" % relax_sensitivity, id(c), c) 

269 for _, (relax_sensitivity, c) in data] 

270 lines += constraint_table(data, title, **kwargs) 

271 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

272 data = [(-int(1e5*rel_diff), 

273 "%.4g %s %.4g" % tightvalues, id(c), c) 

274 for _, (rel_diff, tightvalues, c) in data] 

275 lines += constraint_table(data, title, **kwargs) 

276 else: 

277 lines += [title] + ["-"*len(wtype)] 

278 lines += [msg for msg, _ in data] + [""] 

279 if len(lines) == 3: # just the header 

280 return [] 

281 lines[-1] = "~~~~~~~~" 

282 return lines + [""] 

283 

284 

285TABLEFNS = {"sensitivities": senss_table, 

286 "top sensitivities": topsenss_table, 

287 "insensitivities": insenss_table, 

288 "model sensitivities": msenss_table, 

289 "tightest constraints": tight_table, 

290 "loose constraints": loose_table, 

291 "warnings": warnings_table, 

292 } 

293 

294def unrolled_absmax(values): 

295 "From an iterable of numbers and arrays, returns the largest magnitude" 

296 finalval, absmaxest = None, 0 

297 for val in values: 

298 absmaxval = np.abs(val).max() 

299 if absmaxval >= absmaxest: 

300 absmaxest, finalval = absmaxval, val 

301 if getattr(finalval, "shape", None): 

302 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

303 finalval.shape)] 

304 return finalval 

305 

306 

307def cast(function, val1, val2): 

308 "Relative difference between val1 and val2 (positive if val2 is larger)" 

309 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

310 pywarnings.simplefilter("ignore") 

311 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

312 if val1.ndim == val2.ndim: 

313 return function(val1, val2) 

314 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

315 dimdelta = dimmest.ndim - lessdim.ndim 

316 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

317 if dimmest is val1: 

318 return function(dimmest, lessdim[add_axes]) 

319 if dimmest is val2: 

320 return function(lessdim[add_axes], dimmest) 

321 return function(val1, val2) 

322 

323 

324class SolutionArray(DictOfLists): 

325 """A dictionary (of dictionaries) of lists, with convenience methods. 

326 

327 Items 

328 ----- 

329 cost : array 

330 variables: dict of arrays 

331 sensitivities: dict containing: 

332 monomials : array 

333 posynomials : array 

334 variables: dict of arrays 

335 localmodels : NomialArray 

336 Local power-law fits (small sensitivities are cut off) 

337 

338 Example 

339 ------- 

340 >>> import gpkit 

341 >>> import numpy as np 

342 >>> x = gpkit.Variable("x") 

343 >>> x_min = gpkit.Variable("x_{min}", 2) 

344 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

345 >>> 

346 >>> # VALUES 

347 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

348 >>> assert all(np.array(values) == 2) 

349 >>> 

350 >>> # SENSITIVITIES 

351 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

352 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

353 >>> assert all(np.array(senss) == 1) 

354 """ 

355 modelstr = "" 

356 _name_collision_varkeys = None 

357 table_titles = {"choicevariables": "Choice Variables", 

358 "sweepvariables": "Swept Variables", 

359 "freevariables": "Free Variables", 

360 "constants": "Fixed Variables", # TODO: change everywhere 

361 "variables": "Variables"} 

362 

363 def name_collision_varkeys(self): 

364 "Returns the set of contained varkeys whose names are not unique" 

365 if self._name_collision_varkeys is None: 

366 self["variables"].update_keymap() 

367 keymap = self["variables"].keymap 

368 self._name_collision_varkeys = set() 

369 for key in list(keymap): 

370 if hasattr(key, "key"): 

371 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

372 self._name_collision_varkeys.add(key) 

373 return self._name_collision_varkeys 

374 

375 def __len__(self): 

376 try: 

377 return len(self["cost"]) 

378 except TypeError: 

379 return 1 

380 except KeyError: 

381 return 0 

382 

383 def __call__(self, posy): 

384 posy_subbed = self.subinto(posy) 

385 return getattr(posy_subbed, "c", posy_subbed) 

386 

387 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

388 "Checks for almost-equality between two solutions" 

389 svars, ovars = self["variables"], other["variables"] 

390 svks, ovks = set(svars), set(ovars) 

391 if svks != ovks: 

392 return False 

393 for key in svks: 

394 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

395 return False 

396 if abs(self["sensitivities"]["variables"][key] 

397 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

398 return False 

399 return True 

400 

401 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

402 def diff(self, other, showvars=None, *, 

403 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

404 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

405 sortmodelsbysenss=True, **tableargs): 

406 """Outputs differences between this solution and another 

407 

408 Arguments 

409 --------- 

410 other : solution or string 

411 strings will be treated as paths to pickled solutions 

412 senssdiff : boolean 

413 if True, show sensitivity differences 

414 sensstol : float 

415 the smallest sensitivity difference worth showing 

416 absdiff : boolean 

417 if True, show absolute differences 

418 abstol : float 

419 the smallest absolute difference worth showing 

420 reldiff : boolean 

421 if True, show relative differences 

422 reltol : float 

423 the smallest relative difference worth showing 

424 

425 Returns 

426 ------- 

427 str 

428 """ 

429 if sortmodelsbysenss: 

430 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

431 else: 

432 tableargs["sortmodelsbysenss"] = False 

433 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

434 "skipifempty": False}) 

435 if isinstance(other, Strings): 

436 if other[-4:] == ".pgz": 

437 other = SolutionArray.decompress_file(other) 

438 else: 

439 other = pickle.load(open(other, "rb")) 

440 svars, ovars = self["variables"], other["variables"] 

441 lines = ["Solution Diff", 

442 "=============", 

443 "(argument is the baseline solution)", ""] 

444 svks, ovks = set(svars), set(ovars) 

445 if showvars: 

446 lines[0] += " (for selected variables)" 

447 lines[1] += "=========================" 

448 showvars = self._parse_showvars(showvars) 

449 svks = {k for k in showvars if k in svars} 

450 ovks = {k for k in showvars if k in ovars} 

451 if constraintsdiff and other.modelstr and self.modelstr: 

452 if self.modelstr == other.modelstr: 

453 lines += ["** no constraint differences **", ""] 

454 else: 

455 cdiff = ["Constraint Differences", 

456 "**********************"] 

457 cdiff.extend(list(difflib.unified_diff( 

458 other.modelstr.split("\n"), self.modelstr.split("\n"), 

459 lineterm="", n=3))[2:]) 

460 cdiff += ["", "**********************", ""] 

461 lines += cdiff 

462 if svks - ovks: 

463 lines.append("Variable(s) of this solution" 

464 " which are not in the argument:") 

465 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

466 lines.append("") 

467 if ovks - svks: 

468 lines.append("Variable(s) of the argument" 

469 " which are not in this solution:") 

470 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

471 lines.append("") 

472 sharedvks = svks.intersection(ovks) 

473 if reldiff: 

474 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

475 for vk in sharedvks} 

476 lines += var_table(rel_diff, 

477 "Relative Differences |above %g%%|" % reltol, 

478 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

479 minval=reltol, printunits=False, **tableargs) 

480 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

481 lines.insert(-1, ("The largest is %+g%%." 

482 % unrolled_absmax(rel_diff.values()))) 

483 if absdiff: 

484 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

485 lines += var_table(abs_diff, 

486 "Absolute Differences |above %g|" % abstol, 

487 valfmt="%+.2g", vecfmt="%+8.2g", 

488 minval=abstol, **tableargs) 

489 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

490 lines.insert(-1, ("The largest is %+g." 

491 % unrolled_absmax(abs_diff.values()))) 

492 if senssdiff: 

493 ssenss = self["sensitivities"]["variables"] 

494 osenss = other["sensitivities"]["variables"] 

495 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

496 for vk in svks.intersection(ovks)} 

497 lines += var_table(senss_delta, 

498 "Sensitivity Differences |above %g|" % sensstol, 

499 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

500 minval=sensstol, printunits=False, **tableargs) 

501 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

502 lines.insert(-1, ("The largest is %+g." 

503 % unrolled_absmax(senss_delta.values()))) 

504 return "\n".join(lines) 

505 

506 def save(self, filename="solution.pkl", 

507 *, saveconstraints=True, **pickleargs): 

508 """Pickles the solution and saves it to a file. 

509 

510 Solution can then be loaded with e.g.: 

511 >>> import pickle 

512 >>> pickle.load(open("solution.pkl")) 

513 """ 

514 with SolSavingEnvironment(self, saveconstraints): 

515 pickle.dump(self, open(filename, "wb"), **pickleargs) 

516 

517 def save_compressed(self, filename="solution.pgz", 

518 *, saveconstraints=True, **cpickleargs): 

519 "Pickle a file and then compress it into a file with extension." 

520 with gzip.open(filename, "wb") as f: 

521 with SolSavingEnvironment(self, saveconstraints): 

522 pickled = pickle.dumps(self, **cpickleargs) 

523 f.write(pickletools.optimize(pickled)) 

524 

525 @staticmethod 

526 def decompress_file(file): 

527 "Load a gzip-compressed pickle file" 

528 with gzip.open(file, "rb") as f: 

529 return pickle.Unpickler(f).load() 

530 

531 def varnames(self, showvars, exclude): 

532 "Returns list of variables, optionally with minimal unique names" 

533 if showvars: 

534 showvars = self._parse_showvars(showvars) 

535 for key in self.name_collision_varkeys(): 

536 key.descr["necessarylineage"] = True 

537 names = {} 

538 for key in showvars or self["variables"]: 

539 for k in self["variables"].keymap[key]: 

540 names[k.str_without(exclude)] = k 

541 for key in self.name_collision_varkeys(): 

542 del key.descr["necessarylineage"] 

543 return names 

544 

545 def savemat(self, filename="solution.mat", *, showvars=None, 

546 excluded=("unnecessary lineage", "vec")): 

547 "Saves primal solution as matlab file" 

548 from scipy.io import savemat 

549 savemat(filename, 

550 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

551 for name, key in self.varnames(showvars, excluded).items()}) 

552 

553 def todataframe(self, showvars=None, 

554 excluded=("unnecessary lineage", "vec")): 

555 "Returns primal solution as pandas dataframe" 

556 import pandas as pd # pylint:disable=import-error 

557 rows = [] 

558 cols = ["Name", "Index", "Value", "Units", "Label", 

559 "Lineage", "Other"] 

560 for _, key in sorted(self.varnames(showvars, excluded).items(), 

561 key=lambda k: k[0]): 

562 value = self["variables"][key] 

563 if key.shape: 

564 idxs = [] 

565 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

566 while not it.finished: 

567 idx = it.multi_index 

568 idxs.append(idx[0] if len(idx) == 1 else idx) 

569 it.iternext() 

570 else: 

571 idxs = [None] 

572 for idx in idxs: 

573 row = [ 

574 key.name, 

575 "" if idx is None else idx, 

576 value if idx is None else value[idx]] 

577 rows.append(row) 

578 row.extend([ 

579 key.unitstr(), 

580 key.label or "", 

581 key.lineage or "", 

582 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

583 if k not in ["name", "units", "unitrepr", 

584 "idx", "shape", "veckey", 

585 "value", "vecfn", 

586 "lineage", "label"])]) 

587 return pd.DataFrame(rows, columns=cols) 

588 

589 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

590 "Saves solution table as a text file" 

591 with open(filename, "w") as f: 

592 if printmodel: 

593 f.write(self.modelstr + "\n") 

594 f.write(self.table(**kwargs)) 

595 

596 def savejson(self, filename="solution.json", showvars=None): 

597 "Saves solution table as a json file" 

598 sol_dict = {} 

599 for key in self.name_collision_varkeys(): 

600 key.descr["necessarylineage"] = True 

601 data = self["variables"] 

602 if showvars: 

603 showvars = self._parse_showvars(showvars) 

604 data = {k: data[k] for k in showvars if k in data} 

605 # add appropriate data for each variable to the dictionary 

606 for k, v in data.items(): 

607 key = str(k) 

608 if isinstance(v, np.ndarray): 

609 val = {"v": v.tolist(), "u": k.unitstr()} 

610 else: 

611 val = {"v": v, "u": k.unitstr()} 

612 sol_dict[key] = val 

613 for key in self.name_collision_varkeys(): 

614 del key.descr["necessarylineage"] 

615 with open(filename, "w") as f: 

616 json.dump(sol_dict, f) 

617 

618 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

619 "Saves primal solution as a CSV sorted by modelname, like the tables." 

620 data = self["variables"] 

621 if showvars: 

622 showvars = self._parse_showvars(showvars) 

623 data = {k: data[k] for k in showvars if k in data} 

624 # if the columns don't capture any dimensions, skip them 

625 minspan, maxspan = None, 1 

626 for v in data.values(): 

627 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

628 minspan_ = min((di for di in v.shape if di != 1)) 

629 maxspan_ = max((di for di in v.shape if di != 1)) 

630 if minspan is None or minspan_ < minspan: 

631 minspan = minspan_ 

632 if maxspan is None or maxspan_ > maxspan: 

633 maxspan = maxspan_ 

634 if minspan is not None and minspan > valcols: 

635 valcols = 1 

636 if maxspan < valcols: 

637 valcols = maxspan 

638 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

639 tables=("cost", "sweepvariables", "freevariables", 

640 "constants", "sensitivities")) 

641 with open(filename, "w") as f: 

642 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

643 + "Units,Description\n") 

644 for line in lines: 

645 if line[0] == ("newmodelline",): 

646 f.write(line[1]) 

647 elif not line[1]: # spacer line 

648 f.write("\n") 

649 else: 

650 f.write("," + line[0].replace(" : ", "") + ",") 

651 vals = line[1].replace("[", "").replace("]", "").strip() 

652 for el in vals.split(): 

653 f.write(el + ",") 

654 f.write(","*(valcols - len(vals.split()))) 

655 f.write((line[2].replace("[", "").replace("]", "").strip() 

656 + ",")) 

657 f.write(line[3].strip() + "\n") 

658 

659 def subinto(self, posy): 

660 "Returns NomialArray of each solution substituted into posy." 

661 if posy in self["variables"]: 

662 return self["variables"](posy) 

663 

664 if not hasattr(posy, "sub"): 

665 raise ValueError("no variable '%s' found in the solution" % posy) 

666 

667 if len(self) > 1: 

668 return NomialArray([self.atindex(i).subinto(posy) 

669 for i in range(len(self))]) 

670 

671 return posy.sub(self["variables"]) 

672 

673 def _parse_showvars(self, showvars): 

674 showvars_out = set() 

675 for k in showvars: 

676 k, _ = self["variables"].parse_and_index(k) 

677 keys = self["variables"].keymap[k] 

678 showvars_out.update(keys) 

679 return showvars_out 

680 

681 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

682 "Print summary table, showing top sensitivities and no constants" 

683 showvars = self._parse_showvars(showvars) 

684 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

685 "freevariables"], **kwargs) 

686 constants_in_showvars = showvars.intersection(self["constants"]) 

687 senss_tables = [] 

688 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

689 senss_tables.append("sensitivities") 

690 if len(self["constants"]) >= ntopsenss+2: 

691 senss_tables.append("top sensitivities") 

692 senss_tables.append("tightest constraints") 

693 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

694 **kwargs) 

695 if senss_str: 

696 out += "\n" + senss_str 

697 return out 

698 

699 def table(self, showvars=(), 

700 tables=("cost", "warnings", "model sensitivities", 

701 "sweepvariables", "freevariables", 

702 "constants", "sensitivities", "tightest constraints"), 

703 sortmodelsbysenss=True, **kwargs): 

704 """A table representation of this SolutionArray 

705 

706 Arguments 

707 --------- 

708 tables: Iterable 

709 Which to print of ("cost", "sweepvariables", "freevariables", 

710 "constants", "sensitivities") 

711 fixedcols: If true, print vectors in fixed-width format 

712 latex: int 

713 If > 0, return latex format (options 1-3); otherwise plain text 

714 included_models: Iterable of strings 

715 If specified, the models (by name) to include 

716 excluded_models: Iterable of strings 

717 If specified, model names to exclude 

718 

719 Returns 

720 ------- 

721 str 

722 """ 

723 if sortmodelsbysenss and "sensitivities" in self: 

724 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

725 else: 

726 kwargs["sortmodelsbysenss"] = False 

727 varlist = list(self["variables"]) 

728 has_only_one_model = True 

729 for var in varlist[1:]: 

730 if var.lineage != varlist[0].lineage: 

731 has_only_one_model = False 

732 break 

733 if has_only_one_model: 

734 kwargs["sortbymodel"] = False 

735 for key in self.name_collision_varkeys(): 

736 key.descr["necessarylineage"] = True 

737 showvars = self._parse_showvars(showvars) 

738 strs = [] 

739 for table in tables: 

740 if "sensitivities" not in self and ("sensitivities" in table or 

741 "constraints" in table): 

742 continue 

743 if table == "cost": 

744 cost = self["cost"] # pylint: disable=unsubscriptable-object 

745 if kwargs.get("latex", None): # cost is not printed for latex 

746 continue 

747 strs += ["\n%s\n------------" % "Optimal Cost"] 

748 if len(self) > 1: 

749 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

750 strs += [" [ %s %s ]" % (" ".join(costs), 

751 "..." if len(self) > 4 else "")] 

752 else: 

753 strs += [" %-.4g" % mag(cost)] 

754 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

755 strs += [""] 

756 elif table in TABLEFNS: 

757 strs += TABLEFNS[table](self, showvars, **kwargs) 

758 elif table in self: 

759 data = self[table] 

760 if showvars: 

761 showvars = self._parse_showvars(showvars) 

762 data = {k: data[k] for k in showvars if k in data} 

763 strs += var_table(data, self.table_titles[table], **kwargs) 

764 if kwargs.get("latex", None): 

765 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

766 "% \\usepackage{booktabs}", 

767 "% \\usepackage{longtable}", 

768 "% \\usepackage{amsmath}", 

769 "% \\begin{document}\n")) 

770 strs = [preamble] + strs + ["% \\end{document}"] 

771 for key in self.name_collision_varkeys(): 

772 del key.descr["necessarylineage"] 

773 return "\n".join(strs) 

774 

775 def plot(self, posys=None, axes=None): 

776 "Plots a sweep for each posy" 

777 if len(self["sweepvariables"]) != 1: 

778 print("SolutionArray.plot only supports 1-dimensional sweeps") 

779 if not hasattr(posys, "__len__"): 

780 posys = [posys] 

781 import matplotlib.pyplot as plt 

782 from .interactive.plot_sweep import assign_axes 

783 from . import GPBLU 

784 (swept, x), = self["sweepvariables"].items() 

785 posys, axes = assign_axes(swept, posys, axes) 

786 for posy, ax in zip(posys, axes): 

787 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

788 ax.plot(x, y, color=GPBLU) 

789 if len(axes) == 1: 

790 axes, = axes 

791 return plt.gcf(), axes 

792 

793 

794# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

795def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

796 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

797 minval=0, sortbyvals=False, hidebelowminval=False, 

798 included_models=None, excluded_models=None, sortbymodel=True, 

799 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

800 """ 

801 Pretty string representation of a dict of VarKeys 

802 Iterable values are handled specially (partial printing) 

803 

804 Arguments 

805 --------- 

806 data : dict whose keys are VarKey's 

807 data to represent in table 

808 title : string 

809 printunits : bool 

810 latex : int 

811 If > 0, return latex format (options 1-3); otherwise plain text 

812 varfmt : string 

813 format for variable names 

814 valfmt : string 

815 format for scalar values 

816 vecfmt : string 

817 format for vector values 

818 minval : float 

819 skip values with all(abs(value)) < minval 

820 sortbyvals : boolean 

821 If true, rows are sorted by their average value instead of by name. 

822 included_models : Iterable of strings 

823 If specified, the models (by name) to include 

824 excluded_models : Iterable of strings 

825 If specified, model names to exclude 

826 """ 

827 if not data: 

828 return [] 

829 decorated, models = [], set() 

830 for i, (k, v) in enumerate(data.items()): 

831 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

832 continue # no values below minval 

833 if minval and hidebelowminval and getattr(v, "shape", None): 

834 v[np.abs(v) <= minval] = np.nan 

835 model = lineagestr(k.lineage) if sortbymodel else "" 

836 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

837 if hasattr(msenss, "shape"): 

838 msenss = np.mean(msenss) 

839 models.add(model) 

840 b = bool(getattr(v, "shape", None)) 

841 s = k.str_without(("lineage", "vec")) 

842 if not sortbyvals: 

843 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

844 else: # for consistent sorting, add small offset to negative vals 

845 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

846 sort = (float("%.4g" % -val), k.name) 

847 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

848 if not decorated and skipifempty: 

849 return [] 

850 if included_models: 

851 included_models = set(included_models) 

852 included_models.add("") 

853 models = models.intersection(included_models) 

854 if excluded_models: 

855 models = models.difference(excluded_models) 

856 decorated.sort() 

857 previous_model, lines = None, [] 

858 for varlist in decorated: 

859 if sortbyvals: 

860 model, _, msenss, isvector, varstr, _, var, val = varlist 

861 else: 

862 msenss, model, isvector, varstr, _, var, val = varlist 

863 if model not in models: 

864 continue 

865 if model != previous_model: 

866 if lines: 

867 lines.append(["", "", "", ""]) 

868 if model: 

869 if not latex: 

870 lines.append([("newmodelline",), model, "", ""]) 

871 else: 

872 lines.append( 

873 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

874 previous_model = model 

875 label = var.descr.get("label", "") 

876 units = var.unitstr(" [%s] ") if printunits else "" 

877 if not isvector: 

878 valstr = valfmt % val 

879 else: 

880 last_dim_index = len(val.shape)-1 

881 horiz_dim, ncols = last_dim_index, 1 # starting values 

882 for dim_idx, dim_size in enumerate(val.shape): 

883 if ncols <= dim_size <= maxcolumns: 

884 horiz_dim, ncols = dim_idx, dim_size 

885 # align the array with horiz_dim by making it the last one 

886 dim_order = list(range(last_dim_index)) 

887 dim_order.insert(horiz_dim, last_dim_index) 

888 flatval = val.transpose(dim_order).flatten() 

889 vals = [vecfmt % v for v in flatval[:ncols]] 

890 bracket = " ] " if len(flatval) <= ncols else "" 

891 valstr = "[ %s%s" % (" ".join(vals), bracket) 

892 for before, after in VALSTR_REPLACES: 

893 valstr = valstr.replace(before, after) 

894 if not latex: 

895 lines.append([varstr, valstr, units, label]) 

896 if isvector and len(flatval) > ncols: 

897 values_remaining = len(flatval) - ncols 

898 while values_remaining > 0: 

899 idx = len(flatval)-values_remaining 

900 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

901 values_remaining -= ncols 

902 valstr = " " + " ".join(vals) 

903 for before, after in VALSTR_REPLACES: 

904 valstr = valstr.replace(before, after) 

905 if values_remaining <= 0: 

906 spaces = (-values_remaining 

907 * len(valstr)//(values_remaining + ncols)) 

908 valstr = valstr + " ]" + " "*spaces 

909 lines.append(["", valstr, "", ""]) 

910 else: 

911 varstr = "$%s$" % varstr.replace(" : ", "") 

912 if latex == 1: # normal results table 

913 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

914 label]) 

915 coltitles = [title, "Value", "Units", "Description"] 

916 elif latex == 2: # no values 

917 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

918 coltitles = [title, "Units", "Description"] 

919 elif latex == 3: # no description 

920 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

921 coltitles = [title, "Value", "Units"] 

922 else: 

923 raise ValueError("Unexpected latex option, %s." % latex) 

924 if rawlines: 

925 return lines 

926 if not latex: 

927 if lines: 

928 maxlens = np.max([list(map(len, line)) for line in lines 

929 if line[0] != ("newmodelline",)], axis=0) 

930 dirs = [">", "<", "<", "<"] 

931 # check lengths before using zip 

932 assert len(list(dirs)) == len(list(maxlens)) 

933 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

934 for i, line in enumerate(lines): 

935 if line[0] == ("newmodelline",): 

936 line = [fmts[0].format(" | "), line[1]] 

937 else: 

938 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

939 lines[i] = "".join(line).rstrip() 

940 lines = [title] + ["-"*len(title)] + lines + [""] 

941 else: 

942 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

943 lines = (["\n".join(["{\\footnotesize", 

944 "\\begin{longtable}{%s}" % colfmt[latex], 

945 "\\toprule", 

946 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

947 [" & ".join(l) + " \\\\" for l in lines] + 

948 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

949 return lines