Coverage for gpkit/solution_array.py: 73%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

639 statements  

1"""Defines SolutionArray class""" 

2import sys 

3import re 

4import json 

5import difflib 

6from operator import sub 

7import warnings as pywarnings 

8import pickle 

9import gzip 

10import pickletools 

11from collections import defaultdict 

12import numpy as np 

13from .nomials import NomialArray 

14from .small_classes import DictOfLists, Strings, SolverLog 

15from .small_scripts import mag, try_str_without 

16from .repr_conventions import unitstr, lineagestr, UNICODE_EXPONENTS 

17from .breakdowns import Breakdowns 

18 

19 

20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

21 

22VALSTR_REPLACES = [ 

23 ("+nan", " nan"), 

24 ("-nan", " nan"), 

25 ("nan%", "nan "), 

26 ("nan", " - "), 

27] 

28 

29 

30class SolSavingEnvironment: 

31 """Temporarily removes construction/solve attributes from constraints. 

32 

33 This approximately halves the size of the pickled solution. 

34 """ 

35 

36 def __init__(self, solarray, saveconstraints): 

37 self.solarray = solarray 

38 self.attrstore = {} 

39 self.saveconstraints = saveconstraints 

40 self.constraintstore = None 

41 

42 

43 def __enter__(self): 

44 if "sensitivities" not in self.solarray: 

45 pass 

46 elif self.saveconstraints: 

47 for constraint_attr in ["bounded", "meq_bounded", "vks", 

48 "v_ss", "unsubbed", "varkeys"]: 

49 store = {} 

50 for constraint in self.solarray["sensitivities"]["constraints"]: 

51 if getattr(constraint, constraint_attr, None): 

52 store[constraint] = getattr(constraint, constraint_attr) 

53 delattr(constraint, constraint_attr) 

54 self.attrstore[constraint_attr] = store 

55 else: 

56 self.constraintstore = \ 

57 self.solarray["sensitivities"].pop("constraints") 

58 

59 def __exit__(self, type_, val, traceback): 

60 if not self.constraintstore: 

61 pass 

62 elif self.saveconstraints: 

63 for constraint_attr, store in self.attrstore.items(): 

64 for constraint, value in store.items(): 

65 setattr(constraint, constraint_attr, value) 

66 else: 

67 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

68 

69def msenss_table(data, _, **kwargs): 

70 "Returns model sensitivity table lines" 

71 if "models" not in data.get("sensitivities", {}): 

72 return "" 

73 data = sorted(data["sensitivities"]["models"].items(), 

74 key=lambda i: ((i[1] < 0.1).all(), 

75 -np.max(i[1]) if (i[1] < 0.1).all() 

76 else -round(np.mean(i[1]), 1), i[0])) 

77 lines = ["Model Sensitivities", "-------------------"] 

78 if kwargs["sortmodelsbysenss"]: 

79 lines[0] += " (sorts models in sections below)" 

80 previousmsenssstr = "" 

81 for model, msenss in data: 

82 if not model: # for now let's only do named models 

83 continue 

84 if (msenss < 0.1).all(): 

85 msenss = np.max(msenss) 

86 if msenss: 

87 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss))) 

88 else: 

89 msenssstr = " =0 " 

90 else: 

91 meansenss = round(np.mean(msenss), 1) 

92 msenssstr = "%+6.1f" % meansenss 

93 deltas = msenss - meansenss 

94 if np.max(np.abs(deltas)) > 0.1: 

95 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

96 for d in deltas] 

97 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

98 if msenssstr == previousmsenssstr: 

99 msenssstr = " "*len(msenssstr) 

100 else: 

101 previousmsenssstr = msenssstr 

102 lines.append("%s : %s" % (msenssstr, model)) 

103 return lines + [""] if len(lines) > 3 else [] 

104 

105 

106def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

107 "Returns sensitivity table lines" 

108 if "variables" in data.get("sensitivities", {}): 

109 data = data["sensitivities"]["variables"] 

110 if showvars: 

111 data = {k: data[k] for k in showvars if k in data} 

112 return var_table(data, title, sortbyvals=True, skipifempty=True, 

113 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

114 printunits=False, minval=1e-3, **kwargs) 

115 

116 

117def topsenss_table(data, showvars, nvars=5, **kwargs): 

118 "Returns top sensitivity table lines" 

119 data, filtered = topsenss_filter(data, showvars, nvars) 

120 title = "Most Sensitive Variables" 

121 if filtered: 

122 title = "Next Most Sensitive Variables" 

123 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

124 

125 

126def topsenss_filter(data, showvars, nvars=5): 

127 "Filters sensitivities down to top N vars" 

128 if "variables" in data.get("sensitivities", {}): 

