Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import json 

4import difflib 

5from operator import sub 

6import warnings as pywarnings 

7import pickle 

8import gzip 

9import pickletools 

10import numpy as np 

11from .nomials import NomialArray 

12from .small_classes import DictOfLists, Strings 

13from .small_scripts import mag, try_str_without 

14from .repr_conventions import unitstr, lineagestr 

15 

16 

17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

18 

19VALSTR_REPLACES = [ 

20 ("+nan", " nan"), 

21 ("-nan", " nan"), 

22 ("nan%", "nan "), 

23 ("nan", " - "), 

24] 

25 

26 

27class SolSavingEnvironment: 

28 """Temporarily removes construction/solve attributes from constraints. 

29 

30 This approximately halves the size of the pickled solution. 

31 """ 

32 

33 def __init__(self, solarray, saveconstraints): 

34 self.solarray = solarray 

35 self.attrstore = {} 

36 self.saveconstraints = saveconstraints 

37 self.constraintstore = None 

38 

39 

40 def __enter__(self): 

41 if self.saveconstraints: 

42 for constraint_attr in ["bounded", "meq_bounded", "vks", 

43 "v_ss", "unsubbed", "varkeys"]: 

44 store = {} 

45 for constraint in self.solarray["sensitivities"]["constraints"]: 

46 if getattr(constraint, constraint_attr, None): 

47 store[constraint] = getattr(constraint, constraint_attr) 

48 delattr(constraint, constraint_attr) 

49 self.attrstore[constraint_attr] = store 

50 else: 

51 self.constraintstore = \ 

52 self.solarray["sensitivities"].pop("constraints") 

53 

54 def __exit__(self, type_, val, traceback): 

55 if self.saveconstraints: 

56 for constraint_attr, store in self.attrstore.items(): 

57 for constraint, value in store.items(): 

58 setattr(constraint, constraint_attr, value) 

59 else: 

60 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

61 

62def msenss_table(data, _, **kwargs): 

63 "Returns model sensitivity table lines" 

64 if "models" not in data.get("sensitivities", {}): 

65 return "" 

66 data = sorted(data["sensitivities"]["models"].items(), 

67 key=lambda i: ((i[1] < 0.1).all(), 

68 -np.max(i[1]) if (i[1] < 0.1).all() 

69 else -round(np.mean(i[1]), 1), i[0])) 

70 lines = ["Model Sensitivities", "-------------------"] 

71 if kwargs["sortmodelsbysenss"]: 

72 lines[0] += " (sorts models in sections below)" 

73 previousmsenssstr = "" 

74 for model, msenss in data: 

75 if not model: # for now let's only do named models 

76 continue 

77 if (msenss < 0.1).all(): 

78 msenss = np.max(msenss) 

79 if msenss: 

80 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss)) 

81 else: 

82 msenssstr = " =0 " 

83 else: 

84 meansenss = round(np.mean(msenss), 1) 

85 msenssstr = "%+6.1f" % meansenss 

86 deltas = msenss - meansenss 

87 if np.max(np.abs(deltas)) > 0.1: 

88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

89 for d in deltas] 

90 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

91 if msenssstr == previousmsenssstr: 

92 msenssstr = " "*len(msenssstr) 

93 else: 

94 previousmsenssstr = msenssstr 

95 lines.append("%s : %s" % (msenssstr, model)) 

96 return lines + [""] if len(lines) > 3 else [] 

97 

98 

99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

100 "Returns sensitivity table lines" 

101 if "variables" in data.get("sensitivities", {}): 

102 data = data["sensitivities"]["variables"] 

103 if showvars: 

104 data = {k: data[k] for k in showvars if k in data} 

105 return var_table(data, title, sortbyvals=True, skipifempty=True, 

106 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

107 printunits=False, minval=1e-3, **kwargs) 

108 

109 

110def topsenss_table(data, showvars, nvars=5, **kwargs): 

111 "Returns top sensitivity table lines" 

112 data, filtered = topsenss_filter(data, showvars, nvars) 

113 title = "Most Sensitive Variables" 

114 if filtered: 

115 title = "Next Most Sensitive Variables" 

116 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

117 

118 

119def topsenss_filter(data, showvars, nvars=5): 

120 "Filters sensitivities down to top N vars" 

121 if "variables" in data.get("sensitivities", {}): 

122 data = data["sensitivities"]["variables"] 

123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

124 if not np.isnan(s).any()} 

