Coverage for gpkit/solution_array.py: 80%

640 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 22:13 -0500

1"""Defines SolutionArray class""" 

2import sys 

3import re 

4import json 

5import difflib 

6from operator import sub 

7import warnings as pywarnings 

8import pickle 

9import gzip 

10import pickletools 

11from collections import defaultdict 

12import numpy as np 

13from .nomials import NomialArray 

14from .small_classes import DictOfLists, Strings, SolverLog 

15from .small_scripts import mag, try_str_without 

16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS 

17from .breakdowns import Breakdowns 

18 

19 

20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

21 

22VALSTR_REPLACES = [ 

23 ("+nan", " nan"), 

24 ("-nan", " nan"), 

25 ("nan%", "nan "), 

26 ("nan", " - "), 

27] 

28# pylint: disable=consider-using-f-string # some would be less readable 

29 

30class SolSavingEnvironment: 

31 """Temporarily removes construction/solve attributes from constraints. 

32 

33 This approximately halves the size of the pickled solution. 

34 """ 

35 

36 def __init__(self, solarray, saveconstraints): 

37 self.solarray = solarray 

38 self.attrstore = {} 

39 self.saveconstraints = saveconstraints 

40 self.constraintstore = None 

41 

42 

43 def __enter__(self): 

44 if "sensitivities" not in self.solarray: 

45 pass 

46 elif self.saveconstraints: 

47 for constraint_attr in ["bounded", "meq_bounded", "vks", 

48 "v_ss", "unsubbed", "varkeys"]: 

49 store = {} 

50 for constraint in self.solarray["sensitivities"]["constraints"]: 

51 if getattr(constraint, constraint_attr, None): 

52 store[constraint] = getattr(constraint, constraint_attr) 

53 delattr(constraint, constraint_attr) 

54 self.attrstore[constraint_attr] = store 

55 else: 

56 self.constraintstore = \ 

57 self.solarray["sensitivities"].pop("constraints") 

58 

59 def __exit__(self, type_, val, traceback): 

60 if self.saveconstraints: 

61 for constraint_attr, store in self.attrstore.items(): 

62 for constraint, value in store.items(): 

63 setattr(constraint, constraint_attr, value) 

64 elif self.constraintstore: 

65 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

66 

67def msenss_table(data, _, **kwargs): 

68 "Returns model sensitivity table lines" 

69 if "models" not in data.get("sensitivities", {}): 

70 return "" 

71 data = sorted(data["sensitivities"]["models"].items(), 

72 key=lambda i: ((i[1] < 0.1).all(), 

73 -np.max(i[1]) if (i[1] < 0.1).all() 

74 else -round(np.mean(i[1]), 1), i[0])) 

75 lines = ["Model Sensitivities", "-------------------"] 

76 if kwargs["sortmodelsbysenss"]: 

77 lines[0] += " (sorts models in sections below)" 

78 previousmsenssstr = "" 

79 for model, msenss in data: 

80 if not model: # for now let's only do named models 

81 continue 

82 if (msenss < 0.1).all(): 

83 msenss = np.max(msenss) 

84 if msenss: 

85 #pylint: disable=consider-using-f-string 

86 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss))) 

87 else: 

88 msenssstr = " =0 " 

89 else: 

90 meansenss = round(np.mean(msenss), 1) 

91 msenssstr = f"{meansenss:+6.1f}" 

92 deltas = msenss - meansenss 

93 if np.max(np.abs(deltas)) > 0.1: 

94 deltastrs = [f"{d:+4.1f}" if abs(d) >= 0.1 else " - " 

95 for d in deltas] 

96 msenssstr += f" + [ {' '.join(deltastrs)} ]" 

97 if msenssstr == previousmsenssstr: 

98 msenssstr = " "*len(msenssstr) 

99 else: 

100 previousmsenssstr = msenssstr 

101 lines.append(f"{msenssstr} : {model}") 

102 return lines + [""] if len(lines) > 3 else [] 

103 

104 

105def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

106 "Returns sensitivity table lines" 

107 if "variables" in data.get("sensitivities", {}): 

108 data = data["sensitivities"]["variables"] 

109 if showvars: 

110 data = {k: data[k] for k in showvars if k in data} 

111 return var_table(data, title, sortbyvals=True, skipifempty=True, 

112 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

113 printunits=False, minval=1e-3, **kwargs) 

114 

115 

116def topsenss_table(data, showvars, nvars=5, **kwargs): 

117 "Returns top sensitivity table lines" 

118 data, filtered = topsenss_filter(data, showvars, nvars) 

119 title = "Most Sensitive Variables" 

120 if filtered: 

121 title = "Next Most Sensitive Variables" 

122 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

123 

124 

125def topsenss_filter(data, showvars, nvars=5): 

126 "Filters sensitivities down to top N vars" 

