Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import json 

9import pickletools 

10import numpy as np 

11from .nomials import NomialArray 

12from .small_classes import DictOfLists, Strings 

13from .small_scripts import mag, try_str_without 

14from .repr_conventions import unitstr, lineagestr 

15 

16 

17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

18 

19VALSTR_REPLACES = [ 

20 ("+nan", " nan"), 

21 ("-nan", " nan"), 

22 ("nan%", "nan "), 

23 ("nan", " - "), 

24] 

25 

26 

27class SolSavingEnvironment: 

28 """Temporarily removes construction/solve attributes from constraints. 

29 

30 This approximately halves the size of the pickled solution. 

31 """ 

32 

33 def __init__(self, solarray, saveconstraints): 

34 self.solarray = solarray 

35 self.attrstore = {} 

36 self.saveconstraints = saveconstraints 

37 self.constraintstore = None 

38 

39 

40 def __enter__(self): 

41 if self.saveconstraints: 

42 for constraint_attr in ["bounded", "meq_bounded", "vks", 

43 "v_ss", "unsubbed", "varkeys"]: 

44 store = {} 

45 for constraint in self.solarray["sensitivities"]["constraints"]: 

46 if getattr(constraint, constraint_attr, None): 

47 store[constraint] = getattr(constraint, constraint_attr) 

48 delattr(constraint, constraint_attr) 

49 self.attrstore[constraint_attr] = store 

50 else: 

51 self.constraintstore = \ 

52 self.solarray["sensitivities"].pop("constraints") 

53 

54 def __exit__(self, type_, val, traceback): 

55 if self.saveconstraints: 

56 for constraint_attr, store in self.attrstore.items(): 

57 for constraint, value in store.items(): 

58 setattr(constraint, constraint_attr, value) 

59 else: 

60 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

61 

62def msenss_table(data, _, **kwargs): 

63 "Returns model sensitivity table lines" 

64 if "models" not in data.get("sensitivities", {}): 

65 return "" 

66 data = sorted(data["sensitivities"]["models"].items(), 

67 key=lambda i: ((i[1] < 0.1).all(), 

68 -np.max(i[1]) if (i[1] < 0.1).all() 

69 else -round(np.mean(i[1]), 1), i[0])) 

70 lines = ["Model Sensitivities", "-------------------"] 

71 if kwargs["sortmodelsbysenss"]: 

72 lines[0] += " (sorts models in sections below)" 

73 previousmsenssstr = "" 

74 for model, msenss in data: 

75 if not model: # for now let's only do named models 

76 continue 

77 if (msenss < 0.1).all(): 

78 msenss = np.max(msenss) 

79 if msenss: 

80 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss)) 

81 else: 

82 msenssstr = " =0 " 

83 else: 

84 meansenss = round(np.mean(msenss), 1) 

85 msenssstr = "%+6.1f" % meansenss 

86 deltas = msenss - meansenss 

87 if np.max(np.abs(deltas)) > 0.1: 

88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

89 for d in deltas] 

90 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

91 if msenssstr == previousmsenssstr: 

92 msenssstr = " "*len(msenssstr) 

93 else: 

94 previousmsenssstr = msenssstr 

95 lines.append("%s : %s" % (msenssstr, model)) 

96 return lines + [""] if len(lines) > 3 else [] 

97 

98 

99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

100 "Returns sensitivity table lines" 

101 if "variables" in data.get("sensitivities", {}): 

102 data = data["sensitivities"]["variables"] 

103 if showvars: 

104 data = {k: data[k] for k in showvars if k in data} 

105 return var_table(data, title, sortbyvals=True, skipifempty=True, 

106 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

107 printunits=False, minval=1e-3, **kwargs) 

108 

109 

110def topsenss_table(data, showvars, nvars=5, **kwargs): 

111 "Returns top sensitivity table lines" 

112 data, filtered = topsenss_filter(data, showvars, nvars) 

113 title = "Most Sensitive Variables" 

114 if filtered: 

115 title = "Next Most Sensitive Variables" 

116 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

117 

118 

119def topsenss_filter(data, showvars, nvars=5): 

120 "Filters sensitivities down to top N vars" 

121 if "variables" in data.get("sensitivities", {}): 

122 data = data["sensitivities"]["variables"] 

123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

124 if not np.isnan(s).any()} 

125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

126 filter_already_shown = showvars.intersection(topk) 

127 for k in filter_already_shown: 

128 topk.remove(k) 

129 if nvars > 3: # always show at least 3 

130 nvars -= 1 

131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

132 

133 

