Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import difflib 

4from operator import sub 

5import warnings as pywarnings 

6import pickle 

7import gzip 

8import pickletools 

9import numpy as np 

10from .nomials import NomialArray 

11from .small_classes import DictOfLists, Strings 

12from .small_scripts import mag, try_str_without 

13from .repr_conventions import unitstr, lineagestr 

14 

15 

16CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

17 

18VALSTR_REPLACES = [ 

19 ("+nan", " nan"), 

20 ("-nan", " nan"), 

21 ("nan%", "nan "), 

22 ("nan", " - "), 

23] 

24 

25 

26class SolSavingEnvironment: 

27 """Temporarily removes construction/solve attributes from constraints. 

28 

29 This approximately halves the size of the pickled solution. 

30 """ 

31 

32 def __init__(self, solarray, saveconstraints): 

33 self.solarray = solarray 

34 self.attrstore = {} 

35 self.saveconstraints = saveconstraints 

36 self.constraintstore = None 

37 

38 

39 def __enter__(self): 

40 if self.saveconstraints: 

41 for constraint_attr in ["bounded", "meq_bounded", "vks", 

42 "v_ss", "unsubbed", "varkeys"]: 

43 store = {} 

44 for constraint in self.solarray["sensitivities"]["constraints"]: 

45 if getattr(constraint, constraint_attr, None): 

46 store[constraint] = getattr(constraint, constraint_attr) 

47 delattr(constraint, constraint_attr) 

48 self.attrstore[constraint_attr] = store 

49 else: 

50 self.constraintstore = \ 

51 self.solarray["sensitivities"].pop("constraints") 

52 

53 def __exit__(self, type_, val, traceback): 

54 if self.saveconstraints: 

55 for constraint_attr, store in self.attrstore.items(): 

56 for constraint, value in store.items(): 

57 setattr(constraint, constraint_attr, value) 

58 else: 

59 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

60 

61def msenss_table(data, _, **kwargs): 

62 "Returns model sensitivity table lines" 

63 if "models" not in data.get("sensitivities", {}): 

64 return "" 

65 data = sorted(data["sensitivities"]["models"].items(), 

66 key=lambda i: ((i[1] < 0.1).all(), 

67 -np.max(i[1]) if (i[1] < 0.1).all() 

68 else -round(np.mean(i[1]), 1), i[0])) 

69 lines = ["Model Sensitivities", "-------------------"] 

70 if kwargs["sortmodelsbysenss"]: 

71 lines[0] += " (sorts models in sections below)" 

72 previousmsenssstr = "" 

73 for model, msenss in data: 

74 if not model: # for now let's only do named models 

75 continue 

76 if (msenss < 0.1).all(): 

77 msenss = np.max(msenss) 

78 if msenss: 

79 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss)) 

80 else: 

81 msenssstr = " =0 " 

82 else: 

83 meansenss = round(np.mean(msenss), 1) 

84 msenssstr = "%+6.1f" % meansenss 

85 deltas = msenss - meansenss 

86 if np.max(np.abs(deltas)) > 0.1: 

87 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

88 for d in deltas] 

89 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

90 if msenssstr == previousmsenssstr: 

91 msenssstr = " "*len(msenssstr) 

92 else: 

93 previousmsenssstr = msenssstr 

94 lines.append("%s : %s" % (msenssstr, model)) 

95 return lines + [""] if len(lines) > 3 else [] 

96 

97 

98def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

99 "Returns sensitivity table lines" 

100 if "variables" in data.get("sensitivities", {}): 

101 data = data["sensitivities"]["variables"] 

102 if showvars: 

103 data = {k: data[k] for k in showvars if k in data} 

104 return var_table(data, title, sortbyvals=True, skipifempty=True, 

105 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

106 printunits=False, minval=1e-3, **kwargs) 

107 

108 

109def topsenss_table(data, showvars, nvars=5, **kwargs): 

110 "Returns top sensitivity table lines" 

111 data, filtered = topsenss_filter(data, showvars, nvars) 

112 title = "Most Sensitive Variables" 

113 if filtered: 

114 title = "Next Most Sensitive Variables" 

115 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

116 

117 

118def topsenss_filter(data, showvars, nvars=5): 

119 "Filters sensitivities down to top N vars" 

120 if "variables" in data.get("sensitivities", {}): 

