Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import sys 

3import re 

4import json 

5import difflib 

6from operator import sub 

7import warnings as pywarnings 

8import pickle 

9import gzip 

10import pickletools 

11from collections import defaultdict 

12import numpy as np 

13from .nomials import NomialArray 

14from .small_classes import DictOfLists, Strings, SolverLog 

15from .small_scripts import mag, try_str_without 

16from .repr_conventions import unitstr, lineagestr 

17from .breakdowns import Breakdowns 

18 

19 

20CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

21 

22VALSTR_REPLACES = [ 

23 ("+nan", " nan"), 

24 ("-nan", " nan"), 

25 ("nan%", "nan "), 

26 ("nan", " - "), 

27] 

28 

29 

30class SolSavingEnvironment: 

31 """Temporarily removes construction/solve attributes from constraints. 

32 

33 This approximately halves the size of the pickled solution. 

34 """ 

35 

36 def __init__(self, solarray, saveconstraints): 

37 self.solarray = solarray 

38 self.attrstore = {} 

39 self.saveconstraints = saveconstraints 

40 self.constraintstore = None 

41 

42 

43 def __enter__(self): 

44 if self.saveconstraints: 

45 for constraint_attr in ["bounded", "meq_bounded", "vks", 

46 "v_ss", "unsubbed", "varkeys"]: 

47 store = {} 

48 for constraint in self.solarray["sensitivities"]["constraints"]: 

49 if getattr(constraint, constraint_attr, None): 

50 store[constraint] = getattr(constraint, constraint_attr) 

51 delattr(constraint, constraint_attr) 

52 self.attrstore[constraint_attr] = store 

53 else: 

54 self.constraintstore = \ 

55 self.solarray["sensitivities"].pop("constraints") 

56 

57 def __exit__(self, type_, val, traceback): 

58 if self.saveconstraints: 

59 for constraint_attr, store in self.attrstore.items(): 

60 for constraint, value in store.items(): 

61 setattr(constraint, constraint_attr, value) 

62 else: 

63 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

64 

65def msenss_table(data, _, **kwargs): 

66 "Returns model sensitivity table lines" 

67 if "models" not in data.get("sensitivities", {}): 

68 return "" 

69 data = sorted(data["sensitivities"]["models"].items(), 

70 key=lambda i: ((i[1] < 0.1).all(), 

71 -np.max(i[1]) if (i[1] < 0.1).all() 

72 else -round(np.mean(i[1]), 1), i[0])) 

73 lines = ["Model Sensitivities", "-------------------"] 

74 if kwargs["sortmodelsbysenss"]: 

75 lines[0] += " (sorts models in sections below)" 

76 previousmsenssstr = "" 

77 for model, msenss in data: 

78 if not model: # for now let's only do named models 

79 continue 

80 if (msenss < 0.1).all(): 

81 msenss = np.max(msenss) 

82 if msenss: 

83 msenssstr = "%6s" % ("<1e%i" % max(-3, np.log10(msenss))) 

84 else: 

85 msenssstr = " =0 " 

86 else: 

87 meansenss = round(np.mean(msenss), 1) 

88 msenssstr = "%+6.1f" % meansenss 

89 deltas = msenss - meansenss 

90 if np.max(np.abs(deltas)) > 0.1: 

91 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

92 for d in deltas] 

93 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

94 if msenssstr == previousmsenssstr: 

95 msenssstr = " "*len(msenssstr) 

96 else: 

97 previousmsenssstr = msenssstr 

98 lines.append("%s : %s" % (msenssstr, model)) 

99 return lines + [""] if len(lines) > 3 else [] 

100 

101 

102def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

103 "Returns sensitivity table lines" 

104 if "variables" in data.get("sensitivities", {}): 

105 data = data["sensitivities"]["variables"] 

106 if showvars: 

107 data = {k: data[k] for k in showvars if k in data} 

108 return var_table(data, title, sortbyvals=True, skipifempty=True, 

109 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

110 printunits=False, minval=1e-3, **kwargs) 

111 

112 

113def topsenss_table(data, showvars, nvars=5, **kwargs): 

114 "Returns top sensitivity table lines" 

115 data, filtered = topsenss_filter(data, showvars, nvars) 

116 title = "Most Sensitive Variables" 

117 if filtered: 

118 title = "Next Most Sensitive Variables" 

119 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

120 

121 

122def topsenss_filter(data, showvars, nvars=5): 

123 "Filters sensitivities down to top N vars" 

124 if "variables" in data.get("sensitivities", {}): 

125 data = data["sensitivities"]["variables"] 

126 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

127 if not np.isnan(s).any()} 

128 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

129 filter_already_shown = showvars.intersection(topk) 

