Coverage for gpkit/solution_array.py: 79%

638 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 16:47 -0500

1"""Defines SolutionArray class""" 

2import sys 

3import re 

4import json 

5import difflib 

6from operator import sub 

7import warnings as pywarnings 

8import pickle 

9import gzip 

10import pickletools 

11from collections import defaultdict 

12import numpy as np 

13from .nomials import NomialArray 

14from .small_classes import DictOfLists, Strings, SolverLog 

15from .small_scripts import mag, try_str_without 

16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS 

17from .breakdowns import Breakdowns 

18 

19 

20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

21 

22VALSTR_REPLACES = [ 

23 ("+nan", " nan"), 

24 ("-nan", " nan"), 

25 ("nan%", "nan "), 

26 ("nan", " - "), 

27] 

28 

29 

30class SolSavingEnvironment: 

31 """Temporarily removes construction/solve attributes from constraints. 

32 

33 This approximately halves the size of the pickled solution. 

34 """ 

35 

36 def __init__(self, solarray, saveconstraints): 

37 self.solarray = solarray 

38 self.attrstore = {} 

39 self.saveconstraints = saveconstraints 

40 self.constraintstore = None 

41 

42 

43 def __enter__(self): 

44 if "sensitivities" not in self.solarray: 

45 pass 

46 elif self.saveconstraints: 

47 for constraint_attr in ["bounded", "meq_bounded", "vks", 

48 "v_ss", "unsubbed", "varkeys"]: 

49 store = {} 

50 for constraint in self.solarray["sensitivities"]["constraints"]: 

51 if getattr(constraint, constraint_attr, None): 

52 store[constraint] = getattr(constraint, constraint_attr) 

53 delattr(constraint, constraint_attr) 

54 self.attrstore[constraint_attr] = store 

55 else: 

56 self.constraintstore = \ 

57 self.solarray["sensitivities"].pop("constraints") 

58 

59 def __exit__(self, type_, val, traceback): 

60 if self.saveconstraints: 

61 for constraint_attr, store in self.attrstore.items(): 

62 for constraint, value in store.items(): 

63 setattr(constraint, constraint_attr, value) 

64 elif self.constraintstore: 

65 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

66 

67def msenss_table(data, _, **kwargs): 

68 "Returns model sensitivity table lines" 

69 if "models" not in data.get("sensitivities", {}): 

70 return "" 

71 data = sorted(data["sensitivities"]["models"].items(), 

72 key=lambda i: ((i[1] < 0.1).all(), 

73 -np.max(i[1]) if (i[1] < 0.1).all() 

74 else -round(np.mean(i[1]), 1), i[0])) 

75 lines = ["Model Sensitivities", "-------------------"] 

76 if kwargs["sortmodelsbysenss"]: 

77 lines[0] += " (sorts models in sections below)" 

78 previousmsenssstr = "" 

79 for model, msenss in data: 

80 if not model: # for now let's only do named models 

81 continue 

82 if (msenss < 0.1).all(): 

83 msenss = np.max(msenss) 

84 if msenss: 

85 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss))) 

86 else: 

87 msenssstr = " =0 " 

88 else: 

89 meansenss = round(np.mean(msenss), 1) 

90 msenssstr = "%+6.1f" % meansenss 

91 deltas = msenss - meansenss 

92 if np.max(np.abs(deltas)) > 0.1: 

93 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

94 for d in deltas] 

95 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

96 if msenssstr == previousmsenssstr: 

97 msenssstr = " "*len(msenssstr) 

98 else: 

99 previousmsenssstr = msenssstr 

100 lines.append("%s : %s" % (msenssstr, model)) 

101 return lines + [""] if len(lines) > 3 else [] 

102 

103 

104def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

105 "Returns sensitivity table lines" 

106 if "variables" in data.get("sensitivities", {}): 

107 data = data["sensitivities"]["variables"] 

108 if showvars: 

109 data = {k: data[k] for k in showvars if k in data} 

110 return var_table(data, title, sortbyvals=True, skipifempty=True, 

111 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

112 printunits=False, minval=1e-3, **kwargs) 

113 

114 

115def topsenss_table(data, showvars, nvars=5, **kwargs): 

116 "Returns top sensitivity table lines" 

117 data, filtered = topsenss_filter(data, showvars, nvars) 

118 title = "Most Sensitive Variables" 

119 if filtered: 

120 title = "Next Most Sensitive Variables" 

121 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

122 

123 

124def topsenss_filter(data, showvars, nvars=5): 

125 "Filters sensitivities down to top N vars" 

126 if "variables" in data.get("sensitivities", {}): 

127 data = data["sensitivities"]["variables"] 

128 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