121 data = data["sensitivities"]["variables"] 

122 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

123 if not np.isnan(s).any()} 

124 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

125 filter_already_shown = showvars.intersection(topk) 

126 for k in filter_already_shown: 

127 topk.remove(k) 

128 if nvars > 3: # always show at least 3 

129 nvars -= 1 

130 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

131 

132 

133def insenss_table(data, _, maxval=0.1, **kwargs): 

134 "Returns insensitivity table lines" 

135 if "constants" in data.get("sensitivities", {}): 

136 data = data["sensitivities"]["variables"] 

137 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

138 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

139 

140 

141def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

142 "Return constraint tightness lines" 

143 title = "Most Sensitive Constraints" 

144 if len(self) > 1: 

145 title += " (in last sweep)" 

146 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

147 "%+6.2g" % s[-1], id(c), c) 

148 for c, s in self["sensitivities"]["constraints"].items() 

149 if s[-1] >= tight_senss)[:ntightconstrs] 

150 else: 

151 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

152 for c, s in self["sensitivities"]["constraints"].items() 

153 if s >= tight_senss)[:ntightconstrs] 

154 return constraint_table(data, title, **kwargs) 

155 

156def loose_table(self, _, min_senss=1e-5, **kwargs): 

157 "Return constraint tightness lines" 

158 title = "Insensitive Constraints |below %+g|" % min_senss 

159 if len(self) > 1: 

160 title += " (in last sweep)" 

161 data = [(0, "", id(c), c) 

162 for c, s in self["sensitivities"]["constraints"].items() 

163 if s[-1] <= min_senss] 

164 else: 

165 data = [(0, "", id(c), c) 

166 for c, s in self["sensitivities"]["constraints"].items() 

167 if s <= min_senss] 

168 return constraint_table(data, title, **kwargs) 

169 

170 

171# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

172def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

173 "Creates lines for tables where the right side is a constraint." 

174 # TODO: this should support 1D array inputs from sweeps 

175 excluded = ("units", "unnecessary lineage") 

176 if not showmodels: 

177 excluded = ("units", "lineage") # hide all of it 

178 models, decorated = {}, [] 

179 for sortby, openingstr, _, constraint in sorted(data): 

180 model = lineagestr(constraint) if sortbymodel else "" 

181 if model not in models: 

182 models[model] = len(models) 

183 constrstr = try_str_without(constraint, excluded) 

184 if " at 0x" in constrstr: # don't print memory addresses 

185 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

186 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

187 decorated.sort() 

188 previous_model, lines = None, [] 

189 for varlist in decorated: 

190 _, model, _, constrstr, openingstr = varlist 

191 if model != previous_model: 

192 if lines: 

193 lines.append(["", ""]) 

194 if model or lines: 

195 lines.append([("newmodelline",), model]) 

196 previous_model = model 

197 constrstr = constrstr.replace(model, "") 

198 minlen, maxlen = 25, 80 

199 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

200 constraintlines = [] 

201 line = "" 

202 next_idx = 0 

203 while next_idx < len(segments): 

204 segment = segments[next_idx] 

205 next_idx += 1 

206 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

207 segments[next_idx] = segment[1:] + segments[next_idx] 

208 segment = segment[0] 

209 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

210 constraintlines.append(line) 

211 line = " " # start a new line 

212 line += segment 

213 while len(line) > maxlen: 

214 constraintlines.append(line[:maxlen]) 

215 line = " " + line[maxlen:] 

216 constraintlines.append(line) 

217 lines += [(openingstr + " : ", constraintlines[0])] 

218 lines += [("", l) for l in constraintlines[1:]] 

219 if not lines: 

220 lines = [("", "(none)")] 

221 maxlens = np.max([list(map(len, line)) for line in lines 

222 if line[0] != ("newmodelline",)], axis=0) 

223 dirs = [">", "<"] # we'll check lengths before using zip 

224 assert len(list(dirs)) == len(list(maxlens)) 

225 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

226 for i, line in enumerate(lines): 

227 if line[0] == ("newmodelline",): 

228 linelist = [fmts[0].format(" | "), line[1]] 

229 else: 

230 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

231 lines[i] = "".join(linelist).rstrip() 

232 return [title] + ["-"*len(title)] + lines + [""] 

233 

234 

