Coverage for gpkit/solution_array.py: 81%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

634 statements  

1"""Defines SolutionArray class""" 

2import sys 

3import re 

4import json 

5import difflib 

6from operator import sub 

7import warnings as pywarnings 

8import pickle 

9import gzip 

10import pickletools 

11from collections import defaultdict 

12import numpy as np 

13from .nomials import NomialArray 

14from .small_classes import DictOfLists, Strings, SolverLog 

15from .small_scripts import mag, try_str_without 

16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS 

17from .breakdowns import Breakdowns 

18 

19 

20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

21 

22VALSTR_REPLACES = [ 

23 ("+nan", " nan"), 

24 ("-nan", " nan"), 

25 ("nan%", "nan "), 

26 ("nan", " - "), 

27] 

28 

29 

30class SolSavingEnvironment: 

31 """Temporarily removes construction/solve attributes from constraints. 

32 

33 This approximately halves the size of the pickled solution. 

34 """ 

35 

36 def __init__(self, solarray, saveconstraints): 

37 self.solarray = solarray 

38 self.attrstore = {} 

39 self.saveconstraints = saveconstraints 

40 self.constraintstore = None 

41 

42 

43 def __enter__(self): 

44 if self.saveconstraints: 

45 for constraint_attr in ["bounded", "meq_bounded", "vks", 

46 "v_ss", "unsubbed", "varkeys"]: 

47 store = {} 

48 for constraint in self.solarray["sensitivities"]["constraints"]: 

49 if getattr(constraint, constraint_attr, None): 

50 store[constraint] = getattr(constraint, constraint_attr) 

51 delattr(constraint, constraint_attr) 

52 self.attrstore[constraint_attr] = store 

53 else: 

54 self.constraintstore = \ 

55 self.solarray["sensitivities"].pop("constraints") 

56 

57 def __exit__(self, type_, val, traceback): 

58 if self.saveconstraints: 

59 for constraint_attr, store in self.attrstore.items(): 

60 for constraint, value in store.items(): 

61 setattr(constraint, constraint_attr, value) 

62 else: 

63 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

64 

65def msenss_table(data, _, **kwargs): 

66 "Returns model sensitivity table lines" 

67 if "models" not in data.get("sensitivities", {}): 

68 return "" 

69 data = sorted(data["sensitivities"]["models"].items(), 

70 key=lambda i: ((i[1] < 0.1).all(), 

71 -np.max(i[1]) if (i[1] < 0.1).all() 

72 else -round(np.mean(i[1]), 1), i[0])) 

73 lines = ["Model Sensitivities", "-------------------"] 

74 if kwargs["sortmodelsbysenss"]: 

75 lines[0] += " (sorts models in sections below)" 

76 previousmsenssstr = "" 

77 for model, msenss in data: 

78 if not model: # for now let's only do named models 

79 continue 

80 if (msenss < 0.1).all(): 

81 msenss = np.max(msenss) 

82 if msenss: 

83 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss))) 

84 else: 

85 msenssstr = " =0 " 

86 else: 

87 meansenss = round(np.mean(msenss), 1) 

88 msenssstr = "%+6.1f" % meansenss 

89 deltas = msenss - meansenss 

90 if np.max(np.abs(deltas)) > 0.1: 

91 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

92 for d in deltas] 

93 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

94 if msenssstr == previousmsenssstr: 

95 msenssstr = " "*len(msenssstr) 

96 else: 

97 previousmsenssstr = msenssstr 

98 lines.append("%s : %s" % (msenssstr, model)) 

99 return lines + [""] if len(lines) > 3 else [] 

100 

101 

102def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

103 "Returns sensitivity table lines" 

104 if "variables" in data.get("sensitivities", {}): 

105 data = data["sensitivities"]["variables"] 

106 if showvars: 

107 data = {k: data[k] for k in showvars if k in data} 

108 return var_table(data, title, sortbyvals=True, skipifempty=True, 

109 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

110 printunits=False, minval=1e-3, **kwargs) 

111 

112 

113def topsenss_table(data, showvars, nvars=5, **kwargs): 

114 "Returns top sensitivity table lines" 

115 data, filtered = topsenss_filter(data, showvars, nvars) 

116 title = "Most Sensitive Variables" 

117 if filtered: 

118 title = "Next Most Sensitive Variables" 

119 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

120 

121 

122def topsenss_filter(data, showvars, nvars=5): 

123 "Filters sensitivities down to top N vars" 

124 if "variables" in data.get("sensitivities", {}): 

125 data = data["sensitivities"]["variables"] 

126 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

127 if not np.isnan(s).any()} 

128 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