130 for k in filter_already_shown: 

131 topk.remove(k) 

132 if nvars > 3: # always show at least 3 

133 nvars -= 1 

134 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

135 

136 

137def insenss_table(data, _, maxval=0.1, **kwargs): 

138 "Returns insensitivity table lines" 

139 if "constants" in data.get("sensitivities", {}): 

140 data = data["sensitivities"]["variables"] 

141 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

142 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

143 

144 

145def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

146 "Return constraint tightness lines" 

147 title = "Most Sensitive Constraints" 

148 if len(self) > 1: 

149 title += " (in last sweep)" 

150 data = sorted(((-float("%+6.2g" % abs(s[-1])), str(c)), 

151 "%+6.2g" % abs(s[-1]), id(c), c) 

152 for c, s in self["sensitivities"]["constraints"].items() 

153 if s[-1] >= tight_senss)[:ntightconstrs] 

154 else: 

155 data = sorted(((-float("%+6.2g" % abs(s)), str(c)), 

156 "%+6.2g" % abs(s), id(c), c) 

157 for c, s in self["sensitivities"]["constraints"].items() 

158 if s >= tight_senss)[:ntightconstrs] 

159 return constraint_table(data, title, **kwargs) 

160 

161def loose_table(self, _, min_senss=1e-5, **kwargs): 

162 "Return constraint tightness lines" 

163 title = "Insensitive Constraints |below %+g|" % min_senss 

164 if len(self) > 1: 

165 title += " (in last sweep)" 

166 data = [(0, "", id(c), c) 

167 for c, s in self["sensitivities"]["constraints"].items() 

168 if s[-1] <= min_senss] 

169 else: 

170 data = [(0, "", id(c), c) 

171 for c, s in self["sensitivities"]["constraints"].items() 

172 if s <= min_senss] 

173 return constraint_table(data, title, **kwargs) 

174 

175 

176# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

177def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

178 "Creates lines for tables where the right side is a constraint." 

179 # TODO: this should support 1D array inputs from sweeps 

180 excluded = {"units"} if showmodels else {"units", "lineage"} 

181 models, decorated = {}, [] 

182 for sortby, openingstr, _, constraint in sorted(data): 

183 model = lineagestr(constraint) if sortbymodel else "" 

184 if model not in models: 

185 models[model] = len(models) 

186 constrstr = try_str_without( 

187 constraint, {":MAGIC:"+lineagestr(constraint)}.union(excluded)) 

188 if " at 0x" in constrstr: # don't print memory addresses 

189 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

190 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

191 decorated.sort() 

192 previous_model, lines = None, [] 

193 for varlist in decorated: 

194 _, model, _, constrstr, openingstr = varlist 

195 if model != previous_model: 

196 if lines: 

197 lines.append(["", ""]) 

198 if model or lines: 

199 lines.append([("newmodelline",), model]) 

200 previous_model = model 

201 minlen, maxlen = 25, 80 

202 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

203 constraintlines = [] 

204 line = "" 

205 next_idx = 0 

206 while next_idx < len(segments): 

207 segment = segments[next_idx] 

208 next_idx += 1 

209 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

210 segments[next_idx] = segment[1:] + segments[next_idx] 

211 segment = segment[0] 

212 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

213 constraintlines.append(line) 

214 line = " " # start a new line 

215 line += segment 

216 while len(line) > maxlen: 

217 constraintlines.append(line[:maxlen]) 

218 line = " " + line[maxlen:] 

219 constraintlines.append(line) 

220 lines += [(openingstr + " : ", constraintlines[0])] 

221 lines += [("", l) for l in constraintlines[1:]] 

222 if not lines: 

223 lines = [("", "(none)")] 

224 maxlens = np.max([list(map(len, line)) for line in lines 

225 if line[0] != ("newmodelline",)], axis=0) 

226 dirs = [">", "<"] # we'll check lengths before using zip 

227 assert len(list(dirs)) == len(list(maxlens)) 

228 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

229 for i, line in enumerate(lines): 

230 if line[0] == ("newmodelline",): 

231 linelist = [fmts[0].format(" | "), line[1]] 

232 else: 

233 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

234 lines[i] = "".join(linelist).rstrip() 

235 return [title] + ["-"*len(title)] + lines + [""] 

236 

237 

238def warnings_table(self, _, **kwargs): 

239 "Makes a table for all warnings in the solution." 

240 title = "WARNINGS" 

241 lines = ["~"*len(title), title, "~"*len(title)] 

242 if "warnings" not in self or not self["warnings"]: 

243 return [] 

244 for wtype in sorted(self["warnings"]): 

245 data_vec = self["warnings"][wtype] 

246 if len(data_vec) == 0: 