235def warnings_table(self, _, **kwargs): 

236 "Makes a table for all warnings in the solution." 

237 title = "WARNINGS" 

238 lines = ["~"*len(title), title, "~"*len(title)] 

239 if "warnings" not in self or not self["warnings"]: 

240 return [] 

241 for wtype in sorted(self["warnings"]): 

242 data_vec = self["warnings"][wtype] 

243 if len(data_vec) == 0: 

244 continue 

245 if not hasattr(data_vec, "shape"): 

246 data_vec = [data_vec] # not a sweep 

247 else: 

248 all_equal = True 

249 for data in data_vec[1:]: 

250 eq_i = (data == data_vec[0]) 

251 if hasattr(eq_i, "all"): 

252 eq_i = eq_i.all() 

253 if not eq_i: 

254 all_equal = False 

255 break 

256 if all_equal: 

257 data_vec = [data_vec[0]] # warnings identical across sweeps 

258 for i, data in enumerate(data_vec): 

259 if len(data) == 0: 

260 continue 

261 data = sorted(data, key=lambda l: l[0]) # sort by msg 

262 title = wtype 

263 if len(data_vec) > 1: 

264 title += " in sweep %i" % i 

265 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

266 data = [(-int(1e5*relax_sensitivity), 

267 "%+6.2g" % relax_sensitivity, id(c), c) 

268 for _, (relax_sensitivity, c) in data] 

269 lines += constraint_table(data, title, **kwargs) 

270 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

271 data = [(-int(1e5*rel_diff), 

272 "%.4g %s %.4g" % tightvalues, id(c), c) 

273 for _, (rel_diff, tightvalues, c) in data] 

274 lines += constraint_table(data, title, **kwargs) 

275 else: 

276 lines += [title] + ["-"*len(wtype)] 

277 lines += [msg for msg, _ in data] + [""] 

278 if len(lines) == 3: # just the header 

279 return [] 

280 lines[-1] = "~~~~~~~~" 

281 return lines + [""] 

282 

283 

284TABLEFNS = {"sensitivities": senss_table, 

285 "top sensitivities": topsenss_table, 

286 "insensitivities": insenss_table, 

287 "model sensitivities": msenss_table, 

288 "tightest constraints": tight_table, 

289 "loose constraints": loose_table, 

290 "warnings": warnings_table, 

291 } 

292 

293def unrolled_absmax(values): 

294 "From an iterable of numbers and arrays, returns the largest magnitude" 

295 finalval, absmaxest = None, 0 

296 for val in values: 

297 absmaxval = np.abs(val).max() 

298 if absmaxval >= absmaxest: 

299 absmaxest, finalval = absmaxval, val 

300 if getattr(finalval, "shape", None): 

301 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

302 finalval.shape)] 

303 return finalval 

304 

305 

306def cast(function, val1, val2): 

307 "Relative difference between val1 and val2 (positive if val2 is larger)" 

308 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

309 pywarnings.simplefilter("ignore") 

310 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

311 if val1.ndim == val2.ndim: 

312 return function(val1, val2) 

313 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

314 dimdelta = dimmest.ndim - lessdim.ndim 

315 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

316 if dimmest is val1: 

317 return function(dimmest, lessdim[add_axes]) 

318 if dimmest is val2: 

319 return function(lessdim[add_axes], dimmest) 

320 return function(val1, val2) 

321 

322 

323class SolutionArray(DictOfLists): 

324 """A dictionary (of dictionaries) of lists, with convenience methods. 

325 

326 Items 

327 ----- 

328 cost : array 

329 variables: dict of arrays 

330 sensitivities: dict containing: 

331 monomials : array 

332 posynomials : array 

333 variables: dict of arrays 

334 localmodels : NomialArray 

335 Local power-law fits (small sensitivities are cut off) 

336 

337 Example 

338 ------- 

339 >>> import gpkit 

340 >>> import numpy as np 

341 >>> x = gpkit.Variable("x") 

342 >>> x_min = gpkit.Variable("x_{min}", 2) 

343 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

344 >>> 

345 >>> # VALUES 

346 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

347 >>> assert all(np.array(values) == 2) 

348 >>> 

349 >>> # SENSITIVITIES 

350 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

351 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

352 >>> assert all(np.array(senss) == 1) 

353 """ 