127 if "variables" in data.get("sensitivities", {}): 

128 data = data["sensitivities"]["variables"] 

129 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

130 if not np.isnan(s).any()} 

131 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

132 filter_already_shown = showvars.intersection(topk) 

133 for k in filter_already_shown: 

134 topk.remove(k) 

135 if nvars > 3: # always show at least 3 

136 nvars -= 1 

137 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

138 

139 

140def insenss_table(data, _, maxval=0.1, **kwargs): 

141 "Returns insensitivity table lines" 

142 if "constants" in data.get("sensitivities", {}): 

143 data = data["sensitivities"]["variables"] 

144 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

145 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

146 

147 

148def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

149 "Return constraint tightness lines" 

150 title = "Most Sensitive Constraints" 

151 if len(self) > 1: 

152 title += " (in last sweep)" 

153 data = sorted(((-float(f"{abs(s[-1]):+6.2g}"), str(c)), 

154 f"{abs(s[-1]):+6.2g}", id(c), c) 

155 for c, s in self["sensitivities"]["constraints"].items() 

156 if s[-1] >= tight_senss)[:ntightconstrs] 

157 else: 

158 data = sorted(((-float(f"{abs(s):+6.2g}"), str(c)), 

159 f"{abs(s):+6.2g}", id(c), c) 

160 for c, s in self["sensitivities"]["constraints"].items() 

161 if s >= tight_senss)[:ntightconstrs] 

162 return constraint_table(data, title, **kwargs) 

163 

164def loose_table(self, _, min_senss=1e-5, **kwargs): 

165 "Return constraint tightness lines" 

166 title = f"Insensitive Constraints |below {min_senss:+g}|" 

167 if len(self) > 1: 

168 title += " (in last sweep)" 

169 data = [(0, "", id(c), c) 

170 for c, s in self["sensitivities"]["constraints"].items() 

171 if s[-1] <= min_senss] 

172 else: 

173 data = [(0, "", id(c), c) 

174 for c, s in self["sensitivities"]["constraints"].items() 

175 if s <= min_senss] 

176 return constraint_table(data, title, **kwargs) 

177 

178 

179# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

180def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

181 "Creates lines for tables where the right side is a constraint." 

182 # TODO: this should support 1D array inputs from sweeps 

183 excluded = {"units"} if showmodels else {"units", "lineage"} 

184 models, decorated = {}, [] 

185 for sortby, openingstr, _, constraint in sorted(data): 

186 model = lineagestr(constraint) if sortbymodel else "" 

187 if model not in models: 

188 models[model] = len(models) 

189 constrstr = try_str_without( 

190 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded)) 

191 if " at 0x" in constrstr: # don't print memory addresses 

192 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

193 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

194 decorated.sort() 

195 previous_model, lines = None, [] 

196 for varlist in decorated: 

197 _, model, _, constrstr, openingstr = varlist 

198 if model != previous_model: 

199 if lines: 

200 lines.append(["", ""]) 

201 if model or lines: 

202 lines.append([("newmodelline",), model]) 

203 previous_model = model 

204 minlen, maxlen = 25, 80 

205 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

206 constraintlines = [] 

207 line = "" 

208 next_idx = 0 

209 while next_idx < len(segments): 

210 segment = segments[next_idx] 

211 next_idx += 1 

212 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

213 segments[next_idx] = segment[1:] + segments[next_idx] 

214 segment = segment[0] 

215 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

216 constraintlines.append(line) 

217 line = " " # start a new line 

218 line += segment 

219 while len(line) > maxlen: 

220 constraintlines.append(line[:maxlen]) 

221 line = " " + line[maxlen:] 

222 constraintlines.append(line) 

223 lines += [(openingstr + " : ", constraintlines[0])] 

224 lines += [("", l) for l in constraintlines[1:]] 

225 if not lines: 

226 lines = [("", "(none)")] 

227 maxlens = np.max([list(map(len, line)) for line in lines 

228 if line[0] != ("newmodelline",)], axis=0) 

229 dirs = [">", "<"] # we'll check lengths before using zip 

230 assert len(list(dirs)) == len(list(maxlens)) 

231 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

232 for i, line in enumerate(lines): 

233 if line[0] == ("newmodelline",): 

234 linelist = [fmts[0].format(" | "), line[1]] 

235 else: 

236 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

237 lines[i] = "".join(linelist).rstrip() 

238 return [title] + ["-"*len(title)] + lines + [""] 

239 

240 

241def warnings_table(self, _, **kwargs): 

242 "Makes a table for all warnings in the solution." 

243 title = "WARNINGS" 

244 lines = ["~"*len(title), title, "~"*len(title)] 

245 if "warnings" not in self or not self["warnings"]: 

246 return [] 

247 for wtype in sorted(self["warnings"]): 

248 data_vec = self["warnings"][wtype] 