129 filter_already_shown = showvars.intersection(topk) 

130 for k in filter_already_shown: 

131 topk.remove(k) 

132 if nvars > 3: # always show at least 3 

133 nvars -= 1 

134 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

135 

136 

137def insenss_table(data, _, maxval=0.1, **kwargs): 

138 "Returns insensitivity table lines" 

139 if "constants" in data.get("sensitivities", {}): 

140 data = data["sensitivities"]["variables"] 

141 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

142 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

143 

144 

145def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

146 "Return constraint tightness lines" 

147 title = "Most Sensitive Constraints" 

148 if len(self) > 1: 

149 title += " (in last sweep)" 

150 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)), 

151 "%+6.2g" % abs(s[-1]), id(c), c) 

152 for c, s in self["sensitivities"]["constraints"].items() 

153 if s[-1] >= tight_senss)[:ntightconstrs] 

154 else: 

155 data = sorted(((-float("%+6.2g" % abs(s)), str(c)), 

156 "%+6.2g" % abs(s), id(c), c) 

157 for c, s in self["sensitivities"]["constraints"].items() 

158 if s >= tight_senss)[:ntightconstrs] 

159 return constraint_table(data, title, **kwargs) 

160 

161def loose_table(self, _, min_senss=1e-5, **kwargs): 

162 "Return constraint tightness lines" 

163 title = "Insensitive Constraints |below %+g|" % min_senss 

164 if len(self) > 1: 

165 title += " (in last sweep)" 

166 data = [(0, "", id(c), c) 

167 for c, s in self["sensitivities"]["constraints"].items() 

168 if s[-1] <= min_senss] 

169 else: 

170 data = [(0, "", id(c), c) 

171 for c, s in self["sensitivities"]["constraints"].items() 

172 if s <= min_senss] 

173 return constraint_table(data, title, **kwargs) 

174 

175 

176# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

177def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

178 "Creates lines for tables where the right side is a constraint." 

179 # TODO: this should support 1D array inputs from sweeps 

180 excluded = {"units"} if showmodels else {"units", "lineage"} 

181 models, decorated = {}, [] 

182 for sortby, openingstr, _, constraint in sorted(data): 

183 model = lineagestr(constraint) if sortbymodel else "" 

184 if model not in models: 

185 models[model] = len(models) 

186 constrstr = try_str_without( 

187 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded)) 

188 if " at 0x" in constrstr: # don't print memory addresses 

189 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

190 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

191 decorated.sort() 

192 previous_model, lines = None, [] 

193 for varlist in decorated: 

194 _, model, _, constrstr, openingstr = varlist 

195 if model != previous_model: 

196 if lines: 

197 lines.append(["", ""]) 

198 if model or lines: 

199 lines.append([("newmodelline",), model]) 

200 previous_model = model 

201 minlen, maxlen = 25, 80 

202 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

203 constraintlines = [] 

204 line = "" 

205 next_idx = 0 

206 while next_idx < len(segments): 

207 segment = segments[next_idx] 

208 next_idx += 1 

209 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

210 segments[next_idx] = segment[1:] + segments[next_idx] 

211 segment = segment[0] 

212 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

213 constraintlines.append(line) 

214 line = " " # start a new line 

215 line += segment 

216 while len(line) > maxlen: 

217 constraintlines.append(line[:maxlen]) 

218 line = " " + line[maxlen:] 

219 constraintlines.append(line) 

220 lines += [(openingstr + " : ", constraintlines[0])] 

221 lines += [("", l) for l in constraintlines[1:]] 

222 if not lines: 

223 lines = [("", "(none)")] 

224 maxlens = np.max([list(map(len, line)) for line in lines 

225 if line[0] != ("newmodelline",)], axis=0) 

226 dirs = [">", "<"] # we'll check lengths before using zip 

227 assert len(list(dirs)) == len(list(maxlens)) 

228 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

229 for i, line in enumerate(lines): 

230 if line[0] == ("newmodelline",): 

231 linelist = [fmts[0].format(" | "), line[1]] 

232 else: 

233 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

234 lines[i] = "".join(linelist).rstrip() 

235 return [title] + ["-"*len(title)] + lines + [""] 

236 

237 

238def warnings_table(self, _, **kwargs): 

239 "Makes a table for all warnings in the solution." 

240 title = "WARNINGS" 

241 lines = ["~"*len(title), title, "~"*len(title)] 

242 if "warnings" not in self or not self["warnings"]: 

243 return [] 

244 for wtype in sorted(self["warnings"]): 

245 data_vec = self["warnings"][wtype] 

246 if len(data_vec) == 0: 

247 continue 

248 if not hasattr(data_vec, "shape"): 