354 modelstr = "" 

355 _name_collision_varkeys = None 

356 table_titles = {"choicevariables": "Choice Variables", 

357 "sweepvariables": "Swept Variables", 

358 "freevariables": "Free Variables", 

359 "constants": "Fixed Variables", # TODO: change everywhere 

360 "variables": "Variables"} 

361 

362 def name_collision_varkeys(self): 

363 "Returns the set of contained varkeys whose names are not unique" 

364 if self._name_collision_varkeys is None: 

365 self["variables"].update_keymap() 

366 keymap = self["variables"].keymap 

367 self._name_collision_varkeys = set() 

368 for key in list(keymap): 

369 if hasattr(key, "key"): 

370 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

371 self._name_collision_varkeys.add(key) 

372 return self._name_collision_varkeys 

373 

374 def __len__(self): 

375 try: 

376 return len(self["cost"]) 

377 except TypeError: 

378 return 1 

379 except KeyError: 

380 return 0 

381 

382 def __call__(self, posy): 

383 posy_subbed = self.subinto(posy) 

384 return getattr(posy_subbed, "c", posy_subbed) 

385 

386 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

387 "Checks for almost-equality between two solutions" 

388 svars, ovars = self["variables"], other["variables"] 

389 svks, ovks = set(svars), set(ovars) 

390 if svks != ovks: 

391 return False 

392 for key in svks: 

393 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

394 return False 

395 if abs(self["sensitivities"]["variables"][key] 

396 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

397 return False 

398 return True 

399 

400 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

401 def diff(self, other, showvars=None, *, 

402 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

403 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

404 sortmodelsbysenss=True, **tableargs): 

405 """Outputs differences between this solution and another 

406 

407 Arguments 

408 --------- 

409 other : solution or string 

410 strings will be treated as paths to pickled solutions 

411 senssdiff : boolean 

412 if True, show sensitivity differences 

413 sensstol : float 

414 the smallest sensitivity difference worth showing 

415 absdiff : boolean 

416 if True, show absolute differences 

417 abstol : float 

418 the smallest absolute difference worth showing 

419 reldiff : boolean 

420 if True, show relative differences 

421 reltol : float 

422 the smallest relative difference worth showing 

423 

424 Returns 

425 ------- 

426 str 

427 """ 

428 if sortmodelsbysenss: 

429 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

430 else: 

431 tableargs["sortmodelsbysenss"] = False 

432 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

433 "skipifempty": False}) 

434 if isinstance(other, Strings): 

435 if other[-4:] == ".pgz": 

436 other = SolutionArray.decompress_file(other) 

437 else: 

438 other = pickle.load(open(other, "rb")) 

439 svars, ovars = self["variables"], other["variables"] 

440 lines = ["Solution Diff", 

441 "=============", 

442 "(argument is the baseline solution)", ""] 

443 svks, ovks = set(svars), set(ovars) 

444 if showvars: 

445 lines[0] += " (for selected variables)" 

446 lines[1] += "=========================" 

447 showvars = self._parse_showvars(showvars) 

448 svks = {k for k in showvars if k in svars} 

449 ovks = {k for k in showvars if k in ovars} 

450 if constraintsdiff and other.modelstr and self.modelstr: 

451 if self.modelstr == other.modelstr: 

452 lines += ["** no constraint differences **", ""] 

453 else: 

454 cdiff = ["Constraint Differences", 

455 "**********************"] 

456 cdiff.extend(list(difflib.unified_diff( 

457 other.modelstr.split("\n"), self.modelstr.split("\n"), 

458 lineterm="", n=3))[2:]) 

459 cdiff += ["", "**********************", ""] 

460 lines += cdiff 

461 if svks - ovks: 

462 lines.append("Variable(s) of this solution" 

463 " which are not in the argument:") 

464 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

465 lines.append("") 

466 if ovks - svks: 

467 lines.append("Variable(s) of the argument" 

468 " which are not in this solution:") 

469 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

470 lines.append("") 

471 sharedvks = svks.intersection(ovks) 

472 if reldiff: 

473 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

474 for vk in sharedvks} 

475 lines += var_table(rel_diff, 

476 "Relative Differences |above %g%%|" % reltol, 

477 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

478 minval=reltol, printunits=False, **tableargs) 

479 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