249 if len(data_vec) == 0: 

250 continue 

251 if not hasattr(data_vec, "shape"): 

252 data_vec = [data_vec] # not a sweep 

253 else: 

254 all_equal = True 

255 for data in data_vec[1:]: 

256 eq_i = data == data_vec[0] 

257 if hasattr(eq_i, "all"): 

258 eq_i = eq_i.all() 

259 if not eq_i: 

260 all_equal = False 

261 break 

262 if all_equal: 

263 data_vec = [data_vec[0]] # warnings identical across sweeps 

264 for i, data in enumerate(data_vec): 

265 if len(data) == 0: 

266 continue 

267 data = sorted(data, key=lambda l: l[0]) # sort by msg 

268 title = wtype 

269 if len(data_vec) > 1: 

270 title += f" in sweep {i}" 

271 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

272 data = [(-int(1e5*relax_sensitivity), 

273 f"{relax_sensitivity:+6.2g}", id(c), c) 

274 for _, (relax_sensitivity, c) in data] 

275 lines += constraint_table(data, title, **kwargs) 

276 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

277 data = [(-int(1e5*rel_diff), 

278 "%.4g %s %.4g" % tightvalues, id(c), c) 

279 for _, (rel_diff, tightvalues, c) in data] 

280 lines += constraint_table(data, title, **kwargs) 

281 else: 

282 lines += [title] + ["-"*len(wtype)] 

283 lines += [msg for msg, _ in data] + [""] 

284 if len(lines) == 3: # just the header 

285 return [] 

286 lines[-1] = "~~~~~~~~" 

287 return lines + [""] 

288 

289def bdtable_gen(key): 

290 "Generator for breakdown tablefns" 

291 

292 def bdtable(self, _showvars, **_): 

293 "Cost breakdown plot" 

294 bds = Breakdowns(self) 

295 original_stdout = sys.stdout 

296 try: 

297 sys.stdout = SolverLog(original_stdout, verbosity=0) 

298 bds.plot(key) 

299 finally: 

300 lines = sys.stdout.lines() 

301 sys.stdout = original_stdout 

302 return lines 

303 

304 return bdtable 

305 

306 

307TABLEFNS = {"sensitivities": senss_table, 

308 "top sensitivities": topsenss_table, 

309 "insensitivities": insenss_table, 

310 "model sensitivities": msenss_table, 

311 "tightest constraints": tight_table, 

312 "loose constraints": loose_table, 

313 "warnings": warnings_table, 

314 "model sensitivities breakdown": bdtable_gen("model sensitivities"), 

315 "cost breakdown": bdtable_gen("cost") 

316 } 

317 

318def unrolled_absmax(values): 

319 "From an iterable of numbers and arrays, returns the largest magnitude" 

320 finalval, absmaxest = None, 0 

321 for val in values: 

322 absmaxval = np.abs(val).max() 

323 if absmaxval >= absmaxest: 

324 absmaxest, finalval = absmaxval, val 

325 if getattr(finalval, "shape", None): 

326 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

327 finalval.shape)] 

328 return finalval 

329 

330 

331def cast(function, val1, val2): 

332 "Relative difference between val1 and val2 (positive if val2 is larger)" 

333 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

334 pywarnings.simplefilter("ignore") 

335 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

336 if val1.ndim == val2.ndim: 

337 return function(val1, val2) 

338 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

339 dimdelta = dimmest.ndim - lessdim.ndim 

340 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

341 if dimmest is val1: 

342 return function(dimmest, lessdim[add_axes]) 

343 if dimmest is val2: 

344 return function(lessdim[add_axes], dimmest) 

345 return function(val1, val2) 

346 

347 

348class SolutionArray(DictOfLists): 

349 """A dictionary (of dictionaries) of lists, with convenience methods. 

350 

351 Items 

352 ----- 

353 cost : array 

354 variables: dict of arrays 

355 sensitivities: dict containing: 

356 monomials : array 

357 posynomials : array 

358 variables: dict of arrays 

359 localmodels : NomialArray 

360 Local power-law fits (small sensitivities are cut off) 

361 

362 Example 

363 ------- 

364 >>> import gpkit 

365 >>> import numpy as np 

366 >>> x = gpkit.Variable("x") 

367 >>> x_min = gpkit.Variable("x_{min}", 2) 

368 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

369 >>> 

370 >>> # VALUES 

371 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

372 >>> assert all(np.array(values) == 2) 

373 >>> 

374 >>> # SENSITIVITIES 

375 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

376 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

377 >>> assert all(np.array(senss) == 1) 

378 """ 

379 modelstr = "" 

380 _name_collision_varkeys = None 

381 _lineageset = False 

382 table_titles = {"choicevariables": "Choice Variables", 

383 "sweepvariables": "Swept Variables", 

384 "freevariables": "Free Variables", 

385 "constants": "Fixed Variables", # TODO: change everywhere 

386 "variables": "Variables"} 