249 data_vec = [data_vec] # not a sweep 

250 else: 

251 all_equal = True 

252 for data in data_vec[1:]: 

253 eq_i = (data == data_vec[0]) 

254 if hasattr(eq_i, "all"): 

255 eq_i = eq_i.all() 

256 if not eq_i: 

257 all_equal = False 

258 break 

259 if all_equal: 

260 data_vec = [data_vec[0]] # warnings identical across sweeps 

261 for i, data in enumerate(data_vec): 

262 if len(data) == 0: 

263 continue 

264 data = sorted(data, key=lambda l: l[0]) # sort by msg 

265 title = wtype 

266 if len(data_vec) > 1: 

267 title += " in sweep %i" % i 

268 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

269 data = [(-int(1e5*relax_sensitivity), 

270 "%+6.2g" % relax_sensitivity, id(c), c) 

271 for _, (relax_sensitivity, c) in data] 

272 lines += constraint_table(data, title, **kwargs) 

273 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

274 data = [(-int(1e5*rel_diff), 

275 "%.4g %s %.4g" % tightvalues, id(c), c) 

276 for _, (rel_diff, tightvalues, c) in data] 

277 lines += constraint_table(data, title, **kwargs) 

278 else: 

279 lines += [title] + ["-"*len(wtype)] 

280 lines += [msg for msg, _ in data] + [""] 

281 if len(lines) == 3: # just the header 

282 return [] 

283 lines[-1] = "~~~~~~~~" 

284 return lines + [""] 

285 

286def bdtable_gen(key): 

287 "Generator for breakdown tablefns" 

288 

289 def bdtable(self, _showvars, **_): 

290 "Cost breakdown plot" 

291 bds = Breakdowns(self) 

292 original_stdout = sys.stdout 

293 try: 

294 sys.stdout = SolverLog(original_stdout, verbosity=0) 

295 bds.plot(key) 

296 finally: 

297 lines = sys.stdout.lines() 

298 sys.stdout = original_stdout 

299 return lines 

300 

301 return bdtable 

302 

303 

304TABLEFNS = {"sensitivities": senss_table, 

305 "top sensitivities": topsenss_table, 

306 "insensitivities": insenss_table, 

307 "model sensitivities": msenss_table, 

308 "tightest constraints": tight_table, 

309 "loose constraints": loose_table, 

310 "warnings": warnings_table, 

311 "model sensitivities breakdown": bdtable_gen("model sensitivities"), 

312 "cost breakdown": bdtable_gen("cost") 

313 } 

314 

315def unrolled_absmax(values): 

316 "From an iterable of numbers and arrays, returns the largest magnitude" 

317 finalval, absmaxest = None, 0 

318 for val in values: 

319 absmaxval = np.abs(val).max() 

320 if absmaxval >= absmaxest: 

321 absmaxest, finalval = absmaxval, val 

322 if getattr(finalval, "shape", None): 

323 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

324 finalval.shape)] 

325 return finalval 

326 

327 

328def cast(function, val1, val2): 

329 "Relative difference between val1 and val2 (positive if val2 is larger)" 

330 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

331 pywarnings.simplefilter("ignore") 

332 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

333 if val1.ndim == val2.ndim: 

334 return function(val1, val2) 

335 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

336 dimdelta = dimmest.ndim - lessdim.ndim 

337 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

338 if dimmest is val1: 

339 return function(dimmest, lessdim[add_axes]) 

340 if dimmest is val2: 

341 return function(lessdim[add_axes], dimmest) 

342 return function(val1, val2) 

343 

344 

345class SolutionArray(DictOfLists): 

346 """A dictionary (of dictionaries) of lists, with convenience methods. 

347 

348 Items 

349 ----- 

350 cost : array 

351 variables: dict of arrays 

352 sensitivities: dict containing: 

353 monomials : array 

354 posynomials : array 

355 variables: dict of arrays 

356 localmodels : NomialArray 

357 Local power-law fits (small sensitivities are cut off) 

358 

359 Example 

360 ------- 

361 >>> import gpkit 

362 >>> import numpy as np 

363 >>> x = gpkit.Variable("x") 

364 >>> x_min = gpkit.Variable("x_{min}", 2) 

365 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

366 >>> 

367 >>> # VALUES 

368 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

369 >>> assert all(np.array(values) == 2) 

370 >>> 

371 >>> # SENSITIVITIES 

372 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

373 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

374 >>> assert all(np.array(senss) == 1) 

375 """ 

376 modelstr = "" 

377 _name_collision_varkeys = None 

378 _lineageset = False 