480 lines.insert(-1, ("The largest is %+g%%." 

481 % unrolled_absmax(rel_diff.values()))) 

482 if absdiff: 

483 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

484 lines += var_table(abs_diff, 

485 "Absolute Differences |above %g|" % abstol, 

486 valfmt="%+.2g", vecfmt="%+8.2g", 

487 minval=abstol, **tableargs) 

488 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

489 lines.insert(-1, ("The largest is %+g." 

490 % unrolled_absmax(abs_diff.values()))) 

491 if senssdiff: 

492 ssenss = self["sensitivities"]["variables"] 

493 osenss = other["sensitivities"]["variables"] 

494 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

495 for vk in svks.intersection(ovks)} 

496 lines += var_table(senss_delta, 

497 "Sensitivity Differences |above %g|" % sensstol, 

498 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

499 minval=sensstol, printunits=False, **tableargs) 

500 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

501 lines.insert(-1, ("The largest is %+g." 

502 % unrolled_absmax(senss_delta.values()))) 

503 return "\n".join(lines) 

504 

505 def save(self, filename="solution.pkl", 

506 *, saveconstraints=True, **pickleargs): 

507 """Pickles the solution and saves it to a file. 

508 

509 Solution can then be loaded with e.g.: 

510 >>> import pickle 

511 >>> pickle.load(open("solution.pkl")) 

512 """ 

513 with SolSavingEnvironment(self, saveconstraints): 

514 pickle.dump(self, open(filename, "wb"), **pickleargs) 

515 

516 def save_compressed(self, filename="solution.pgz", 

517 *, saveconstraints=True, **cpickleargs): 

518 "Pickle a file and then compress it into a file with extension." 

519 with gzip.open(filename, "wb") as f: 

520 with SolSavingEnvironment(self, saveconstraints): 

521 pickled = pickle.dumps(self, **cpickleargs) 

522 f.write(pickletools.optimize(pickled)) 

523 

524 @staticmethod 

525 def decompress_file(file): 

526 "Load a gzip-compressed pickle file" 

527 with gzip.open(file, "rb") as f: 

528 return pickle.Unpickler(f).load() 

529 

530 def varnames(self, showvars, exclude): 

531 "Returns list of variables, optionally with minimal unique names" 

532 if showvars: 

533 showvars = self._parse_showvars(showvars) 

534 for key in self.name_collision_varkeys(): 

535 key.descr["necessarylineage"] = True 

536 names = {} 

537 for key in showvars or self["variables"]: 

538 for k in self["variables"].keymap[key]: 

539 names[k.str_without(exclude)] = k 

540 for key in self.name_collision_varkeys(): 

541 del key.descr["necessarylineage"] 

542 return names 

543 

544 def savemat(self, filename="solution.mat", *, showvars=None, 

545 excluded=("unnecessary lineage", "vec")): 

546 "Saves primal solution as matlab file" 

547 from scipy.io import savemat 

548 savemat(filename, 

549 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

550 for name, key in self.varnames(showvars, excluded).items()}) 

551 

552 def todataframe(self, showvars=None, 

553 excluded=("unnecessary lineage", "vec")): 

554 "Returns primal solution as pandas dataframe" 

555 import pandas as pd # pylint:disable=import-error 

556 rows = [] 

557 cols = ["Name", "Index", "Value", "Units", "Label", 

558 "Lineage", "Other"] 

559 for _, key in sorted(self.varnames(showvars, excluded).items(), 

560 key=lambda k: k[0]): 

561 value = self["variables"][key] 

562 if key.shape: 

563 idxs = [] 

564 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

565 while not it.finished: 

566 idx = it.multi_index 

567 idxs.append(idx[0] if len(idx) == 1 else idx) 

568 it.iternext() 

569 else: 

570 idxs = [None] 

571 for idx in idxs: 

572 row = [ 

573 key.name, 

574 "" if idx is None else idx, 

575 value if idx is None else value[idx]] 

576 rows.append(row) 

577 row.extend([ 

578 key.unitstr(), 

579 key.label or "", 

580 key.lineage or "", 

581 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

582 if k not in ["name", "units", "unitrepr", 

583 "idx", "shape", "veckey", 

584 "value", "vecfn", 

585 "lineage", "label"])]) 

586 return pd.DataFrame(rows, columns=cols) 

587 

588 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