387 

388 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches 

389 "Returns the set of contained varkeys whose names are not unique" 

390 if self._name_collision_varkeys is None: 

391 self._name_collision_varkeys = {} 

392 self["variables"].update_keymap() 

393 keymap = self["variables"].keymap 

394 name_collisions = defaultdict(set) 

395 for key in keymap: 

396 if hasattr(key, "key"): 

397 if len(keymap[key.name]) == 1: # very unique 

398 self._name_collision_varkeys[key] = 0 

399 else: 

400 shortname = key.str_without(["lineage", "vec"]) 

401 if len(keymap[shortname]) > 1: 

402 name_collisions[shortname].add(key) 

403 for varkeys in name_collisions.values(): 

404 min_namespaced = defaultdict(set) 

405 for vk in varkeys: 

406 *_, mineage = vk.lineagestr().split(".") 

407 min_namespaced[(mineage, 1)].add(vk) 

408 while any(len(vks) > 1 for vks in min_namespaced.values()): 

409 for key, vks in list(min_namespaced.items()): 

410 if len(vks) <= 1: 

411 continue 

412 del min_namespaced[key] 

413 mineage, idx = key 

414 idx += 1 

415 for vk in vks: 

416 lineages = vk.lineagestr().split(".") 

417 submineage = lineages[-idx] + "." + mineage 

418 min_namespaced[(submineage, idx)].add(vk) 

419 for (_, idx), vks in min_namespaced.items(): 

420 vk, = vks 

421 self._name_collision_varkeys[vk] = idx 

422 if clear: 

423 self._lineageset = False 

424 for vk in self._name_collision_varkeys: 

425 del vk.descr["necessarylineage"] 

426 else: 

427 self._lineageset = True 

428 for vk, idx in self._name_collision_varkeys.items(): 

429 vk.descr["necessarylineage"] = idx 

430 

431 def __len__(self): 

432 try: 

433 return len(self["cost"]) 

434 except TypeError: 

435 return 1 

436 except KeyError: 

437 return 0 

438 

439 def __call__(self, posy): 

440 posy_subbed = self.subinto(posy) 

441 return getattr(posy_subbed, "c", posy_subbed) 

442 

443 def almost_equal(self, other, reltol=1e-3): 

444 "Checks for almost-equality between two solutions" 

445 svars, ovars = self["variables"], other["variables"] 

446 svks, ovks = set(svars), set(ovars) 

447 if svks != ovks: 

448 return False 

449 for key in svks: 

450 reldiff = np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1)) 

451 if reldiff >= reltol: 

452 return False 

453 return True 

454 

455 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

456 # pylint: disable=too-many-arguments 

457 def diff(self, other, showvars=None, *, 

458 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

459 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

460 sortmodelsbysenss=True, **tableargs): 

461 """Outputs differences between this solution and another 

462 

463 Arguments 

464 --------- 

465 other : solution or string 

466 strings will be treated as paths to pickled solutions 

467 senssdiff : boolean 

468 if True, show sensitivity differences 

469 sensstol : float 

470 the smallest sensitivity difference worth showing 

471 absdiff : boolean 

472 if True, show absolute differences 

473 abstol : float 

474 the smallest absolute difference worth showing 

475 reldiff : boolean 

476 if True, show relative differences 

477 reltol : float 

478 the smallest relative difference worth showing 

479 

480 Returns 

481 ------- 

482 str 

483 """ 

484 if sortmodelsbysenss: 

485 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

486 else: 

487 tableargs["sortmodelsbysenss"] = False 

488 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

489 "skipifempty": False}) 

490 if isinstance(other, Strings): 

491 if other[-4:] == ".pgz": 

492 other = SolutionArray.decompress_file(other) 

493 else: 

494 with open(other, "rb") as f: 

495 other = pickle.load(f) 

496 svars, ovars = self["variables"], other["variables"] 

497 lines = ["Solution Diff", 

498 "=============", 

499 "(argument is the baseline solution)", ""] 

500 svks, ovks = set(svars), set(ovars) 

501 if showvars: 

502 lines[0] += " (for selected variables)" 

503 lines[1] += "=========================" 

504 showvars = self._parse_showvars(showvars) 

505 svks = {k for k in showvars if k in svars} 

506 ovks = {k for k in showvars if k in ovars} 

507 if constraintsdiff and other.modelstr and self.modelstr: 

508 if self.modelstr == other.modelstr: 

509 lines += ["** no constraint differences **", ""] 

510 else: 

511 cdiff = ["Constraint Differences", 

512 "**********************"] 

513 cdiff.extend(list(difflib.unified_diff( 

514 other.modelstr.split("\n"), self.modelstr.split("\n"), 

515 lineterm="", n=3))[2:]) 