379 table_titles = {"choicevariables": "Choice Variables", 

380 "sweepvariables": "Swept Variables", 

381 "freevariables": "Free Variables", 

382 "constants": "Fixed Variables", # TODO: change everywhere 

383 "variables": "Variables"} 

384 

385 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches 

386 "Returns the set of contained varkeys whose names are not unique" 

387 if self._name_collision_varkeys is None: 

388 self._name_collision_varkeys = {} 

389 self["variables"].update_keymap() 

390 keymap = self["variables"].keymap 

391 name_collisions = defaultdict(set) 

392 for key in keymap: 

393 if hasattr(key, "key"): 

394 if len(keymap[key.name]) == 1: # very unique 

395 self._name_collision_varkeys[key] = 0 

396 else: 

397 shortname = key.str_without(["lineage", "vec"]) 

398 if len(keymap[shortname]) > 1: 

399 name_collisions[shortname].add(key) 

400 for varkeys in name_collisions.values(): 

401 min_namespaced = defaultdict(set) 

402 for vk in varkeys: 

403 *_, mineage = vk.lineagestr().split(".") 

404 min_namespaced[(mineage, 1)].add(vk) 

405 while any(len(vks) > 1 for vks in min_namespaced.values()): 

406 for key, vks in list(min_namespaced.items()): 

407 if len(vks) <= 1: 

408 continue 

409 del min_namespaced[key] 

410 mineage, idx = key 

411 idx += 1 

412 for vk in vks: 

413 lineages = vk.lineagestr().split(".") 

414 submineage = lineages[-idx] + "." + mineage 

415 min_namespaced[(submineage, idx)].add(vk) 

416 for (_, idx), vks in min_namespaced.items(): 

417 vk, = vks 

418 self._name_collision_varkeys[vk] = idx 

419 if clear: 

420 self._lineageset = False 

421 for vk in self._name_collision_varkeys: 

422 del vk.descr["necessarylineage"] 

423 else: 

424 self._lineageset = True 

425 for vk, idx in self._name_collision_varkeys.items(): 

426 vk.descr["necessarylineage"] = idx 

427 

428 def __len__(self): 

429 try: 

430 return len(self["cost"]) 

431 except TypeError: 

432 return 1 

433 except KeyError: 

434 return 0 

435 

436 def __call__(self, posy): 

437 posy_subbed = self.subinto(posy) 

438 return getattr(posy_subbed, "c", posy_subbed) 

439 

440 def almost_equal(self, other, reltol=1e-3): 

441 "Checks for almost-equality between two solutions" 

442 svars, ovars = self["variables"], other["variables"] 

443 svks, ovks = set(svars), set(ovars) 

444 if svks != ovks: 

445 return False 

446 for key in svks: 

447 if np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1)) >= reltol: 

448 return False 

449 return True 

450 

451 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

452 def diff(self, other, showvars=None, *, 

453 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

454 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

455 sortmodelsbysenss=True, **tableargs): 

456 """Outputs differences between this solution and another 

457 

458 Arguments 

459 --------- 

460 other : solution or string 

461 strings will be treated as paths to pickled solutions 

462 senssdiff : boolean 

463 if True, show sensitivity differences 

464 sensstol : float 

465 the smallest sensitivity difference worth showing 

466 absdiff : boolean 

467 if True, show absolute differences 

468 abstol : float 

469 the smallest absolute difference worth showing 

470 reldiff : boolean 

471 if True, show relative differences 

472 reltol : float 

473 the smallest relative difference worth showing 

474 

475 Returns 

476 ------- 

477 str 

478 """ 

479 if sortmodelsbysenss: 

480 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

481 else: 

482 tableargs["sortmodelsbysenss"] = False 

483 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

484 "skipifempty": False}) 

485 if isinstance(other, Strings): 

486 if other[-4:] == ".pgz": 

487 other = SolutionArray.decompress_file(other) 

488 else: 

489 other = pickle.load(open(other, "rb")) 

490 svars, ovars = self["variables"], other["variables"] 

491 lines = ["Solution Diff", 

492 "=============", 

493 "(argument is the baseline solution)", ""] 

494 svks, ovks = set(svars), set(ovars) 

495 if showvars: 

496 lines[0] += " (for selected variables)" 

497 lines[1] += "=========================" 

498 showvars = self._parse_showvars(showvars) 

499 svks = {k for k in showvars if k in svars} 

500 ovks = {k for k in showvars if k in ovars} 

501 if constraintsdiff and other.modelstr and self.modelstr: 

502 if self.modelstr == other.modelstr: 

503 lines += ["** no constraint differences **", ""] 

504 else: 

505 cdiff = ["Constraint Differences", 

506 "**********************"] 