129 if not np.isnan(s).any()} 

130 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

131 filter_already_shown = showvars.intersection(topk) 

132 for k in filter_already_shown: 

133 topk.remove(k) 

134 if nvars > 3: # always show at least 3 

135 nvars -= 1 

136 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

137 

138 

139def insenss_table(data, _, maxval=0.1, **kwargs): 

140 "Returns insensitivity table lines" 

141 if "constants" in data.get("sensitivities", {}): 

142 data = data["sensitivities"]["variables"] 

143 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

144 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

145 

146 

147def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

148 "Return constraint tightness lines" 

149 title = "Most Sensitive Constraints" 

150 if len(self) > 1: 

151 title += " (in last sweep)" 

152 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)), 

153 "%+6.2g" % abs(s[-1]), id(c), c) 

154 for c, s in self["sensitivities"]["constraints"].items() 

155 if s[-1] >= tight_senss)[:ntightconstrs] 

156 else: 

157 data = sorted(((-float("%+6.2g" % abs(s)), str(c)), 

158 "%+6.2g" % abs(s), id(c), c) 

159 for c, s in self["sensitivities"]["constraints"].items() 

160 if s >= tight_senss)[:ntightconstrs] 

161 return constraint_table(data, title, **kwargs) 

162 

163def loose_table(self, _, min_senss=1e-5, **kwargs): 

164 "Return constraint tightness lines" 

165 title = "Insensitive Constraints |below %+g|" % min_senss 

166 if len(self) > 1: 

167 title += " (in last sweep)" 

168 data = [(0, "", id(c), c) 

169 for c, s in self["sensitivities"]["constraints"].items() 

170 if s[-1] <= min_senss] 

171 else: 

172 data = [(0, "", id(c), c) 

173 for c, s in self["sensitivities"]["constraints"].items() 

174 if s <= min_senss] 

175 return constraint_table(data, title, **kwargs) 

176 

177 

178# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

179def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

180 "Creates lines for tables where the right side is a constraint." 

181 # TODO: this should support 1D array inputs from sweeps 

182 excluded = {"units"} if showmodels else {"units", "lineage"} 

183 models, decorated = {}, [] 

184 for sortby, openingstr, _, constraint in sorted(data): 

185 model = lineagestr(constraint) if sortbymodel else "" 

186 if model not in models: 

187 models[model] = len(models) 

188 constrstr = try_str_without( 

189 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded)) 

190 if " at 0x" in constrstr: # don't print memory addresses 

191 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

192 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

193 decorated.sort() 

194 previous_model, lines = None, [] 

195 for varlist in decorated: 

196 _, model, _, constrstr, openingstr = varlist 

197 if model != previous_model: 

198 if lines: 

199 lines.append(["", ""]) 

200 if model or lines: 

201 lines.append([("newmodelline",), model]) 

202 previous_model = model 

203 minlen, maxlen = 25, 80 

204 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

205 constraintlines = [] 

206 line = "" 

207 next_idx = 0 

208 while next_idx < len(segments): 

209 segment = segments[next_idx] 

210 next_idx += 1 

211 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

212 segments[next_idx] = segment[1:] + segments[next_idx] 

213 segment = segment[0] 

214 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

215 constraintlines.append(line) 

216 line = " " # start a new line 

217 line += segment 

218 while len(line) > maxlen: 

219 constraintlines.append(line[:maxlen]) 

220 line = " " + line[maxlen:] 

221 constraintlines.append(line) 

222 lines += [(openingstr + " : ", constraintlines[0])] 

223 lines += [("", l) for l in constraintlines[1:]] 

224 if not lines: 

225 lines = [("", "(none)")] 

226 maxlens = np.max([list(map(len, line)) for line in lines 

227 if line[0] != ("newmodelline",)], axis=0) 

228 dirs = [">", "<"] # we'll check lengths before using zip 

229 assert len(list(dirs)) == len(list(maxlens)) 

230 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

231 for i, line in enumerate(lines): 

232 if line[0] == ("newmodelline",): 

233 linelist = [fmts[0].format(" | "), line[1]] 

234 else: 

235 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

236 lines[i] = "".join(linelist).rstrip() 

237 return [title] + ["-"*len(title)] + lines + [""] 

238 

239 

240def warnings_table(self, _, **kwargs): 

241 "Makes a table for all warnings in the solution." 

242 title = "WARNINGS" 

243 lines = ["~"*len(title), title, "~"*len(title)] 

244 if "warnings" not in self or not self["warnings"]: 

245 return [] 

246 for wtype in sorted(self["warnings"]): 

247 data_vec = self["warnings"][wtype] 

248 if len(data_vec) == 0: 