516 cdiff += ["", "**********************", ""] 

517 lines += cdiff 

518 if svks - ovks: 

519 lines.append("Variable(s) of this solution" 

520 " which are not in the argument:") 

521 lines.append("\n".join(f" {key}" for key in svks - ovks)) 

522 lines.append("") 

523 if ovks - svks: 

524 lines.append("Variable(s) of the argument" 

525 " which are not in this solution:") 

526 lines.append("\n".join(f" {key}" for key in ovks - svks)) 

527 lines.append("") 

528 sharedvks = svks.intersection(ovks) 

529 if reldiff: 

530 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

531 for vk in sharedvks} 

532 lines += var_table(rel_diff, 

533 f"Relative Differences |above {reltol:g}%|", 

534 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

535 minval=reltol, printunits=False, **tableargs) 

536 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

537 lines.insert(-1, ("The largest is %+g%%." 

538 % unrolled_absmax(rel_diff.values()))) 

539 if absdiff: 

540 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

541 lines += var_table(abs_diff, 

542 f"Absolute Differences |above {abstol:g}|", 

543 valfmt="%+.2g", vecfmt="%+8.2g", 

544 minval=abstol, **tableargs) 

545 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

546 lines.insert(-1, ("The largest is %+g." 

547 % unrolled_absmax(abs_diff.values()))) 

548 if senssdiff: 

549 ssenss = self["sensitivities"]["variables"] 

550 osenss = other["sensitivities"]["variables"] 

551 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

552 for vk in svks.intersection(ovks)} 

553 lines += var_table(senss_delta, 

554 f"Sensitivity Differences |above {sensstol:g}|", 

555 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

556 minval=sensstol, printunits=False, **tableargs) 

557 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

558 lines.insert(-1, ("The largest is %+g." 

559 % unrolled_absmax(senss_delta.values()))) 

560 return "\n".join(lines) 

561 

562 def save(self, filename="solution.pkl", 

563 *, saveconstraints=True, **pickleargs): 

564 """Pickles the solution and saves it to a file. 

565 

566 Solution can then be loaded with e.g.: 

567 >>> import pickle 

568 >>> pickle.load(open("solution.pkl")) 

569 """ 

570 with open(filename, "wb") as f: 

571 with SolSavingEnvironment(self, saveconstraints): 

572 pickle.dump(self, f, **pickleargs) 

573 

574 def save_compressed(self, filename="solution.pgz", 

575 *, saveconstraints=True, **cpickleargs): 

576 "Pickle a file and then compress it into a file with extension." 

577 with gzip.open(filename, "wb") as f: 

578 with SolSavingEnvironment(self, saveconstraints): 

579 pickled = pickle.dumps(self, **cpickleargs) 

580 f.write(pickletools.optimize(pickled)) 

581 

582 @staticmethod 

583 def decompress_file(file): 

584 "Load a gzip-compressed pickle file" 

585 with gzip.open(file, "rb") as f: 

586 return pickle.Unpickler(f).load() 

587 

588 def varnames(self, showvars, exclude): 

589 "Returns list of variables, optionally with minimal unique names" 

590 if showvars: 

591 showvars = self._parse_showvars(showvars) 

592 self.set_necessarylineage() 

593 names = {} 

594 for key in showvars or self["variables"]: 

595 for k in self["variables"].keymap[key]: 

596 names[k.str_without(exclude)] = k 

597 self.set_necessarylineage(clear=True) 

598 return names 

599 

600 def savemat(self, filename="solution.mat", *, showvars=None, 

601 excluded="vec"): 

602 "Saves primal solution as matlab file" 

603 from scipy.io import savemat # pylint: disable=import-outside-toplevel 

604 savemat(filename, 

605 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

606 for name, key in self.varnames(showvars, excluded).items()}) 

607 

608 def todataframe(self, showvars=None, excluded="vec"): 

609 "Returns primal solution as pandas dataframe" 

610 # pylint: disable=import-outside-toplevel 

611 import pandas as pd # pylint:disable=import-error 

612 rows = [] 

613 cols = ["Name", "Index", "Value", "Units", "Label", 

614 "Lineage", "Other"] 

615 for _, key in sorted(self.varnames(showvars, excluded).items(), 

616 key=lambda k: k[0]): 

617 value = self["variables"][key] 

618 if key.shape: 

619 idxs = [] 

620 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

621 while not it.finished: 

622 idx = it.multi_index 

623 idxs.append(idx[0] if len(idx) == 1 else idx) 

624 it.iternext() 

625 else: 

626 idxs = [None] 

627 for idx in idxs: 

628 row = [ 

629 key.name, 

630 "" if idx is None else idx, 

631 value if idx is None else value[idx]] 

632 rows.append(row) 