507 cdiff.extend(list(difflib.unified_diff( 

508 other.modelstr.split("\n"), self.modelstr.split("\n"), 

509 lineterm="", n=3))[2:]) 

510 cdiff += ["", "**********************", ""] 

511 lines += cdiff 

512 if svks - ovks: 

513 lines.append("Variable(s) of this solution" 

514 " which are not in the argument:") 

515 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

516 lines.append("") 

517 if ovks - svks: 

518 lines.append("Variable(s) of the argument" 

519 " which are not in this solution:") 

520 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

521 lines.append("") 

522 sharedvks = svks.intersection(ovks) 

523 if reldiff: 

524 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

525 for vk in sharedvks} 

526 lines += var_table(rel_diff, 

527 "Relative Differences |above %g%%|" % reltol, 

528 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

529 minval=reltol, printunits=False, **tableargs) 

530 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

531 lines.insert(-1, ("The largest is %+g%%." 

532 % unrolled_absmax(rel_diff.values()))) 

533 if absdiff: 

534 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

535 lines += var_table(abs_diff, 

536 "Absolute Differences |above %g|" % abstol, 

537 valfmt="%+.2g", vecfmt="%+8.2g", 

538 minval=abstol, **tableargs) 

539 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

540 lines.insert(-1, ("The largest is %+g." 

541 % unrolled_absmax(abs_diff.values()))) 

542 if senssdiff: 

543 ssenss = self["sensitivities"]["variables"] 

544 osenss = other["sensitivities"]["variables"] 

545 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

546 for vk in svks.intersection(ovks)} 

547 lines += var_table(senss_delta, 

548 "Sensitivity Differences |above %g|" % sensstol, 

549 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

550 minval=sensstol, printunits=False, **tableargs) 

551 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

552 lines.insert(-1, ("The largest is %+g." 

553 % unrolled_absmax(senss_delta.values()))) 

554 return "\n".join(lines) 

555 

556 def save(self, filename="solution.pkl", 

557 *, saveconstraints=True, **pickleargs): 

558 """Pickles the solution and saves it to a file. 

559 

560 Solution can then be loaded with e.g.: 

561 >>> import pickle 

562 >>> pickle.load(open("solution.pkl")) 

563 """ 

564 with SolSavingEnvironment(self, saveconstraints): 

565 pickle.dump(self, open(filename, "wb"), **pickleargs) 

566 

567 def save_compressed(self, filename="solution.pgz", 

568 *, saveconstraints=True, **cpickleargs): 

569 "Pickle a file and then compress it into a file with extension." 

570 with gzip.open(filename, "wb") as f: 

571 with SolSavingEnvironment(self, saveconstraints): 

572 pickled = pickle.dumps(self, **cpickleargs) 

573 f.write(pickletools.optimize(pickled)) 

574 

575 @staticmethod 

576 def decompress_file(file): 

577 "Load a gzip-compressed pickle file" 

578 with gzip.open(file, "rb") as f: 

579 return pickle.Unpickler(f).load() 

580 

581 def varnames(self, showvars, exclude): 

582 "Returns list of variables, optionally with minimal unique names" 

583 if showvars: 

584 showvars = self._parse_showvars(showvars) 

585 self.set_necessarylineage() 

586 names = {} 

587 for key in showvars or self["variables"]: 

588 for k in self["variables"].keymap[key]: 

589 names[k.str_without(exclude)] = k 

590 self.set_necessarylineage(clear=True) 

591 return names 

592 

593 def savemat(self, filename="solution.mat", *, showvars=None, 

594 excluded=("vec")): 

595 "Saves primal solution as matlab file" 

596 from scipy.io import savemat 

597 savemat(filename, 

598 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

599 for name, key in self.varnames(showvars, excluded).items()}) 

600 

601 def todataframe(self, showvars=None, excluded=("vec")): 

602 "Returns primal solution as pandas dataframe" 

603 import pandas as pd # pylint:disable=import-error 

604 rows = [] 

605 cols = ["Name", "Index", "Value", "Units", "Label", 

606 "Lineage", "Other"] 

607 for _, key in sorted(self.varnames(showvars, excluded).items(), 

608 key=lambda k: k[0]): 

609 value = self["variables"][key] 

610 if key.shape: 

611 idxs = [] 

612 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

613 while not it.finished: 

614 idx = it.multi_index 

615 idxs.append(idx[0] if len(idx) == 1 else idx) 

616 it.iternext() 

617 else: 

618 idxs = [None] 

619 for idx in idxs: 

620 row = [ 

621 key.name, 

622 "" if idx is None else idx, 

623 value if idx is None else value[idx]] 

624 rows.append(row) 