125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

126 filter_already_shown = showvars.intersection(topk) 

127 for k in filter_already_shown: 

128 topk.remove(k) 

129 if nvars > 3: # always show at least 3 

130 nvars -= 1 

131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

132 

133 

134def insenss_table(data, _, maxval=0.1, **kwargs): 

135 "Returns insensitivity table lines" 

136 if "constants" in data.get("sensitivities", {}): 

137 data = data["sensitivities"]["variables"] 

138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

140 

141 

142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

143 "Return constraint tightness lines" 

144 title = "Most Sensitive Constraints" 

145 if len(self) > 1: 

146 title += " (in last sweep)" 

147 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

148 "%+6.2g" % s[-1], id(c), c) 

149 for c, s in self["sensitivities"]["constraints"].items() 

150 if s[-1] >= tight_senss)[:ntightconstrs] 

151 else: 

152 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

153 for c, s in self["sensitivities"]["constraints"].items() 

154 if s >= tight_senss)[:ntightconstrs] 

155 return constraint_table(data, title, **kwargs) 

156 

157def loose_table(self, _, min_senss=1e-5, **kwargs): 

158 "Return constraint tightness lines" 

159 title = "Insensitive Constraints |below %+g|" % min_senss 

160 if len(self) > 1: 

161 title += " (in last sweep)" 

162 data = [(0, "", id(c), c) 

163 for c, s in self["sensitivities"]["constraints"].items() 

164 if s[-1] <= min_senss] 

165 else: 

166 data = [(0, "", id(c), c) 

167 for c, s in self["sensitivities"]["constraints"].items() 

168 if s <= min_senss] 

169 return constraint_table(data, title, **kwargs) 

170 

171 

172# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

174 "Creates lines for tables where the right side is a constraint." 

175 # TODO: this should support 1D array inputs from sweeps 

176 excluded = ("units", "unnecessary lineage") 

177 if not showmodels: 

178 excluded = ("units", "lineage") # hide all of it 

179 models, decorated = {}, [] 

180 for sortby, openingstr, _, constraint in sorted(data): 

181 model = lineagestr(constraint) if sortbymodel else "" 

182 if model not in models: 

183 models[model] = len(models) 

184 constrstr = try_str_without(constraint, excluded) 

185 if " at 0x" in constrstr: # don't print memory addresses 

186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

187 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

188 decorated.sort() 

189 previous_model, lines = None, [] 

190 for varlist in decorated: 

191 _, model, _, constrstr, openingstr = varlist 

192 if model != previous_model: 

193 if lines: 

194 lines.append(["", ""]) 

195 if model or lines: 

196 lines.append([("newmodelline",), model]) 

197 previous_model = model 

198 constrstr = constrstr.replace(model, "") 

199 minlen, maxlen = 25, 80 

200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

201 constraintlines = [] 

202 line = "" 

203 next_idx = 0 

204 while next_idx < len(segments): 

205 segment = segments[next_idx] 

206 next_idx += 1 

207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

208 segments[next_idx] = segment[1:] + segments[next_idx] 

209 segment = segment[0] 

210 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

211 constraintlines.append(line) 

212 line = " " # start a new line 

213 line += segment 

214 while len(line) > maxlen: 

215 constraintlines.append(line[:maxlen]) 

216 line = " " + line[maxlen:] 

217 constraintlines.append(line) 

218 lines += [(openingstr + " : ", constraintlines[0])] 

219 lines += [("", l) for l in constraintlines[1:]] 

220 if not lines: 

221 lines = [("", "(none)")] 

222 maxlens = np.max([list(map(len, line)) for line in lines 

223 if line[0] != ("newmodelline",)], axis=0) 

224 dirs = [">", "<"] # we'll check lengths before using zip 

225 assert len(list(dirs)) == len(list(maxlens)) 

226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

227 for i, line in enumerate(lines): 

228 if line[0] == ("newmodelline",): 

229 linelist = [fmts[0].format(" | "), line[1]] 

230 else: 

231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

232 lines[i] = "".join(linelist).rstrip() 

233 return [title] + ["-"*len(title)] + lines + [""] 

234 

235 

236def warnings_table(self, _, **kwargs): 

237 "Makes a table for all warnings in the solution." 

238 title = "WARNINGS" 

239 lines = ["~"*len(title), title, "~"*len(title)] 

240 if "warnings" not in self or not self["warnings"]: 

241 return [] 