247 continue 

248 if not hasattr(data_vec, "shape"): 

249 data_vec = [data_vec] # not a sweep 

250 else: 

251 all_equal = True 

252 for data in data_vec[1:]: 

253 eq_i = (data == data_vec[0]) 

254 if hasattr(eq_i, "all"): 

255 eq_i = eq_i.all() 

256 if not eq_i: 

257 all_equal = False 

258 break 

259 if all_equal: 

260 data_vec = [data_vec[0]] # warnings identical across sweeps 

261 for i, data in enumerate(data_vec): 

262 if len(data) == 0: 

263 continue 

264 data = sorted(data, key=lambda l: l[0]) # sort by msg 

265 title = wtype 

266 if len(data_vec) > 1: 

267 title += " in sweep %i" % i 

268 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

269 data = [(-int(1e5*relax_sensitivity), 

270 "%+6.2g" % relax_sensitivity, id(c), c) 

271 for _, (relax_sensitivity, c) in data] 

272 lines += constraint_table(data, title, **kwargs) 

273 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

274 data = [(-int(1e5*rel_diff), 

275 "%.4g %s %.4g" % tightvalues, id(c), c) 

276 for _, (rel_diff, tightvalues, c) in data] 

277 lines += constraint_table(data, title, **kwargs) 

278 else: 

279 lines += [title] + ["-"*len(wtype)] 

280 lines += [msg for msg, _ in data] + [""] 

281 if len(lines) == 3: # just the header 

282 return [] 

283 lines[-1] = "~~~~~~~~" 

284 return lines + [""] 

285 

286# TODO: deduplicate these two functions 

287def bdtable_gen(key): 

288 "Generator for breakdown tablefns" 

289 

290 def bdtable(self, _showvars, **_): 

291 "Cost breakdown plot" 

292 bds = Breakdowns(self) 

293 original_stdout = sys.stdout 

294 try: 

295 sys.stdout = SolverLog(original_stdout, verbosity=0) 

296 bds.plot(key) 

297 finally: 

298 lines = sys.stdout.lines() 

299 sys.stdout = original_stdout 

300 return lines 

301 

302 return bdtable 

303 

304 

305TABLEFNS = {"sensitivities": senss_table, 

306 "top sensitivities": topsenss_table, 

307 "insensitivities": insenss_table, 

308 "model sensitivities": msenss_table, 

309 "tightest constraints": tight_table, 

310 "loose constraints": loose_table, 

311 "warnings": warnings_table, 

312 "model sensitivities breakdown": bdtable_gen("model sensitivities"), 

313 "cost breakdown": bdtable_gen("cost") 

314 } 

315 

316def unrolled_absmax(values): 

317 "From an iterable of numbers and arrays, returns the largest magnitude" 

318 finalval, absmaxest = None, 0 

319 for val in values: 

320 absmaxval = np.abs(val).max() 

321 if absmaxval >= absmaxest: 

322 absmaxest, finalval = absmaxval, val 

323 if getattr(finalval, "shape", None): 

324 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

325 finalval.shape)] 

326 return finalval 

327 

328 

329def cast(function, val1, val2): 

330 "Relative difference between val1 and val2 (positive if val2 is larger)" 

331 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

332 pywarnings.simplefilter("ignore") 

333 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

334 if val1.ndim == val2.ndim: 

335 return function(val1, val2) 

336 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

337 dimdelta = dimmest.ndim - lessdim.ndim 

338 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

339 if dimmest is val1: 

340 return function(dimmest, lessdim[add_axes]) 

341 if dimmest is val2: 

342 return function(lessdim[add_axes], dimmest) 

343 return function(val1, val2) 

344 

345 

346class SolutionArray(DictOfLists): 

347 """A dictionary (of dictionaries) of lists, with convenience methods. 

348 

349 Items 

350 ----- 

351 cost : array 

352 variables: dict of arrays 

353 sensitivities: dict containing: 

354 monomials : array 

355 posynomials : array 

356 variables: dict of arrays 

357 localmodels : NomialArray 

358 Local power-law fits (small sensitivities are cut off) 

359 

360 Example 

361 ------- 

362 >>> import gpkit 

363 >>> import numpy as np 

364 >>> x = gpkit.Variable("x") 

365 >>> x_min = gpkit.Variable("x_{min}", 2) 

366 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

367 >>> 

368 >>> # VALUES 

369 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

370 >>> assert all(np.array(values) == 2) 

371 >>> 

372 >>> # SENSITIVITIES 

373 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

374 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

375 >>> assert all(np.array(senss) == 1) 

376 """ 

377 modelstr = "" 

378 _name_collision_varkeys = None 

379 _lineageset = False 