134def insenss_table(data, _, maxval=0.1, **kwargs): 

135 "Returns insensitivity table lines" 

136 if "constants" in data.get("sensitivities", {}): 

137 data = data["sensitivities"]["variables"] 

138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

140 

141 

142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

143 "Return constraint tightness lines" 

144 title = "Most Sensitive Constraints" 

145 if len(self) > 1: 

146 title += " (in last sweep)" 

147 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

148 "%+6.2g" % s[-1], id(c), c) 

149 for c, s in self["sensitivities"]["constraints"].items() 

150 if s[-1] >= tight_senss)[:ntightconstrs] 

151 else: 

152 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

153 for c, s in self["sensitivities"]["constraints"].items() 

154 if s >= tight_senss)[:ntightconstrs] 

155 return constraint_table(data, title, **kwargs) 

156 

157def loose_table(self, _, min_senss=1e-5, **kwargs): 

158 "Return constraint tightness lines" 

159 title = "Insensitive Constraints |below %+g|" % min_senss 

160 if len(self) > 1: 

161 title += " (in last sweep)" 

162 data = [(0, "", id(c), c) 

163 for c, s in self["sensitivities"]["constraints"].items() 

164 if s[-1] <= min_senss] 

165 else: 

166 data = [(0, "", id(c), c) 

167 for c, s in self["sensitivities"]["constraints"].items() 

168 if s <= min_senss] 

169 return constraint_table(data, title, **kwargs) 

170 

171 

172# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

174 "Creates lines for tables where the right side is a constraint." 

175 # TODO: this should support 1D array inputs from sweeps 

176 excluded = ("units", "unnecessary lineage") 

177 if not showmodels: 

178 excluded = ("units", "lineage") # hide all of it 

179 models, decorated = {}, [] 

180 for sortby, openingstr, _, constraint in sorted(data): 

181 model = lineagestr(constraint) if sortbymodel else "" 

182 if model not in models: 

183 models[model] = len(models) 

184 constrstr = try_str_without(constraint, excluded) 

185 if " at 0x" in constrstr: # don't print memory addresses 

186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

187 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

188 decorated.sort() 

189 previous_model, lines = None, [] 

190 for varlist in decorated: 

191 _, model, _, constrstr, openingstr = varlist 

192 if model != previous_model: 

193 if lines: 

194 lines.append(["", ""]) 

195 if model or lines: 

196 lines.append([("newmodelline",), model]) 

197 previous_model = model 

198 constrstr = constrstr.replace(model, "") 

199 minlen, maxlen = 25, 80 

200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

201 constraintlines = [] 

202 line = "" 

203 next_idx = 0 

204 while next_idx < len(segments): 

205 segment = segments[next_idx] 

206 next_idx += 1 

207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

208 segments[next_idx] = segment[1:] + segments[next_idx] 

209 segment = segment[0] 

210 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

211 constraintlines.append(line) 

212 line = " " # start a new line 

213 line += segment 

214 while len(line) > maxlen: 

215 constraintlines.append(line[:maxlen]) 

216 line = " " + line[maxlen:] 

217 constraintlines.append(line) 

218 lines += [(openingstr + " : ", constraintlines[0])] 

219 lines += [("", l) for l in constraintlines[1:]] 

220 if not lines: 

221 lines = [("", "(none)")] 

222 maxlens = np.max([list(map(len, line)) for line in lines 

223 if line[0] != ("newmodelline",)], axis=0) 

224 dirs = [">", "<"] # we'll check lengths before using zip 

225 assert len(list(dirs)) == len(list(maxlens)) 

226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

227 for i, line in enumerate(lines): 

228 if line[0] == ("newmodelline",): 

229 linelist = [fmts[0].format(" | "), line[1]] 

230 else: 

231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

232 lines[i] = "".join(linelist).rstrip() 

233 return [title] + ["-"*len(title)] + lines + [""] 

234 

235 

236def warnings_table(self, _, **kwargs): 

237 "Makes a table for all warnings in the solution." 

238 title = "WARNINGS" 

239 lines = ["~"*len(title), title, "~"*len(title)] 

240 if "warnings" not in self or not self["warnings"]: 

241 return [] 

242 for wtype in sorted(self["warnings"]): 

243 data_vec = self["warnings"][wtype] 

244 if len(data_vec) == 0: 

245 continue 

246 if not hasattr(data_vec, "shape"): 

247 data_vec = [data_vec] # not a sweep 

248 else: 

249 all_equal = True 

250 for data in data_vec[1:]: 

251 eq_i = (data == data_vec[0]) 