242 for wtype in sorted(self["warnings"]): 

243 data_vec = self["warnings"][wtype] 

244 if len(data_vec) == 0: 

245 continue 

246 if not hasattr(data_vec, "shape"): 

247 data_vec = [data_vec] # not a sweep 

248 else: 

249 all_equal = True 

250 for data in data_vec[1:]: 

251 eq_i = (data == data_vec[0]) 

252 if hasattr(eq_i, "all"): 

253 eq_i = eq_i.all() 

254 if not eq_i: 

255 all_equal = False 

256 break 

257 if all_equal: 

258 data_vec = [data_vec[0]] # warnings identical across sweeps 

259 for i, data in enumerate(data_vec): 

260 if len(data) == 0: 

261 continue 

262 data = sorted(data, key=lambda l: l[0]) # sort by msg 

263 title = wtype 

264 if len(data_vec) > 1: 

265 title += " in sweep %i" % i 

266 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

267 data = [(-int(1e5*relax_sensitivity), 

268 "%+6.2g" % relax_sensitivity, id(c), c) 

269 for _, (relax_sensitivity, c) in data] 

270 lines += constraint_table(data, title, **kwargs) 

271 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

272 data = [(-int(1e5*rel_diff), 

273 "%.4g %s %.4g" % tightvalues, id(c), c) 

274 for _, (rel_diff, tightvalues, c) in data] 

275 lines += constraint_table(data, title, **kwargs) 

276 else: 

277 lines += [title] + ["-"*len(wtype)] 

278 lines += [msg for msg, _ in data] + [""] 

279 if len(lines) == 3: # just the header 

280 return [] 

281 lines[-1] = "~~~~~~~~" 

282 return lines + [""] 

283 

284 

285TABLEFNS = {"sensitivities": senss_table, 

286 "top sensitivities": topsenss_table, 

287 "insensitivities": insenss_table, 

288 "model sensitivities": msenss_table, 

289 "tightest constraints": tight_table, 

290 "loose constraints": loose_table, 

291 "warnings": warnings_table, 

292 } 

293 

294def unrolled_absmax(values): 

295 "From an iterable of numbers and arrays, returns the largest magnitude" 

296 finalval, absmaxest = None, 0 

297 for val in values: 

298 absmaxval = np.abs(val).max() 

299 if absmaxval >= absmaxest: 

300 absmaxest, finalval = absmaxval, val 

301 if getattr(finalval, "shape", None): 

302 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

303 finalval.shape)] 

304 return finalval 

305 

306 

307def cast(function, val1, val2): 

308 "Relative difference between val1 and val2 (positive if val2 is larger)" 

309 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

310 pywarnings.simplefilter("ignore") 

311 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

312 if val1.ndim == val2.ndim: 

313 return function(val1, val2) 

314 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

315 dimdelta = dimmest.ndim - lessdim.ndim 

316 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

317 if dimmest is val1: 

318 return function(dimmest, lessdim[add_axes]) 

319 if dimmest is val2: 

320 return function(lessdim[add_axes], dimmest) 

321 return function(val1, val2) 

322 

323 

324class SolutionArray(DictOfLists): 

325 """A dictionary (of dictionaries) of lists, with convenience methods. 

326 

327 Items 

328 ----- 

329 cost : array 

330 variables: dict of arrays 

331 sensitivities: dict containing: 

332 monomials : array 

333 posynomials : array 

334 variables: dict of arrays 

335 localmodels : NomialArray 

336 Local power-law fits (small sensitivities are cut off) 

337 

338 Example 

339 ------- 

340 >>> import gpkit 

341 >>> import numpy as np 

342 >>> x = gpkit.Variable("x") 

343 >>> x_min = gpkit.Variable("x_{min}", 2) 

344 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

345 >>> 

346 >>> # VALUES 

347 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

348 >>> assert all(np.array(values) == 2) 

349 >>> 

350 >>> # SENSITIVITIES 

351 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

352 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

353 >>> assert all(np.array(senss) == 1) 

354 """ 

355 modelstr = "" 

356 _name_collision_varkeys = None 

357 table_titles = {"choicevariables": "Choice Variables", 

358 "sweepvariables": "Swept Variables", 

359 "freevariables": "Free Variables", 

360 "constants": "Fixed Variables", # TODO: change everywhere 

361 "variables": "Variables"} 

362 

363 def name_collision_varkeys(self): 

364 "Returns the set of contained varkeys whose names are not unique" 