633 row.extend([ 

634 key.unitstr(), 

635 key.label or "", 

636 key.lineage or "", 

637 ", ".join(f"{k}={v}" for (k, v) in key.descr.items() 

638 if k not in ["name", "units", "unitrepr", 

639 "idx", "shape", "veckey", 

640 "value", "vecfn", 

641 "lineage", "label"])]) 

642 return pd.DataFrame(rows, columns=cols) 

643 

644 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

645 "Saves solution table as a text file" 

646 with open(filename, "w", encoding="UTF-8") as f: 

647 if printmodel: 

648 f.write(self.modelstr + "\n") 

649 f.write(self.table(**kwargs)) 

650 

651 def savejson(self, filename="solution.json", showvars=None): 

652 "Saves solution table as a json file" 

653 sol_dict = {} 

654 if self._lineageset: 

655 self.set_necessarylineage(clear=True) 

656 data = self["variables"] 

657 if showvars: 

658 showvars = self._parse_showvars(showvars) 

659 data = {k: data[k] for k in showvars if k in data} 

660 # add appropriate data for each variable to the dictionary 

661 for k, v in data.items(): 

662 key = str(k) 

663 if isinstance(v, np.ndarray): 

664 val = {"v": v.tolist(), "u": k.unitstr()} 

665 else: 

666 val = {"v": v, "u": k.unitstr()} 

667 sol_dict[key] = val 

668 with open(filename, "w", encoding="UTF-8") as f: 

669 json.dump(sol_dict, f) 

670 

671 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

672 "Saves primal solution as a CSV sorted by modelname, like the tables." 

673 data = self["variables"] 

674 if showvars: 

675 showvars = self._parse_showvars(showvars) 

676 data = {k: data[k] for k in showvars if k in data} 

677 # if the columns don't capture any dimensions, skip them 

678 minspan, maxspan = None, 1 

679 for v in data.values(): 

680 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

681 minspan_ = min((di for di in v.shape if di != 1)) 

682 maxspan_ = max((di for di in v.shape if di != 1)) 

683 if minspan is None or minspan_ < minspan: 

684 minspan = minspan_ 

685 if maxspan is None or maxspan_ > maxspan: 

686 maxspan = maxspan_ 

687 if minspan is not None and minspan > valcols: 

688 valcols = 1 

689 if maxspan < valcols: 

690 valcols = maxspan 

691 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

692 tables=("cost", "sweepvariables", "freevariables", 

693 "constants", "sensitivities")) 

694 with open(filename, "w", encoding="UTF-8") as f: 

695 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

696 + "Units,Description\n") 

697 for line in lines: 

698 if line[0] == ("newmodelline",): 

699 f.write(line[1]) 

700 elif not line[1]: # spacer line 

701 f.write("\n") 

702 else: 

703 f.write("," + line[0].replace(" : ", "") + ",") 

704 vals = line[1].replace("[", "").replace("]", "").strip() 

705 for el in vals.split(): 

706 f.write(el + ",") 

707 f.write(","*(valcols - len(vals.split()))) 

708 f.write((line[2].replace("[", "").replace("]", "").strip() 

709 + ",")) 

710 f.write(line[3].strip() + "\n") 

711 

712 def subinto(self, posy): 

713 "Returns NomialArray of each solution substituted into posy." 

714 if posy in self["variables"]: 

715 return self["variables"](posy) 

716 

717 if not hasattr(posy, "sub"): 

718 raise ValueError(f"no variable '{posy}' found in the solution") 

719 

720 if len(self) > 1: 

721 return NomialArray([self.atindex(i).subinto(posy) 

722 for i in range(len(self))]) 

723 

724 return posy.sub(self["variables"], require_positive=False) 

725 

726 def _parse_showvars(self, showvars): 

727 showvars_out = set() 

728 for k in showvars: 

729 k, _ = self["variables"].parse_and_index(k) 

730 keys = self["variables"].keymap[k] 

731 showvars_out.update(keys) 

732 return showvars_out 

733 

734 def summary(self, showvars=(), **kwargs): 

735 "Print summary table, showing no sensitivities or constants" 

736 return self.table(showvars, 

737 ["cost breakdown", "model sensitivities breakdown", 

738 "warnings", "sweepvariables", "freevariables"], 

739 **kwargs) 

740 

741 def table(self, showvars=(), 

742 tables=("cost breakdown", "model sensitivities breakdown", 

743 "warnings", "sweepvariables", "freevariables", 

744 "constants", "sensitivities", "tightest constraints"), 

745 sortmodelsbysenss=False, **kwargs): 