589 "Saves solution table as a text file" 

590 with open(filename, "w") as f: 

591 if printmodel: 

592 f.write(self.modelstr + "\n") 

593 f.write(self.table(**kwargs)) 

594 

595 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

596 "Saves primal solution as a CSV sorted by modelname, like the tables." 

597 data = self["variables"] 

598 if showvars: 

599 showvars = self._parse_showvars(showvars) 

600 data = {k: data[k] for k in showvars if k in data} 

601 # if the columns don't capture any dimensions, skip them 

602 minspan, maxspan = None, 1 

603 for v in data.values(): 

604 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

605 minspan_ = min((di for di in v.shape if di != 1)) 

606 maxspan_ = max((di for di in v.shape if di != 1)) 

607 if minspan is None or minspan_ < minspan: 

608 minspan = minspan_ 

609 if maxspan is None or maxspan_ > maxspan: 

610 maxspan = maxspan_ 

611 if minspan is not None and minspan > valcols: 

612 valcols = 1 

613 if maxspan < valcols: 

614 valcols = maxspan 

615 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

616 tables=("cost", "sweepvariables", "freevariables", 

617 "constants", "sensitivities")) 

618 with open(filename, "w") as f: 

619 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

620 + "Units,Description\n") 

621 for line in lines: 

622 if line[0] == ("newmodelline",): 

623 f.write(line[1]) 

624 elif not line[1]: # spacer line 

625 f.write("\n") 

626 else: 

627 f.write("," + line[0].replace(" : ", "") + ",") 

628 vals = line[1].replace("[", "").replace("]", "").strip() 

629 for el in vals.split(): 

630 f.write(el + ",") 

631 f.write(","*(valcols - len(vals.split()))) 

632 f.write((line[2].replace("[", "").replace("]", "").strip() 

633 + ",")) 

634 f.write(line[3].strip() + "\n") 

635 

636 def subinto(self, posy): 

637 "Returns NomialArray of each solution substituted into posy." 

638 if posy in self["variables"]: 

639 return self["variables"](posy) 

640 

641 if not hasattr(posy, "sub"): 

642 raise ValueError("no variable '%s' found in the solution" % posy) 

643 

644 if len(self) > 1: 

645 return NomialArray([self.atindex(i).subinto(posy) 

646 for i in range(len(self))]) 

647 

648 return posy.sub(self["variables"]) 

649 

650 def _parse_showvars(self, showvars): 

651 showvars_out = set() 

652 for k in showvars: 

653 k, _ = self["variables"].parse_and_index(k) 

654 keys = self["variables"].keymap[k] 

655 showvars_out.update(keys) 

656 return showvars_out 

657 

658 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

659 "Print summary table, showing top sensitivities and no constants" 

660 showvars = self._parse_showvars(showvars) 

661 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

662 "freevariables"], **kwargs) 

663 constants_in_showvars = showvars.intersection(self["constants"]) 

664 senss_tables = [] 

665 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

666 senss_tables.append("sensitivities") 

667 if len(self["constants"]) >= ntopsenss+2: 

668 senss_tables.append("top sensitivities") 

669 senss_tables.append("tightest constraints") 

670 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

671 **kwargs) 

672 if senss_str: 

673 out += "\n" + senss_str 

674 return out 

675 

676 def table(self, showvars=(), 

677 tables=("cost", "warnings", "model sensitivities", 

678 "sweepvariables", "freevariables", 

679 "constants", "sensitivities", "tightest constraints"), 

680 sortmodelsbysenss=True, **kwargs): 

681 """A table representation of this SolutionArray 

682 

683 Arguments 

684 --------- 

685 tables: Iterable 

686 Which to print of ("cost", "sweepvariables", "freevariables", 

687 "constants", "sensitivities") 

688 fixedcols: If true, print vectors in fixed-width format 

689 latex: int 

690 If > 0, return latex format (options 1-3); otherwise plain text 

691 included_models: Iterable of strings 

692 If specified, the models (by name) to include 

693 excluded_models: Iterable of strings 

694 If specified, model names to exclude 

695 

696 Returns 

697 ------- 

698 str 

699 """ 

700 if sortmodelsbysenss and "sensitivities" in self: 

701 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

702 else: 

703 kwargs["sortmodelsbysenss"] = False 