249 continue 

250 if not hasattr(data_vec, "shape"): 

251 data_vec = [data_vec] # not a sweep 

252 else: 

253 all_equal = True 

254 for data in data_vec[1:]: 

255 eq_i = (data == data_vec[0]) 

256 if hasattr(eq_i, "all"): 

257 eq_i = eq_i.all() 

258 if not eq_i: 

259 all_equal = False 

260 break 

261 if all_equal: 

262 data_vec = [data_vec[0]] # warnings identical across sweeps 

263 for i, data in enumerate(data_vec): 

264 if len(data) == 0: 

265 continue 

266 data = sorted(data, key=lambda l: l[0]) # sort by msg 

267 title = wtype 

268 if len(data_vec) > 1: 

269 title += " in sweep %i" % i 

270 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

271 data = [(-int(1e5*relax_sensitivity), 

272 "%+6.2g" % relax_sensitivity, id(c), c) 

273 for _, (relax_sensitivity, c) in data] 

274 lines += constraint_table(data, title, **kwargs) 

275 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

276 data = [(-int(1e5*rel_diff), 

277 "%.4g %s %.4g" % tightvalues, id(c), c) 

278 for _, (rel_diff, tightvalues, c) in data] 

279 lines += constraint_table(data, title, **kwargs) 

280 else: 

281 lines += [title] + ["-"*len(wtype)] 

282 lines += [msg for msg, _ in data] + [""] 

283 if len(lines) == 3: # just the header 

284 return [] 

285 lines[-1] = "~~~~~~~~" 

286 return lines + [""] 

287 

288def bdtable_gen(key): 

289 "Generator for breakdown tablefns" 

290 

291 def bdtable(self, _showvars, **_): 

292 "Cost breakdown plot" 

293 bds = Breakdowns(self) 

294 original_stdout = sys.stdout 

295 try: 

296 sys.stdout = SolverLog(original_stdout, verbosity=0) 

297 bds.plot(key) 

298 finally: 

299 lines = sys.stdout.lines() 

300 sys.stdout = original_stdout 

301 return lines 

302 

303 return bdtable 

304 

305 

306TABLEFNS = {"sensitivities": senss_table, 

307 "top sensitivities": topsenss_table, 

308 "insensitivities": insenss_table, 

309 "model sensitivities": msenss_table, 

310 "tightest constraints": tight_table, 

311 "loose constraints": loose_table, 

312 "warnings": warnings_table, 

313 "model sensitivities breakdown": bdtable_gen("model sensitivities"), 

314 "cost breakdown": bdtable_gen("cost") 

315 } 

316 

317def unrolled_absmax(values): 

318 "From an iterable of numbers and arrays, returns the largest magnitude" 

319 finalval, absmaxest = None, 0 

320 for val in values: 

321 absmaxval = np.abs(val).max() 

322 if absmaxval >= absmaxest: 

323 absmaxest, finalval = absmaxval, val 

324 if getattr(finalval, "shape", None): 

325 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

326 finalval.shape)] 

327 return finalval 

328 

329 

330def cast(function, val1, val2): 

331 "Relative difference between val1 and val2 (positive if val2 is larger)" 

332 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

333 pywarnings.simplefilter("ignore") 

334 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

335 if val1.ndim == val2.ndim: 

336 return function(val1, val2) 

337 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

338 dimdelta = dimmest.ndim - lessdim.ndim 

339 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

340 if dimmest is val1: 

341 return function(dimmest, lessdim[add_axes]) 

342 if dimmest is val2: 

343 return function(lessdim[add_axes], dimmest) 

344 return function(val1, val2) 

345 

346 

347class SolutionArray(DictOfLists): 

348 """A dictionary (of dictionaries) of lists, with convenience methods. 

349 

350 Items 

351 ----- 

352 cost : array 

353 variables: dict of arrays 

354 sensitivities: dict containing: 

355 monomials : array 

356 posynomials : array 

357 variables: dict of arrays 

358 localmodels : NomialArray 

359 Local power-law fits (small sensitivities are cut off) 

360 

361 Example 

362 ------- 

363 >>> import gpkit 

364 >>> import numpy as np 

365 >>> x = gpkit.Variable("x") 

366 >>> x_min = gpkit.Variable("x_{min}", 2) 

367 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

368 >>> 

369 >>> # VALUES 

370 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

371 >>> assert all(np.array(values) == 2) 

372 >>> 

373 >>> # SENSITIVITIES 

374 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

375 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

376 >>> assert all(np.array(senss) == 1) 

377 """ 

378 modelstr = "" 

379 _name_collision_varkeys = None 

380 _lineageset = False 