252 if hasattr(eq_i, "all"): 

253 eq_i = eq_i.all() 

254 if not eq_i: 

255 all_equal = False 

256 break 

257 if all_equal: 

258 data_vec = [data_vec[0]] # warnings identical across sweeps 

259 for i, data in enumerate(data_vec): 

260 if len(data) == 0: 

261 continue 

262 data = sorted(data, key=lambda l: l[0]) # sort by msg 

263 title = wtype 

264 if len(data_vec) > 1: 

265 title += " in sweep %i" % i 

266 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

267 data = [(-int(1e5*relax_sensitivity), 

268 "%+6.2g" % relax_sensitivity, id(c), c) 

269 for _, (relax_sensitivity, c) in data] 

270 lines += constraint_table(data, title, **kwargs) 

271 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

272 data = [(-int(1e5*rel_diff), 

273 "%.4g %s %.4g" % tightvalues, id(c), c) 

274 for _, (rel_diff, tightvalues, c) in data] 

275 lines += constraint_table(data, title, **kwargs) 

276 else: 

277 lines += [title] + ["-"*len(wtype)] 

278 lines += [msg for msg, _ in data] + [""] 

279 if len(lines) == 3: # just the header 

280 return [] 

281 lines[-1] = "~~~~~~~~" 

282 return lines + [""] 

283 

284 

285TABLEFNS = {"sensitivities": senss_table, 

286 "top sensitivities": topsenss_table, 

287 "insensitivities": insenss_table, 

288 "model sensitivities": msenss_table, 

289 "tightest constraints": tight_table, 

290 "loose constraints": loose_table, 

291 "warnings": warnings_table, 

292 } 

293 

294def unrolled_absmax(values): 

295 "From an iterable of numbers and arrays, returns the largest magnitude" 

296 finalval, absmaxest = None, 0 

297 for val in values: 

298 absmaxval = np.abs(val).max() 

299 if absmaxval >= absmaxest: 

300 absmaxest, finalval = absmaxval, val 

301 if getattr(finalval, "shape", None): 

302 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

303 finalval.shape)] 

304 return finalval 

305 

306 

307def cast(function, val1, val2): 

308 "Relative difference between val1 and val2 (positive if val2 is larger)" 

309 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

310 pywarnings.simplefilter("ignore") 

311 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

312 if val1.ndim == val2.ndim: 

313 return function(val1, val2) 

314 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

315 dimdelta = dimmest.ndim - lessdim.ndim 

316 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

317 if dimmest is val1: 

318 return function(dimmest, lessdim[add_axes]) 

319 if dimmest is val2: 

320 return function(lessdim[add_axes], dimmest) 

321 return function(val1, val2) 

322 

323 

324def diff_retrieval(self, other, sharedvks, showvars=None, *, jsondiff=False, 

325 senssdiff=False, absdiff=False, reldiff=False): 

326 """A helper function for generalized diff method 

327 - retreives svars and ovars, 

328 """ 

329 svars = self["variables"] 

330 ovars = other["variables"] 

331 # get the type of diffs 

332 diff_dict = {} 

333 if jsondiff == False: 

334 if reldiff: 

335 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

336 for vk in sharedvks} 

337 diff_dict['rel'] = rel_diff 

338 if absdiff: 

339 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

340 diff_dict['abs'] = abs_diff 

341 if senssdiff: 

342 ssenss = self["sensitivities"]["variables"] 

343 osenss = other["sensitivities"]["variables"] 

344 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

345 for vk in sharedvks} 

346 diff_dict['sens'] = senss_delta 

347 else: 

348 if reldiff: 

349 rel_diff = {} 

350 for vk in sharedvks: 

351 val = 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

352 if isinstance(val, np.ndarray): 

353 val = val.tolist() 

354 rel_diff[str(vk)] = val 

355 diff_dict['rel'] = rel_diff 

356 if absdiff: 

357 abs_diff = {} 

358 for vk in sharedvks: 

359 val = cast(sub, svars[vk], ovars[vk]) 

360 if isinstance(val, np.ndarray): 

361 val = val.tolist() 

362 abs_diff[str(vk)] = val 

363 diff_dict['abs'] = abs_diff 

364 if senssdiff: 

365 sense_delta = {} 

366 ssenss = self["sensitivities"]["variables"] 

367 osenss = other["sensitivities"]["variables"] 

368 for vk in sharedvks: 

369 val = cast(sub, ssenss[vk], osenss[vk]) 

370 sense_delta[str(vk)] = val 

371 diff_dict['sens'] = senss_delta 

372 return diff_dict 

373 

374 