380 table_titles = {"choicevariables": "Choice Variables", 

381 "sweepvariables": "Swept Variables", 

382 "freevariables": "Free Variables", 

383 "constants": "Fixed Variables", # TODO: change everywhere 

384 "variables": "Variables"} 

385 

386 def set_necessarylineage(self, clear=False): 

387 "Returns the set of contained varkeys whose names are not unique" 

388 if self._name_collision_varkeys is None: 

389 self._name_collision_varkeys = {} 

390 self["variables"].update_keymap() 

391 keymap = self["variables"].keymap 

392 name_collisions = defaultdict(set) 

393 for key in keymap: 

394 if hasattr(key, "key"): 

395 if len(keymap[key.name]) == 1: # very unique 

396 self._name_collision_varkeys[key] = 0 

397 else: 

398 shortname = key.str_without(["lineage", "vec"]) 

399 if len(keymap[shortname]) > 1: 

400 name_collisions[shortname].add(key) 

401 for varkeys in name_collisions.values(): 

402 min_namespaced = defaultdict(set) 

403 for vk in varkeys: 

404 *_, mineage = vk.lineagestr().split(".") 

405 min_namespaced[(mineage, 1)].add(vk) 

406 while any(len(vks) > 1 for vks in min_namespaced.values()): 

407 for key, vks in list(min_namespaced.items()): 

408 if len(vks) <= 1: 

409 continue 

410 del min_namespaced[key] 

411 mineage, idx = key 

412 idx += 1 

413 for vk in vks: 

414 lineages = vk.lineagestr().split(".") 

415 submineage = lineages[-idx] + "." + mineage 

416 min_namespaced[(submineage, idx)].add(vk) 

417 for (_, idx), vks in min_namespaced.items(): 

418 vk, = vks 

419 self._name_collision_varkeys[vk] = idx 

420 if clear: 

421 self._lineageset = False 

422 for vk in self._name_collision_varkeys: 

423 del vk.descr["necessarylineage"] 

424 else: 

425 self._lineageset = True 

426 for vk, idx in self._name_collision_varkeys.items(): 

427 vk.descr["necessarylineage"] = idx 

428 

429 def __len__(self): 

430 try: 

431 return len(self["cost"]) 

432 except TypeError: 

433 return 1 

434 except KeyError: 

435 return 0 

436 

437 def __call__(self, posy): 

438 posy_subbed = self.subinto(posy) 

439 return getattr(posy_subbed, "c", posy_subbed) 

440 

441 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

442 "Checks for almost-equality between two solutions" 

443 svars, ovars = self["variables"], other["variables"] 

444 svks, ovks = set(svars), set(ovars) 

445 if svks != ovks: 

446 return False 

447 for key in svks: 

448 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

449 return False 

450 if abs(self["sensitivities"]["variables"][key] 

451 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

452 return False 

453 return True 

454 

455 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

456 def diff(self, other, showvars=None, *, 

457 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

458 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

459 sortmodelsbysenss=True, **tableargs): 

460 """Outputs differences between this solution and another 

461 

462 Arguments 

463 --------- 

464 other : solution or string 

465 strings will be treated as paths to pickled solutions 

466 senssdiff : boolean 

467 if True, show sensitivity differences 

468 sensstol : float 

469 the smallest sensitivity difference worth showing 

470 absdiff : boolean 

471 if True, show absolute differences 

472 abstol : float 

473 the smallest absolute difference worth showing 

474 reldiff : boolean 

475 if True, show relative differences 

476 reltol : float 

477 the smallest relative difference worth showing 

478 

479 Returns 

480 ------- 

481 str 

482 """ 

483 if sortmodelsbysenss: 

484 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

485 else: 

486 tableargs["sortmodelsbysenss"] = False 

487 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

488 "skipifempty": False}) 

489 if isinstance(other, Strings): 

490 if other[-4:] == ".pgz": 

491 other = SolutionArray.decompress_file(other) 

492 else: 

493 other = pickle.load(open(other, "rb")) 

494 svars, ovars = self["variables"], other["variables"] 

495 lines = ["Solution Diff", 

496 "=============", 

497 "(argument is the baseline solution)", ""] 

498 svks, ovks = set(svars), set(ovars) 

499 if showvars: 

500 lines[0] += " (for selected variables)" 

501 lines[1] += "=========================" 

502 showvars = self._parse_showvars(showvars) 

503 svks = {k for k in showvars if k in svars} 

504 ovks = {k for k in showvars if k in ovars} 

505 if constraintsdiff and other.modelstr and self.modelstr: 

506 if self.modelstr == other.modelstr: 

507 lines += ["** no constraint differences **", ""] 

508 else: 

