Coverage for gpkit/solution_array.py: 81%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

635 statements  

1"""Defines SolutionArray class""" 

2import sys 

3import re 

4import json 

5import difflib 

6from operator import sub 

7import warnings as pywarnings 

8import pickle 

9import gzip 

10import pickletools 

11from collections import defaultdict 

12import numpy as np 

13from .nomials import NomialArray 

14from .small_classes import DictOfLists, Strings, SolverLog 

15from .small_scripts import mag, try_str_without 

16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS 

17from .breakdowns import Breakdowns 

18 

19 

20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

21 

22VALSTR_REPLACES = [ 

23 ("+nan", " nan"), 

24 ("-nan", " nan"), 

25 ("nan%", "nan "), 

26 ("nan", " - "), 

27] 

28 

29 

30class SolSavingEnvironment: 

31 """Temporarily removes construction/solve attributes from constraints. 

32 

33 This approximately halves the size of the pickled solution. 

34 """ 

35 

36 def __init__(self, solarray, saveconstraints): 

37 self.solarray = solarray 

38 self.attrstore = {} 

39 self.saveconstraints = saveconstraints 

40 self.constraintstore = None 

41 

42 

43 def __enter__(self): 

44 if self.saveconstraints: 

45 for constraint_attr in ["bounded", "meq_bounded", "vks", 

46 "v_ss", "unsubbed", "varkeys"]: 

47 store = {} 

48 for constraint in self.solarray["sensitivities"]["constraints"]: 

49 if getattr(constraint, constraint_attr, None): 

50 store[constraint] = getattr(constraint, constraint_attr) 

51 delattr(constraint, constraint_attr) 

52 self.attrstore[constraint_attr] = store 

53 else: 

54 self.constraintstore = \ 

55 self.solarray["sensitivities"].pop("constraints") 

56 

57 def __exit__(self, type_, val, traceback): 

58 if self.saveconstraints: 

59 for constraint_attr, store in self.attrstore.items(): 

60 for constraint, value in store.items(): 

61 setattr(constraint, constraint_attr, value) 

62 else: 

63 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

64 

65def msenss_table(data, _, **kwargs): 

66 "Returns model sensitivity table lines" 

67 if "models" not in data.get("sensitivities", {}): 

68 return "" 

69 data = sorted(data["sensitivities"]["models"].items(), 

70 key=lambda i: ((i[1] < 0.1).all(), 

71 -np.max(i[1]) if (i[1] < 0.1).all() 

72 else -round(np.mean(i[1]), 1), i[0])) 

73 lines = ["Model Sensitivities", "-------------------"] 

74 if kwargs["sortmodelsbysenss"]: 

75 lines[0] += " (sorts models in sections below)" 

76 previousmsenssstr = "" 

77 for model, msenss in data: 

78 if not model: # for now let's only do named models 

79 continue 

80 if (msenss < 0.1).all(): 

81 msenss = np.max(msenss) 

82 if msenss: 

83 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss))) 

84 else: 

85 msenssstr = " =0 " 

86 else: 

87 meansenss = round(np.mean(msenss), 1) 

88 msenssstr = "%+6.1f" % meansenss 

89 deltas = msenss - meansenss 

90 if np.max(np.abs(deltas)) > 0.1: 

91 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

92 for d in deltas] 

93 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

94 if msenssstr == previousmsenssstr: 

95 msenssstr = " "*len(msenssstr) 

96 else: 

97 previousmsenssstr = msenssstr 

98 lines.append("%s : %s" % (msenssstr, model)) 

99 return lines + [""] if len(lines) > 3 else [] 

100 

101 

102def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

103 "Returns sensitivity table lines" 

104 if "variables" in data.get("sensitivities", {}): 

105 data = data["sensitivities"]["variables"] 

106 if showvars: 

107 data = {k: data[k] for k in showvars if k in data} 

108 return var_table(data, title, sortbyvals=True, skipifempty=True, 

109 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

110 printunits=False, minval=1e-3, **kwargs) 

111 

112 

113def topsenss_table(data, showvars, nvars=5, **kwargs): 

114 "Returns top sensitivity table lines" 

115 data, filtered = topsenss_filter(data, showvars, nvars) 

116 title = "Most Sensitive Variables" 

117 if filtered: 

118 title = "Next Most Sensitive Variables" 

119 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

120 

121 

122def topsenss_filter(data, showvars, nvars=5): 

123 "Filters sensitivities down to top N vars" 

124 if "variables" in data.get("sensitivities", {}): 

125 data = data["sensitivities"]["variables"] 

126 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

127 if not np.isnan(s).any()} 

128 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