375class SolutionArray(DictOfLists): 

376 """A dictionary (of dictionaries) of lists, with convenience methods. 

377 

378 Items 

379 ----- 

380 cost : array 

381 variables: dict of arrays 

382 sensitivities: dict containing: 

383 monomials : array 

384 posynomials : array 

385 variables: dict of arrays 

386 localmodels : NomialArray 

387 Local power-law fits (small sensitivities are cut off) 

388 

389 Example 

390 ------- 

391 >>> import gpkit 

392 >>> import numpy as np 

393 >>> x = gpkit.Variable("x") 

394 >>> x_min = gpkit.Variable("x_{min}", 2) 

395 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

396 >>> 

397 >>> # VALUES 

398 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

399 >>> assert all(np.array(values) == 2) 

400 >>> 

401 >>> # SENSITIVITIES 

402 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

403 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

404 >>> assert all(np.array(senss) == 1) 

405 """ 

406 modelstr = "" 

407 _name_collision_varkeys = None 

408 table_titles = {"choicevariables": "Choice Variables", 

409 "sweepvariables": "Swept Variables", 

410 "freevariables": "Free Variables", 

411 "constants": "Fixed Variables", # TODO: change everywhere 

412 "variables": "Variables"} 

413 

414 def name_collision_varkeys(self): 

415 "Returns the set of contained varkeys whose names are not unique" 

416 if self._name_collision_varkeys is None: 

417 self["variables"].update_keymap() 

418 keymap = self["variables"].keymap 

419 self._name_collision_varkeys = set() 

420 for key in list(keymap): 

421 if hasattr(key, "key"): 

422 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

423 self._name_collision_varkeys.add(key) 

424 return self._name_collision_varkeys 

425 

426 def __len__(self): 

427 try: 

428 return len(self["cost"]) 

429 except TypeError: 

430 return 1 

431 except KeyError: 

432 return 0 

433 

434 def __call__(self, posy): 

435 posy_subbed = self.subinto(posy) 

436 return getattr(posy_subbed, "c", posy_subbed) 

437 

438 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

439 "Checks for almost-equality between two solutions" 

440 svars, ovars = self["variables"], other["variables"] 

441 svks, ovks = set(svars), set(ovars) 

442 if svks != ovks: 

443 return False 

444 for key in svks: 

445 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

446 return False 

447 if abs(self["sensitivities"]["variables"][key] 

448 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

449 return False 

450 return True 

451 

452 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

453 def diff(self, other, showvars=None, *, jsondiff=False, filename="solution.json" 

454 ,constraintsdiff=True, senssdiff=False, sensstol=0.1, 

455 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

456 sortmodelsbysenss=True, **tableargs): 

457 """Outputs differences between this solution and another 

458 

459 Arguments 

460 --------- 

461 other : solution or string 

462 strings will be treated as paths to pickled solutions 

463 senssdiff : boolean 

464 if True, show sensitivity differences 

465 sensstol : float 

466 the smallest sensitivity difference worth showing 

467 absdiff : boolean 

468 if True, show absolute differences 

469 abstol : float 

470 the smallest absolute difference worth showing 

471 reldiff : boolean 

472 if True, show relative differences 

473 reltol : float 

474 the smallest relative difference worth showing 

475 

476 Returns 

477 ------- 

478 str 

479 """ 

480 if sortmodelsbysenss: 

481 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

482 else: 

483 tableargs["sortmodelsbysenss"] = False 

484 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

485 "skipifempty": False}) 

486 if isinstance(other, Strings): 

487 if other[-4:] == ".pgz": 

488 other = SolutionArray.decompress_file(other) 

489 else: 

490 other = pickle.load(open(other, "rb")) 

491 svars, ovars = self["variables"], other["variables"] 

492 lines = ["Solution Diff", 

493 "=============", 

494 "(argument is the baseline solution)", ""] 

495 svks, ovks = set(svars), set(ovars) 

496 if showvars: 

497 lines[0] += " (for selected variables)" 

498 lines[1] += "=========================" 

499 showvars = self._parse_showvars(showvars) 

500 svks = {k for k in showvars if k in svars} 

501 ovks = {k for k in showvars if k in ovars} 

502 if constraintsdiff and other.modelstr and self.modelstr: 

503 if self.modelstr == other.modelstr: 

504 lines += ["** no constraint differences **", ""] 

505 else: 

506 cdiff = ["Constraint Differences", 

507 "**********************"] 

508 cdiff.extend(list(difflib.unified_diff( 

509 other.modelstr.split("\n"), self.modelstr.split("\n"), 

510 lineterm="", n=3))[2:]) 