509 cdiff = ["Constraint Differences", 

510 "**********************"] 

511 cdiff.extend(list(difflib.unified_diff( 

512 other.modelstr.split("\n"), self.modelstr.split("\n"), 

513 lineterm="", n=3))[2:]) 

514 cdiff += ["", "**********************", ""] 

515 lines += cdiff 

516 if svks - ovks: 

517 lines.append("Variable(s) of this solution" 

518 " which are not in the argument:") 

519 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

520 lines.append("") 

521 if ovks - svks: 

522 lines.append("Variable(s) of the argument" 

523 " which are not in this solution:") 

524 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

525 lines.append("") 

526 sharedvks = svks.intersection(ovks) 

527 if reldiff: 

528 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

529 for vk in sharedvks} 

530 lines += var_table(rel_diff, 

531 "Relative Differences |above %g%%|" % reltol, 

532 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

533 minval=reltol, printunits=False, **tableargs) 

534 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

535 lines.insert(-1, ("The largest is %+g%%." 

536 % unrolled_absmax(rel_diff.values()))) 

537 if absdiff: 

538 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

539 lines += var_table(abs_diff, 

540 "Absolute Differences |above %g|" % abstol, 

541 valfmt="%+.2g", vecfmt="%+8.2g", 

542 minval=abstol, **tableargs) 

543 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

544 lines.insert(-1, ("The largest is %+g." 

545 % unrolled_absmax(abs_diff.values()))) 

546 if senssdiff: 

547 ssenss = self["sensitivities"]["variables"] 

548 osenss = other["sensitivities"]["variables"] 

549 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

550 for vk in svks.intersection(ovks)} 

551 lines += var_table(senss_delta, 

552 "Sensitivity Differences |above %g|" % sensstol, 

553 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

554 minval=sensstol, printunits=False, **tableargs) 

555 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

556 lines.insert(-1, ("The largest is %+g." 

557 % unrolled_absmax(senss_delta.values()))) 

558 return "\n".join(lines) 

559 

560 def save(self, filename="solution.pkl", 

561 *, saveconstraints=True, **pickleargs): 

562 """Pickles the solution and saves it to a file. 

563 

564 Solution can then be loaded with e.g.: 

565 >>> import pickle 

566 >>> pickle.load(open("solution.pkl")) 

567 """ 

568 with SolSavingEnvironment(self, saveconstraints): 

569 pickle.dump(self, open(filename, "wb"), **pickleargs) 

570 

571 def save_compressed(self, filename="solution.pgz", 

572 *, saveconstraints=True, **cpickleargs): 

573 "Pickle a file and then compress it into a file with extension." 

574 with gzip.open(filename, "wb") as f: 

575 with SolSavingEnvironment(self, saveconstraints): 

576 pickled = pickle.dumps(self, **cpickleargs) 

577 f.write(pickletools.optimize(pickled)) 

578 

579 @staticmethod 

580 def decompress_file(file): 

581 "Load a gzip-compressed pickle file" 

582 with gzip.open(file, "rb") as f: 

583 return pickle.Unpickler(f).load() 

584 

585 def varnames(self, showvars, exclude): 

586 "Returns list of variables, optionally with minimal unique names" 

587 if showvars: 

588 showvars = self._parse_showvars(showvars) 

589 self.set_necessarylineage() 

590 names = {} 

591 for key in showvars or self["variables"]: 

592 for k in self["variables"].keymap[key]: 

593 names[k.str_without(exclude)] = k 

594 self.set_necessarylineage(clear=True) 

595 return names 

596 

597 def savemat(self, filename="solution.mat", *, showvars=None, 

598 excluded=("vec")): 

599 "Saves primal solution as matlab file" 

600 from scipy.io import savemat 

601 savemat(filename, 

602 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

603 for name, key in self.varnames(showvars, excluded).items()}) 

604 

605 def todataframe(self, showvars=None, excluded=("vec")): 

606 "Returns primal solution as pandas dataframe" 

607 import pandas as pd # pylint:disable=import-error 

608 rows = [] 

609 cols = ["Name", "Index", "Value", "Units", "Label", 

610 "Lineage", "Other"] 

611 for _, key in sorted(self.varnames(showvars, excluded).items(), 

612 key=lambda k: k[0]): 

613 value = self["variables"][key] 

614 if key.shape: 

615 idxs = [] 

616 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

617 while not it.finished: 

618 idx = it.multi_index 

619 idxs.append(idx[0] if len(idx) == 1 else idx) 

620 it.iternext() 

621 else: 

622 idxs = [None] 

623 for idx in idxs: 

624 row = [ 

625 key.name, 

626 "" if idx is None else idx, 

627 value if idx is None else value[idx]] 