129 filter_already_shown = showvars.intersection(topk) 

130 for k in filter_already_shown: 

131 topk.remove(k) 

132 if nvars > 3: # always show at least 3 

133 nvars -= 1 

134 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

135 

136 

137def insenss_table(data, _, maxval=0.1, **kwargs): 

138 "Returns insensitivity table lines" 

139 if "constants" in data.get("sensitivities", {}): 

140 data = data["sensitivities"]["variables"] 

141 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

142 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

143 

144 

145def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

146 "Return constraint tightness lines" 

147 title = "Most Sensitive Constraints" 

148 if len(self) > 1: 

149 title += " (in last sweep)" 

150 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)), 

151 "%+6.2g" % abs(s[-1]), id(c), c) 

152 for c, s in self["sensitivities"]["constraints"].items() 

153 if s[-1] >= tight_senss)[:ntightconstrs] 

154 else: 

155 data = sorted(((-float("%+6.2g" % abs(s)), str(c)), 

156 "%+6.2g" % abs(s), id(c), c) 

157 for c, s in self["sensitivities"]["constraints"].items() 

158 if s >= tight_senss)[:ntightconstrs] 

159 return constraint_table(data, title, **kwargs) 

160 

161def loose_table(self, _, min_senss=1e-5, **kwargs): 

162 "Return constraint tightness lines" 

163 title = "Insensitive Constraints |below %+g|" % min_senss 

164 if len(self) > 1: 

165 title += " (in last sweep)" 

166 data = [(0, "", id(c), c) 

167 for c, s in self["sensitivities"]["constraints"].items() 

168 if s[-1] <= min_senss] 

169 else: 

170 data = [(0, "", id(c), c) 

171 for c, s in self["sensitivities"]["constraints"].items() 

172 if s <= min_senss] 

173 return constraint_table(data, title, **kwargs) 

174 

175 

176# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

177def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

178 "Creates lines for tables where the right side is a constraint." 

179 # TODO: this should support 1D array inputs from sweeps 

180 excluded = {"units"} if showmodels else {"units", "lineage"} 

181 models, decorated = {}, [] 

182 for sortby, openingstr, _, constraint in sorted(data): 

183 model = lineagestr(constraint) if sortbymodel else "" 

184 if model not in models: 

185 models[model] = len(models) 

186 constrstr = try_str_without( 

187 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded)) 

188 if " at 0x" in constrstr: # don't print memory addresses 

189 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

190 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

191 decorated.sort() 

192 previous_model, lines = None, [] 

193 for varlist in decorated: 

194 _, model, _, constrstr, openingstr = varlist 

195 if model != previous_model: 

196 if lines: 

197 lines.append(["", ""]) 

198 if model or lines: 

199 lines.append([("newmodelline",), model]) 

200 previous_model = model 

201 minlen, maxlen = 25, 80 

202 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

203 constraintlines = [] 

204 line = "" 

205 next_idx = 0 

206 while next_idx < len(segments): 

207 segment = segments[next_idx] 

208 next_idx += 1 

209 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

210 segments[next_idx] = segment[1:] + segments[next_idx] 

211 segment = segment[0] 

212 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

213 constraintlines.append(line) 

214 line = " " # start a new line 

215 line += segment 

216 while len(line) > maxlen: 

217 constraintlines.append(line[:maxlen]) 

218 line = " " + line[maxlen:] 

219 constraintlines.append(line) 

220 lines += [(openingstr + " : ", constraintlines[0])] 

221 lines += [("", l) for l in constraintlines[1:]] 

222 if not lines: 

223 lines = [("", "(none)")] 

224 maxlens = np.max([list(map(len, line)) for line in lines 

225 if line[0] != ("newmodelline",)], axis=0) 

226 dirs = [">", "<"] # we'll check lengths before using zip 

227 assert len(list(dirs)) == len(list(maxlens)) 

228 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

229 for i, line in enumerate(lines): 

230 if line[0] == ("newmodelline",): 

231 linelist = [fmts[0].format(" | "), line[1]] 

232 else: 

233 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

234 lines[i] = "".join(linelist).rstrip() 

235 return [title] + ["-"*len(title)] + lines + [""] 

236 

237 

238def warnings_table(self, _, **kwargs): 

239 "Makes a table for all warnings in the solution." 

240 title = "WARNINGS" 

241 lines = ["~"*len(title), title, "~"*len(title)] 

242 if "warnings" not in self or not self["warnings"]: 

243 return [] 

244 for wtype in sorted(self["warnings"]): 

245 data_vec = self["warnings"][wtype] 

246 if len(data_vec) == 0: 

247 continue 

248 if not hasattr(data_vec, "shape"): 