129 data = data["sensitivities"]["variables"] 

130 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

131 if not np.isnan(s).any()} 

132 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

133 filter_already_shown = showvars.intersection(topk) 

134 for k in filter_already_shown: 

135 topk.remove(k) 

136 if nvars > 3: # always show at least 3 

137 nvars -= 1 

138 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

139 

140 

141def insenss_table(data, _, maxval=0.1, **kwargs): 

142 "Returns insensitivity table lines" 

143 if "constants" in data.get("sensitivities", {}): 

144 data = data["sensitivities"]["variables"] 

145 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

146 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

147 

148 

149def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

150 "Return constraint tightness lines" 

151 title = "Most Sensitive Constraints" 

152 if len(self) > 1: 

153 title += " (in last sweep)" 

154 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)), 

155 "%+6.2g" % abs(s[-1]), id(c), c) 

156 for c, s in self["sensitivities"]["constraints"].items() 

157 if s[-1] >= tight_senss)[:ntightconstrs] 

158 else: 

159 data = sorted(((-float("%+6.2g" % abs(s)), str(c)), 

160 "%+6.2g" % abs(s), id(c), c) 

161 for c, s in self["sensitivities"]["constraints"].items() 

162 if s >= tight_senss)[:ntightconstrs] 

163 return constraint_table(data, title, **kwargs) 

164 

165def loose_table(self, _, min_senss=1e-5, **kwargs): 

166 "Return constraint tightness lines" 

167 title = "Insensitive Constraints |below %+g|" % min_senss 

168 if len(self) > 1: 

169 title += " (in last sweep)" 

170 data = [(0, "", id(c), c) 

171 for c, s in self["sensitivities"]["constraints"].items() 

172 if s[-1] <= min_senss] 

173 else: 

174 data = [(0, "", id(c), c) 

175 for c, s in self["sensitivities"]["constraints"].items() 

176 if s <= min_senss] 

177 return constraint_table(data, title, **kwargs) 

178 

179 

180# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

181def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

182 "Creates lines for tables where the right side is a constraint." 

183 # TODO: this should support 1D array inputs from sweeps 

184 excluded = {"units"} if showmodels else {"units", "lineage"} 

185 models, decorated = {}, [] 

186 for sortby, openingstr, _, constraint in sorted(data): 

187 model = lineagestr(constraint) if sortbymodel else "" 

188 if model not in models: 

189 models[model] = len(models) 

190 constrstr = try_str_without( 

191 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded)) 

192 if " at 0x" in constrstr: # don't print memory addresses 

193 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

194 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

195 decorated.sort() 

196 previous_model, lines = None, [] 

197 for varlist in decorated: 

198 _, model, _, constrstr, openingstr = varlist 

199 if model != previous_model: 

200 if lines: 

201 lines.append(["", ""]) 

202 if model or lines: 

203 lines.append([("newmodelline",), model]) 

204 previous_model = model 

205 minlen, maxlen = 25, 80 

206 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

207 constraintlines = [] 

208 line = "" 

209 next_idx = 0 

210 while next_idx < len(segments): 

211 segment = segments[next_idx] 

212 next_idx += 1 

213 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

214 segments[next_idx] = segment[1:] + segments[next_idx] 

215 segment = segment[0] 

216 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

217 constraintlines.append(line) 

218 line = " " # start a new line 

219 line += segment 

220 while len(line) > maxlen: 

221 constraintlines.append(line[:maxlen]) 

222 line = " " + line[maxlen:] 

223 constraintlines.append(line) 

224 lines += [(openingstr + " : ", constraintlines[0])] 

225 lines += [("", l) for l in constraintlines[1:]] 

226 if not lines: 

227 lines = [("", "(none)")] 

228 maxlens = np.max([list(map(len, line)) for line in lines 

229 if line[0] != ("newmodelline",)], axis=0) 

230 dirs = [">", "<"] # we'll check lengths before using zip 

231 assert len(list(dirs)) == len(list(maxlens)) 

232 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

233 for i, line in enumerate(lines): 

234 if line[0] == ("newmodelline",): 

235 linelist = [fmts[0].format(" | "), line[1]] 

236 else: 

237 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

238 lines[i] = "".join(linelist).rstrip() 

239 return [title] + ["-"*len(title)] + lines + [""] 

240 

241 

242def warnings_table(self, _, **kwargs): 

243 "Makes a table for all warnings in the solution." 

244 title = "WARNINGS" 

245 lines = ["~"*len(title), title, "~"*len(title)] 

246 if "warnings" not in self or not self["warnings"]: 

247 return [] 

248 for wtype in sorted(self["warnings"]): 

249 data_vec = self["warnings"][wtype] 