628 rows.append(row) 

629 row.extend([ 

630 key.unitstr(), 

631 key.label or "", 

632 key.lineage or "", 

633 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

634 if k not in ["name", "units", "unitrepr", 

635 "idx", "shape", "veckey", 

636 "value", "vecfn", 

637 "lineage", "label"])]) 

638 return pd.DataFrame(rows, columns=cols) 

639 

640 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

641 "Saves solution table as a text file" 

642 with open(filename, "w") as f: 

643 if printmodel: 

644 f.write(self.modelstr + "\n") 

645 f.write(self.table(**kwargs)) 

646 

647 def savejson(self, filename="solution.json", showvars=None): 

648 "Saves solution table as a json file" 

649 sol_dict = {} 

650 for key in self.name_collision_varkeys(): 

651 key.descr["necessarylineage"] = True 

652 data = self["variables"] 

653 if showvars: 

654 showvars = self._parse_showvars(showvars) 

655 data = {k: data[k] for k in showvars if k in data} 

656 # add appropriate data for each variable to the dictionary 

657 for k, v in data.items(): 

658 key = str(k) 

659 if isinstance(v, np.ndarray): 

660 val = {"v": v.tolist(), "u": k.unitstr()} 

661 else: 

662 val = {"v": v, "u": k.unitstr()} 

663 sol_dict[key] = val 

664 for key in self.name_collision_varkeys(): 

665 del key.descr["necessarylineage"] 

666 with open(filename, "w") as f: 

667 json.dump(sol_dict, f) 

668 

669 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

670 "Saves primal solution as a CSV sorted by modelname, like the tables." 

671 data = self["variables"] 

672 if showvars: 

673 showvars = self._parse_showvars(showvars) 

674 data = {k: data[k] for k in showvars if k in data} 

675 # if the columns don't capture any dimensions, skip them 

676 minspan, maxspan = None, 1 

677 for v in data.values(): 

678 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

679 minspan_ = min((di for di in v.shape if di != 1)) 

680 maxspan_ = max((di for di in v.shape if di != 1)) 

681 if minspan is None or minspan_ < minspan: 

682 minspan = minspan_ 

683 if maxspan is None or maxspan_ > maxspan: 

684 maxspan = maxspan_ 

685 if minspan is not None and minspan > valcols: 

686 valcols = 1 

687 if maxspan < valcols: 

688 valcols = maxspan 

689 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

690 tables=("cost", "sweepvariables", "freevariables", 

691 "constants", "sensitivities")) 

692 with open(filename, "w") as f: 

693 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

694 + "Units,Description\n") 

695 for line in lines: 

696 if line[0] == ("newmodelline",): 

697 f.write(line[1]) 

698 elif not line[1]: # spacer line 

699 f.write("\n") 

700 else: 

701 f.write("," + line[0].replace(" : ", "") + ",") 

702 vals = line[1].replace("[", "").replace("]", "").strip() 

703 for el in vals.split(): 

704 f.write(el + ",") 

705 f.write(","*(valcols - len(vals.split()))) 

706 f.write((line[2].replace("[", "").replace("]", "").strip() 

707 + ",")) 

708 f.write(line[3].strip() + "\n") 

709 

710 def subinto(self, posy): 

711 "Returns NomialArray of each solution substituted into posy." 

712 if posy in self["variables"]: 

713 return self["variables"](posy) 

714 

715 if not hasattr(posy, "sub"): 

716 raise ValueError("no variable '%s' found in the solution" % posy) 

717 

718 if len(self) > 1: 

719 return NomialArray([self.atindex(i).subinto(posy) 

720 for i in range(len(self))]) 

721 

722 return posy.sub(self["variables"], require_positive=False) 

723 

724 def _parse_showvars(self, showvars): 

725 showvars_out = set() 

726 for k in showvars: 

727 k, _ = self["variables"].parse_and_index(k) 

728 keys = self["variables"].keymap[k] 

729 showvars_out.update(keys) 

730 return showvars_out 

731 

732 def summary(self, showvars=(), **kwargs): 

733 "Print summary table, showing no sensitivities or constants" 

734 return self.table(showvars, 

735 ["cost breakdown", "model sensitivities breakdown", 

736 "warnings", "sweepvariables", "freevariables"], 

737 **kwargs) 

738 

739 def table(self, showvars=(), 

740 tables=("cost breakdown", "model sensitivities breakdown", 

741 "warnings", "sweepvariables", "freevariables", 

742 "constants", "sensitivities", "tightest constraints"), 

743 sortmodelsbysenss=False, **kwargs): 