381 table_titles = {"choicevariables": "Choice Variables", 

382 "sweepvariables": "Swept Variables", 

383 "freevariables": "Free Variables", 

384 "constants": "Fixed Variables", # TODO: change everywhere 

385 "variables": "Variables"} 

386 

387 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches 

388 "Returns the set of contained varkeys whose names are not unique" 

389 if self._name_collision_varkeys is None: 

390 self._name_collision_varkeys = {} 

391 self["variables"].update_keymap() 

392 keymap = self["variables"].keymap 

393 name_collisions = defaultdict(set) 

394 for key in keymap: 

395 if hasattr(key, "key"): 

396 if len(keymap[key.name]) == 1: # very unique 

397 self._name_collision_varkeys[key] = 0 

398 else: 

399 shortname = key.str_without(["lineage", "vec"]) 

400 if len(keymap[shortname]) > 1: 

401 name_collisions[shortname].add(key) 

402 for varkeys in name_collisions.values(): 

403 min_namespaced = defaultdict(set) 

404 for vk in varkeys: 

405 *_, mineage = vk.lineagestr().split(".") 

406 min_namespaced[(mineage, 1)].add(vk) 

407 while any(len(vks) > 1 for vks in min_namespaced.values()): 

408 for key, vks in list(min_namespaced.items()): 

409 if len(vks) <= 1: 

410 continue 

411 del min_namespaced[key] 

412 mineage, idx = key 

413 idx += 1 

414 for vk in vks: 

415 lineages = vk.lineagestr().split(".") 

416 submineage = lineages[-idx] + "." + mineage 

417 min_namespaced[(submineage, idx)].add(vk) 

418 for (_, idx), vks in min_namespaced.items(): 

419 vk, = vks 

420 self._name_collision_varkeys[vk] = idx 

421 if clear: 

422 self._lineageset = False 

423 for vk in self._name_collision_varkeys: 

424 del vk.descr["necessarylineage"] 

425 else: 

426 self._lineageset = True 

427 for vk, idx in self._name_collision_varkeys.items(): 

428 vk.descr["necessarylineage"] = idx 

429 

430 def __len__(self): 

431 try: 

432 return len(self["cost"]) 

433 except TypeError: 

434 return 1 

435 except KeyError: 

436 return 0 

437 

438 def __call__(self, posy): 

439 posy_subbed = self.subinto(posy) 

440 return getattr(posy_subbed, "c", posy_subbed) 

441 

442 def almost_equal(self, other, reltol=1e-3): 

443 "Checks for almost-equality between two solutions" 

444 svars, ovars = self["variables"], other["variables"] 

445 svks, ovks = set(svars), set(ovars) 

446 if svks != ovks: 

447 return False 

448 for key in svks: 

449 reldiff = np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1)) 

450 if reldiff >= reltol: 

451 return False 

452 return True 

453 

454 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

455 def diff(self, other, showvars=None, *, 

456 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

457 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

458 sortmodelsbysenss=True, **tableargs): 

459 """Outputs differences between this solution and another 

460 

461 Arguments 

462 --------- 

463 other : solution or string 

464 strings will be treated as paths to pickled solutions 

465 senssdiff : boolean 

466 if True, show sensitivity differences 

467 sensstol : float 

468 the smallest sensitivity difference worth showing 

469 absdiff : boolean 

470 if True, show absolute differences 

471 abstol : float 

472 the smallest absolute difference worth showing 

473 reldiff : boolean 

474 if True, show relative differences 

475 reltol : float 

476 the smallest relative difference worth showing 

477 

478 Returns 

479 ------- 

480 str 

481 """ 

482 if sortmodelsbysenss: 

483 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

484 else: 

485 tableargs["sortmodelsbysenss"] = False 

486 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

487 "skipifempty": False}) 

488 if isinstance(other, Strings): 

489 if other[-4:] == ".pgz": 

490 other = SolutionArray.decompress_file(other) 

491 else: 

492 other = pickle.load(open(other, "rb")) 

493 svars, ovars = self["variables"], other["variables"] 

494 lines = ["Solution Diff", 

495 "=============", 

496 "(argument is the baseline solution)", ""] 

497 svks, ovks = set(svars), set(ovars) 

498 if showvars: 

499 lines[0] += " (for selected variables)" 

500 lines[1] += "=========================" 

501 showvars = self._parse_showvars(showvars) 

502 svks = {k for k in showvars if k in svars} 

503 ovks = {k for k in showvars if k in ovars} 

504 if constraintsdiff and other.modelstr and self.modelstr: 

505 if self.modelstr == other.modelstr: 

506 lines += ["** no constraint differences **", ""] 

507 else: 