511 cdiff += ["", "**********************", ""] 

512 lines += cdiff 

513 if svks - ovks: 

514 lines.append("Variable(s) of this solution" 

515 " which are not in the argument:") 

516 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

517 lines.append("") 

518 if ovks - svks: 

519 lines.append("Variable(s) of the argument" 

520 " which are not in this solution:") 

521 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

522 lines.append("") 

523 sharedvks = svks.intersection(ovks) 

524 

525 # retrieve diff data 

526 diff_dict = diff_retreval(self, other, sharedvks, showvars=showvars, 

527 jsondiff=jsondiff, senssdiff=senssdiff, 

528 absdiff=absdiff, reldiff=reldiff) 

529 if jsondiff: 

530 with open(filename, "w") as f: 

531 json.dump(diff_dict, f) 

532 return diff_dict 

533 

534 if reldiff: 

535 rel_diff = diff_dict['rel'] 

536 lines += var_table(rel_diff, 

537 "Relative Differences |above %g%%|" % reltol, 

538 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

539 minval=reltol, printunits=False, **tableargs) 

540 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

541 lines.insert(-1, ("The largest is %+g%%." 

542 % unrolled_absmax(rel_diff.values()))) 

543 if absdiff: 

544 abs_diff = diff_dict['abs'] 

545 lines += var_table(abs_diff, 

546 "Absolute Differences |above %g|" % abstol, 

547 valfmt="%+.2g", vecfmt="%+8.2g", 

548 minval=abstol, **tableargs) 

549 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

550 lines.insert(-1, ("The largest is %+g." 

551 % unrolled_absmax(abs_diff.values()))) 

552 if senssdiff: 

553 senss_delta = diff_dict['sens'] 

554 lines += var_table(senss_delta, 

555 "Sensitivity Differences |above %g|" % sensstol, 

556 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

557 minval=sensstol, printunits=False, **tableargs) 

558 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

559 lines.insert(-1, ("The largest is %+g." 

560 % unrolled_absmax(senss_delta.values()))) 

561 return "\n".join(lines) 

562 

563 def save(self, filename="solution.pkl", 

564 *, saveconstraints=True, **pickleargs): 

565 """Pickles the solution and saves it to a file. 

566 

567 Solution can then be loaded with e.g.: 

568 >>> import pickle 

569 >>> pickle.load(open("solution.pkl")) 

570 """ 

571 with SolSavingEnvironment(self, saveconstraints): 

572 pickle.dump(self, open(filename, "wb"), **pickleargs) 

573 

574 def save_compressed(self, filename="solution.pgz", 

575 *, saveconstraints=True, **cpickleargs): 

576 "Pickle a file and then compress it into a file with extension." 

577 with gzip.open(filename, "wb") as f: 

578 with SolSavingEnvironment(self, saveconstraints): 

579 pickled = pickle.dumps(self, **cpickleargs) 

580 f.write(pickletools.optimize(pickled)) 

581 

582 @staticmethod 

583 def decompress_file(file): 

584 "Load a gzip-compressed pickle file" 

585 with gzip.open(file, "rb") as f: 

586 return pickle.Unpickler(f).load() 

587 

588 def varnames(self, showvars, exclude): 

589 "Returns list of variables, optionally with minimal unique names" 

590 if showvars: 

591 showvars = self._parse_showvars(showvars) 

592 for key in self.name_collision_varkeys(): 

593 key.descr["necessarylineage"] = True 

594 names = {} 

595 for key in showvars or self["variables"]: 

596 for k in self["variables"].keymap[key]: 

597 names[k.str_without(exclude)] = k 

598 for key in self.name_collision_varkeys(): 

599 del key.descr["necessarylineage"] 

600 return names 

601 

602 def savemat(self, filename="solution.mat", *, showvars=None, 

603 excluded=("unnecessary lineage", "vec")): 

604 "Saves primal solution as matlab file" 

605 from scipy.io import savemat 

606 savemat(filename, 

607 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

608 for name, key in self.varnames(showvars, excluded).items()}) 

609 

610 def todataframe(self, showvars=None, 

611 excluded=("unnecessary lineage", "vec")): 

612 "Returns primal solution as pandas dataframe" 

613 import pandas as pd # pylint:disable=import-error 

614 rows = [] 

615 cols = ["Name", "Index", "Value", "Units", "Label", 

616 "Lineage", "Other"] 

617 for _, key in sorted(self.varnames(showvars, excluded).items(), 

618 key=lambda k: k[0]): 

619 value = self["variables"][key] 