250 if len(data_vec) == 0: 

251 continue 

252 if not hasattr(data_vec, "shape"): 

253 data_vec = [data_vec] # not a sweep 

254 else: 

255 all_equal = True 

256 for data in data_vec[1:]: 

257 eq_i = (data == data_vec[0]) 

258 if hasattr(eq_i, "all"): 

259 eq_i = eq_i.all() 

260 if not eq_i: 

261 all_equal = False 

262 break 

263 if all_equal: 

264 data_vec = [data_vec[0]] # warnings identical across sweeps 

265 for i, data in enumerate(data_vec): 

266 if len(data) == 0: 

267 continue 

268 data = sorted(data, key=lambda l: l[0]) # sort by msg 

269 title = wtype 

270 if len(data_vec) > 1: 

271 title += " in sweep %i" % i 

272 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

273 data = [(-int(1e5*relax_sensitivity), 

274 "%+6.2g" % relax_sensitivity, id(c), c) 

275 for _, (relax_sensitivity, c) in data] 

276 lines += constraint_table(data, title, **kwargs) 

277 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

278 data = [(-int(1e5*rel_diff), 

279 "%.4g %s %.4g" % tightvalues, id(c), c) 

280 for _, (rel_diff, tightvalues, c) in data] 

281 lines += constraint_table(data, title, **kwargs) 

282 else: 

283 lines += [title] + ["-"*len(wtype)] 

284 lines += [msg for msg, _ in data] + [""] 

285 if len(lines) == 3: # just the header 

286 return [] 

287 lines[-1] = "~~~~~~~~" 

288 return lines + [""] 

289 

290def bdtable_gen(key): 

291 "Generator for breakdown tablefns" 

292 

293 def bdtable(self, _showvars, **_): 

294 "Cost breakdown plot" 

295 bds = Breakdowns(self) 

296 original_stdout = sys.stdout 

297 try: 

298 sys.stdout = SolverLog(original_stdout, verbosity=0) 

299 bds.plot(key) 

300 finally: 

301 lines = sys.stdout.lines() 

302 sys.stdout = original_stdout 

303 return lines 

304 

305 return bdtable 

306 

307 

308TABLEFNS = {"sensitivities": senss_table, 

309 "top sensitivities": topsenss_table, 

310 "insensitivities": insenss_table, 

311 "model sensitivities": msenss_table, 

312 "tightest constraints": tight_table, 

313 "loose constraints": loose_table, 

314 "warnings": warnings_table, 

315 "model sensitivities breakdown": bdtable_gen("model sensitivities"), 

316 "cost breakdown": bdtable_gen("cost") 

317 } 

318 

319def unrolled_absmax(values): 

320 "From an iterable of numbers and arrays, returns the largest magnitude" 

321 finalval, absmaxest = None, 0 

322 for val in values: 

323 absmaxval = np.abs(val).max() 

324 if absmaxval >= absmaxest: 

325 absmaxest, finalval = absmaxval, val 

326 if getattr(finalval, "shape", None): 

327 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

328 finalval.shape)] 

329 return finalval 

330 

331 

332def cast(function, val1, val2): 

333 "Relative difference between val1 and val2 (positive if val2 is larger)" 

334 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

335 pywarnings.simplefilter("ignore") 

336 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

337 if val1.ndim == val2.ndim: 

338 return function(val1, val2) 

339 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

340 dimdelta = dimmest.ndim - lessdim.ndim 

341 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

342 if dimmest is val1: 

343 return function(dimmest, lessdim[add_axes]) 

344 if dimmest is val2: 

345 return function(lessdim[add_axes], dimmest) 

346 return function(val1, val2) 

347 

348 

349class SolutionArray(DictOfLists): 

350 """A dictionary (of dictionaries) of lists, with convenience methods. 

351 

352 Items 

353 ----- 

354 cost : array 

355 variables: dict of arrays 

356 sensitivities: dict containing: 

357 monomials : array 

358 posynomials : array 

359 variables: dict of arrays 

360 localmodels : NomialArray 

361 Local power-law fits (small sensitivities are cut off) 

362 

363 Example 

364 ------- 

365 >>> import gpkit 

366 >>> import numpy as np 

367 >>> x = gpkit.Variable("x") 

368 >>> x_min = gpkit.Variable("x_{min}", 2) 

369 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

370 >>> 

371 >>> # VALUES 

372 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

373 >>> assert all(np.array(values) == 2) 

374 >>> 

375 >>> # SENSITIVITIES 

376 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

377 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

378 >>> assert all(np.array(senss) == 1) 

379 """ 

380 modelstr = "" 

381 _name_collision_varkeys = None 

382 _lineageset = False 