744 """A table representation of this SolutionArray 

745 

746 Arguments 

747 --------- 

748 tables: Iterable 

749 Which to print of ("cost", "sweepvariables", "freevariables", 

750 "constants", "sensitivities") 

751 fixedcols: If true, print vectors in fixed-width format 

752 latex: int 

753 If > 0, return latex format (options 1-3); otherwise plain text 

754 included_models: Iterable of strings 

755 If specified, the models (by name) to include 

756 excluded_models: Iterable of strings 

757 If specified, model names to exclude 

758 

759 Returns 

760 ------- 

761 str 

762 """ 

763 if sortmodelsbysenss and "sensitivities" in self: 

764 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

765 else: 

766 kwargs["sortmodelsbysenss"] = False 

767 varlist = list(self["variables"]) 

768 has_only_one_model = True 

769 for var in varlist[1:]: 

770 if var.lineage != varlist[0].lineage: 

771 has_only_one_model = False 

772 break 

773 if has_only_one_model: 

774 kwargs["sortbymodel"] = False 

775 self.set_necessarylineage() 

776 showvars = self._parse_showvars(showvars) 

777 strs = [] 

778 for table in tables: 

779 if len(self) > 1 and "breakdown" in table: 

780 # no breakdowns for sweeps 

781 table = table.replace(" breakdown", "") 

782 if "sensitivities" not in self and ("sensitivities" in table or 

783 "constraints" in table): 

784 continue 

785 if table == "cost": 

786 cost = self["cost"] # pylint: disable=unsubscriptable-object 

787 if kwargs.get("latex", None): # cost is not printed for latex 

788 continue 

789 strs += ["\n%s\n------------" % "Optimal Cost"] 

790 if len(self) > 1: 

791 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

792 strs += [" [ %s %s ]" % (" ".join(costs), 

793 "..." if len(self) > 4 else "")] 

794 else: 

795 strs += [" %-.4g" % mag(cost)] 

796 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

797 strs += [""] 

798 elif table in TABLEFNS: 

799 strs += TABLEFNS[table](self, showvars, **kwargs) 

800 elif table in self: 

801 data = self[table] 

802 if showvars: 

803 showvars = self._parse_showvars(showvars) 

804 data = {k: data[k] for k in showvars if k in data} 

805 strs += var_table(data, self.table_titles[table], **kwargs) 

806 if kwargs.get("latex", None): 

807 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

808 "% \\usepackage{booktabs}", 

809 "% \\usepackage{longtable}", 

810 "% \\usepackage{amsmath}", 

811 "% \\begin{document}\n")) 

812 strs = [preamble] + strs + ["% \\end{document}"] 

813 self.set_necessarylineage(clear=True) 

814 return "\n".join(strs) 

815 

816 def plot(self, posys=None, axes=None): 

817 "Plots a sweep for each posy" 

818 if len(self["sweepvariables"]) != 1: 

819 print("SolutionArray.plot only supports 1-dimensional sweeps") 

820 if not hasattr(posys, "__len__"): 

821 posys = [posys] 

822 import matplotlib.pyplot as plt 

823 from .interactive.plot_sweep import assign_axes 

824 from . import GPBLU 

825 (swept, x), = self["sweepvariables"].items() 

826 posys, axes = assign_axes(swept, posys, axes) 

827 for posy, ax in zip(posys, axes): 

828 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

829 ax.plot(x, y, color=GPBLU) 

830 if len(axes) == 1: 

831 axes, = axes 

832 return plt.gcf(), axes 

833 

834 

835# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

836def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

837 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

838 minval=0, sortbyvals=False, hidebelowminval=False, 

839 included_models=None, excluded_models=None, sortbymodel=True, 

840 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

841 """ 

842 Pretty string representation of a dict of VarKeys 

843 Iterable values are handled specially (partial printing) 

844 

845 Arguments 

846 --------- 

847 data : dict whose keys are VarKey's 

848 data to represent in table 

849 title : string 

850 printunits : bool 

851 latex : int 

852 If > 0, return latex format (options 1-3); otherwise plain text 

853 varfmt : string 

854 format for variable names 

855 valfmt : string 

856 format for scalar values 

857 vecfmt : string 

858 format for vector values 

859 minval : float 

860 skip values with all(abs(value)) < minval 

861 sortbyvals : boolean 

862 If true, rows are sorted by their average value instead of by name. 

863 included_models : Iterable of strings 

864 If specified, the models (by name) to include 

865 excluded_models : Iterable of strings 

866 If specified, model names to exclude 

867 """ 

868 if not data: 

869 return [] 

870 decorated, models = [], set() 

871 for i, (k, v) in enumerate(data.items()): 

872 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

873 continue # no values below minval 

874 if minval and hidebelowminval and getattr(v, "shape", None): 