620 if key.shape: 

621 idxs = [] 

622 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

623 while not it.finished: 

624 idx = it.multi_index 

625 idxs.append(idx[0] if len(idx) == 1 else idx) 

626 it.iternext() 

627 else: 

628 idxs = [None] 

629 for idx in idxs: 

630 row = [ 

631 key.name, 

632 "" if idx is None else idx, 

633 value if idx is None else value[idx]] 

634 rows.append(row) 

635 row.extend([ 

636 key.unitstr(), 

637 key.label or "", 

638 key.lineage or "", 

639 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

640 if k not in ["name", "units", "unitrepr", 

641 "idx", "shape", "veckey", 

642 "value", "vecfn", 

643 "lineage", "label"])]) 

644 return pd.DataFrame(rows, columns=cols) 

645 

646 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

647 "Saves solution table as a text file" 

648 with open(filename, "w") as f: 

649 if printmodel: 

650 f.write(self.modelstr + "\n") 

651 f.write(self.table(**kwargs)) 

652 

653 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

654 "Saves primal solution as a CSV sorted by modelname, like the tables." 

655 data = self["variables"] 

656 if showvars: 

657 showvars = self._parse_showvars(showvars) 

658 data = {k: data[k] for k in showvars if k in data} 

659 # if the columns don't capture any dimensions, skip them 

660 minspan, maxspan = None, 1 

661 for v in data.values(): 

662 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

663 minspan_ = min((di for di in v.shape if di != 1)) 

664 maxspan_ = max((di for di in v.shape if di != 1)) 

665 if minspan is None or minspan_ < minspan: 

666 minspan = minspan_ 

667 if maxspan is None or maxspan_ > maxspan: 

668 maxspan = maxspan_ 

669 if minspan is not None and minspan > valcols: 

670 valcols = 1 

671 if maxspan < valcols: 

672 valcols = maxspan 

673 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

674 tables=("cost", "sweepvariables", "freevariables", 

675 "constants", "sensitivities")) 

676 with open(filename, "w") as f: 

677 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

678 + "Units,Description\n") 

679 for line in lines: 

680 if line[0] == ("newmodelline",): 

681 f.write(line[1]) 

682 elif not line[1]: # spacer line 

683 f.write("\n") 

684 else: 

685 f.write("," + line[0].replace(" : ", "") + ",") 

686 vals = line[1].replace("[", "").replace("]", "").strip() 

687 for el in vals.split(): 

688 f.write(el + ",") 

689 f.write(","*(valcols - len(vals.split()))) 

690 f.write((line[2].replace("[", "").replace("]", "").strip() 

691 + ",")) 

692 f.write(line[3].strip() + "\n") 

693 

694 def subinto(self, posy): 

695 "Returns NomialArray of each solution substituted into posy." 

696 if posy in self["variables"]: 

697 return self["variables"](posy) 

698 

699 if not hasattr(posy, "sub"): 

700 raise ValueError("no variable '%s' found in the solution" % posy) 

701 

702 if len(self) > 1: 

703 return NomialArray([self.atindex(i).subinto(posy) 

704 for i in range(len(self))]) 

705 

706 return posy.sub(self["variables"]) 

707 

708 def _parse_showvars(self, showvars): 

709 showvars_out = set() 

710 for k in showvars: 

711 k, _ = self["variables"].parse_and_index(k) 

712 keys = self["variables"].keymap[k] 

713 showvars_out.update(keys) 

714 return showvars_out 

715 

716 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

717 "Print summary table, showing top sensitivities and no constants" 

718 showvars = self._parse_showvars(showvars) 

719 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

720 "freevariables"], **kwargs) 

721 constants_in_showvars = showvars.intersection(self["constants"]) 

722 senss_tables = [] 

723 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

724 senss_tables.append("sensitivities") 

725 if len(self["constants"]) >= ntopsenss+2: 

726 senss_tables.append("top sensitivities") 

727 senss_tables.append("tightest constraints") 

728 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

729 **kwargs) 

730 if senss_str: 

731 out += "\n" + senss_str 

732 return out 

733 

734 def table(self, showvars=(), 

735 tables=("cost", "warnings", "model sensitivities", 

736 "sweepvariables", "freevariables", 

737 "constants", "sensitivities", "tightest constraints"), 

738 sortmodelsbysenss=True, **kwargs): 