383 table_titles = {"choicevariables": "Choice Variables", 

384 "sweepvariables": "Swept Variables", 

385 "freevariables": "Free Variables", 

386 "constants": "Fixed Variables", # TODO: change everywhere 

387 "variables": "Variables"} 

388 

389 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches 

390 "Returns the set of contained varkeys whose names are not unique" 

391 if self._name_collision_varkeys is None: 

392 self._name_collision_varkeys = {} 

393 self["variables"].update_keymap() 

394 keymap = self["variables"].keymap 

395 name_collisions = defaultdict(set) 

396 for key in keymap: 

397 if hasattr(key, "key"): 

398 if len(keymap[key.name]) == 1: # very unique 

399 self._name_collision_varkeys[key] = 0 

400 else: 

401 shortname = key.str_without(["lineage", "vec"]) 

402 if len(keymap[shortname]) > 1: 

403 name_collisions[shortname].add(key) 

404 for varkeys in name_collisions.values(): 

405 min_namespaced = defaultdict(set) 

406 for vk in varkeys: 

407 *_, mineage = vk.lineagestr().split(".") 

408 min_namespaced[(mineage, 1)].add(vk) 

409 while any(len(vks) > 1 for vks in min_namespaced.values()): 

410 for key, vks in list(min_namespaced.items()): 

411 if len(vks) <= 1: 

412 continue 

413 del min_namespaced[key] 

414 mineage, idx = key 

415 idx += 1 

416 for vk in vks: 

417 lineages = vk.lineagestr().split(".") 

418 submineage = lineages[-idx] + "." + mineage 

419 min_namespaced[(submineage, idx)].add(vk) 

420 for (_, idx), vks in min_namespaced.items(): 

421 vk, = vks 

422 self._name_collision_varkeys[vk] = idx 

423 if clear: 

424 self._lineageset = False 

425 for vk in self._name_collision_varkeys: 

426 del vk.descr["necessarylineage"] 

427 else: 

428 self._lineageset = True 

429 for vk, idx in self._name_collision_varkeys.items(): 

430 vk.descr["necessarylineage"] = idx 

431 

432 def __len__(self): 

433 try: 

434 return len(self["cost"]) 

435 except TypeError: 

436 return 1 

437 except KeyError: 

438 return 0 

439 

440 def __call__(self, posy): 

441 posy_subbed = self.subinto(posy) 

442 return getattr(posy_subbed, "c", posy_subbed) 

443 

444 def almost_equal(self, other, reltol=1e-3): 

445 "Checks for almost-equality between two solutions" 

446 svars, ovars = self["variables"], other["variables"] 

447 svks, ovks = set(svars), set(ovars) 

448 if svks != ovks: 

449 return False 

450 for key in svks: 

451 reldiff = np.max(abs(cast(np.divide, svars[key], ovars[key]) - 1)) 

452 if reldiff >= reltol: 

453 return False 

454 return True 

455 

456 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

457 def diff(self, other, showvars=None, *, 

458 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

459 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

460 sortmodelsbysenss=True, **tableargs): 

461 """Outputs differences between this solution and another 

462 

463 Arguments 

464 --------- 

465 other : solution or string 

466 strings will be treated as paths to pickled solutions 

467 senssdiff : boolean 

468 if True, show sensitivity differences 

469 sensstol : float 

470 the smallest sensitivity difference worth showing 

471 absdiff : boolean 

472 if True, show absolute differences 

473 abstol : float 

474 the smallest absolute difference worth showing 

475 reldiff : boolean 

476 if True, show relative differences 

477 reltol : float 

478 the smallest relative difference worth showing 

479 

480 Returns 

481 ------- 

482 str 

483 """ 

484 if sortmodelsbysenss: 

485 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

486 else: 

487 tableargs["sortmodelsbysenss"] = False 

488 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

489 "skipifempty": False}) 

490 if isinstance(other, Strings): 

491 if other[-4:] == ".pgz": 

492 other = SolutionArray.decompress_file(other) 

493 else: 

494 other = pickle.load(open(other, "rb")) 

495 svars, ovars = self["variables"], other["variables"] 

496 lines = ["Solution Diff", 

497 "=============", 

498 "(argument is the baseline solution)", ""] 

499 svks, ovks = set(svars), set(ovars) 

500 if showvars: 

501 lines[0] += " (for selected variables)" 

502 lines[1] += "=========================" 

503 showvars = self._parse_showvars(showvars) 

504 svks = {k for k in showvars if k in svars} 

505 ovks = {k for k in showvars if k in ovars} 

506 if constraintsdiff and other.modelstr and self.modelstr: 

507 if self.modelstr == other.modelstr: 