508 cdiff = ["Constraint Differences", 

509 "**********************"] 

510 cdiff.extend(list(difflib.unified_diff( 

511 other.modelstr.split("\n"), self.modelstr.split("\n"), 

512 lineterm="", n=3))[2:]) 

513 cdiff += ["", "**********************", ""] 

514 lines += cdiff 

515 if svks - ovks: 

516 lines.append("Variable(s) of this solution" 

517 " which are not in the argument:") 

518 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

519 lines.append("") 

520 if ovks - svks: 

521 lines.append("Variable(s) of the argument" 

522 " which are not in this solution:") 

523 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

524 lines.append("") 

525 sharedvks = svks.intersection(ovks) 

526 if reldiff: 

527 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

528 for vk in sharedvks} 

529 lines += var_table(rel_diff, 

530 "Relative Differences |above %g%%|" % reltol, 

531 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

532 minval=reltol, printunits=False, **tableargs) 

533 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

534 lines.insert(-1, ("The largest is %+g%%." 

535 % unrolled_absmax(rel_diff.values()))) 

536 if absdiff: 

537 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

538 lines += var_table(abs_diff, 

539 "Absolute Differences |above %g|" % abstol, 

540 valfmt="%+.2g", vecfmt="%+8.2g", 

541 minval=abstol, **tableargs) 

542 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

543 lines.insert(-1, ("The largest is %+g." 

544 % unrolled_absmax(abs_diff.values()))) 

545 if senssdiff: 

546 ssenss = self["sensitivities"]["variables"] 

547 osenss = other["sensitivities"]["variables"] 

548 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

549 for vk in svks.intersection(ovks)} 

550 lines += var_table(senss_delta, 

551 "Sensitivity Differences |above %g|" % sensstol, 

552 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

553 minval=sensstol, printunits=False, **tableargs) 

554 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

555 lines.insert(-1, ("The largest is %+g." 

556 % unrolled_absmax(senss_delta.values()))) 

557 return "\n".join(lines) 

558 

559 def save(self, filename="solution.pkl", 

560 *, saveconstraints=True, **pickleargs): 

561 """Pickles the solution and saves it to a file. 

562 

563 Solution can then be loaded with e.g.: 

564 >>> import pickle 

565 >>> pickle.load(open("solution.pkl")) 

566 """ 

567 with SolSavingEnvironment(self, saveconstraints): 

568 pickle.dump(self, open(filename, "wb"), **pickleargs) 

569 

570 def save_compressed(self, filename="solution.pgz", 

571 *, saveconstraints=True, **cpickleargs): 

572 "Pickle a file and then compress it into a file with extension." 

573 with gzip.open(filename, "wb") as f: 

574 with SolSavingEnvironment(self, saveconstraints): 

575 pickled = pickle.dumps(self, **cpickleargs) 

576 f.write(pickletools.optimize(pickled)) 

577 

578 @staticmethod 

579 def decompress_file(file): 

580 "Load a gzip-compressed pickle file" 

581 with gzip.open(file, "rb") as f: 

582 return pickle.Unpickler(f).load() 

583 

584 def varnames(self, showvars, exclude): 

585 "Returns list of variables, optionally with minimal unique names" 

586 if showvars: 

587 showvars = self._parse_showvars(showvars) 

588 self.set_necessarylineage() 

589 names = {} 

590 for key in showvars or self["variables"]: 

591 for k in self["variables"].keymap[key]: 

592 names[k.str_without(exclude)] = k 

593 self.set_necessarylineage(clear=True) 

594 return names 

595 

596 def savemat(self, filename="solution.mat", *, showvars=None, 

597 excluded=("vec")): 

598 "Saves primal solution as matlab file" 

599 from scipy.io import savemat 

600 savemat(filename, 

601 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

602 for name, key in self.varnames(showvars, excluded).items()}) 

603 

604 def todataframe(self, showvars=None, excluded=("vec")): 

605 "Returns primal solution as pandas dataframe" 

606 import pandas as pd # pylint:disable=import-error 

607 rows = [] 

608 cols = ["Name", "Index", "Value", "Units", "Label", 

609 "Lineage", "Other"] 

610 for _, key in sorted(self.varnames(showvars, excluded).items(), 

611 key=lambda k: k[0]): 

612 value = self["variables"][key] 

613 if key.shape: 

614 idxs = [] 

615 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

616 while not it.finished: 

617 idx = it.multi_index 

618 idxs.append(idx[0] if len(idx) == 1 else idx) 

619 it.iternext() 

620 else: 

621 idxs = [None] 

622 for idx in idxs: 

623 row = [ 

624 key.name, 

625 "" if idx is None else idx, 

626 value if idx is None else value[idx]] 