365 if self._name_collision_varkeys is None: 

366 self["variables"].update_keymap() 

367 keymap = self["variables"].keymap 

368 self._name_collision_varkeys = set() 

369 for key in list(keymap): 

370 if hasattr(key, "key"): 

371 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

372 self._name_collision_varkeys.add(key) 

373 return self._name_collision_varkeys 

374 

375 def __len__(self): 

376 try: 

377 return len(self["cost"]) 

378 except TypeError: 

379 return 1 

380 except KeyError: 

381 return 0 

382 

383 def __call__(self, posy): 

384 posy_subbed = self.subinto(posy) 

385 return getattr(posy_subbed, "c", posy_subbed) 

386 

387 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

388 "Checks for almost-equality between two solutions" 

389 svars, ovars = self["variables"], other["variables"] 

390 svks, ovks = set(svars), set(ovars) 

391 if svks != ovks: 

392 return False 

393 for key in svks: 

394 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

395 return False 

396 if abs(self["sensitivities"]["variables"][key] 

397 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

398 return False 

399 return True 

400 

401 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

402 def diff(self, other, showvars=None, *, 

403 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

404 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

405 sortmodelsbysenss=True, **tableargs): 

406 """Outputs differences between this solution and another 

407 

408 Arguments 

409 --------- 

410 other : solution or string 

411 strings will be treated as paths to pickled solutions 

412 senssdiff : boolean 

413 if True, show sensitivity differences 

414 sensstol : float 

415 the smallest sensitivity difference worth showing 

416 absdiff : boolean 

417 if True, show absolute differences 

418 abstol : float 

419 the smallest absolute difference worth showing 

420 reldiff : boolean 

421 if True, show relative differences 

422 reltol : float 

423 the smallest relative difference worth showing 

424 

425 Returns 

426 ------- 

427 str 

428 """ 

429 if sortmodelsbysenss: 

430 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

431 else: 

432 tableargs["sortmodelsbysenss"] = False 

433 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

434 "skipifempty": False}) 

435 if isinstance(other, Strings): 

436 if other[-4:] == ".pgz": 

437 other = SolutionArray.decompress_file(other) 

438 else: 

439 other = pickle.load(open(other, "rb")) 

440 svars, ovars = self["variables"], other["variables"] 

441 lines = ["Solution Diff", 

442 "=============", 

443 "(argument is the baseline solution)", ""] 

444 svks, ovks = set(svars), set(ovars) 

445 if showvars: 

446 lines[0] += " (for selected variables)" 

447 lines[1] += "=========================" 

448 showvars = self._parse_showvars(showvars) 

449 svks = {k for k in showvars if k in svars} 

450 ovks = {k for k in showvars if k in ovars} 

451 if constraintsdiff and other.modelstr and self.modelstr: 

452 if self.modelstr == other.modelstr: 

453 lines += ["** no constraint differences **", ""] 

454 else: 

455 cdiff = ["Constraint Differences", 

456 "**********************"] 

457 cdiff.extend(list(difflib.unified_diff( 

458 other.modelstr.split("\n"), self.modelstr.split("\n"), 

459 lineterm="", n=3))[2:]) 

460 cdiff += ["", "**********************", ""] 

461 lines += cdiff 

462 if svks - ovks: 

463 lines.append("Variable(s) of this solution" 

464 " which are not in the argument:") 

465 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

466 lines.append("") 

467 if ovks - svks: 

468 lines.append("Variable(s) of the argument" 

469 " which are not in this solution:") 

470 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

471 lines.append("") 

472 sharedvks = svks.intersection(ovks) 

473 if reldiff: 

474 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

475 for vk in sharedvks} 

476 lines += var_table(rel_diff, 

477 "Relative Differences |above %g%%|" % reltol, 

478 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

479 minval=reltol, printunits=False, **tableargs) 

480 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

481 lines.insert(-1, ("The largest is %+g%%." 

482 % unrolled_absmax(rel_diff.values()))) 

483 if absdiff: 

484 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

485 lines += var_table(abs_diff, 

486 "Absolute Differences |above %g|" % abstol, 

487 valfmt="%+.2g", vecfmt="%+8.2g", 

488 minval=abstol, **tableargs) 

489 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

490 lines.insert(-1, ("The largest is %+g." 

491 % unrolled_absmax(abs_diff.values()))) 

492 if senssdiff: 

493 ssenss = self["sensitivities"]["variables"] 