746 """A table representation of this SolutionArray 

747 

748 Arguments 

749 --------- 

750 tables: Iterable 

751 Which to print of ("cost", "sweepvariables", "freevariables", 

752 "constants", "sensitivities") 

753 fixedcols: If true, print vectors in fixed-width format 

754 latex: int 

755 If > 0, return latex format (options 1-3); otherwise plain text 

756 included_models: Iterable of strings 

757 If specified, the models (by name) to include 

758 excluded_models: Iterable of strings 

759 If specified, model names to exclude 

760 

761 Returns 

762 ------- 

763 str 

764 """ 

765 if sortmodelsbysenss and "sensitivities" in self: 

766 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

767 else: 

768 kwargs["sortmodelsbysenss"] = False 

769 varlist = list(self["variables"]) 

770 has_only_one_model = True 

771 for var in varlist[1:]: 

772 if var.lineage != varlist[0].lineage: 

773 has_only_one_model = False 

774 break 

775 if has_only_one_model: 

776 kwargs["sortbymodel"] = False 

777 self.set_necessarylineage() 

778 showvars = self._parse_showvars(showvars) 

779 strs = [] 

780 for table in tables: 

781 if "breakdown" in table: 

782 if (len(self) > 1 or not UNICODE_EXPONENTS 

783 or "sensitivities" not in self): 

784 # no breakdowns for sweeps or no-unicode environments 

785 table = table.replace(" breakdown", "") 

786 if "sensitivities" not in self and ("sensitivities" in table or 

787 "constraints" in table): 

788 continue 

789 if table == "cost": 

790 cost = self["cost"] # pylint: disable=unsubscriptable-object 

791 if kwargs.get("latex", None): # cost is not printed for latex 

792 continue 

793 strs += ["\nOptimal Cost\n------------"] 

794 if len(self) > 1: 

795 costs = [f"{c:-8.3g}" for c in mag(cost[:4])] 

796 strs += [" [ %s %s ]" % (" ".join(costs), 

797 "..." if len(self) > 4 else "")] 

798 else: 

799 strs += [f" {mag(cost):-.4g}"] 

800 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

801 strs += [""] 

802 elif table in TABLEFNS: 

803 strs += TABLEFNS[table](self, showvars, **kwargs) 

804 elif table in self: 

805 data = self[table] 

806 if showvars: 

807 showvars = self._parse_showvars(showvars) 

808 data = {k: data[k] for k in showvars if k in data} 

809 strs += var_table(data, self.table_titles[table], **kwargs) 

810 if kwargs.get("latex", None): 

811 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

812 "% \\usepackage{booktabs}", 

813 "% \\usepackage{longtable}", 

814 "% \\usepackage{amsmath}", 

815 "% \\begin{document}\n")) 

816 strs = [preamble] + strs + ["% \\end{document}"] 

817 self.set_necessarylineage(clear=True) 

818 return "\n".join(strs) 

819 

820 def plot(self, posys=None, axes=None): 

821 "Plots a sweep for each posy" 

822 if len(self["sweepvariables"]) != 1: 

823 print("SolutionArray.plot only supports 1-dimensional sweeps") 

824 if not hasattr(posys, "__len__"): 

825 posys = [posys] 

826 # pylint: disable=import-outside-toplevel 

827 import matplotlib.pyplot as plt 

828 from .interactive.plot_sweep import assign_axes 

829 from . import GPBLU 

830 (swept, x), = self["sweepvariables"].items() 

831 posys, axes = assign_axes(swept, posys, axes) 

832 for posy, ax in zip(posys, axes): 

833 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

834 ax.plot(x, y, color=GPBLU) 

835 if len(axes) == 1: 

836 axes, = axes 

837 return plt.gcf(), axes 

838 

839 

840# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

841# pylint: disable=too-many-arguments 

842def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

843 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

844 minval=0, sortbyvals=False, hidebelowminval=False, 

845 included_models=None, excluded_models=None, sortbymodel=True, 

846 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

847 """ 

848 Pretty string representation of a dict of VarKeys 

849 Iterable values are handled specially (partial printing) 

850 

851 Arguments 

852 --------- 

853 data : dict whose keys are VarKey's 

854 data to represent in table 

855 title : string 

856 printunits : bool 

857 latex : int 

858 If > 0, return latex format (options 1-3); otherwise plain text 

859 varfmt : string 

860 format for variable names 

861 valfmt : string 

862 format for scalar values 

863 vecfmt : string 

864 format for vector values 

865 minval : float 

866 skip values with all(abs(value)) < minval 

867 sortbyvals : boolean 

868 If true, rows are sorted by their average value instead of by name. 

869 included_models : Iterable of strings 

870 If specified, the models (by name) to include 

871 excluded_models : Iterable of strings 

872 If specified, model names to exclude 

873 """ 

874 if not data: 

875 return [] 

876 decorated, models = [], set() 

877 for i, (k, v) in enumerate(data.items()): 

878 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

879 continue # no values below minval 

880 if minval and hidebelowminval and getattr(v, "shape", None): 

881 v[np.abs(v) <= minval] = np.nan 