508 lines += ["** no constraint differences **", ""] 

509 else: 

510 cdiff = ["Constraint Differences", 

511 "**********************"] 

512 cdiff.extend(list(difflib.unified_diff( 

513 other.modelstr.split("\n"), self.modelstr.split("\n"), 

514 lineterm="", n=3))[2:]) 

515 cdiff += ["", "**********************", ""] 

516 lines += cdiff 

517 if svks - ovks: 

518 lines.append("Variable(s) of this solution" 

519 " which are not in the argument:") 

520 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

521 lines.append("") 

522 if ovks - svks: 

523 lines.append("Variable(s) of the argument" 

524 " which are not in this solution:") 

525 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

526 lines.append("") 

527 sharedvks = svks.intersection(ovks) 

528 if reldiff: 

529 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

530 for vk in sharedvks} 

531 lines += var_table(rel_diff, 

532 "Relative Differences |above %g%%|" % reltol, 

533 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

534 minval=reltol, printunits=False, **tableargs) 

535 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

536 lines.insert(-1, ("The largest is %+g%%." 

537 % unrolled_absmax(rel_diff.values()))) 

538 if absdiff: 

539 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

540 lines += var_table(abs_diff, 

541 "Absolute Differences |above %g|" % abstol, 

542 valfmt="%+.2g", vecfmt="%+8.2g", 

543 minval=abstol, **tableargs) 

544 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

545 lines.insert(-1, ("The largest is %+g." 

546 % unrolled_absmax(abs_diff.values()))) 

547 if senssdiff: 

548 ssenss = self["sensitivities"]["variables"] 

549 osenss = other["sensitivities"]["variables"] 

550 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

551 for vk in svks.intersection(ovks)} 

552 lines += var_table(senss_delta, 

553 "Sensitivity Differences |above %g|" % sensstol, 

554 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

555 minval=sensstol, printunits=False, **tableargs) 

556 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

557 lines.insert(-1, ("The largest is %+g." 

558 % unrolled_absmax(senss_delta.values()))) 

559 return "\n".join(lines) 

560 

561 def save(self, filename="solution.pkl", 

562 *, saveconstraints=True, **pickleargs): 

563 """Pickles the solution and saves it to a file. 

564 

565 Solution can then be loaded with e.g.: 

566 >>> import pickle 

567 >>> pickle.load(open("solution.pkl")) 

568 """ 

569 with SolSavingEnvironment(self, saveconstraints): 

570 pickle.dump(self, open(filename, "wb"), **pickleargs) 

571 

572 def save_compressed(self, filename="solution.pgz", 

573 *, saveconstraints=True, **cpickleargs): 

574 "Pickle a file and then compress it into a file with extension." 

575 with gzip.open(filename, "wb") as f: 

576 with SolSavingEnvironment(self, saveconstraints): 

577 pickled = pickle.dumps(self, **cpickleargs) 

578 f.write(pickletools.optimize(pickled)) 

579 

580 @staticmethod 

581 def decompress_file(file): 

582 "Load a gzip-compressed pickle file" 

583 with gzip.open(file, "rb") as f: 

584 return pickle.Unpickler(f).load() 

585 

586 def varnames(self, showvars, exclude): 

587 "Returns list of variables, optionally with minimal unique names" 

588 if showvars: 

589 showvars = self._parse_showvars(showvars) 

590 self.set_necessarylineage() 

591 names = {} 

592 for key in showvars or self["variables"]: 

593 for k in self["variables"].keymap[key]: 

594 names[k.str_without(exclude)] = k 

595 self.set_necessarylineage(clear=True) 

596 return names 

597 

598 def savemat(self, filename="solution.mat", *, showvars=None, 

599 excluded=("vec")): 

600 "Saves primal solution as matlab file" 

601 from scipy.io import savemat 

602 savemat(filename, 

603 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

604 for name, key in self.varnames(showvars, excluded).items()}) 

605 

606 def todataframe(self, showvars=None, excluded=("vec")): 

607 "Returns primal solution as pandas dataframe" 

608 import pandas as pd # pylint:disable=import-error 

609 rows = [] 

610 cols = ["Name", "Index", "Value", "Units", "Label", 

611 "Lineage", "Other"] 

612 for _, key in sorted(self.varnames(showvars, excluded).items(), 

613 key=lambda k: k[0]): 

614 value = self["variables"][key] 

615 if key.shape: 

616 idxs = [] 

617 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

618 while not it.finished: 

619 idx = it.multi_index 

620 idxs.append(idx[0] if len(idx) == 1 else idx) 

621 it.iternext() 

622 else: 

623 idxs = [None] 

624 for idx in idxs: 