249 data_vec = [data_vec] # not a sweep 

250 else: 

251 all_equal = True 

252 for data in data_vec[1:]: 

253 eq_i = (data == data_vec[0]) 

254 if hasattr(eq_i, "all"): 

255 eq_i = eq_i.all() 

256 if not eq_i: 

257 all_equal = False 

258 break 

259 if all_equal: 

260 data_vec = [data_vec[0]] # warnings identical across sweeps 

261 for i, data in enumerate(data_vec): 

262 if len(data) == 0: 

263 continue 

264 data = sorted(data, key=lambda l: l[0]) # sort by msg 

265 title = wtype 

266 if len(data_vec) > 1: 

267 title += " in sweep %i" % i 

268 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

269 data = [(-int(1e5*relax_sensitivity), 

270 "%+6.2g" % relax_sensitivity, id(c), c) 

271 for _, (relax_sensitivity, c) in data] 

272 lines += constraint_table(data, title, **kwargs) 

273 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

274 data = [(-int(1e5*rel_diff), 

275 "%.4g %s %.4g" % tightvalues, id(c), c) 

276 for _, (rel_diff, tightvalues, c) in data] 

277 lines += constraint_table(data, title, **kwargs) 

278 else: 

279 lines += [title] + ["-"*len(wtype)] 

280 lines += [msg for msg, _ in data] + [""] 

281 if len(lines) == 3: # just the header 

282 return [] 

283 lines[-1] = "~~~~~~~~" 

284 return lines + [""] 

285 

286def bdtable_gen(key): 

287 "Generator for breakdown tablefns" 

288 

289 def bdtable(self, _showvars, **_): 

290 "Cost breakdown plot" 

291 bds = Breakdowns(self) 

292 original_stdout = sys.stdout 

293 try: 

294 sys.stdout = SolverLog(original_stdout, verbosity=0) 

295 bds.plot(key) 

296 finally: 

297 lines = sys.stdout.lines() 

298 sys.stdout = original_stdout 

299 return lines 

300 

301 return bdtable 

302 

303 

304TABLEFNS = {"sensitivities": senss_table, 

305 "top sensitivities": topsenss_table, 

306 "insensitivities": insenss_table, 

307 "model sensitivities": msenss_table, 

308 "tightest constraints": tight_table, 

309 "loose constraints": loose_table, 

310 "warnings": warnings_table, 

311 "model sensitivities breakdown": bdtable_gen("model sensitivities"), 

312 "cost breakdown": bdtable_gen("cost") 

313 } 

314 

315def unrolled_absmax(values): 

316 "From an iterable of numbers and arrays, returns the largest magnitude" 

317 finalval, absmaxest = None, 0 

318 for val in values: 

319 absmaxval = np.abs(val).max() 

320 if absmaxval >= absmaxest: 

321 absmaxest, finalval = absmaxval, val 

322 if getattr(finalval, "shape", None): 

323 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

324 finalval.shape)] 

325 return finalval 

326 

327 

328def cast(function, val1, val2): 

329 "Relative difference between val1 and val2 (positive if val2 is larger)" 

330 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

331 pywarnings.simplefilter("ignore") 

332 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

333 if val1.ndim == val2.ndim: 

334 return function(val1, val2) 

335 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

336 dimdelta = dimmest.ndim - lessdim.ndim 

337 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

338 if dimmest is val1: 

339 return function(dimmest, lessdim[add_axes]) 

340 if dimmest is val2: 

341 return function(lessdim[add_axes], dimmest) 

342 return function(val1, val2) 

343 

344 

345class SolutionArray(DictOfLists): 

346 """A dictionary (of dictionaries) of lists, with convenience methods. 

347 

348 Items 

349 ----- 

350 cost : array 

351 variables: dict of arrays 

352 sensitivities: dict containing: 

353 monomials : array 

354 posynomials : array 

355 variables: dict of arrays 

356 localmodels : NomialArray 

357 Local power-law fits (small sensitivities are cut off) 

358 

359 Example 

360 ------- 

361 >>> import gpkit 

362 >>> import numpy as np 

363 >>> x = gpkit.Variable("x") 

364 >>> x_min = gpkit.Variable("x_{min}", 2) 

365 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

366 >>> 

367 >>> # VALUES 

368 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

369 >>> assert all(np.array(values) == 2) 

370 >>> 

371 >>> # SENSITIVITIES 

372 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

373 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

374 >>> assert all(np.array(senss) == 1) 

375 """ 

376 modelstr = "" 

377 _name_collision_varkeys = None 

378 _lineageset = False 