704 varlist = list(self["variables"]) 

705 has_only_one_model = True 

706 for var in varlist[1:]: 

707 if var.lineage != varlist[0].lineage: 

708 has_only_one_model = False 

709 break 

710 if has_only_one_model: 

711 kwargs["sortbymodel"] = False 

712 for key in self.name_collision_varkeys(): 

713 key.descr["necessarylineage"] = True 

714 showvars = self._parse_showvars(showvars) 

715 strs = [] 

716 for table in tables: 

717 if "sensitivities" not in self and ("sensitivities" in table or 

718 "constraints" in table): 

719 continue 

720 if table == "cost": 

721 cost = self["cost"] # pylint: disable=unsubscriptable-object 

722 if kwargs.get("latex", None): # cost is not printed for latex 

723 continue 

724 strs += ["\n%s\n------------" % "Optimal Cost"] 

725 if len(self) > 1: 

726 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

727 strs += [" [ %s %s ]" % (" ".join(costs), 

728 "..." if len(self) > 4 else "")] 

729 else: 

730 strs += [" %-.4g" % mag(cost)] 

731 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

732 strs += [""] 

733 elif table in TABLEFNS: 

734 strs += TABLEFNS[table](self, showvars, **kwargs) 

735 elif table in self: 

736 data = self[table] 

737 if showvars: 

738 showvars = self._parse_showvars(showvars) 

739 data = {k: data[k] for k in showvars if k in data} 

740 strs += var_table(data, self.table_titles[table], **kwargs) 

741 if kwargs.get("latex", None): 

742 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

743 "% \\usepackage{booktabs}", 

744 "% \\usepackage{longtable}", 

745 "% \\usepackage{amsmath}", 

746 "% \\begin{document}\n")) 

747 strs = [preamble] + strs + ["% \\end{document}"] 

748 for key in self.name_collision_varkeys(): 

749 del key.descr["necessarylineage"] 

750 return "\n".join(strs) 

751 

752 def plot(self, posys=None, axes=None): 

753 "Plots a sweep for each posy" 

754 if len(self["sweepvariables"]) != 1: 

755 print("SolutionArray.plot only supports 1-dimensional sweeps") 

756 if not hasattr(posys, "__len__"): 

757 posys = [posys] 

758 import matplotlib.pyplot as plt 

759 from .interactive.plot_sweep import assign_axes 

760 from . import GPBLU 

761 (swept, x), = self["sweepvariables"].items() 

762 posys, axes = assign_axes(swept, posys, axes) 

763 for posy, ax in zip(posys, axes): 

764 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

765 ax.plot(x, y, color=GPBLU) 

766 if len(axes) == 1: 

767 axes, = axes 

768 return plt.gcf(), axes 

769 

770 

771# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

772def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

773 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

774 minval=0, sortbyvals=False, hidebelowminval=False, 

775 included_models=None, excluded_models=None, sortbymodel=True, 

776 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

777 """ 

778 Pretty string representation of a dict of VarKeys 

779 Iterable values are handled specially (partial printing) 

780 

781 Arguments 

782 --------- 

783 data : dict whose keys are VarKey's 

784 data to represent in table 

785 title : string 

786 printunits : bool 

787 latex : int 

788 If > 0, return latex format (options 1-3); otherwise plain text 

789 varfmt : string 

790 format for variable names 

791 valfmt : string 

792 format for scalar values 

793 vecfmt : string 

794 format for vector values 

795 minval : float 

796 skip values with all(abs(value)) < minval 

797 sortbyvals : boolean 

798 If true, rows are sorted by their average value instead of by name. 

799 included_models : Iterable of strings 

800 If specified, the models (by name) to include 

801 excluded_models : Iterable of strings 

802 If specified, model names to exclude 

803 """ 

804 if not data: 

805 return [] 

806 decorated, models = [], set() 

807 for i, (k, v) in enumerate(data.items()): 

808 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

809 continue # no values below minval 

810 if minval and hidebelowminval and getattr(v, "shape", None): 

811 v[np.abs(v) <= minval] = np.nan 

812 model = lineagestr(k.lineage) if sortbymodel else "" 

813 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

814 if hasattr(msenss, "shape"): 

815 msenss = np.mean(msenss) 

816 models.add(model) 

817 b = bool(getattr(v, "shape", None)) 