627 rows.append(row) 

628 row.extend([ 

629 key.unitstr(), 

630 key.label or "", 

631 key.lineage or "", 

632 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

633 if k not in ["name", "units", "unitrepr", 

634 "idx", "shape", "veckey", 

635 "value", "vecfn", 

636 "lineage", "label"])]) 

637 return pd.DataFrame(rows, columns=cols) 

638 

639 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

640 "Saves solution table as a text file" 

641 with open(filename, "w") as f: 

642 if printmodel: 

643 f.write(self.modelstr + "\n") 

644 f.write(self.table(**kwargs)) 

645 

646 def savejson(self, filename="solution.json", showvars=None): 

647 "Saves solution table as a json file" 

648 sol_dict = {} 

649 if self._lineageset: 

650 self.set_necessarylineage(clear=True) 

651 data = self["variables"] 

652 if showvars: 

653 showvars = self._parse_showvars(showvars) 

654 data = {k: data[k] for k in showvars if k in data} 

655 # add appropriate data for each variable to the dictionary 

656 for k, v in data.items(): 

657 key = str(k) 

658 if isinstance(v, np.ndarray): 

659 val = {"v": v.tolist(), "u": k.unitstr()} 

660 else: 

661 val = {"v": v, "u": k.unitstr()} 

662 sol_dict[key] = val 

663 with open(filename, "w") as f: 

664 json.dump(sol_dict, f) 

665 

666 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

667 "Saves primal solution as a CSV sorted by modelname, like the tables." 

668 data = self["variables"] 

669 if showvars: 

670 showvars = self._parse_showvars(showvars) 

671 data = {k: data[k] for k in showvars if k in data} 

672 # if the columns don't capture any dimensions, skip them 

673 minspan, maxspan = None, 1 

674 for v in data.values(): 

675 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

676 minspan_ = min((di for di in v.shape if di != 1)) 

677 maxspan_ = max((di for di in v.shape if di != 1)) 

678 if minspan is None or minspan_ < minspan: 

679 minspan = minspan_ 

680 if maxspan is None or maxspan_ > maxspan: 

681 maxspan = maxspan_ 

682 if minspan is not None and minspan > valcols: 

683 valcols = 1 

684 if maxspan < valcols: 

685 valcols = maxspan 

686 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

687 tables=("cost", "sweepvariables", "freevariables", 

688 "constants", "sensitivities")) 

689 with open(filename, "w") as f: 

690 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

691 + "Units,Description\n") 

692 for line in lines: 

693 if line[0] == ("newmodelline",): 

694 f.write(line[1]) 

695 elif not line[1]: # spacer line 

696 f.write("\n") 

697 else: 

698 f.write("," + line[0].replace(" : ", "") + ",") 

699 vals = line[1].replace("[", "").replace("]", "").strip() 

700 for el in vals.split(): 

701 f.write(el + ",") 

702 f.write(","*(valcols - len(vals.split()))) 

703 f.write((line[2].replace("[", "").replace("]", "").strip() 

704 + ",")) 

705 f.write(line[3].strip() + "\n") 

706 

707 def subinto(self, posy): 

708 "Returns NomialArray of each solution substituted into posy." 

709 if posy in self["variables"]: 

710 return self["variables"](posy) 

711 

712 if not hasattr(posy, "sub"): 

713 raise ValueError("no variable '%s' found in the solution" % posy) 

714 

715 if len(self) > 1: 

716 return NomialArray([self.atindex(i).subinto(posy) 

717 for i in range(len(self))]) 

718 

719 return posy.sub(self["variables"], require_positive=False) 

720 

721 def _parse_showvars(self, showvars): 

722 showvars_out = set() 

723 for k in showvars: 

724 k, _ = self["variables"].parse_and_index(k) 

725 keys = self["variables"].keymap[k] 

726 showvars_out.update(keys) 

727 return showvars_out 

728 

729 def summary(self, showvars=(), **kwargs): 

730 "Print summary table, showing no sensitivities or constants" 

731 return self.table(showvars, 

732 ["cost breakdown", "model sensitivities breakdown", 

733 "warnings", "sweepvariables", "freevariables"], 

734 **kwargs) 

735 

736 def table(self, showvars=(), 

737 tables=("cost breakdown", "model sensitivities breakdown", 

738 "warnings", "sweepvariables", "freevariables", 

739 "constants", "sensitivities", "tightest constraints"), 

740 sortmodelsbysenss=False, **kwargs): 