494 osenss = other["sensitivities"]["variables"] 

495 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

496 for vk in svks.intersection(ovks)} 

497 lines += var_table(senss_delta, 

498 "Sensitivity Differences |above %g|" % sensstol, 

499 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

500 minval=sensstol, printunits=False, **tableargs) 

501 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

502 lines.insert(-1, ("The largest is %+g." 

503 % unrolled_absmax(senss_delta.values()))) 

504 return "\n".join(lines) 

505 

506 def save(self, filename="solution.pkl", 

507 *, saveconstraints=True, **pickleargs): 

508 """Pickles the solution and saves it to a file. 

509 

510 Solution can then be loaded with e.g.: 

511 >>> import pickle 

512 >>> pickle.load(open("solution.pkl")) 

513 """ 

514 with SolSavingEnvironment(self, saveconstraints): 

515 pickle.dump(self, open(filename, "wb"), **pickleargs) 

516 

517 def save_compressed(self, filename="solution.pgz", 

518 *, saveconstraints=True, **cpickleargs): 

519 "Pickle a file and then compress it into a file with extension." 

520 with gzip.open(filename, "wb") as f: 

521 with SolSavingEnvironment(self, saveconstraints): 

522 pickled = pickle.dumps(self, **cpickleargs) 

523 f.write(pickletools.optimize(pickled)) 

524 

525 @staticmethod 

526 def decompress_file(file): 

527 "Load a gzip-compressed pickle file" 

528 with gzip.open(file, "rb") as f: 

529 return pickle.Unpickler(f).load() 

530 

531 def varnames(self, showvars, exclude): 

532 "Returns list of variables, optionally with minimal unique names" 

533 if showvars: 

534 showvars = self._parse_showvars(showvars) 

535 for key in self.name_collision_varkeys(): 

536 key.descr["necessarylineage"] = True 

537 names = {} 

538 for key in showvars or self["variables"]: 

539 for k in self["variables"].keymap[key]: 

540 names[k.str_without(exclude)] = k 

541 for key in self.name_collision_varkeys(): 

542 del key.descr["necessarylineage"] 

543 return names 

544 

545 def savemat(self, filename="solution.mat", *, showvars=None, 

546 excluded=("unnecessary lineage", "vec")): 

547 "Saves primal solution as matlab file" 

548 from scipy.io import savemat 

549 savemat(filename, 

550 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

551 for name, key in self.varnames(showvars, excluded).items()}) 

552 

553 def todataframe(self, showvars=None, 

554 excluded=("unnecessary lineage", "vec")): 

555 "Returns primal solution as pandas dataframe" 

556 import pandas as pd # pylint:disable=import-error 

557 rows = [] 

558 cols = ["Name", "Index", "Value", "Units", "Label", 

559 "Lineage", "Other"] 

560 for _, key in sorted(self.varnames(showvars, excluded).items(), 

561 key=lambda k: k[0]): 

562 value = self["variables"][key] 

563 if key.shape: 

564 idxs = [] 

565 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

566 while not it.finished: 

567 idx = it.multi_index 

568 idxs.append(idx[0] if len(idx) == 1 else idx) 

569 it.iternext() 

570 else: 

571 idxs = [None] 

572 for idx in idxs: 

573 row = [ 

574 key.name, 

575 "" if idx is None else idx, 

576 value if idx is None else value[idx]] 

577 rows.append(row) 

578 row.extend([ 

579 key.unitstr(), 

580 key.label or "", 

581 key.lineage or "", 

582 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

583 if k not in ["name", "units", "unitrepr", 

584 "idx", "shape", "veckey", 

585 "value", "vecfn", 

586 "lineage", "label"])]) 

587 return pd.DataFrame(rows, columns=cols) 

588 

589 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

590 "Saves solution table as a text file" 

591 with open(filename, "w") as f: 

592 if printmodel: 

593 f.write(self.modelstr + "\n") 

594 f.write(self.table(**kwargs)) 

595 

596 def savejson(self, filename="solution.json", printjson=False, showvars=None): 

597 "Saves solution table as a json file" 

598 sol_dict = {} 

599 for key in self.name_collision_varkeys(): 

600 key.descr["necessarylineage"] = True 

601 data = self["variables"] 

602 if showvars: 

603 showvars = self._parse_showvars(showvars) 

604 data = {k: data[k] for k in showvars if k in data} 

605 # add appropriate data for each variable to the dictionary 