818 s = k.str_without(("lineage", "vec")) 

819 if not sortbyvals: 

820 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

821 else: # for consistent sorting, add small offset to negative vals 

822 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

823 sort = (float("%.4g" % -val), k.name) 

824 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

825 if not decorated and skipifempty: 

826 return [] 

827 if included_models: 

828 included_models = set(included_models) 

829 included_models.add("") 

830 models = models.intersection(included_models) 

831 if excluded_models: 

832 models = models.difference(excluded_models) 

833 decorated.sort() 

834 previous_model, lines = None, [] 

835 for varlist in decorated: 

836 if sortbyvals: 

837 model, _, msenss, isvector, varstr, _, var, val = varlist 

838 else: 

839 msenss, model, isvector, varstr, _, var, val = varlist 

840 if model not in models: 

841 continue 

842 if model != previous_model: 

843 if lines: 

844 lines.append(["", "", "", ""]) 

845 if model: 

846 if not latex: 

847 lines.append([("newmodelline",), model, "", ""]) 

848 else: 

849 lines.append( 

850 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

851 previous_model = model 

852 label = var.descr.get("label", "") 

853 units = var.unitstr(" [%s] ") if printunits else "" 

854 if not isvector: 

855 valstr = valfmt % val 

856 else: 

857 last_dim_index = len(val.shape)-1 

858 horiz_dim, ncols = last_dim_index, 1 # starting values 

859 for dim_idx, dim_size in enumerate(val.shape): 

860 if ncols <= dim_size <= maxcolumns: 

861 horiz_dim, ncols = dim_idx, dim_size 

862 # align the array with horiz_dim by making it the last one 

863 dim_order = list(range(last_dim_index)) 

864 dim_order.insert(horiz_dim, last_dim_index) 

865 flatval = val.transpose(dim_order).flatten() 

866 vals = [vecfmt % v for v in flatval[:ncols]] 

867 bracket = " ] " if len(flatval) <= ncols else "" 

868 valstr = "[ %s%s" % (" ".join(vals), bracket) 

869 for before, after in VALSTR_REPLACES: 

870 valstr = valstr.replace(before, after) 

871 if not latex: 

872 lines.append([varstr, valstr, units, label]) 

873 if isvector and len(flatval) > ncols: 

874 values_remaining = len(flatval) - ncols 

875 while values_remaining > 0: 

876 idx = len(flatval)-values_remaining 

877 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

878 values_remaining -= ncols 

879 valstr = " " + " ".join(vals) 

880 for before, after in VALSTR_REPLACES: 

881 valstr = valstr.replace(before, after) 

882 if values_remaining <= 0: 

883 spaces = (-values_remaining 

884 * len(valstr)//(values_remaining + ncols)) 

885 valstr = valstr + " ]" + " "*spaces 

886 lines.append(["", valstr, "", ""]) 

887 else: 

888 varstr = "$%s$" % varstr.replace(" : ", "") 

889 if latex == 1: # normal results table 

890 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

891 label]) 

892 coltitles = [title, "Value", "Units", "Description"] 

893 elif latex == 2: # no values 

894 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

895 coltitles = [title, "Units", "Description"] 

896 elif latex == 3: # no description 

897 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

898 coltitles = [title, "Value", "Units"] 

899 else: 

900 raise ValueError("Unexpected latex option, %s." % latex) 

901 if rawlines: 

902 return lines 

903 if not latex: 

904 if lines: 

905 maxlens = np.max([list(map(len, line)) for line in lines 

906 if line[0] != ("newmodelline",)], axis=0) 

907 dirs = [">", "<", "<", "<"] 

908 # check lengths before using zip 

909 assert len(list(dirs)) == len(list(maxlens)) 

910 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

911 for i, line in enumerate(lines): 

912 if line[0] == ("newmodelline",): 

913 line = [fmts[0].format(" | "), line[1]] 

914 else: 

915 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

916 lines[i] = "".join(line).rstrip() 

917 lines = [title] + ["-"*len(title)] + lines + [""] 

918 else: 

919 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

920 lines = (["\n".join(["{\\footnotesize", 

921 "\\begin{longtable}{%s}" % colfmt[latex], 

922 "\\toprule", 

923 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

924 [" & ".join(l) + " \\\\" for l in lines] + 

925 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

926 return lines