379 table_titles = {"choicevariables": "Choice Variables", 

380 "sweepvariables": "Swept Variables", 

381 "freevariables": "Free Variables", 

382 "constants": "Fixed Variables", # TODO: change everywhere 

383 "variables": "Variables"} 

384 

385 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches 

386 "Returns the set of contained varkeys whose names are not unique" 

387 if self._name_collision_varkeys is None: 

388 self._name_collision_varkeys = {} 

389 self["variables"].update_keymap() 

390 keymap = self["variables"].keymap 

391 name_collisions = defaultdict(set) 

392 for key in keymap: 

393 if hasattr(key, "key"): 

394 if len(keymap[key.name]) == 1: # very unique 

395 self._name_collision_varkeys[key] = 0 

396 else: 

397 shortname = key.str_without(["lineage", "vec"]) 

398 if len(keymap[shortname]) > 1: 

399 name_collisions[shortname].add(key) 

400 for varkeys in name_collisions.values(): 

401 min_namespaced = defaultdict(set) 

402 for vk in varkeys: 

403 *_, mineage = vk.lineagestr().split(".") 

404 min_namespaced[(mineage, 1)].add(vk) 

405 while any(len(vks) > 1 for vks in min_namespaced.values()): 

406 for key, vks in list(min_namespaced.items()): 

407 if len(vks) <= 1: 

408 continue 

409 del min_namespaced[key] 

410 mineage, idx = key 

411 idx += 1 

412 for vk in vks: 

413 lineages = vk.lineagestr().split(".") 

414 submineage = lineages[-idx] + "." + mineage 

415 min_namespaced[(submineage, idx)].add(vk) 

416 for (_, idx), vks in min_namespaced.items(): 

417 vk, = vks 

418 self._name_collision_varkeys[vk] = idx 

419 if clear: 

420 self._lineageset = False 

421 for vk in self._name_collision_varkeys: 

422 del vk.descr["necessarylineage"] 

423 else: 

424 self._lineageset = True 

425 for vk, idx in self._name_collision_varkeys.items(): 

426 vk.descr["necessarylineage"] = idx 

427 

428 def __len__(self): 

429 try: 

430 return len(self["cost"]) 

431 except TypeError: 

432 return 1 

433 except KeyError: 

434 return 0 

435 

436 def __call__(self, posy): 

437 posy_subbed = self.subinto(posy) 

438 return getattr(posy_subbed, "c", posy_subbed) 

439 

440 def almost_equal(self, other, reltol=1e-3): 

441 "Checks for almost-equality between two solutions" 

442 svars, ovars = self["variables"], other["variables"] 

443 svks, ovks = set(svars), set(ovars) 

444 if svks != ovks: 

445 return False 

446 for key in svks: 

447 reldiff = np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1)) 

448 if reldiff >= reltol: 

449 return False 

450 return True 

451 

452 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

453 def diff(self, other, showvars=None, *, 

454 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

455 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

456 sortmodelsbysenss=True, **tableargs): 

457 """Outputs differences between this solution and another 

458 

459 Arguments 

460 --------- 

461 other : solution or string 

462 strings will be treated as paths to pickled solutions 

463 senssdiff : boolean 

464 if True, show sensitivity differences 

465 sensstol : float 

466 the smallest sensitivity difference worth showing 

467 absdiff : boolean 

468 if True, show absolute differences 

469 abstol : float 

470 the smallest absolute difference worth showing 

471 reldiff : boolean 

472 if True, show relative differences 

473 reltol : float 

474 the smallest relative difference worth showing 

475 

476 Returns 

477 ------- 

478 str 

479 """ 

480 if sortmodelsbysenss: 

481 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

482 else: 

483 tableargs["sortmodelsbysenss"] = False 

484 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

485 "skipifempty": False}) 

486 if isinstance(other, Strings): 

487 if other[-4:] == ".pgz": 

488 other = SolutionArray.decompress_file(other) 

489 else: 

490 other = pickle.load(open(other, "rb")) 

491 svars, ovars = self["variables"], other["variables"] 

492 lines = ["Solution Diff", 

493 "=============", 

494 "(argument is the baseline solution)", ""] 

495 svks, ovks = set(svars), set(ovars) 

496 if showvars: 

497 lines[0] += " (for selected variables)" 

498 lines[1] += "=========================" 

499 showvars = self._parse_showvars(showvars) 

500 svks = {k for k in showvars if k in svars} 

501 ovks = {k for k in showvars if k in ovars} 

502 if constraintsdiff and other.modelstr and self.modelstr: 

503 if self.modelstr == other.modelstr: 

504 lines += ["** no constraint differences **", ""] 

505 else: 

506 cdiff = ["Constraint Differences", 

507 "**********************"] 