625 row = [ 

626 key.name, 

627 "" if idx is None else idx, 

628 value if idx is None else value[idx]] 

629 rows.append(row) 

630 row.extend([ 

631 key.unitstr(), 

632 key.label or "", 

633 key.lineage or "", 

634 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

635 if k not in ["name", "units", "unitrepr", 

636 "idx", "shape", "veckey", 

637 "value", "vecfn", 

638 "lineage", "label"])]) 

639 return pd.DataFrame(rows, columns=cols) 

640 

641 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

642 "Saves solution table as a text file" 

643 with open(filename, "w") as f: 

644 if printmodel: 

645 f.write(self.modelstr + "\n") 

646 f.write(self.table(**kwargs)) 

647 

648 def savejson(self, filename="solution.json", showvars=None): 

649 "Saves solution table as a json file" 

650 sol_dict = {} 

651 if self._lineageset: 

652 self.set_necessarylineage(clear=True) 

653 data = self["variables"] 

654 if showvars: 

655 showvars = self._parse_showvars(showvars) 

656 data = {k: data[k] for k in showvars if k in data} 

657 # add appropriate data for each variable to the dictionary 

658 for k, v in data.items(): 

659 key = str(k) 

660 if isinstance(v, np.ndarray): 

661 val = {"v": v.tolist(), "u": k.unitstr()} 

662 else: 

663 val = {"v": v, "u": k.unitstr()} 

664 sol_dict[key] = val 

665 with open(filename, "w") as f: 

666 json.dump(sol_dict, f) 

667 

668 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

669 "Saves primal solution as a CSV sorted by modelname, like the tables." 

670 data = self["variables"] 

671 if showvars: 

672 showvars = self._parse_showvars(showvars) 

673 data = {k: data[k] for k in showvars if k in data} 

674 # if the columns don't capture any dimensions, skip them 

675 minspan, maxspan = None, 1 

676 for v in data.values(): 

677 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

678 minspan_ = min((di for di in v.shape if di != 1)) 

679 maxspan_ = max((di for di in v.shape if di != 1)) 

680 if minspan is None or minspan_ < minspan: 

681 minspan = minspan_ 

682 if maxspan is None or maxspan_ > maxspan: 

683 maxspan = maxspan_ 

684 if minspan is not None and minspan > valcols: 

685 valcols = 1 

686 if maxspan < valcols: 

687 valcols = maxspan 

688 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

689 tables=("cost", "sweepvariables", "freevariables", 

690 "constants", "sensitivities")) 

691 with open(filename, "w") as f: 

692 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

693 + "Units,Description\n") 

694 for line in lines: 

695 if line[0] == ("newmodelline",): 

696 f.write(line[1]) 

697 elif not line[1]: # spacer line 

698 f.write("\n") 

699 else: 

700 f.write("," + line[0].replace(" : ", "") + ",") 

701 vals = line[1].replace("[", "").replace("]", "").strip() 

702 for el in vals.split(): 

703 f.write(el + ",") 

704 f.write(","*(valcols - len(vals.split()))) 

705 f.write((line[2].replace("[", "").replace("]", "").strip() 

706 + ",")) 

707 f.write(line[3].strip() + "\n") 

708 

709 def subinto(self, posy): 

710 "Returns NomialArray of each solution substituted into posy." 

711 if posy in self["variables"]: 

712 return self["variables"](posy) 

713 

714 if not hasattr(posy, "sub"): 

715 raise ValueError("no variable '%s' found in the solution" % posy) 

716 

717 if len(self) > 1: 

718 return NomialArray([self.atindex(i).subinto(posy) 

719 for i in range(len(self))]) 

720 

721 return posy.sub(self["variables"], require_positive=False) 

722 

723 def _parse_showvars(self, showvars): 

724 showvars_out = set() 

725 for k in showvars: 

726 k, _ = self["variables"].parse_and_index(k) 

727 keys = self["variables"].keymap[k] 

728 showvars_out.update(keys) 

729 return showvars_out 

730 

731 def summary(self, showvars=(), **kwargs): 

732 "Print summary table, showing no sensitivities or constants" 

733 return self.table(showvars, 

734 ["cost breakdown", "model sensitivities breakdown", 

735 "warnings", "sweepvariables", "freevariables"], 

736 **kwargs) 

737 

738 def table(self, showvars=(), 

739 tables=("cost breakdown", "model sensitivities breakdown", 

740 "warnings", "sweepvariables", "freevariables", 

741 "constants", "sensitivities", "tightest constraints"), 

742 sortmodelsbysenss=False, **kwargs): 