739 """A table representation of this SolutionArray 

740 

741 Arguments 

742 --------- 

743 tables: Iterable 

744 Which to print of ("cost", "sweepvariables", "freevariables", 

745 "constants", "sensitivities") 

746 fixedcols: If true, print vectors in fixed-width format 

747 latex: int 

748 If > 0, return latex format (options 1-3); otherwise plain text 

749 included_models: Iterable of strings 

750 If specified, the models (by name) to include 

751 excluded_models: Iterable of strings 

752 If specified, model names to exclude 

753 

754 Returns 

755 ------- 

756 str 

757 """ 

758 if sortmodelsbysenss and "sensitivities" in self: 

759 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

760 else: 

761 kwargs["sortmodelsbysenss"] = False 

762 varlist = list(self["variables"]) 

763 has_only_one_model = True 

764 for var in varlist[1:]: 

765 if var.lineage != varlist[0].lineage: 

766 has_only_one_model = False 

767 break 

768 if has_only_one_model: 

769 kwargs["sortbymodel"] = False 

770 for key in self.name_collision_varkeys(): 

771 key.descr["necessarylineage"] = True 

772 showvars = self._parse_showvars(showvars) 

773 strs = [] 

774 for table in tables: 

775 if "sensitivities" not in self and ("sensitivities" in table or 

776 "constraints" in table): 

777 continue 

778 if table == "cost": 

779 cost = self["cost"] # pylint: disable=unsubscriptable-object 

780 if kwargs.get("latex", None): # cost is not printed for latex 

781 continue 

782 strs += ["\n%s\n------------" % "Optimal Cost"] 

783 if len(self) > 1: 

784 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

785 strs += [" [ %s %s ]" % (" ".join(costs), 

786 "..." if len(self) > 4 else "")] 

787 else: 

788 strs += [" %-.4g" % mag(cost)] 

789 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

790 strs += [""] 

791 elif table in TABLEFNS: 

792 strs += TABLEFNS[table](self, showvars, **kwargs) 

793 elif table in self: 

794 data = self[table] 

795 if showvars: 

796 showvars = self._parse_showvars(showvars) 

797 data = {k: data[k] for k in showvars if k in data} 

798 strs += var_table(data, self.table_titles[table], **kwargs) 

799 if kwargs.get("latex", None): 

800 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

801 "% \\usepackage{booktabs}", 

802 "% \\usepackage{longtable}", 

803 "% \\usepackage{amsmath}", 

804 "% \\begin{document}\n")) 

805 strs = [preamble] + strs + ["% \\end{document}"] 

806 for key in self.name_collision_varkeys(): 

807 del key.descr["necessarylineage"] 

808 return "\n".join(strs) 

809 

810 def plot(self, posys=None, axes=None): 

811 "Plots a sweep for each posy" 

812 if len(self["sweepvariables"]) != 1: 

813 print("SolutionArray.plot only supports 1-dimensional sweeps") 

814 if not hasattr(posys, "__len__"): 

815 posys = [posys] 

816 import matplotlib.pyplot as plt 

817 from .interactive.plot_sweep import assign_axes 

818 from . import GPBLU 

819 (swept, x), = self["sweepvariables"].items() 

820 posys, axes = assign_axes(swept, posys, axes) 

821 for posy, ax in zip(posys, axes): 

822 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

823 ax.plot(x, y, color=GPBLU) 

824 if len(axes) == 1: 

825 axes, = axes 

826 return plt.gcf(), axes 

827 

828 

829# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

830def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

831 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

832 minval=0, sortbyvals=False, hidebelowminval=False, 

833 included_models=None, excluded_models=None, sortbymodel=True, 

834 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

835 """ 

836 Pretty string representation of a dict of VarKeys 

837 Iterable values are handled specially (partial printing) 

838 

839 Arguments 

840 --------- 

841 data : dict whose keys are VarKey's 

842 data to represent in table 

843 title : string 

844 printunits : bool 

845 latex : int 

846 If > 0, return latex format (options 1-3); otherwise plain text 

847 varfmt : string 

848 format for variable names 

849 valfmt : string 

850 format for scalar values 

851 vecfmt : string 

852 format for vector values 

853 minval : float 

854 skip values with all(abs(value)) < minval 

855 sortbyvals : boolean 

856 If true, rows are sorted by their average value instead of by name. 

857 included_models : Iterable of strings 

858 If specified, the models (by name) to include 

859 excluded_models : Iterable of strings 

860 If specified, model names to exclude 

861 """ 

862 if not data: 

863 return [] 

864 decorated, models = [], set() 

865 for i, (k, v) in enumerate(data.items()): 

866 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

867 continue # no values below minval 

868 if minval and hidebelowminval and getattr(v, "shape", None): 

869 v[np.abs(v) <= minval] = np.nan 