875 v[np.abs(v) <= minval] = np.nan 

876 model = lineagestr(k.lineage) if sortbymodel else "" 

877 if not sortmodelsbysenss: 

878 msenss = 0 

879 else: # sort should match that in msenss_table above 

880 msenss = -round(np.mean(sortmodelsbysenss.get(model, 0)), 4) 

881 models.add(model) 

882 b = bool(getattr(v, "shape", None)) 

883 s = k.str_without(("lineage", "vec")) 

884 if not sortbyvals: 

885 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

886 else: # for consistent sorting, add small offset to negative vals 

887 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

888 sort = (float("%.4g" % -val), k.name) 

889 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

890 if not decorated and skipifempty: 

891 return [] 

892 if included_models: 

893 included_models = set(included_models) 

894 included_models.add("") 

895 models = models.intersection(included_models) 

896 if excluded_models: 

897 models = models.difference(excluded_models) 

898 decorated.sort() 

899 previous_model, lines = None, [] 

900 for varlist in decorated: 

901 if sortbyvals: 

902 model, _, msenss, isvector, varstr, _, var, val = varlist 

903 else: 

904 msenss, model, isvector, varstr, _, var, val = varlist 

905 if model not in models: 

906 continue 

907 if model != previous_model: 

908 if lines: 

909 lines.append(["", "", "", ""]) 

910 if model: 

911 if not latex: 

912 lines.append([("newmodelline",), model, "", ""]) 

913 else: 

914 lines.append( 

915 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

916 previous_model = model 

917 label = var.descr.get("label", "") 

918 units = var.unitstr(" [%s] ") if printunits else "" 

919 if not isvector: 

920 valstr = valfmt % val 

921 else: 

922 last_dim_index = len(val.shape)-1 

923 horiz_dim, ncols = last_dim_index, 1 # starting values 

924 for dim_idx, dim_size in enumerate(val.shape): 

925 if ncols <= dim_size <= maxcolumns: 

926 horiz_dim, ncols = dim_idx, dim_size 

927 # align the array with horiz_dim by making it the last one 

928 dim_order = list(range(last_dim_index)) 

929 dim_order.insert(horiz_dim, last_dim_index) 

930 flatval = val.transpose(dim_order).flatten() 

931 vals = [vecfmt % v for v in flatval[:ncols]] 

932 bracket = " ] " if len(flatval) <= ncols else "" 

933 valstr = "[ %s%s" % (" ".join(vals), bracket) 

934 for before, after in VALSTR_REPLACES: 

935 valstr = valstr.replace(before, after) 

936 if not latex: 

937 lines.append([varstr, valstr, units, label]) 

938 if isvector and len(flatval) > ncols: 

939 values_remaining = len(flatval) - ncols 

940 while values_remaining > 0: 

941 idx = len(flatval)-values_remaining 

942 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

943 values_remaining -= ncols 

944 valstr = " " + " ".join(vals) 

945 for before, after in VALSTR_REPLACES: 

946 valstr = valstr.replace(before, after) 

947 if values_remaining <= 0: 

948 spaces = (-values_remaining 

949 * len(valstr)//(values_remaining + ncols)) 

950 valstr = valstr + " ]" + " "*spaces 

951 lines.append(["", valstr, "", ""]) 

952 else: 

953 varstr = "$%s$" % varstr.replace(" : ", "") 

954 if latex == 1: # normal results table 

955 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

956 label]) 

957 coltitles = [title, "Value", "Units", "Description"] 

958 elif latex == 2: # no values 

959 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

960 coltitles = [title, "Units", "Description"] 

961 elif latex == 3: # no description 

962 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

963 coltitles = [title, "Value", "Units"] 

964 else: 

965 raise ValueError("Unexpected latex option, %s." % latex) 

966 if rawlines: 

967 return lines 

968 if not latex: 

969 if lines: 

970 maxlens = np.max([list(map(len, line)) for line in lines 

971 if line[0] != ("newmodelline",)], axis=0) 

972 dirs = [">", "<", "<", "<"] 

973 # check lengths before using zip 

974 assert len(list(dirs)) == len(list(maxlens)) 

975 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

976 for i, line in enumerate(lines): 

977 if line[0] == ("newmodelline",): 

978 line = [fmts[0].format(" | "), line[1]] 

979 else: 

980 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

981 lines[i] = "".join(line).rstrip() 

982 lines = [title] + ["-"*len(title)] + lines + [""] 

983 else: 

984 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

985 lines = (["\n".join(["{\\footnotesize", 

986 "\\begin{longtable}{%s}" % colfmt[latex], 

987 "\\toprule", 

988 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

989 [" & ".join(l) + " \\\\" for l in lines] + 

990 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

991 return lines