625 row.extend([ 

626 key.unitstr(), 

627 key.label or "", 

628 key.lineage or "", 

629 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

630 if k not in ["name", "units", "unitrepr", 

631 "idx", "shape", "veckey", 

632 "value", "vecfn", 

633 "lineage", "label"])]) 

634 return pd.DataFrame(rows, columns=cols) 

635 

636 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

637 "Saves solution table as a text file" 

638 with open(filename, "w") as f: 

639 if printmodel: 

640 f.write(self.modelstr + "\n") 

641 f.write(self.table(**kwargs)) 

642 

643 def savejson(self, filename="solution.json", showvars=None): 

644 "Saves solution table as a json file" 

645 sol_dict = {} 

646 if self._lineageset: 

647 self.set_necessarylineage(clear=True) 

648 data = self["variables"] 

649 if showvars: 

650 showvars = self._parse_showvars(showvars) 

651 data = {k: data[k] for k in showvars if k in data} 

652 # add appropriate data for each variable to the dictionary 

653 for k, v in data.items(): 

654 key = str(k) 

655 if isinstance(v, np.ndarray): 

656 val = {"v": v.tolist(), "u": k.unitstr()} 

657 else: 

658 val = {"v": v, "u": k.unitstr()} 

659 sol_dict[key] = val 

660 with open(filename, "w") as f: 

661 json.dump(sol_dict, f) 

662 

663 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

664 "Saves primal solution as a CSV sorted by modelname, like the tables." 

665 data = self["variables"] 

666 if showvars: 

667 showvars = self._parse_showvars(showvars) 

668 data = {k: data[k] for k in showvars if k in data} 

669 # if the columns don't capture any dimensions, skip them 

670 minspan, maxspan = None, 1 

671 for v in data.values(): 

672 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

673 minspan_ = min((di for di in v.shape if di != 1)) 

674 maxspan_ = max((di for di in v.shape if di != 1)) 

675 if minspan is None or minspan_ < minspan: 

676 minspan = minspan_ 

677 if maxspan is None or maxspan_ > maxspan: 

678 maxspan = maxspan_ 

679 if minspan is not None and minspan > valcols: 

680 valcols = 1 

681 if maxspan < valcols: 

682 valcols = maxspan 

683 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

684 tables=("cost", "sweepvariables", "freevariables", 

685 "constants", "sensitivities")) 

686 with open(filename, "w") as f: 

687 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

688 + "Units,Description\n") 

689 for line in lines: 

690 if line[0] == ("newmodelline",): 

691 f.write(line[1]) 

692 elif not line[1]: # spacer line 

693 f.write("\n") 

694 else: 

695 f.write("," + line[0].replace(" : ", "") + ",") 

696 vals = line[1].replace("[", "").replace("]", "").strip() 

697 for el in vals.split(): 

698 f.write(el + ",") 

699 f.write(","*(valcols - len(vals.split()))) 

700 f.write((line[2].replace("[", "").replace("]", "").strip() 

701 + ",")) 

702 f.write(line[3].strip() + "\n") 

703 

704 def subinto(self, posy): 

705 "Returns NomialArray of each solution substituted into posy." 

706 if posy in self["variables"]: 

707 return self["variables"](posy) 

708 

709 if not hasattr(posy, "sub"): 

710 raise ValueError("no variable '%s' found in the solution" % posy) 

711 

712 if len(self) > 1: 

713 return NomialArray([self.atindex(i).subinto(posy) 

714 for i in range(len(self))]) 

715 

716 return posy.sub(self["variables"], require_positive=False) 

717 

718 def _parse_showvars(self, showvars): 

719 showvars_out = set() 

720 for k in showvars: 

721 k, _ = self["variables"].parse_and_index(k) 

722 keys = self["variables"].keymap[k] 

723 showvars_out.update(keys) 

724 return showvars_out 

725 

726 def summary(self, showvars=(), **kwargs): 

727 "Print summary table, showing no sensitivities or constants" 

728 return self.table(showvars, 

729 ["cost breakdown", "model sensitivities breakdown", 

730 "warnings", "sweepvariables", "freevariables"], 

731 **kwargs) 

732 

733 def table(self, showvars=(), 

734 tables=("cost breakdown", "model sensitivities breakdown", 

735 "warnings", "sweepvariables", "freevariables", 

736 "constants", "sensitivities", "tightest constraints"), 

737 sortmodelsbysenss=False, **kwargs): 