508 cdiff.extend(list(difflib.unified_diff( 

509 other.modelstr.split("\n"), self.modelstr.split("\n"), 

510 lineterm="", n=3))[2:]) 

511 cdiff += ["", "**********************", ""] 

512 lines += cdiff 

513 if svks - ovks: 

514 lines.append("Variable(s) of this solution" 

515 " which are not in the argument:") 

516 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

517 lines.append("") 

518 if ovks - svks: 

519 lines.append("Variable(s) of the argument" 

520 " which are not in this solution:") 

521 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

522 lines.append("") 

523 sharedvks = svks.intersection(ovks) 

524 if reldiff: 

525 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

526 for vk in sharedvks} 

527 lines += var_table(rel_diff, 

528 "Relative Differences |above %g%%|" % reltol, 

529 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

530 minval=reltol, printunits=False, **tableargs) 

531 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

532 lines.insert(-1, ("The largest is %+g%%." 

533 % unrolled_absmax(rel_diff.values()))) 

534 if absdiff: 

535 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

536 lines += var_table(abs_diff, 

537 "Absolute Differences |above %g|" % abstol, 

538 valfmt="%+.2g", vecfmt="%+8.2g", 

539 minval=abstol, **tableargs) 

540 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

541 lines.insert(-1, ("The largest is %+g." 

542 % unrolled_absmax(abs_diff.values()))) 

543 if senssdiff: 

544 ssenss = self["sensitivities"]["variables"] 

545 osenss = other["sensitivities"]["variables"] 

546 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

547 for vk in svks.intersection(ovks)} 

548 lines += var_table(senss_delta, 

549 "Sensitivity Differences |above %g|" % sensstol, 

550 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

551 minval=sensstol, printunits=False, **tableargs) 

552 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

553 lines.insert(-1, ("The largest is %+g." 

554 % unrolled_absmax(senss_delta.values()))) 

555 return "\n".join(lines) 

556 

557 def save(self, filename="solution.pkl", 

558 *, saveconstraints=True, **pickleargs): 

559 """Pickles the solution and saves it to a file. 

560 

561 Solution can then be loaded with e.g.: 

562 >>> import pickle 

563 >>> pickle.load(open("solution.pkl")) 

564 """ 

565 with SolSavingEnvironment(self, saveconstraints): 

566 pickle.dump(self, open(filename, "wb"), **pickleargs) 

567 

568 def save_compressed(self, filename="solution.pgz", 

569 *, saveconstraints=True, **cpickleargs): 

570 "Pickle a file and then compress it into a file with extension." 

571 with gzip.open(filename, "wb") as f: 

572 with SolSavingEnvironment(self, saveconstraints): 

573 pickled = pickle.dumps(self, **cpickleargs) 

574 f.write(pickletools.optimize(pickled)) 

575 

576 @staticmethod 

577 def decompress_file(file): 

578 "Load a gzip-compressed pickle file" 

579 with gzip.open(file, "rb") as f: 

580 return pickle.Unpickler(f).load() 

581 

582 def varnames(self, showvars, exclude): 

583 "Returns list of variables, optionally with minimal unique names" 

584 if showvars: 

585 showvars = self._parse_showvars(showvars) 

586 self.set_necessarylineage() 

587 names = {} 

588 for key in showvars or self["variables"]: 

589 for k in self["variables"].keymap[key]: 

590 names[k.str_without(exclude)] = k 

591 self.set_necessarylineage(clear=True) 

592 return names 

593 

594 def savemat(self, filename="solution.mat", *, showvars=None, 

595 excluded=("vec")): 

596 "Saves primal solution as matlab file" 

597 from scipy.io import savemat 

598 savemat(filename, 

599 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

600 for name, key in self.varnames(showvars, excluded).items()}) 

601 

602 def todataframe(self, showvars=None, excluded=("vec")): 

603 "Returns primal solution as pandas dataframe" 

604 import pandas as pd # pylint:disable=import-error 

605 rows = [] 

606 cols = ["Name", "Index", "Value", "Units", "Label", 

607 "Lineage", "Other"] 

608 for _, key in sorted(self.varnames(showvars, excluded).items(), 

609 key=lambda k: k[0]): 

610 value = self["variables"][key] 

611 if key.shape: 

612 idxs = [] 

613 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

614 while not it.finished: 

615 idx = it.multi_index 

616 idxs.append(idx[0] if len(idx) == 1 else idx) 

617 it.iternext() 

618 else: 

619 idxs = [None] 

620 for idx in idxs: 

621 row = [ 

622 key.name, 

623 "" if idx is None else idx, 

624 value if idx is None else value[idx]] 