741 """A table representation of this SolutionArray 

742 

743 Arguments 

744 --------- 

745 tables: Iterable 

746 Which to print of ("cost", "sweepvariables", "freevariables", 

747 "constants", "sensitivities") 

748 fixedcols: If true, print vectors in fixed-width format 

749 latex: int 

750 If > 0, return latex format (options 1-3); otherwise plain text 

751 included_models: Iterable of strings 

752 If specified, the models (by name) to include 

753 excluded_models: Iterable of strings 

754 If specified, model names to exclude 

755 

756 Returns 

757 ------- 

758 str 

759 """ 

760 if sortmodelsbysenss and "sensitivities" in self: 

761 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

762 else: 

763 kwargs["sortmodelsbysenss"] = False 

764 varlist = list(self["variables"]) 

765 has_only_one_model = True 

766 for var in varlist[1:]: 

767 if var.lineage != varlist[0].lineage: 

768 has_only_one_model = False 

769 break 

770 if has_only_one_model: 

771 kwargs["sortbymodel"] = False 

772 self.set_necessarylineage() 

773 showvars = self._parse_showvars(showvars) 

774 strs = [] 

775 for table in tables: 

776 if "breakdown" in table: 

777 if (len(self) > 1 or not UNICODE_EXPONENTS 

778 or "sensitivities" not in self): 

779 # no breakdowns for sweeps or no-unicode environments 

780 table = table.replace(" breakdown", "") 

781 if "sensitivities" not in self and ("sensitivities" in table or 

782 "constraints" in table): 

783 continue 

784 if table == "cost": 

785 cost = self["cost"] # pylint: disable=unsubscriptable-object 

786 if kwargs.get("latex", None): # cost is not printed for latex 

787 continue 

788 strs += ["\n%s\n------------" % "Optimal Cost"] 

789 if len(self) > 1: 

790 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

791 strs += [" [ %s %s ]" % (" ".join(costs), 

792 "..." if len(self) > 4 else "")] 

793 else: 

794 strs += [" %-.4g" % mag(cost)] 

795 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

796 strs += [""] 

797 elif table in TABLEFNS: 

798 strs += TABLEFNS[table](self, showvars, **kwargs) 

799 elif table in self: 

800 data = self[table] 

801 if showvars: 

802 showvars = self._parse_showvars(showvars) 

803 data = {k: data[k] for k in showvars if k in data} 

804 strs += var_table(data, self.table_titles[table], **kwargs) 

805 if kwargs.get("latex", None): 

806 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

807 "% \\usepackage{booktabs}", 

808 "% \\usepackage{longtable}", 

809 "% \\usepackage{amsmath}", 

810 "% \\begin{document}\n")) 

811 strs = [preamble] + strs + ["% \\end{document}"] 

812 self.set_necessarylineage(clear=True) 

813 return "\n".join(strs) 

814 

815 def plot(self, posys=None, axes=None): 

816 "Plots a sweep for each posy" 

817 if len(self["sweepvariables"]) != 1: 

818 print("SolutionArray.plot only supports 1-dimensional sweeps") 

819 if not hasattr(posys, "__len__"): 

820 posys = [posys] 

821 import matplotlib.pyplot as plt 

822 from .interactive.plot_sweep import assign_axes 

823 from . import GPBLU 

824 (swept, x), = self["sweepvariables"].items() 

825 posys, axes = assign_axes(swept, posys, axes) 

826 for posy, ax in zip(posys, axes): 

827 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

828 ax.plot(x, y, color=GPBLU) 

829 if len(axes) == 1: 

830 axes, = axes 

831 return plt.gcf(), axes 

832 

833 

834# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

835def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

836 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

837 minval=0, sortbyvals=False, hidebelowminval=False, 

838 included_models=None, excluded_models=None, sortbymodel=True, 

839 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

840 """ 

841 Pretty string representation of a dict of VarKeys 

842 Iterable values are handled specially (partial printing) 

843 

844 Arguments 

845 --------- 

846 data : dict whose keys are VarKey's 

847 data to represent in table 

848 title : string 

849 printunits : bool 

850 latex : int 

851 If > 0, return latex format (options 1-3); otherwise plain text 

852 varfmt : string 

853 format for variable names 

854 valfmt : string 

855 format for scalar values 

856 vecfmt : string 

857 format for vector values 

858 minval : float 

859 skip values with all(abs(value)) < minval 

860 sortbyvals : boolean 

861 If true, rows are sorted by their average value instead of by name. 

862 included_models : Iterable of strings 

863 If specified, the models (by name) to include 

864 excluded_models : Iterable of strings 

865 If specified, model names to exclude 

866 """ 

867 if not data: 

868 return [] 

869 decorated, models = [], set() 

870 for i, (k, v) in enumerate(data.items()): 

871 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

872 continue # no values below minval 

873 if minval and hidebelowminval and getattr(v, "shape", None): 