743 """A table representation of this SolutionArray 

744 

745 Arguments 

746 --------- 

747 tables: Iterable 

748 Which to print of ("cost", "sweepvariables", "freevariables", 

749 "constants", "sensitivities") 

750 fixedcols: If true, print vectors in fixed-width format 

751 latex: int 

752 If > 0, return latex format (options 1-3); otherwise plain text 

753 included_models: Iterable of strings 

754 If specified, the models (by name) to include 

755 excluded_models: Iterable of strings 

756 If specified, model names to exclude 

757 

758 Returns 

759 ------- 

760 str 

761 """ 

762 if sortmodelsbysenss and "sensitivities" in self: 

763 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

764 else: 

765 kwargs["sortmodelsbysenss"] = False 

766 varlist = list(self["variables"]) 

767 has_only_one_model = True 

768 for var in varlist[1:]: 

769 if var.lineage != varlist[0].lineage: 

770 has_only_one_model = False 

771 break 

772 if has_only_one_model: 

773 kwargs["sortbymodel"] = False 

774 self.set_necessarylineage() 

775 showvars = self._parse_showvars(showvars) 

776 strs = [] 

777 for table in tables: 

778 if "breakdown" in table: 

779 if (len(self) > 1 or not UNICODE_EXPONENTS 

780 or "sensitivities" not in self): 

781 # no breakdowns for sweeps or no-unicode environments 

782 table = table.replace(" breakdown", "") 

783 if "sensitivities" not in self and ("sensitivities" in table or 

784 "constraints" in table): 

785 continue 

786 if table == "cost": 

787 cost = self["cost"] # pylint: disable=unsubscriptable-object 

788 if kwargs.get("latex", None): # cost is not printed for latex 

789 continue 

790 strs += ["\n%s\n------------" % "Optimal Cost"] 

791 if len(self) > 1: 

792 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

793 strs += [" [ %s %s ]" % (" ".join(costs), 

794 "..." if len(self) > 4 else "")] 

795 else: 

796 strs += [" %-.4g" % mag(cost)] 

797 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

798 strs += [""] 

799 elif table in TABLEFNS: 

800 strs += TABLEFNS[table](self, showvars, **kwargs) 

801 elif table in self: 

802 data = self[table] 

803 if showvars: 

804 showvars = self._parse_showvars(showvars) 

805 data = {k: data[k] for k in showvars if k in data} 

806 strs += var_table(data, self.table_titles[table], **kwargs) 

807 if kwargs.get("latex", None): 

808 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

809 "% \\usepackage{booktabs}", 

810 "% \\usepackage{longtable}", 

811 "% \\usepackage{amsmath}", 

812 "% \\begin{document}\n")) 

813 strs = [preamble] + strs + ["% \\end{document}"] 

814 self.set_necessarylineage(clear=True) 

815 return "\n".join(strs) 

816 

817 def plot(self, posys=None, axes=None): 

818 "Plots a sweep for each posy" 

819 if len(self["sweepvariables"]) != 1: 

820 print("SolutionArray.plot only supports 1-dimensional sweeps") 

821 if not hasattr(posys, "__len__"): 

822 posys = [posys] 

823 import matplotlib.pyplot as plt 

824 from .interactive.plot_sweep import assign_axes 

825 from . import GPBLU 

826 (swept, x), = self["sweepvariables"].items() 

827 posys, axes = assign_axes(swept, posys, axes) 

828 for posy, ax in zip(posys, axes): 

829 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

830 ax.plot(x, y, color=GPBLU) 

831 if len(axes) == 1: 

832 axes, = axes 

833 return plt.gcf(), axes 

834 

835 

836# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

837def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

838 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

839 minval=0, sortbyvals=False, hidebelowminval=False, 

840 included_models=None, excluded_models=None, sortbymodel=True, 

841 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

842 """ 

843 Pretty string representation of a dict of VarKeys 

844 Iterable values are handled specially (partial printing) 

845 

846 Arguments 

847 --------- 

848 data : dict whose keys are VarKey's 

849 data to represent in table 

850 title : string 

851 printunits : bool 

852 latex : int 

853 If > 0, return latex format (options 1-3); otherwise plain text 

854 varfmt : string 

855 format for variable names 

856 valfmt : string 

857 format for scalar values 

858 vecfmt : string 

859 format for vector values 

860 minval : float 

861 skip values with all(abs(value)) < minval 

862 sortbyvals : boolean 

863 If true, rows are sorted by their average value instead of by name. 

864 included_models : Iterable of strings 

865 If specified, the models (by name) to include 

866 excluded_models : Iterable of strings 

867 If specified, model names to exclude 

868 """ 

869 if not data: 

870 return [] 

871 decorated, models = [], set() 

872 for i, (k, v) in enumerate(data.items()): 

873 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

874 continue # no values below minval 