606 for i, (k, v) in enumerate(data.items()): 

607 key = str(k) 

608 if isinstance(v, np.ndarray): 

609 val = {"v": v.tolist(), "u": k.unitstr()} 

610 else: 

611 val = {"v": v, "u": k.unitstr()} 

612 sol_dict[key] = val 

613 for key in self.name_collision_varkeys(): 

614 del key.descr["necessarylineage"] 

615 if printjson: 

616 return str(sol_dict) 

617 else: 

618 with open(filename, "w") as f: 

619 json.dump(sol_dict, f) 

620 

621 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

622 "Saves primal solution as a CSV sorted by modelname, like the tables." 

623 data = self["variables"] 

624 if showvars: 

625 showvars = self._parse_showvars(showvars) 

626 data = {k: data[k] for k in showvars if k in data} 

627 # if the columns don't capture any dimensions, skip them 

628 minspan, maxspan = None, 1 

629 for v in data.values(): 

630 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

631 minspan_ = min((di for di in v.shape if di != 1)) 

632 maxspan_ = max((di for di in v.shape if di != 1)) 

633 if minspan is None or minspan_ < minspan: 

634 minspan = minspan_ 

635 if maxspan is None or maxspan_ > maxspan: 

636 maxspan = maxspan_ 

637 if minspan is not None and minspan > valcols: 

638 valcols = 1 

639 if maxspan < valcols: 

640 valcols = maxspan 

641 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

642 tables=("cost", "sweepvariables", "freevariables", 

643 "constants", "sensitivities")) 

644 with open(filename, "w") as f: 

645 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

646 + "Units,Description\n") 

647 for line in lines: 

648 if line[0] == ("newmodelline",): 

649 f.write(line[1]) 

650 elif not line[1]: # spacer line 

651 f.write("\n") 

652 else: 

653 f.write("," + line[0].replace(" : ", "") + ",") 

654 vals = line[1].replace("[", "").replace("]", "").strip() 

655 for el in vals.split(): 

656 f.write(el + ",") 

657 f.write(","*(valcols - len(vals.split()))) 

658 f.write((line[2].replace("[", "").replace("]", "").strip() 

659 + ",")) 

660 f.write(line[3].strip() + "\n") 

661 

662 def subinto(self, posy): 

663 "Returns NomialArray of each solution substituted into posy." 

664 if posy in self["variables"]: 

665 return self["variables"](posy) 

666 

667 if not hasattr(posy, "sub"): 

668 raise ValueError("no variable '%s' found in the solution" % posy) 

669 

670 if len(self) > 1: 

671 return NomialArray([self.atindex(i).subinto(posy) 

672 for i in range(len(self))]) 

673 

674 return posy.sub(self["variables"]) 

675 

676 def _parse_showvars(self, showvars): 

677 showvars_out = set() 

678 for k in showvars: 

679 k, _ = self["variables"].parse_and_index(k) 

680 keys = self["variables"].keymap[k] 

681 showvars_out.update(keys) 

682 return showvars_out 

683 

684 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

685 "Print summary table, showing top sensitivities and no constants" 

686 showvars = self._parse_showvars(showvars) 

687 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

688 "freevariables"], **kwargs) 

689 constants_in_showvars = showvars.intersection(self["constants"]) 

690 senss_tables = [] 

691 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

692 senss_tables.append("sensitivities") 

693 if len(self["constants"]) >= ntopsenss+2: 

694 senss_tables.append("top sensitivities") 

695 senss_tables.append("tightest constraints") 

696 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

697 **kwargs) 

698 if senss_str: 

699 out += "\n" + senss_str 

700 return out 

701 

702 def table(self, showvars=(), 

703 tables=("cost", "warnings", "model sensitivities", 

704 "sweepvariables", "freevariables", 

705 "constants", "sensitivities", "tightest constraints"), 

706 sortmodelsbysenss=True, **kwargs): 

707 """A table representation of this SolutionArray 

708 

709 Arguments 

710 --------- 

711 tables: Iterable 

712 Which to print of ("cost", "sweepvariables", "freevariables", 

713 "constants", "sensitivities") 

714 fixedcols: If true, print vectors in fixed-width format 

715 latex: int 

716 If > 0, return latex format (options 1-3); otherwise plain text 

717 included_models: Iterable of strings 

718 If specified, the models (by name) to include 

719 excluded_models: Iterable of strings 

720 If specified, model names to exclude 

721 

722 Returns 

723 ------- 

724 str 

725 """ 