625 rows.append(row) 

626 row.extend([ 

627 key.unitstr(), 

628 key.label or "", 

629 key.lineage or "", 

630 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

631 if k not in ["name", "units", "unitrepr", 

632 "idx", "shape", "veckey", 

633 "value", "vecfn", 

634 "lineage", "label"])]) 

635 return pd.DataFrame(rows, columns=cols) 

636 

637 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

638 "Saves solution table as a text file" 

639 with open(filename, "w") as f: 

640 if printmodel: 

641 f.write(self.modelstr + "\n") 

642 f.write(self.table(**kwargs)) 

643 

644 def savejson(self, filename="solution.json", showvars=None): 

645 "Saves solution table as a json file" 

646 sol_dict = {} 

647 if self._lineageset: 

648 self.set_necessarylineage(clear=True) 

649 data = self["variables"] 

650 if showvars: 

651 showvars = self._parse_showvars(showvars) 

652 data = {k: data[k] for k in showvars if k in data} 

653 # add appropriate data for each variable to the dictionary 

654 for k, v in data.items(): 

655 key = str(k) 

656 if isinstance(v, np.ndarray): 

657 val = {"v": v.tolist(), "u": k.unitstr()} 

658 else: 

659 val = {"v": v, "u": k.unitstr()} 

660 sol_dict[key] = val 

661 with open(filename, "w") as f: 

662 json.dump(sol_dict, f) 

663 

664 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

665 "Saves primal solution as a CSV sorted by modelname, like the tables." 

666 data = self["variables"] 

667 if showvars: 

668 showvars = self._parse_showvars(showvars) 

669 data = {k: data[k] for k in showvars if k in data} 

670 # if the columns don't capture any dimensions, skip them 

671 minspan, maxspan = None, 1 

672 for v in data.values(): 

673 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

674 minspan_ = min((di for di in v.shape if di != 1)) 

675 maxspan_ = max((di for di in v.shape if di != 1)) 

676 if minspan is None or minspan_ < minspan: 

677 minspan = minspan_ 

678 if maxspan is None or maxspan_ > maxspan: 

679 maxspan = maxspan_ 

680 if minspan is not None and minspan > valcols: 

681 valcols = 1 

682 if maxspan < valcols: 

683 valcols = maxspan 

684 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

685 tables=("cost", "sweepvariables", "freevariables", 

686 "constants", "sensitivities")) 

687 with open(filename, "w") as f: 

688 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

689 + "Units,Description\n") 

690 for line in lines: 

691 if line[0] == ("newmodelline",): 

692 f.write(line[1]) 

693 elif not line[1]: # spacer line 

694 f.write("\n") 

695 else: 

696 f.write("," + line[0].replace(" : ", "") + ",") 

697 vals = line[1].replace("[", "").replace("]", "").strip() 

698 for el in vals.split(): 

699 f.write(el + ",") 

700 f.write(","*(valcols - len(vals.split()))) 

701 f.write((line[2].replace("[", "").replace("]", "").strip() 

702 + ",")) 

703 f.write(line[3].strip() + "\n") 

704 

705 def subinto(self, posy): 

706 "Returns NomialArray of each solution substituted into posy." 

707 if posy in self["variables"]: 

708 return self["variables"](posy) 

709 

710 if not hasattr(posy, "sub"): 

711 raise ValueError("no variable '%s' found in the solution" % posy) 

712 

713 if len(self) > 1: 

714 return NomialArray([self.atindex(i).subinto(posy) 

715 for i in range(len(self))]) 

716 

717 return posy.sub(self["variables"], require_positive=False) 

718 

719 def _parse_showvars(self, showvars): 

720 showvars_out = set() 

721 for k in showvars: 

722 k, _ = self["variables"].parse_and_index(k) 

723 keys = self["variables"].keymap[k] 

724 showvars_out.update(keys) 

725 return showvars_out 

726 

727 def summary(self, showvars=(), **kwargs): 

728 "Print summary table, showing no sensitivities or constants" 

729 return self.table(showvars, 

730 ["cost breakdown", "model sensitivities breakdown", 

731 "warnings", "sweepvariables", "freevariables"], 

732 **kwargs) 

733 

734 def table(self, showvars=(), 

735 tables=("cost breakdown", "model sensitivities breakdown", 

736 "warnings", "sweepvariables", "freevariables", 

737 "constants", "sensitivities", "tightest constraints"), 

738 sortmodelsbysenss=False, **kwargs): 