870 model = lineagestr(k.lineage) if sortbymodel else "" 

871 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

872 if hasattr(msenss, "shape"): 

873 msenss = np.mean(msenss) 

874 models.add(model) 

875 b = bool(getattr(v, "shape", None)) 

876 s = k.str_without(("lineage", "vec")) 

877 if not sortbyvals: 

878 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

879 else: # for consistent sorting, add small offset to negative vals 

880 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

881 sort = (float("%.4g" % -val), k.name) 

882 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

883 if not decorated and skipifempty: 

884 return [] 

885 if included_models: 

886 included_models = set(included_models) 

887 included_models.add("") 

888 models = models.intersection(included_models) 

889 if excluded_models: 

890 models = models.difference(excluded_models) 

891 decorated.sort() 

892 previous_model, lines = None, [] 

893 for varlist in decorated: 

894 if sortbyvals: 

895 model, _, msenss, isvector, varstr, _, var, val = varlist 

896 else: 

897 msenss, model, isvector, varstr, _, var, val = varlist 

898 if model not in models: 

899 continue 

900 if model != previous_model: 

901 if lines: 

902 lines.append(["", "", "", ""]) 

903 if model: 

904 if not latex: 

905 lines.append([("newmodelline",), model, "", ""]) 

906 else: 

907 lines.append( 

908 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

909 previous_model = model 

910 label = var.descr.get("label", "") 

911 units = var.unitstr(" [%s] ") if printunits else "" 

912 if not isvector: 

913 valstr = valfmt % val 

914 else: 

915 last_dim_index = len(val.shape)-1 

916 horiz_dim, ncols = last_dim_index, 1 # starting values 

917 for dim_idx, dim_size in enumerate(val.shape): 

918 if ncols <= dim_size <= maxcolumns: 

919 horiz_dim, ncols = dim_idx, dim_size 

920 # align the array with horiz_dim by making it the last one 

921 dim_order = list(range(last_dim_index)) 

922 dim_order.insert(horiz_dim, last_dim_index) 

923 flatval = val.transpose(dim_order).flatten() 

924 vals = [vecfmt % v for v in flatval[:ncols]] 

925 bracket = " ] " if len(flatval) <= ncols else "" 

926 valstr = "[ %s%s" % (" ".join(vals), bracket) 

927 for before, after in VALSTR_REPLACES: 

928 valstr = valstr.replace(before, after) 

929 if not latex: 

930 lines.append([varstr, valstr, units, label]) 

931 if isvector and len(flatval) > ncols: 

932 values_remaining = len(flatval) - ncols 

933 while values_remaining > 0: 

934 idx = len(flatval)-values_remaining 

935 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

936 values_remaining -= ncols 

937 valstr = " " + " ".join(vals) 

938 for before, after in VALSTR_REPLACES: 

939 valstr = valstr.replace(before, after) 

940 if values_remaining <= 0: 

941 spaces = (-values_remaining 

942 * len(valstr)//(values_remaining + ncols)) 

943 valstr = valstr + " ]" + " "*spaces 

944 lines.append(["", valstr, "", ""]) 

945 else: 

946 varstr = "$%s$" % varstr.replace(" : ", "") 

947 if latex == 1: # normal results table 

948 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

949 label]) 

950 coltitles = [title, "Value", "Units", "Description"] 

951 elif latex == 2: # no values 

952 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

953 coltitles = [title, "Units", "Description"] 

954 elif latex == 3: # no description 

955 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

956 coltitles = [title, "Value", "Units"] 

957 else: 

958 raise ValueError("Unexpected latex option, %s." % latex) 

959 if rawlines: 

960 return lines 

961 if not latex: 

962 if lines: 

963 maxlens = np.max([list(map(len, line)) for line in lines 

964 if line[0] != ("newmodelline",)], axis=0) 

965 dirs = [">", "<", "<", "<"] 

966 # check lengths before using zip 

967 assert len(list(dirs)) == len(list(maxlens)) 

968 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

969 for i, line in enumerate(lines): 

970 if line[0] == ("newmodelline",): 

971 line = [fmts[0].format(" | "), line[1]] 

972 else: 

973 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

974 lines[i] = "".join(line).rstrip() 

975 lines = [title] + ["-"*len(title)] + lines + [""] 

976 else: 

977 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

978 lines = (["\n".join(["{\\footnotesize", 

979 "\\begin{longtable}{%s}" % colfmt[latex], 

980 "\\toprule", 

981 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

982 [" & ".join(l) + " \\\\" for l in lines] + 

983 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

984 return lines