738 """A table representation of this SolutionArray 

739 

740 Arguments 

741 --------- 

742 tables: Iterable 

743 Which to print of ("cost", "sweepvariables", "freevariables", 

744 "constants", "sensitivities") 

745 fixedcols: If true, print vectors in fixed-width format 

746 latex: int 

747 If > 0, return latex format (options 1-3); otherwise plain text 

748 included_models: Iterable of strings 

749 If specified, the models (by name) to include 

750 excluded_models: Iterable of strings 

751 If specified, model names to exclude 

752 

753 Returns 

754 ------- 

755 str 

756 """ 

757 if sortmodelsbysenss and "sensitivities" in self: 

758 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

759 else: 

760 kwargs["sortmodelsbysenss"] = False 

761 varlist = list(self["variables"]) 

762 has_only_one_model = True 

763 for var in varlist[1:]: 

764 if var.lineage != varlist[0].lineage: 

765 has_only_one_model = False 

766 break 

767 if has_only_one_model: 

768 kwargs["sortbymodel"] = False 

769 self.set_necessarylineage() 

770 showvars = self._parse_showvars(showvars) 

771 strs = [] 

772 for table in tables: 

773 if "breakdown" in table: 

774 if len(self) > 1 or not UNICODE_EXPONENTS: 

775 # no breakdowns for sweeps or no-unicode environments 

776 table = table.replace(" breakdown", "") 

777 if "sensitivities" not in self and ("sensitivities" in table or 

778 "constraints" in table): 

779 continue 

780 if table == "cost": 

781 cost = self["cost"] # pylint: disable=unsubscriptable-object 

782 if kwargs.get("latex", None): # cost is not printed for latex 

783 continue 

784 strs += ["\n%s\n------------" % "Optimal Cost"] 

785 if len(self) > 1: 

786 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

787 strs += [" [ %s %s ]" % (" ".join(costs), 

788 "..." if len(self) > 4 else "")] 

789 else: 

790 strs += [" %-.4g" % mag(cost)] 

791 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

792 strs += [""] 

793 elif table in TABLEFNS: 

794 strs += TABLEFNS[table](self, showvars, **kwargs) 

795 elif table in self: 

796 data = self[table] 

797 if showvars: 

798 showvars = self._parse_showvars(showvars) 

799 data = {k: data[k] for k in showvars if k in data} 

800 strs += var_table(data, self.table_titles[table], **kwargs) 

801 if kwargs.get("latex", None): 

802 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

803 "% \\usepackage{booktabs}", 

804 "% \\usepackage{longtable}", 

805 "% \\usepackage{amsmath}", 

806 "% \\begin{document}\n")) 

807 strs = [preamble] + strs + ["% \\end{document}"] 

808 self.set_necessarylineage(clear=True) 

809 return "\n".join(strs) 

810 

811 def plot(self, posys=None, axes=None): 

812 "Plots a sweep for each posy" 

813 if len(self["sweepvariables"]) != 1: 

814 print("SolutionArray.plot only supports 1-dimensional sweeps") 

815 if not hasattr(posys, "__len__"): 

816 posys = [posys] 

817 import matplotlib.pyplot as plt 

818 from .interactive.plot_sweep import assign_axes 

819 from . import GPBLU 

820 (swept, x), = self["sweepvariables"].items() 

821 posys, axes = assign_axes(swept, posys, axes) 

822 for posy, ax in zip(posys, axes): 

823 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

824 ax.plot(x, y, color=GPBLU) 

825 if len(axes) == 1: 

826 axes, = axes 

827 return plt.gcf(), axes 

828 

829 

830# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

831def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

832 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

833 minval=0, sortbyvals=False, hidebelowminval=False, 

834 included_models=None, excluded_models=None, sortbymodel=True, 

835 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

836 """ 

837 Pretty string representation of a dict of VarKeys 

838 Iterable values are handled specially (partial printing) 

839 

840 Arguments 

841 --------- 

842 data : dict whose keys are VarKey's 

843 data to represent in table 

844 title : string 

845 printunits : bool 

846 latex : int 

847 If > 0, return latex format (options 1-3); otherwise plain text 

848 varfmt : string 

849 format for variable names 

850 valfmt : string 

851 format for scalar values 

852 vecfmt : string 

853 format for vector values 

854 minval : float 

855 skip values with all(abs(value)) < minval 

856 sortbyvals : boolean 

857 If true, rows are sorted by their average value instead of by name. 

858 included_models : Iterable of strings 

859 If specified, the models (by name) to include 

860 excluded_models : Iterable of strings 

861 If specified, model names to exclude 

862 """ 

863 if not data: 

864 return [] 

865 decorated, models = [], set() 

866 for i, (k, v) in enumerate(data.items()): 

867 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

868 continue # no values below minval 

869 if minval and hidebelowminval and getattr(v, "shape", None): 