726 if sortmodelsbysenss and "sensitivities" in self: 

727 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

728 else: 

729 kwargs["sortmodelsbysenss"] = False 

730 varlist = list(self["variables"]) 

731 has_only_one_model = True 

732 for var in varlist[1:]: 

733 if var.lineage != varlist[0].lineage: 

734 has_only_one_model = False 

735 break 

736 if has_only_one_model: 

737 kwargs["sortbymodel"] = False 

738 for key in self.name_collision_varkeys(): 

739 key.descr["necessarylineage"] = True 

740 showvars = self._parse_showvars(showvars) 

741 strs = [] 

742 for table in tables: 

743 if "sensitivities" not in self and ("sensitivities" in table or 

744 "constraints" in table): 

745 continue 

746 if table == "cost": 

747 cost = self["cost"] # pylint: disable=unsubscriptable-object 

748 if kwargs.get("latex", None): # cost is not printed for latex 

749 continue 

750 strs += ["\n%s\n------------" % "Optimal Cost"] 

751 if len(self) > 1: 

752 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

753 strs += [" [ %s %s ]" % (" ".join(costs), 

754 "..." if len(self) > 4 else "")] 

755 else: 

756 strs += [" %-.4g" % mag(cost)] 

757 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

758 strs += [""] 

759 elif table in TABLEFNS: 

760 strs += TABLEFNS[table](self, showvars, **kwargs) 

761 elif table in self: 

762 data = self[table] 

763 if showvars: 

764 showvars = self._parse_showvars(showvars) 

765 data = {k: data[k] for k in showvars if k in data} 

766 strs += var_table(data, self.table_titles[table], **kwargs) 

767 if kwargs.get("latex", None): 

768 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

769 "% \\usepackage{booktabs}", 

770 "% \\usepackage{longtable}", 

771 "% \\usepackage{amsmath}", 

772 "% \\begin{document}\n")) 

773 strs = [preamble] + strs + ["% \\end{document}"] 

774 for key in self.name_collision_varkeys(): 

775 del key.descr["necessarylineage"] 

776 return "\n".join(strs) 

777 

778 def plot(self, posys=None, axes=None): 

779 "Plots a sweep for each posy" 

780 if len(self["sweepvariables"]) != 1: 

781 print("SolutionArray.plot only supports 1-dimensional sweeps") 

782 if not hasattr(posys, "__len__"): 

783 posys = [posys] 

784 import matplotlib.pyplot as plt 

785 from .interactive.plot_sweep import assign_axes 

786 from . import GPBLU 

787 (swept, x), = self["sweepvariables"].items() 

788 posys, axes = assign_axes(swept, posys, axes) 

789 for posy, ax in zip(posys, axes): 

790 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

791 ax.plot(x, y, color=GPBLU) 

792 if len(axes) == 1: 

793 axes, = axes 

794 return plt.gcf(), axes 

795 

796 

797# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

798def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

799 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

800 minval=0, sortbyvals=False, hidebelowminval=False, 

801 included_models=None, excluded_models=None, sortbymodel=True, 

802 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

803 """ 

804 Pretty string representation of a dict of VarKeys 

805 Iterable values are handled specially (partial printing) 

806 

807 Arguments 

808 --------- 

809 data : dict whose keys are VarKey's 

810 data to represent in table 

811 title : string 

812 printunits : bool 

813 latex : int 

814 If > 0, return latex format (options 1-3); otherwise plain text 

815 varfmt : string 

816 format for variable names 

817 valfmt : string 

818 format for scalar values 

819 vecfmt : string 

820 format for vector values 

821 minval : float 

822 skip values with all(abs(value)) < minval 

823 sortbyvals : boolean 

824 If true, rows are sorted by their average value instead of by name. 

825 included_models : Iterable of strings 

826 If specified, the models (by name) to include 

827 excluded_models : Iterable of strings 

828 If specified, model names to exclude 

829 """ 

830 if not data: 

831 return [] 

832 decorated, models = [], set() 

833 for i, (k, v) in enumerate(data.items()): 

834 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

835 continue # no values below minval 

836 if minval and hidebelowminval and getattr(v, "shape", None): 

837 v[np.abs(v) <= minval] = np.nan 

838 model = lineagestr(k.lineage) if sortbymodel else "" 

839 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

840 if hasattr(msenss, "shape"): 