875 if minval and hidebelowminval and getattr(v, "shape", None): 

876 v[np.abs(v) <= minval] = np.nan 

877 model = lineagestr(k.lineage) if sortbymodel else "" 

878 if not sortmodelsbysenss: 

879 msenss = 0 

880 else: # sort should match that in msenss_table above 

881 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4) 

882 models.add(model) 

883 b = bool(getattr(v, "shape", None)) 

884 s = k.str_without(("lineage", "vec")) 

885 if not sortbyvals: 

886 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

887 else: # for consistent sorting, add small offset to negative vals 

888 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

889 sort = (float("%.4g" % -val), k.name) 

890 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

891 if not decorated and skipifempty: 

892 return [] 

893 if included_models: 

894 included_models = set(included_models) 

895 included_models.add("") 

896 models = models.intersection(included_models) 

897 if excluded_models: 

898 models = models.difference(excluded_models) 

899 decorated.sort() 

900 previous_model, lines = None, [] 

901 for varlist in decorated: 

902 if sortbyvals: 

903 model, _, msenss, isvector, varstr, _, var, val = varlist 

904 else: 

905 msenss, model, isvector, varstr, _, var, val = varlist 

906 if model not in models: 

907 continue 

908 if model != previous_model: 

909 if lines: 

910 lines.append(["", "", "", ""]) 

911 if model: 

912 if not latex: 

913 lines.append([("newmodelline",), model, "", ""]) 

914 else: 

915 lines.append( 

916 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

917 previous_model = model 

918 label = var.descr.get("label", "") 

919 units = var.unitstr(" [%s] ") if printunits else "" 

920 if not isvector: 

921 valstr = valfmt % val 

922 else: 

923 last_dim_index = len(val.shape)-1 

924 horiz_dim, ncols = last_dim_index, 1 # starting values 

925 for dim_idx, dim_size in enumerate(val.shape): 

926 if ncols <= dim_size <= maxcolumns: 

927 horiz_dim, ncols = dim_idx, dim_size 

928 # align the array with horiz_dim by making it the last one 

929 dim_order = list(range(last_dim_index)) 

930 dim_order.insert(horiz_dim, last_dim_index) 

931 flatval = val.transpose(dim_order).flatten() 

932 vals = [vecfmt % v for v in flatval[:ncols]] 

933 bracket = " ] " if len(flatval) <= ncols else "" 

934 valstr = "[ %s%s" % (" ".join(vals), bracket) 

935 for before, after in VALSTR_REPLACES: 

936 valstr = valstr.replace(before, after) 

937 if not latex: 

938 lines.append([varstr, valstr, units, label]) 

939 if isvector and len(flatval) > ncols: 

940 values_remaining = len(flatval) - ncols 

941 while values_remaining > 0: 

942 idx = len(flatval)-values_remaining 

943 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

944 values_remaining -= ncols 

945 valstr = " " + " ".join(vals) 

946 for before, after in VALSTR_REPLACES: 

947 valstr = valstr.replace(before, after) 

948 if values_remaining <= 0: 

949 spaces = (-values_remaining 

950 * len(valstr)//(values_remaining + ncols)) 

951 valstr = valstr + " ]" + " "*spaces 

952 lines.append(["", valstr, "", ""]) 

953 else: 

954 varstr = "$%s$" % varstr.replace(" : ", "") 

955 if latex == 1: # normal results table 

956 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

957 label]) 

958 coltitles = [title, "Value", "Units", "Description"] 

959 elif latex == 2: # no values 

960 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

961 coltitles = [title, "Units", "Description"] 

962 elif latex == 3: # no description 

963 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

964 coltitles = [title, "Value", "Units"] 

965 else: 

966 raise ValueError("Unexpected latex option, %s." % latex) 

967 if rawlines: 

968 return lines 

969 if not latex: 

970 if lines: 

971 maxlens = np.max([list(map(len, line)) for line in lines 

972 if line[0] != ("newmodelline",)], axis=0) 

973 dirs = [">", "<", "<", "<"] 

974 # check lengths before using zip 

975 assert len(list(dirs)) == len(list(maxlens)) 

976 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

977 for i, line in enumerate(lines): 

978 if line[0] == ("newmodelline",): 

979 line = [fmts[0].format(" | "), line[1]] 

980 else: 

981 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

982 lines[i] = "".join(line).rstrip() 

983 lines = [title] + ["-"*len(title)] + lines + [""] 

984 else: 

985 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

986 lines = (["\n".join(["{\\footnotesize", 

987 "\\begin{longtable}{%s}" % colfmt[latex], 

988 "\\toprule", 

989 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

990 [" & ".join(l) + " \\\\" for l in lines] + 

991 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

992 return lines