870 v[np.abs(v) <= minval] = np.nan 

871 model = lineagestr(k.lineage) if sortbymodel else "" 

872 if not sortmodelsbysenss: 

873 msenss = 0 

874 else: # sort should match that in msenss_table above 

875 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4) 

876 models.add(model) 

877 b = bool(getattr(v, "shape", None)) 

878 s = k.str_without(("lineage", "vec")) 

879 if not sortbyvals: 

880 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

881 else: # for consistent sorting, add small offset to negative vals 

882 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

883 sort = (float("%.4g" % -val), k.name) 

884 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

885 if not decorated and skipifempty: 

886 return [] 

887 if included_models: 

888 included_models = set(included_models) 

889 included_models.add("") 

890 models = models.intersection(included_models) 

891 if excluded_models: 

892 models = models.difference(excluded_models) 

893 decorated.sort() 

894 previous_model, lines = None, [] 

895 for varlist in decorated: 

896 if sortbyvals: 

897 model, _, msenss, isvector, varstr, _, var, val = varlist 

898 else: 

899 msenss, model, isvector, varstr, _, var, val = varlist 

900 if model not in models: 

901 continue 

902 if model != previous_model: 

903 if lines: 

904 lines.append(["", "", "", ""]) 

905 if model: 

906 if not latex: 

907 lines.append([("newmodelline",), model, "", ""]) 

908 else: 

909 lines.append( 

910 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

911 previous_model = model 

912 label = var.descr.get("label", "") 

913 units = var.unitstr(" [%s] ") if printunits else "" 

914 if not isvector: 

915 valstr = valfmt % val 

916 else: 

917 last_dim_index = len(val.shape)-1 

918 horiz_dim, ncols = last_dim_index, 1 # starting values 

919 for dim_idx, dim_size in enumerate(val.shape): 

920 if ncols <= dim_size <= maxcolumns: 

921 horiz_dim, ncols = dim_idx, dim_size 

922 # align the array with horiz_dim by making it the last one 

923 dim_order = list(range(last_dim_index)) 

924 dim_order.insert(horiz_dim, last_dim_index) 

925 flatval = val.transpose(dim_order).flatten() 

926 vals = [vecfmt % v for v in flatval[:ncols]] 

927 bracket = " ] " if len(flatval) <= ncols else "" 

928 valstr = "[ %s%s" % (" ".join(vals), bracket) 

929 for before, after in VALSTR_REPLACES: 

930 valstr = valstr.replace(before, after) 

931 if not latex: 

932 lines.append([varstr, valstr, units, label]) 

933 if isvector and len(flatval) > ncols: 

934 values_remaining = len(flatval) - ncols 

935 while values_remaining > 0: 

936 idx = len(flatval)-values_remaining 

937 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

938 values_remaining -= ncols 

939 valstr = " " + " ".join(vals) 

940 for before, after in VALSTR_REPLACES: 

941 valstr = valstr.replace(before, after) 

942 if values_remaining <= 0: 

943 spaces = (-values_remaining 

944 * len(valstr)//(values_remaining + ncols)) 

945 valstr = valstr + " ]" + " "*spaces 

946 lines.append(["", valstr, "", ""]) 

947 else: 

948 varstr = "$%s$" % varstr.replace(" : ", "") 

949 if latex == 1: # normal results table 

950 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

951 label]) 

952 coltitles = [title, "Value", "Units", "Description"] 

953 elif latex == 2: # no values 

954 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

955 coltitles = [title, "Units", "Description"] 

956 elif latex == 3: # no description 

957 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

958 coltitles = [title, "Value", "Units"] 

959 else: 

960 raise ValueError("Unexpected latex option, %s." % latex) 

961 if rawlines: 

962 return lines 

963 if not latex: 

964 if lines: 

965 maxlens = np.max([list(map(len, line)) for line in lines 

966 if line[0] != ("newmodelline",)], axis=0) 

967 dirs = [">", "<", "<", "<"] 

968 # check lengths before using zip 

969 assert len(list(dirs)) == len(list(maxlens)) 

970 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

971 for i, line in enumerate(lines): 

972 if line[0] == ("newmodelline",): 

973 line = [fmts[0].format(" | "), line[1]] 

974 else: 

975 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

976 lines[i] = "".join(line).rstrip() 

977 lines = [title] + ["-"*len(title)] + lines + [""] 

978 else: 

979 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

980 lines = (["\n".join(["{\\footnotesize", 

981 "\\begin{longtable}{%s}" % colfmt[latex], 

982 "\\toprule", 

983 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

984 [" & ".join(l) + " \\\\" for l in lines] + 

985 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

986 return lines