841 msenss = np.mean(msenss) 

842 models.add(model) 

843 b = bool(getattr(v, "shape", None)) 

844 s = k.str_without(("lineage", "vec")) 

845 if not sortbyvals: 

846 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

847 else: # for consistent sorting, add small offset to negative vals 

848 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

849 sort = (float("%.4g" % -val), k.name) 

850 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

851 if not decorated and skipifempty: 

852 return [] 

853 if included_models: 

854 included_models = set(included_models) 

855 included_models.add("") 

856 models = models.intersection(included_models) 

857 if excluded_models: 

858 models = models.difference(excluded_models) 

859 decorated.sort() 

860 previous_model, lines = None, [] 

861 for varlist in decorated: 

862 if sortbyvals: 

863 model, _, msenss, isvector, varstr, _, var, val = varlist 

864 else: 

865 msenss, model, isvector, varstr, _, var, val = varlist 

866 if model not in models: 

867 continue 

868 if model != previous_model: 

869 if lines: 

870 lines.append(["", "", "", ""]) 

871 if model: 

872 if not latex: 

873 lines.append([("newmodelline",), model, "", ""]) 

874 else: 

875 lines.append( 

876 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

877 previous_model = model 

878 label = var.descr.get("label", "") 

879 units = var.unitstr(" [%s] ") if printunits else "" 

880 if not isvector: 

881 valstr = valfmt % val 

882 else: 

883 last_dim_index = len(val.shape)-1 

884 horiz_dim, ncols = last_dim_index, 1 # starting values 

885 for dim_idx, dim_size in enumerate(val.shape): 

886 if ncols <= dim_size <= maxcolumns: 

887 horiz_dim, ncols = dim_idx, dim_size 

888 # align the array with horiz_dim by making it the last one 

889 dim_order = list(range(last_dim_index)) 

890 dim_order.insert(horiz_dim, last_dim_index) 

891 flatval = val.transpose(dim_order).flatten() 

892 vals = [vecfmt % v for v in flatval[:ncols]] 

893 bracket = " ] " if len(flatval) <= ncols else "" 

894 valstr = "[ %s%s" % (" ".join(vals), bracket) 

895 for before, after in VALSTR_REPLACES: 

896 valstr = valstr.replace(before, after) 

897 if not latex: 

898 lines.append([varstr, valstr, units, label]) 

899 if isvector and len(flatval) > ncols: 

900 values_remaining = len(flatval) - ncols 

901 while values_remaining > 0: 

902 idx = len(flatval)-values_remaining 

903 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

904 values_remaining -= ncols 

905 valstr = " " + " ".join(vals) 

906 for before, after in VALSTR_REPLACES: 

907 valstr = valstr.replace(before, after) 

908 if values_remaining <= 0: 

909 spaces = (-values_remaining 

910 * len(valstr)//(values_remaining + ncols)) 

911 valstr = valstr + " ]" + " "*spaces 

912 lines.append(["", valstr, "", ""]) 

913 else: 

914 varstr = "$%s$" % varstr.replace(" : ", "") 

915 if latex == 1: # normal results table 

916 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

917 label]) 

918 coltitles = [title, "Value", "Units", "Description"] 

919 elif latex == 2: # no values 

920 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

921 coltitles = [title, "Units", "Description"] 

922 elif latex == 3: # no description 

923 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

924 coltitles = [title, "Value", "Units"] 

925 else: 

926 raise ValueError("Unexpected latex option, %s." % latex) 

927 if rawlines: 

928 return lines 

929 if not latex: 

930 if lines: 

931 maxlens = np.max([list(map(len, line)) for line in lines 

932 if line[0] != ("newmodelline",)], axis=0) 

933 dirs = [">", "<", "<", "<"] 

934 # check lengths before using zip 

935 assert len(list(dirs)) == len(list(maxlens)) 

936 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

937 for i, line in enumerate(lines): 

938 if line[0] == ("newmodelline",): 

939 line = [fmts[0].format(" | "), line[1]] 

940 else: 

941 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

942 lines[i] = "".join(line).rstrip() 

943 lines = [title] + ["-"*len(title)] + lines + [""] 

944 else: 

945 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

946 lines = (["\n".join(["{\\footnotesize", 

947 "\\begin{longtable}{%s}" % colfmt[latex], 

948 "\\toprule", 

949 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

950 [" & ".join(l) + " \\\\" for l in lines] + 

951 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

952 return lines