882 model = lineagestr(k.lineage) if sortbymodel else "" 

883 if not sortmodelsbysenss: 

884 msenss = 0 

885 else: # sort should match that in msenss_table above 

886 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4) 

887 models.add(model) 

888 b = bool(getattr(v, "shape", None)) 

889 s = k.str_without(("lineage", "vec")) 

890 if not sortbyvals: 

891 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

892 else: # for consistent sorting, add small offset to negative vals 

893 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

894 sort = (float(f"{-val:.4g}"), k.name) 

895 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

896 if not decorated and skipifempty: 

897 return [] 

898 if included_models: 

899 included_models = set(included_models) 

900 included_models.add("") 

901 models = models.intersection(included_models) 

902 if excluded_models: 

903 models = models.difference(excluded_models) 

904 decorated.sort() 

905 previous_model, lines = None, [] 

906 for varlist in decorated: 

907 if sortbyvals: 

908 model, _, msenss, isvector, varstr, _, var, val = varlist 

909 else: 

910 msenss, model, isvector, varstr, _, var, val = varlist 

911 if model not in models: 

912 continue 

913 if model != previous_model: 

914 if lines: 

915 lines.append(["", "", "", ""]) 

916 if model: 

917 if not latex: 

918 lines.append([("newmodelline",), model, "", ""]) 

919 else: 

920 lines.append( 

921 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

922 previous_model = model 

923 label = var.descr.get("label", "") 

924 units = var.unitstr(" [%s] ") if printunits else "" 

925 if not isvector: 

926 valstr = valfmt % val 

927 else: 

928 last_dim_index = len(val.shape)-1 

929 horiz_dim, ncols = last_dim_index, 1 # starting values 

930 for dim_idx, dim_size in enumerate(val.shape): 

931 if ncols <= dim_size <= maxcolumns: 

932 horiz_dim, ncols = dim_idx, dim_size 

933 # align the array with horiz_dim by making it the last one 

934 dim_order = list(range(last_dim_index)) 

935 dim_order.insert(horiz_dim, last_dim_index) 

936 flatval = val.transpose(dim_order).flatten() 

937 vals = [vecfmt % v for v in flatval[:ncols]] 

938 bracket = " ] " if len(flatval) <= ncols else "" 

939 valstr = "[ %s%s" % (" ".join(vals), bracket) 

940 for before, after in VALSTR_REPLACES: 

941 valstr = valstr.replace(before, after) 

942 if not latex: 

943 lines.append([varstr, valstr, units, label]) 

944 if isvector and len(flatval) > ncols: 

945 values_remaining = len(flatval) - ncols 

946 while values_remaining > 0: 

947 idx = len(flatval)-values_remaining 

948 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

949 values_remaining -= ncols 

950 valstr = " " + " ".join(vals) 

951 for before, after in VALSTR_REPLACES: 

952 valstr = valstr.replace(before, after) 

953 if values_remaining <= 0: 

954 spaces = (-values_remaining 

955 * len(valstr)//(values_remaining + ncols)) 

956 valstr = valstr + " ]" + " "*spaces 

957 lines.append(["", valstr, "", ""]) 

958 else: 

959 varstr = "$%s$" % varstr.replace(" : ", "") 

960 if latex == 1: # normal results table 

961 lines.append([varstr, valstr, f"${var.latex_unitstr()}$", 

962 label]) 

963 coltitles = [title, "Value", "Units", "Description"] 

964 elif latex == 2: # no values 

965 lines.append([varstr, f"${var.latex_unitstr()}$", label]) 

966 coltitles = [title, "Units", "Description"] 

967 elif latex == 3: # no description 

968 lines.append([varstr, valstr, f"${var.latex_unitstr()}$"]) 

969 coltitles = [title, "Value", "Units"] 

970 else: 

971 raise ValueError(f"Unexpected latex option, {latex}.") 

972 if rawlines: 

973 return lines 

974 if not latex: 

975 if lines: 

976 maxlens = np.max([list(map(len, line)) for line in lines 

977 if line[0] != ("newmodelline",)], axis=0) 

978 dirs = [">", "<", "<", "<"] 

979 # check lengths before using zip 

980 assert len(list(dirs)) == len(list(maxlens)) 

981 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

982 for i, line in enumerate(lines): 

983 if line[0] == ("newmodelline",): 

984 line = [fmts[0].format(" | "), line[1]] 

985 else: 

986 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

987 lines[i] = "".join(line).rstrip() 

988 lines = [title] + ["-"*len(title)] + lines + [""] 

989 else: 

990 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

991 lines = (["\n".join(["{\\footnotesize", 

992 "\\begin{longtable}{%s}" % colfmt[latex], 

993 "\\toprule", 

994 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

995 [" & ".join(l) + " \\\\" for l in lines] + 

996 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

997 return lines