739 """A table representation of this SolutionArray 

740 

741 Arguments 

742 --------- 

743 tables: Iterable 

744 Which to print of ("cost", "sweepvariables", "freevariables", 

745 "constants", "sensitivities") 

746 fixedcols: If true, print vectors in fixed-width format 

747 latex: int 

748 If > 0, return latex format (options 1-3); otherwise plain text 

749 included_models: Iterable of strings 

750 If specified, the models (by name) to include 

751 excluded_models: Iterable of strings 

752 If specified, model names to exclude 

753 

754 Returns 

755 ------- 

756 str 

757 """ 

758 if sortmodelsbysenss and "sensitivities" in self: 

759 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

760 else: 

761 kwargs["sortmodelsbysenss"] = False 

762 varlist = list(self["variables"]) 

763 has_only_one_model = True 

764 for var in varlist[1:]: 

765 if var.lineage != varlist[0].lineage: 

766 has_only_one_model = False 

767 break 

768 if has_only_one_model: 

769 kwargs["sortbymodel"] = False 

770 self.set_necessarylineage() 

771 showvars = self._parse_showvars(showvars) 

772 strs = [] 

773 for table in tables: 

774 if "breakdown" in table: 

775 if len(self) > 1 or not UNICODE_EXPONENTS: 

776 # no breakdowns for sweeps or no-unicode environments 

777 table = table.replace(" breakdown", "") 

778 if "sensitivities" not in self and ("sensitivities" in table or 

779 "constraints" in table): 

780 continue 

781 if table == "cost": 

782 cost = self["cost"] # pylint: disable=unsubscriptable-object 

783 if kwargs.get("latex", None): # cost is not printed for latex 

784 continue 

785 strs += ["\n%s\n------------" % "Optimal Cost"] 

786 if len(self) > 1: 

787 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

788 strs += [" [ %s %s ]" % (" ".join(costs), 

789 "..." if len(self) > 4 else "")] 

790 else: 

791 strs += [" %-.4g" % mag(cost)] 

792 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

793 strs += [""] 

794 elif table in TABLEFNS: 

795 strs += TABLEFNS[table](self, showvars, **kwargs) 

796 elif table in self: 

797 data = self[table] 

798 if showvars: 

799 showvars = self._parse_showvars(showvars) 

800 data = {k: data[k] for k in showvars if k in data} 

801 strs += var_table(data, self.table_titles[table], **kwargs) 

802 if kwargs.get("latex", None): 

803 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

804 "% \\usepackage{booktabs}", 

805 "% \\usepackage{longtable}", 

806 "% \\usepackage{amsmath}", 

807 "% \\begin{document}\n")) 

808 strs = [preamble] + strs + ["% \\end{document}"] 

809 self.set_necessarylineage(clear=True) 

810 return "\n".join(strs) 

811 

812 def plot(self, posys=None, axes=None): 

813 "Plots a sweep for each posy" 

814 if len(self["sweepvariables"]) != 1: 

815 print("SolutionArray.plot only supports 1-dimensional sweeps") 

816 if not hasattr(posys, "__len__"): 

817 posys = [posys] 

818 import matplotlib.pyplot as plt 

819 from .interactive.plot_sweep import assign_axes 

820 from . import GPBLU 

821 (swept, x), = self["sweepvariables"].items() 

822 posys, axes = assign_axes(swept, posys, axes) 

823 for posy, ax in zip(posys, axes): 

824 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

825 ax.plot(x, y, color=GPBLU) 

826 if len(axes) == 1: 

827 axes, = axes 

828 return plt.gcf(), axes 

829 

830 

831# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

832def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

833 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

834 minval=0, sortbyvals=False, hidebelowminval=False, 

835 included_models=None, excluded_models=None, sortbymodel=True, 

836 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

837 """ 

838 Pretty string representation of a dict of VarKeys 

839 Iterable values are handled specially (partial printing) 

840 

841 Arguments 

842 --------- 

843 data : dict whose keys are VarKey's 

844 data to represent in table 

845 title : string 

846 printunits : bool 

847 latex : int 

848 If > 0, return latex format (options 1-3); otherwise plain text 

849 varfmt : string 

850 format for variable names 

851 valfmt : string 

852 format for scalar values 

853 vecfmt : string 

854 format for vector values 

855 minval : float 

856 skip values with all(abs(value)) < minval 

857 sortbyvals : boolean 

858 If true, rows are sorted by their average value instead of by name. 

859 included_models : Iterable of strings 

860 If specified, the models (by name) to include 

861 excluded_models : Iterable of strings 

862 If specified, model names to exclude 

863 """ 

864 if not data: 

865 return [] 

866 decorated, models = [], set() 

867 for i, (k, v) in enumerate(data.items()): 

868 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

869 continue # no values below minval 

870 if minval and hidebelowminval and getattr(v, "shape", None): 