874 v[np.abs(v) <= minval] = np.nan 

875 model = lineagestr(k.lineage) if sortbymodel else "" 

876 if not sortmodelsbysenss: 

877 msenss = 0 

878 else: # sort should match that in msenss_table above 

879 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4) 

880 models.add(model) 

881 b = bool(getattr(v, "shape", None)) 

882 s = k.str_without(("lineage", "vec")) 

883 if not sortbyvals: 

884 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

885 else: # for consistent sorting, add small offset to negative vals 

886 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

887 sort = (float("%.4g" % -val), k.name) 

888 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

889 if not decorated and skipifempty: 

890 return [] 

891 if included_models: 

892 included_models = set(included_models) 

893 included_models.add("") 

894 models = models.intersection(included_models) 

895 if excluded_models: 

896 models = models.difference(excluded_models) 

897 decorated.sort() 

898 previous_model, lines = None, [] 

899 for varlist in decorated: 

900 if sortbyvals: 

901 model, _, msenss, isvector, varstr, _, var, val = varlist 

902 else: 

903 msenss, model, isvector, varstr, _, var, val = varlist 

904 if model not in models: 

905 continue 

906 if model != previous_model: 

907 if lines: 

908 lines.append(["", "", "", ""]) 

909 if model: 

910 if not latex: 

911 lines.append([("newmodelline",), model, "", ""]) 

912 else: 

913 lines.append( 

914 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

915 previous_model = model 

916 label = var.descr.get("label", "") 

917 units = var.unitstr(" [%s] ") if printunits else "" 

918 if not isvector: 

919 valstr = valfmt % val 

920 else: 

921 last_dim_index = len(val.shape)-1 

922 horiz_dim, ncols = last_dim_index, 1 # starting values 

923 for dim_idx, dim_size in enumerate(val.shape): 

924 if ncols <= dim_size <= maxcolumns: 

925 horiz_dim, ncols = dim_idx, dim_size 

926 # align the array with horiz_dim by making it the last one 

927 dim_order = list(range(last_dim_index)) 

928 dim_order.insert(horiz_dim, last_dim_index) 

929 flatval = val.transpose(dim_order).flatten() 

930 vals = [vecfmt % v for v in flatval[:ncols]] 

931 bracket = " ] " if len(flatval) <= ncols else "" 

932 valstr = "[ %s%s" % (" ".join(vals), bracket) 

933 for before, after in VALSTR_REPLACES: 

934 valstr = valstr.replace(before, after) 

935 if not latex: 

936 lines.append([varstr, valstr, units, label]) 

937 if isvector and len(flatval) > ncols: 

938 values_remaining = len(flatval) - ncols 

939 while values_remaining > 0: 

940 idx = len(flatval)-values_remaining 

941 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

942 values_remaining -= ncols 

943 valstr = " " + " ".join(vals) 

944 for before, after in VALSTR_REPLACES: 

945 valstr = valstr.replace(before, after) 

946 if values_remaining <= 0: 

947 spaces = (-values_remaining 

948 * len(valstr)//(values_remaining + ncols)) 

949 valstr = valstr + " ]" + " "*spaces 

950 lines.append(["", valstr, "", ""]) 

951 else: 

952 varstr = "$%s$" % varstr.replace(" : ", "") 

953 if latex == 1: # normal results table 

954 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

955 label]) 

956 coltitles = [title, "Value", "Units", "Description"] 

957 elif latex == 2: # no values 

958 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

959 coltitles = [title, "Units", "Description"] 

960 elif latex == 3: # no description 

961 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

962 coltitles = [title, "Value", "Units"] 

963 else: 

964 raise ValueError("Unexpected latex option, %s." % latex) 

965 if rawlines: 

966 return lines 

967 if not latex: 

968 if lines: 

969 maxlens = np.max([list(map(len, line)) for line in lines 

970 if line[0] != ("newmodelline",)], axis=0) 

971 dirs = [">", "<", "<", "<"] 

972 # check lengths before using zip 

973 assert len(list(dirs)) == len(list(maxlens)) 

974 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

975 for i, line in enumerate(lines): 

976 if line[0] == ("newmodelline",): 

977 line = [fmts[0].format(" | "), line[1]] 

978 else: 

979 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

980 lines[i] = "".join(line).rstrip() 

981 lines = [title] + ["-"*len(title)] + lines + [""] 

982 else: 

983 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

984 lines = (["\n".join(["{\\footnotesize", 

985 "\\begin{longtable}{%s}" % colfmt[latex], 

986 "\\toprule", 

987 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

988 [" & ".join(l) + " \\\\" for l in lines] + 

989 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

990 return lines