871 v[np.abs(v) <= minval] = np.nan 

872 model = lineagestr(k.lineage) if sortbymodel else "" 

873 if not sortmodelsbysenss: 

874 msenss = 0 

875 else: # sort should match that in msenss_table above 

876 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4) 

877 models.add(model) 

878 b = bool(getattr(v, "shape", None)) 

879 s = k.str_without(("lineage", "vec")) 

880 if not sortbyvals: 

881 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

882 else: # for consistent sorting, add small offset to negative vals 

883 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

884 sort = (float("%.4g" % -val), k.name) 

885 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

886 if not decorated and skipifempty: 

887 return [] 

888 if included_models: 

889 included_models = set(included_models) 

890 included_models.add("") 

891 models = models.intersection(included_models) 

892 if excluded_models: 

893 models = models.difference(excluded_models) 

894 decorated.sort() 

895 previous_model, lines = None, [] 

896 for varlist in decorated: 

897 if sortbyvals: 

898 model, _, msenss, isvector, varstr, _, var, val = varlist 

899 else: 

900 msenss, model, isvector, varstr, _, var, val = varlist 

901 if model not in models: 

902 continue 

903 if model != previous_model: 

904 if lines: 

905 lines.append(["", "", "", ""]) 

906 if model: 

907 if not latex: 

908 lines.append([("newmodelline",), model, "", ""]) 

909 else: 

910 lines.append( 

911 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

912 previous_model = model 

913 label = var.descr.get("label", "") 

914 units = var.unitstr(" [%s] ") if printunits else "" 

915 if not isvector: 

916 valstr = valfmt % val 

917 else: 

918 last_dim_index = len(val.shape)-1 

919 horiz_dim, ncols = last_dim_index, 1 # starting values 

920 for dim_idx, dim_size in enumerate(val.shape): 

921 if ncols <= dim_size <= maxcolumns: 

922 horiz_dim, ncols = dim_idx, dim_size 

923 # align the array with horiz_dim by making it the last one 

924 dim_order = list(range(last_dim_index)) 

925 dim_order.insert(horiz_dim, last_dim_index) 

926 flatval = val.transpose(dim_order).flatten() 

927 vals = [vecfmt % v for v in flatval[:ncols]] 

928 bracket = " ] " if len(flatval) <= ncols else "" 

929 valstr = "[ %s%s" % (" ".join(vals), bracket) 

930 for before, after in VALSTR_REPLACES: 

931 valstr = valstr.replace(before, after) 

932 if not latex: 

933 lines.append([varstr, valstr, units, label]) 

934 if isvector and len(flatval) > ncols: 

935 values_remaining = len(flatval) - ncols 

936 while values_remaining > 0: 

937 idx = len(flatval)-values_remaining 

938 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

939 values_remaining -= ncols 

940 valstr = " " + " ".join(vals) 

941 for before, after in VALSTR_REPLACES: 

942 valstr = valstr.replace(before, after) 

943 if values_remaining <= 0: 

944 spaces = (-values_remaining 

945 * len(valstr)//(values_remaining + ncols)) 

946 valstr = valstr + " ]" + " "*spaces 

947 lines.append(["", valstr, "", ""]) 

948 else: 

949 varstr = "$%s$" % varstr.replace(" : ", "") 

950 if latex == 1: # normal results table 

951 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

952 label]) 

953 coltitles = [title, "Value", "Units", "Description"] 

954 elif latex == 2: # no values 

955 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

956 coltitles = [title, "Units", "Description"] 

957 elif latex == 3: # no description 

958 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

959 coltitles = [title, "Value", "Units"] 

960 else: 

961 raise ValueError("Unexpected latex option, %s." % latex) 

962 if rawlines: 

963 return lines 

964 if not latex: 

965 if lines: 

966 maxlens = np.max([list(map(len, line)) for line in lines 

967 if line[0] != ("newmodelline",)], axis=0) 

968 dirs = [">", "<", "<", "<"] 

969 # check lengths before using zip 

970 assert len(list(dirs)) == len(list(maxlens)) 

971 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

972 for i, line in enumerate(lines): 

973 if line[0] == ("newmodelline",): 

974 line = [fmts[0].format(" | "), line[1]] 

975 else: 

976 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

977 lines[i] = "".join(line).rstrip() 

978 lines = [title] + ["-"*len(title)] + lines + [""] 

979 else: 

980 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

981 lines = (["\n".join(["{\\footnotesize", 

982 "\\begin{longtable}{%s}" % colfmt[latex], 

983 "\\toprule", 

984 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

985 [" & ".join(l) + " \\\\" for l in lines] + 

986 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

987 return lines