Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Defines SolutionArray class""" 

2import re 

3import json 

4import difflib 

5from operator import sub 

6import warnings as pywarnings 

7import pickle 

8import gzip 

9import pickletools 

10import numpy as np 

11from .nomials import NomialArray 

12from .small_classes import DictOfLists, Strings 

13from .small_scripts import mag, try_str_without 

14from .repr_conventions import unitstr, lineagestr 

15 

16 

17CONSTRSPLITPATTERN = re.compile(r"([^*]\*[^*])|( \+ )|( >= )|( <= )|( = )") 

18 

19VALSTR_REPLACES = [ 

20 ("+nan", " nan"), 

21 ("-nan", " nan"), 

22 ("nan%", "nan "), 

23 ("nan", " - "), 

24] 

25 

26 

27class SolSavingEnvironment: 

28 """Temporarily removes construction/solve attributes from constraints. 

29 

30 This approximately halves the size of the pickled solution. 

31 """ 

32 

33 def __init__(self, solarray, saveconstraints): 

34 self.solarray = solarray 

35 self.attrstore = {} 

36 self.saveconstraints = saveconstraints 

37 self.constraintstore = None 

38 

39 

40 def __enter__(self): 

41 if self.saveconstraints: 

42 for constraint_attr in ["bounded", "meq_bounded", "vks", 

43 "v_ss", "unsubbed", "varkeys"]: 

44 store = {} 

45 for constraint in self.solarray["sensitivities"]["constraints"]: 

46 if getattr(constraint, constraint_attr, None): 

47 store[constraint] = getattr(constraint, constraint_attr) 

48 delattr(constraint, constraint_attr) 

49 self.attrstore[constraint_attr] = store 

50 else: 

51 self.constraintstore = \ 

52 self.solarray["sensitivities"].pop("constraints") 

53 

54 def __exit__(self, type_, val, traceback): 

55 if self.saveconstraints: 

56 for constraint_attr, store in self.attrstore.items(): 

57 for constraint, value in store.items(): 

58 setattr(constraint, constraint_attr, value) 

59 else: 

60 self.solarray["sensitivities"]["constraints"] = self.constraintstore 

61 

62def msenss_table(data, _, **kwargs): 

63 "Returns model sensitivity table lines" 

64 if "models" not in data.get("sensitivities", {}): 

65 return "" 

66 data = sorted(data["sensitivities"]["models"].items(), 

67 key=lambda i: ((i[1] < 0.1).all(), 

68 -np.max(i[1]) if (i[1] < 0.1).all() 

69 else -round(np.mean(i[1]), 1), i[0])) 

70 lines = ["Model Sensitivities", "-------------------"] 

71 if kwargs["sortmodelsbysenss"]: 

72 lines[0] += " (sorts models in sections below)" 

73 previousmsenssstr = "" 

74 for model, msenss in data: 

75 if not model: # for now let's only do named models 

76 continue 

77 if (msenss < 0.1).all(): 

78 msenss = np.max(msenss) 

79 if msenss: 

80 msenssstr = "%6s" % ("<1e%i" % np.log10(msenss)) 

81 else: 

82 msenssstr = " =0 " 

83 else: 

84 meansenss = round(np.mean(msenss), 1) 

85 msenssstr = "%+6.1f" % meansenss 

86 deltas = msenss - meansenss 

87 if np.max(np.abs(deltas)) > 0.1: 

88 deltastrs = ["%+4.1f" % d if abs(d) >= 0.1 else " - " 

89 for d in deltas] 

90 msenssstr += " + [ %s ]" % " ".join(deltastrs) 

91 if msenssstr == previousmsenssstr: 

92 msenssstr = " "*len(msenssstr) 

93 else: 

94 previousmsenssstr = msenssstr 

95 lines.append("%s : %s" % (msenssstr, model)) 

96 return lines + [""] if len(lines) > 3 else [] 

97 

98 

99def senss_table(data, showvars=(), title="Variable Sensitivities", **kwargs): 

100 "Returns sensitivity table lines" 

101 if "variables" in data.get("sensitivities", {}): 

102 data = data["sensitivities"]["variables"] 

103 if showvars: 

104 data = {k: data[k] for k in showvars if k in data} 

105 return var_table(data, title, sortbyvals=True, skipifempty=True, 

106 valfmt="%+-.2g ", vecfmt="%+-8.2g", 

107 printunits=False, minval=1e-3, **kwargs) 

108 

109 

110def topsenss_table(data, showvars, nvars=5, **kwargs): 

111 "Returns top sensitivity table lines" 

112 data, filtered = topsenss_filter(data, showvars, nvars) 

113 title = "Most Sensitive Variables" 

114 if filtered: 

115 title = "Next Most Sensitive Variables" 

116 return senss_table(data, title=title, hidebelowminval=True, **kwargs) 

117 

118 

119def topsenss_filter(data, showvars, nvars=5): 

120 "Filters sensitivities down to top N vars" 

121 if "variables" in data.get("sensitivities", {}): 

122 data = data["sensitivities"]["variables"] 

123 mean_abs_senss = {k: np.abs(s).mean() for k, s in data.items() 

124 if not np.isnan(s).any()} 

125 topk = [k for k, _ in sorted(mean_abs_senss.items(), key=lambda l: l[1])] 

126 filter_already_shown = showvars.intersection(topk) 

127 for k in filter_already_shown: 

128 topk.remove(k) 

129 if nvars > 3: # always show at least 3 

130 nvars -= 1 

131 return {k: data[k] for k in topk[-nvars:]}, filter_already_shown 

132 

133 

134def insenss_table(data, _, maxval=0.1, **kwargs): 

135 "Returns insensitivity table lines" 

136 if "constants" in data.get("sensitivities", {}): 

137 data = data["sensitivities"]["variables"] 

138 data = {k: s for k, s in data.items() if np.mean(np.abs(s)) < maxval} 

139 return senss_table(data, title="Insensitive Fixed Variables", **kwargs) 

140 

141 

142def tight_table(self, _, ntightconstrs=5, tight_senss=1e-2, **kwargs): 

143 "Return constraint tightness lines" 

144 title = "Most Sensitive Constraints" 

145 if len(self) > 1: 

146 title += " (in last sweep)" 

147 data = sorted(((-float("%+6.2g" % s[-1]), str(c)), 

148 "%+6.2g" % s[-1], id(c), c) 

149 for c, s in self["sensitivities"]["constraints"].items() 

150 if s[-1] >= tight_senss)[:ntightconstrs] 

151 else: 

152 data = sorted(((-float("%+6.2g" % s), str(c)), "%+6.2g" % s, id(c), c) 

153 for c, s in self["sensitivities"]["constraints"].items() 

154 if s >= tight_senss)[:ntightconstrs] 

155 return constraint_table(data, title, **kwargs) 

156 

157def loose_table(self, _, min_senss=1e-5, **kwargs): 

158 "Return constraint tightness lines" 

159 title = "Insensitive Constraints |below %+g|" % min_senss 

160 if len(self) > 1: 

161 title += " (in last sweep)" 

162 data = [(0, "", id(c), c) 

163 for c, s in self["sensitivities"]["constraints"].items() 

164 if s[-1] <= min_senss] 

165 else: 

166 data = [(0, "", id(c), c) 

167 for c, s in self["sensitivities"]["constraints"].items() 

168 if s <= min_senss] 

169 return constraint_table(data, title, **kwargs) 

170 

171 

172# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

173def constraint_table(data, title, sortbymodel=True, showmodels=True, **_): 

174 "Creates lines for tables where the right side is a constraint." 

175 # TODO: this should support 1D array inputs from sweeps 

176 excluded = ("units", "unnecessary lineage") 

177 if not showmodels: 

178 excluded = ("units", "lineage") # hide all of it 

179 models, decorated = {}, [] 

180 for sortby, openingstr, _, constraint in sorted(data): 

181 model = lineagestr(constraint) if sortbymodel else "" 

182 if model not in models: 

183 models[model] = len(models) 

184 constrstr = try_str_without(constraint, excluded) 

185 if " at 0x" in constrstr: # don't print memory addresses 

186 constrstr = constrstr[:constrstr.find(" at 0x")] + ">" 

187 decorated.append((models[model], model, sortby, constrstr, openingstr)) 

188 decorated.sort() 

189 previous_model, lines = None, [] 

190 for varlist in decorated: 

191 _, model, _, constrstr, openingstr = varlist 

192 if model != previous_model: 

193 if lines: 

194 lines.append(["", ""]) 

195 if model or lines: 

196 lines.append([("newmodelline",), model]) 

197 previous_model = model 

198 constrstr = constrstr.replace(model, "") 

199 minlen, maxlen = 25, 80 

200 segments = [s for s in CONSTRSPLITPATTERN.split(constrstr) if s] 

201 constraintlines = [] 

202 line = "" 

203 next_idx = 0 

204 while next_idx < len(segments): 

205 segment = segments[next_idx] 

206 next_idx += 1 

207 if CONSTRSPLITPATTERN.match(segment) and next_idx < len(segments): 

208 segments[next_idx] = segment[1:] + segments[next_idx] 

209 segment = segment[0] 

210 elif len(line) + len(segment) > maxlen and len(line) > minlen: 

211 constraintlines.append(line) 

212 line = " " # start a new line 

213 line += segment 

214 while len(line) > maxlen: 

215 constraintlines.append(line[:maxlen]) 

216 line = " " + line[maxlen:] 

217 constraintlines.append(line) 

218 lines += [(openingstr + " : ", constraintlines[0])] 

219 lines += [("", l) for l in constraintlines[1:]] 

220 if not lines: 

221 lines = [("", "(none)")] 

222 maxlens = np.max([list(map(len, line)) for line in lines 

223 if line[0] != ("newmodelline",)], axis=0) 

224 dirs = [">", "<"] # we'll check lengths before using zip 

225 assert len(list(dirs)) == len(list(maxlens)) 

226 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

227 for i, line in enumerate(lines): 

228 if line[0] == ("newmodelline",): 

229 linelist = [fmts[0].format(" | "), line[1]] 

230 else: 

231 linelist = [fmt.format(s) for fmt, s in zip(fmts, line)] 

232 lines[i] = "".join(linelist).rstrip() 

233 return [title] + ["-"*len(title)] + lines + [""] 

234 

235 

236def warnings_table(self, _, **kwargs): 

237 "Makes a table for all warnings in the solution." 

238 title = "WARNINGS" 

239 lines = ["~"*len(title), title, "~"*len(title)] 

240 if "warnings" not in self or not self["warnings"]: 

241 return [] 

242 for wtype in sorted(self["warnings"]): 

243 data_vec = self["warnings"][wtype] 

244 if len(data_vec) == 0: 

245 continue 

246 if not hasattr(data_vec, "shape"): 

247 data_vec = [data_vec] # not a sweep 

248 else: 

249 all_equal = True 

250 for data in data_vec[1:]: 

251 eq_i = (data == data_vec[0]) 

252 if hasattr(eq_i, "all"): 

253 eq_i = eq_i.all() 

254 if not eq_i: 

255 all_equal = False 

256 break 

257 if all_equal: 

258 data_vec = [data_vec[0]] # warnings identical across sweeps 

259 for i, data in enumerate(data_vec): 

260 if len(data) == 0: 

261 continue 

262 data = sorted(data, key=lambda l: l[0]) # sort by msg 

263 title = wtype 

264 if len(data_vec) > 1: 

265 title += " in sweep %i" % i 

266 if wtype == "Unexpectedly Tight Constraints" and data[0][1]: 

267 data = [(-int(1e5*relax_sensitivity), 

268 "%+6.2g" % relax_sensitivity, id(c), c) 

269 for _, (relax_sensitivity, c) in data] 

270 lines += constraint_table(data, title, **kwargs) 

271 elif wtype == "Unexpectedly Loose Constraints" and data[0][1]: 

272 data = [(-int(1e5*rel_diff), 

273 "%.4g %s %.4g" % tightvalues, id(c), c) 

274 for _, (rel_diff, tightvalues, c) in data] 

275 lines += constraint_table(data, title, **kwargs) 

276 else: 

277 lines += [title] + ["-"*len(wtype)] 

278 lines += [msg for msg, _ in data] + [""] 

279 if len(lines) == 3: # just the header 

280 return [] 

281 lines[-1] = "~~~~~~~~" 

282 return lines + [""] 

283 

284 

285TABLEFNS = {"sensitivities": senss_table, 

286 "top sensitivities": topsenss_table, 

287 "insensitivities": insenss_table, 

288 "model sensitivities": msenss_table, 

289 "tightest constraints": tight_table, 

290 "loose constraints": loose_table, 

291 "warnings": warnings_table, 

292 } 

293 

294def unrolled_absmax(values): 

295 "From an iterable of numbers and arrays, returns the largest magnitude" 

296 finalval, absmaxest = None, 0 

297 for val in values: 

298 absmaxval = np.abs(val).max() 

299 if absmaxval >= absmaxest: 

300 absmaxest, finalval = absmaxval, val 

301 if getattr(finalval, "shape", None): 

302 return finalval[np.unravel_index(np.argmax(np.abs(finalval)), 

303 finalval.shape)] 

304 return finalval 

305 

306 

307def cast(function, val1, val2): 

308 "Relative difference between val1 and val2 (positive if val2 is larger)" 

309 with pywarnings.catch_warnings(): # skip those pesky divide-by-zeros 

310 pywarnings.simplefilter("ignore") 

311 if hasattr(val1, "shape") and hasattr(val2, "shape"): 

312 if val1.ndim == val2.ndim: 

313 return function(val1, val2) 

314 lessdim, dimmest = sorted([val1, val2], key=lambda v: v.ndim) 

315 dimdelta = dimmest.ndim - lessdim.ndim 

316 add_axes = (slice(None),)*lessdim.ndim + (np.newaxis,)*dimdelta 

317 if dimmest is val1: 

318 return function(dimmest, lessdim[add_axes]) 

319 if dimmest is val2: 

320 return function(lessdim[add_axes], dimmest) 

321 return function(val1, val2) 

322 

323 

324class SolutionArray(DictOfLists): 

325 """A dictionary (of dictionaries) of lists, with convenience methods. 

326 

327 Items 

328 ----- 

329 cost : array 

330 variables: dict of arrays 

331 sensitivities: dict containing: 

332 monomials : array 

333 posynomials : array 

334 variables: dict of arrays 

335 localmodels : NomialArray 

336 Local power-law fits (small sensitivities are cut off) 

337 

338 Example 

339 ------- 

340 >>> import gpkit 

341 >>> import numpy as np 

342 >>> x = gpkit.Variable("x") 

343 >>> x_min = gpkit.Variable("x_{min}", 2) 

344 >>> sol = gpkit.Model(x, [x >= x_min]).solve(verbosity=0) 

345 >>> 

346 >>> # VALUES 

347 >>> values = [sol(x), sol.subinto(x), sol["variables"]["x"]] 

348 >>> assert all(np.array(values) == 2) 

349 >>> 

350 >>> # SENSITIVITIES 

351 >>> senss = [sol.sens(x_min), sol.sens(x_min)] 

352 >>> senss.append(sol["sensitivities"]["variables"]["x_{min}"]) 

353 >>> assert all(np.array(senss) == 1) 

354 """ 

355 modelstr = "" 

356 _name_collision_varkeys = None 

357 table_titles = {"choicevariables": "Choice Variables", 

358 "sweepvariables": "Swept Variables", 

359 "freevariables": "Free Variables", 

360 "constants": "Fixed Variables", # TODO: change everywhere 

361 "variables": "Variables"} 

362 

363 def name_collision_varkeys(self): 

364 "Returns the set of contained varkeys whose names are not unique" 

365 if self._name_collision_varkeys is None: 

366 self["variables"].update_keymap() 

367 keymap = self["variables"].keymap 

368 self._name_collision_varkeys = set() 

369 for key in list(keymap): 

370 if hasattr(key, "key"): 

371 if len(keymap[key.str_without(["lineage", "vec"])]) > 1: 

372 self._name_collision_varkeys.add(key) 

373 return self._name_collision_varkeys 

374 

375 def __len__(self): 

376 try: 

377 return len(self["cost"]) 

378 except TypeError: 

379 return 1 

380 except KeyError: 

381 return 0 

382 

383 def __call__(self, posy): 

384 posy_subbed = self.subinto(posy) 

385 return getattr(posy_subbed, "c", posy_subbed) 

386 

387 def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): 

388 "Checks for almost-equality between two solutions" 

389 svars, ovars = self["variables"], other["variables"] 

390 svks, ovks = set(svars), set(ovars) 

391 if svks != ovks: 

392 return False 

393 for key in svks: 

394 if abs(cast(np.divide, svars[key], ovars[key]) - 1) >= reltol: 

395 return False 

396 if abs(self["sensitivities"]["variables"][key] 

397 - other["sensitivities"]["variables"][key]) >= sens_abstol: 

398 return False 

399 return True 

400 

401 # pylint: disable=too-many-locals, too-many-branches, too-many-statements 

402 def diff(self, other, showvars=None, *, 

403 constraintsdiff=True, senssdiff=False, sensstol=0.1, 

404 absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, 

405 sortmodelsbysenss=True, **tableargs): 

406 """Outputs differences between this solution and another 

407 

408 Arguments 

409 --------- 

410 other : solution or string 

411 strings will be treated as paths to pickled solutions 

412 senssdiff : boolean 

413 if True, show sensitivity differences 

414 sensstol : float 

415 the smallest sensitivity difference worth showing 

416 absdiff : boolean 

417 if True, show absolute differences 

418 abstol : float 

419 the smallest absolute difference worth showing 

420 reldiff : boolean 

421 if True, show relative differences 

422 reltol : float 

423 the smallest relative difference worth showing 

424 

425 Returns 

426 ------- 

427 str 

428 """ 

429 if sortmodelsbysenss: 

430 tableargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

431 else: 

432 tableargs["sortmodelsbysenss"] = False 

433 tableargs.update({"hidebelowminval": True, "sortbyvals": True, 

434 "skipifempty": False}) 

435 if isinstance(other, Strings): 

436 if other[-4:] == ".pgz": 

437 other = SolutionArray.decompress_file(other) 

438 else: 

439 other = pickle.load(open(other, "rb")) 

440 svars, ovars = self["variables"], other["variables"] 

441 lines = ["Solution Diff", 

442 "=============", 

443 "(argument is the baseline solution)", ""] 

444 svks, ovks = set(svars), set(ovars) 

445 if showvars: 

446 lines[0] += " (for selected variables)" 

447 lines[1] += "=========================" 

448 showvars = self._parse_showvars(showvars) 

449 svks = {k for k in showvars if k in svars} 

450 ovks = {k for k in showvars if k in ovars} 

451 if constraintsdiff and other.modelstr and self.modelstr: 

452 if self.modelstr == other.modelstr: 

453 lines += ["** no constraint differences **", ""] 

454 else: 

455 cdiff = ["Constraint Differences", 

456 "**********************"] 

457 cdiff.extend(list(difflib.unified_diff( 

458 other.modelstr.split("\n"), self.modelstr.split("\n"), 

459 lineterm="", n=3))[2:]) 

460 cdiff += ["", "**********************", ""] 

461 lines += cdiff 

462 if svks - ovks: 

463 lines.append("Variable(s) of this solution" 

464 " which are not in the argument:") 

465 lines.append("\n".join(" %s" % key for key in svks - ovks)) 

466 lines.append("") 

467 if ovks - svks: 

468 lines.append("Variable(s) of the argument" 

469 " which are not in this solution:") 

470 lines.append("\n".join(" %s" % key for key in ovks - svks)) 

471 lines.append("") 

472 sharedvks = svks.intersection(ovks) 

473 if reldiff: 

474 rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) 

475 for vk in sharedvks} 

476 lines += var_table(rel_diff, 

477 "Relative Differences |above %g%%|" % reltol, 

478 valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", 

479 minval=reltol, printunits=False, **tableargs) 

480 if lines[-2][:10] == "-"*10: # nothing larger than reltol 

481 lines.insert(-1, ("The largest is %+g%%." 

482 % unrolled_absmax(rel_diff.values()))) 

483 if absdiff: 

484 abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} 

485 lines += var_table(abs_diff, 

486 "Absolute Differences |above %g|" % abstol, 

487 valfmt="%+.2g", vecfmt="%+8.2g", 

488 minval=abstol, **tableargs) 

489 if lines[-2][:10] == "-"*10: # nothing larger than abstol 

490 lines.insert(-1, ("The largest is %+g." 

491 % unrolled_absmax(abs_diff.values()))) 

492 if senssdiff: 

493 ssenss = self["sensitivities"]["variables"] 

494 osenss = other["sensitivities"]["variables"] 

495 senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) 

496 for vk in svks.intersection(ovks)} 

497 lines += var_table(senss_delta, 

498 "Sensitivity Differences |above %g|" % sensstol, 

499 valfmt="%+-.2f ", vecfmt="%+-6.2f", 

500 minval=sensstol, printunits=False, **tableargs) 

501 if lines[-2][:10] == "-"*10: # nothing larger than sensstol 

502 lines.insert(-1, ("The largest is %+g." 

503 % unrolled_absmax(senss_delta.values()))) 

504 return "\n".join(lines) 

505 

506 def save(self, filename="solution.pkl", 

507 *, saveconstraints=True, **pickleargs): 

508 """Pickles the solution and saves it to a file. 

509 

510 Solution can then be loaded with e.g.: 

511 >>> import pickle 

512 >>> pickle.load(open("solution.pkl")) 

513 """ 

514 with SolSavingEnvironment(self, saveconstraints): 

515 pickle.dump(self, open(filename, "wb"), **pickleargs) 

516 

517 def save_compressed(self, filename="solution.pgz", 

518 *, saveconstraints=True, **cpickleargs): 

519 "Pickle a file and then compress it into a file with extension." 

520 with gzip.open(filename, "wb") as f: 

521 with SolSavingEnvironment(self, saveconstraints): 

522 pickled = pickle.dumps(self, **cpickleargs) 

523 f.write(pickletools.optimize(pickled)) 

524 

525 @staticmethod 

526 def decompress_file(file): 

527 "Load a gzip-compressed pickle file" 

528 with gzip.open(file, "rb") as f: 

529 return pickle.Unpickler(f).load() 

530 

531 def varnames(self, showvars, exclude): 

532 "Returns list of variables, optionally with minimal unique names" 

533 if showvars: 

534 showvars = self._parse_showvars(showvars) 

535 for key in self.name_collision_varkeys(): 

536 key.descr["necessarylineage"] = True 

537 names = {} 

538 for key in showvars or self["variables"]: 

539 for k in self["variables"].keymap[key]: 

540 names[k.str_without(exclude)] = k 

541 for key in self.name_collision_varkeys(): 

542 del key.descr["necessarylineage"] 

543 return names 

544 

545 def savemat(self, filename="solution.mat", *, showvars=None, 

546 excluded=("unnecessary lineage", "vec")): 

547 "Saves primal solution as matlab file" 

548 from scipy.io import savemat 

549 savemat(filename, 

550 {name.replace(".", "_"): np.array(self["variables"][key], "f") 

551 for name, key in self.varnames(showvars, excluded).items()}) 

552 

553 def todataframe(self, showvars=None, 

554 excluded=("unnecessary lineage", "vec")): 

555 "Returns primal solution as pandas dataframe" 

556 import pandas as pd # pylint:disable=import-error 

557 rows = [] 

558 cols = ["Name", "Index", "Value", "Units", "Label", 

559 "Lineage", "Other"] 

560 for _, key in sorted(self.varnames(showvars, excluded).items(), 

561 key=lambda k: k[0]): 

562 value = self["variables"][key] 

563 if key.shape: 

564 idxs = [] 

565 it = np.nditer(np.empty(value.shape), flags=['multi_index']) 

566 while not it.finished: 

567 idx = it.multi_index 

568 idxs.append(idx[0] if len(idx) == 1 else idx) 

569 it.iternext() 

570 else: 

571 idxs = [None] 

572 for idx in idxs: 

573 row = [ 

574 key.name, 

575 "" if idx is None else idx, 

576 value if idx is None else value[idx]] 

577 rows.append(row) 

578 row.extend([ 

579 key.unitstr(), 

580 key.label or "", 

581 key.lineage or "", 

582 ", ".join("%s=%s" % (k, v) for (k, v) in key.descr.items() 

583 if k not in ["name", "units", "unitrepr", 

584 "idx", "shape", "veckey", 

585 "value", "vecfn", 

586 "lineage", "label"])]) 

587 return pd.DataFrame(rows, columns=cols) 

588 

589 def savetxt(self, filename="solution.txt", *, printmodel=True, **kwargs): 

590 "Saves solution table as a text file" 

591 with open(filename, "w") as f: 

592 if printmodel: 

593 f.write(self.modelstr + "\n") 

594 f.write(self.table(**kwargs)) 

595 

596 def savejson(self, filename="solution.json"): 

597 "Saves solution table as a json file" 

598 sol_dict = {} 

599 # get list of variables 

600 data = self["variables"] 

601 # add appropriate data for each variable to the dictionary 

602 for i, (k, v) in enumerate(data.items()): 

603 key = str(k.name) 

604 # check if ndarray, jsonify using tolist() 

605 if isinstance(v, np.ndarray): 

606 val = {'v': v.tolist(), 'u': str(k.descr["units"])} 

607 else: 

608 val = {'v': v, 'u': str(k.descr["units"])} 

609 sol_dict[key] = val 

610 with open(filename, "w") as f: 

611 json.dump(sol_dict, f) 

612 

613 def savecsv(self, filename="solution.csv", *, valcols=5, showvars=None): 

614 "Saves primal solution as a CSV sorted by modelname, like the tables." 

615 data = self["variables"] 

616 if showvars: 

617 showvars = self._parse_showvars(showvars) 

618 data = {k: data[k] for k in showvars if k in data} 

619 # if the columns don't capture any dimensions, skip them 

620 minspan, maxspan = None, 1 

621 for v in data.values(): 

622 if getattr(v, "shape", None) and any(di != 1 for di in v.shape): 

623 minspan_ = min((di for di in v.shape if di != 1)) 

624 maxspan_ = max((di for di in v.shape if di != 1)) 

625 if minspan is None or minspan_ < minspan: 

626 minspan = minspan_ 

627 if maxspan is None or maxspan_ > maxspan: 

628 maxspan = maxspan_ 

629 if minspan is not None and minspan > valcols: 

630 valcols = 1 

631 if maxspan < valcols: 

632 valcols = maxspan 

633 lines = var_table(data, "", rawlines=True, maxcolumns=valcols, 

634 tables=("cost", "sweepvariables", "freevariables", 

635 "constants", "sensitivities")) 

636 with open(filename, "w") as f: 

637 f.write("Model Name,Variable Name,Value(s)" + ","*valcols 

638 + "Units,Description\n") 

639 for line in lines: 

640 if line[0] == ("newmodelline",): 

641 f.write(line[1]) 

642 elif not line[1]: # spacer line 

643 f.write("\n") 

644 else: 

645 f.write("," + line[0].replace(" : ", "") + ",") 

646 vals = line[1].replace("[", "").replace("]", "").strip() 

647 for el in vals.split(): 

648 f.write(el + ",") 

649 f.write(","*(valcols - len(vals.split()))) 

650 f.write((line[2].replace("[", "").replace("]", "").strip() 

651 + ",")) 

652 f.write(line[3].strip() + "\n") 

653 

654 def subinto(self, posy): 

655 "Returns NomialArray of each solution substituted into posy." 

656 if posy in self["variables"]: 

657 return self["variables"](posy) 

658 

659 if not hasattr(posy, "sub"): 

660 raise ValueError("no variable '%s' found in the solution" % posy) 

661 

662 if len(self) > 1: 

663 return NomialArray([self.atindex(i).subinto(posy) 

664 for i in range(len(self))]) 

665 

666 return posy.sub(self["variables"]) 

667 

668 def _parse_showvars(self, showvars): 

669 showvars_out = set() 

670 for k in showvars: 

671 k, _ = self["variables"].parse_and_index(k) 

672 keys = self["variables"].keymap[k] 

673 showvars_out.update(keys) 

674 return showvars_out 

675 

676 def summary(self, showvars=(), ntopsenss=5, **kwargs): 

677 "Print summary table, showing top sensitivities and no constants" 

678 showvars = self._parse_showvars(showvars) 

679 out = self.table(showvars, ["cost", "warnings", "sweepvariables", 

680 "freevariables"], **kwargs) 

681 constants_in_showvars = showvars.intersection(self["constants"]) 

682 senss_tables = [] 

683 if len(self["constants"]) < ntopsenss+2 or constants_in_showvars: 

684 senss_tables.append("sensitivities") 

685 if len(self["constants"]) >= ntopsenss+2: 

686 senss_tables.append("top sensitivities") 

687 senss_tables.append("tightest constraints") 

688 senss_str = self.table(showvars, senss_tables, nvars=ntopsenss, 

689 **kwargs) 

690 if senss_str: 

691 out += "\n" + senss_str 

692 return out 

693 

694 def table(self, showvars=(), 

695 tables=("cost", "warnings", "model sensitivities", 

696 "sweepvariables", "freevariables", 

697 "constants", "sensitivities", "tightest constraints"), 

698 sortmodelsbysenss=True, **kwargs): 

699 """A table representation of this SolutionArray 

700 

701 Arguments 

702 --------- 

703 tables: Iterable 

704 Which to print of ("cost", "sweepvariables", "freevariables", 

705 "constants", "sensitivities") 

706 fixedcols: If true, print vectors in fixed-width format 

707 latex: int 

708 If > 0, return latex format (options 1-3); otherwise plain text 

709 included_models: Iterable of strings 

710 If specified, the models (by name) to include 

711 excluded_models: Iterable of strings 

712 If specified, model names to exclude 

713 

714 Returns 

715 ------- 

716 str 

717 """ 

718 if sortmodelsbysenss and "sensitivities" in self: 

719 kwargs["sortmodelsbysenss"] = self["sensitivities"]["models"] 

720 else: 

721 kwargs["sortmodelsbysenss"] = False 

722 varlist = list(self["variables"]) 

723 has_only_one_model = True 

724 for var in varlist[1:]: 

725 if var.lineage != varlist[0].lineage: 

726 has_only_one_model = False 

727 break 

728 if has_only_one_model: 

729 kwargs["sortbymodel"] = False 

730 for key in self.name_collision_varkeys(): 

731 key.descr["necessarylineage"] = True 

732 showvars = self._parse_showvars(showvars) 

733 strs = [] 

734 for table in tables: 

735 if "sensitivities" not in self and ("sensitivities" in table or 

736 "constraints" in table): 

737 continue 

738 if table == "cost": 

739 cost = self["cost"] # pylint: disable=unsubscriptable-object 

740 if kwargs.get("latex", None): # cost is not printed for latex 

741 continue 

742 strs += ["\n%s\n------------" % "Optimal Cost"] 

743 if len(self) > 1: 

744 costs = ["%-8.3g" % c for c in mag(cost[:4])] 

745 strs += [" [ %s %s ]" % (" ".join(costs), 

746 "..." if len(self) > 4 else "")] 

747 else: 

748 strs += [" %-.4g" % mag(cost)] 

749 strs[-1] += unitstr(cost, into=" [%s]", dimless="") 

750 strs += [""] 

751 elif table in TABLEFNS: 

752 strs += TABLEFNS[table](self, showvars, **kwargs) 

753 elif table in self: 

754 data = self[table] 

755 if showvars: 

756 showvars = self._parse_showvars(showvars) 

757 data = {k: data[k] for k in showvars if k in data} 

758 strs += var_table(data, self.table_titles[table], **kwargs) 

759 if kwargs.get("latex", None): 

760 preamble = "\n".join(("% \\documentclass[12pt]{article}", 

761 "% \\usepackage{booktabs}", 

762 "% \\usepackage{longtable}", 

763 "% \\usepackage{amsmath}", 

764 "% \\begin{document}\n")) 

765 strs = [preamble] + strs + ["% \\end{document}"] 

766 for key in self.name_collision_varkeys(): 

767 del key.descr["necessarylineage"] 

768 return "\n".join(strs) 

769 

770 def plot(self, posys=None, axes=None): 

771 "Plots a sweep for each posy" 

772 if len(self["sweepvariables"]) != 1: 

773 print("SolutionArray.plot only supports 1-dimensional sweeps") 

774 if not hasattr(posys, "__len__"): 

775 posys = [posys] 

776 import matplotlib.pyplot as plt 

777 from .interactive.plot_sweep import assign_axes 

778 from . import GPBLU 

779 (swept, x), = self["sweepvariables"].items() 

780 posys, axes = assign_axes(swept, posys, axes) 

781 for posy, ax in zip(posys, axes): 

782 y = self(posy) if posy not in [None, "cost"] else self["cost"] 

783 ax.plot(x, y, color=GPBLU) 

784 if len(axes) == 1: 

785 axes, = axes 

786 return plt.gcf(), axes 

787 

788 

789# pylint: disable=too-many-branches,too-many-locals,too-many-statements 

790def var_table(data, title, *, printunits=True, latex=False, rawlines=False, 

791 varfmt="%s : ", valfmt="%-.4g ", vecfmt="%-8.3g", 

792 minval=0, sortbyvals=False, hidebelowminval=False, 

793 included_models=None, excluded_models=None, sortbymodel=True, 

794 maxcolumns=5, skipifempty=True, sortmodelsbysenss=None, **_): 

795 """ 

796 Pretty string representation of a dict of VarKeys 

797 Iterable values are handled specially (partial printing) 

798 

799 Arguments 

800 --------- 

801 data : dict whose keys are VarKey's 

802 data to represent in table 

803 title : string 

804 printunits : bool 

805 latex : int 

806 If > 0, return latex format (options 1-3); otherwise plain text 

807 varfmt : string 

808 format for variable names 

809 valfmt : string 

810 format for scalar values 

811 vecfmt : string 

812 format for vector values 

813 minval : float 

814 skip values with all(abs(value)) < minval 

815 sortbyvals : boolean 

816 If true, rows are sorted by their average value instead of by name. 

817 included_models : Iterable of strings 

818 If specified, the models (by name) to include 

819 excluded_models : Iterable of strings 

820 If specified, model names to exclude 

821 """ 

822 if not data: 

823 return [] 

824 decorated, models = [], set() 

825 for i, (k, v) in enumerate(data.items()): 

826 if np.isnan(v).all() or np.nanmax(np.abs(v)) <= minval: 

827 continue # no values below minval 

828 if minval and hidebelowminval and getattr(v, "shape", None): 

829 v[np.abs(v) <= minval] = np.nan 

830 model = lineagestr(k.lineage) if sortbymodel else "" 

831 msenss = -sortmodelsbysenss.get(model, 0) if sortmodelsbysenss else 0 

832 if hasattr(msenss, "shape"): 

833 msenss = np.mean(msenss) 

834 models.add(model) 

835 b = bool(getattr(v, "shape", None)) 

836 s = k.str_without(("lineage", "vec")) 

837 if not sortbyvals: 

838 decorated.append((msenss, model, b, (varfmt % s), i, k, v)) 

839 else: # for consistent sorting, add small offset to negative vals 

840 val = np.nanmean(np.abs(v)) - (1e-9 if np.nanmean(v) < 0 else 0) 

841 sort = (float("%.4g" % -val), k.name) 

842 decorated.append((model, sort, msenss, b, (varfmt % s), i, k, v)) 

843 if not decorated and skipifempty: 

844 return [] 

845 if included_models: 

846 included_models = set(included_models) 

847 included_models.add("") 

848 models = models.intersection(included_models) 

849 if excluded_models: 

850 models = models.difference(excluded_models) 

851 decorated.sort() 

852 previous_model, lines = None, [] 

853 for varlist in decorated: 

854 if sortbyvals: 

855 model, _, msenss, isvector, varstr, _, var, val = varlist 

856 else: 

857 msenss, model, isvector, varstr, _, var, val = varlist 

858 if model not in models: 

859 continue 

860 if model != previous_model: 

861 if lines: 

862 lines.append(["", "", "", ""]) 

863 if model: 

864 if not latex: 

865 lines.append([("newmodelline",), model, "", ""]) 

866 else: 

867 lines.append( 

868 [r"\multicolumn{3}{l}{\textbf{" + model + r"}} \\"]) 

869 previous_model = model 

870 label = var.descr.get("label", "") 

871 units = var.unitstr(" [%s] ") if printunits else "" 

872 if not isvector: 

873 valstr = valfmt % val 

874 else: 

875 last_dim_index = len(val.shape)-1 

876 horiz_dim, ncols = last_dim_index, 1 # starting values 

877 for dim_idx, dim_size in enumerate(val.shape): 

878 if ncols <= dim_size <= maxcolumns: 

879 horiz_dim, ncols = dim_idx, dim_size 

880 # align the array with horiz_dim by making it the last one 

881 dim_order = list(range(last_dim_index)) 

882 dim_order.insert(horiz_dim, last_dim_index) 

883 flatval = val.transpose(dim_order).flatten() 

884 vals = [vecfmt % v for v in flatval[:ncols]] 

885 bracket = " ] " if len(flatval) <= ncols else "" 

886 valstr = "[ %s%s" % (" ".join(vals), bracket) 

887 for before, after in VALSTR_REPLACES: 

888 valstr = valstr.replace(before, after) 

889 if not latex: 

890 lines.append([varstr, valstr, units, label]) 

891 if isvector and len(flatval) > ncols: 

892 values_remaining = len(flatval) - ncols 

893 while values_remaining > 0: 

894 idx = len(flatval)-values_remaining 

895 vals = [vecfmt % v for v in flatval[idx:idx+ncols]] 

896 values_remaining -= ncols 

897 valstr = " " + " ".join(vals) 

898 for before, after in VALSTR_REPLACES: 

899 valstr = valstr.replace(before, after) 

900 if values_remaining <= 0: 

901 spaces = (-values_remaining 

902 * len(valstr)//(values_remaining + ncols)) 

903 valstr = valstr + " ]" + " "*spaces 

904 lines.append(["", valstr, "", ""]) 

905 else: 

906 varstr = "$%s$" % varstr.replace(" : ", "") 

907 if latex == 1: # normal results table 

908 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr(), 

909 label]) 

910 coltitles = [title, "Value", "Units", "Description"] 

911 elif latex == 2: # no values 

912 lines.append([varstr, "$%s$" % var.latex_unitstr(), label]) 

913 coltitles = [title, "Units", "Description"] 

914 elif latex == 3: # no description 

915 lines.append([varstr, valstr, "$%s$" % var.latex_unitstr()]) 

916 coltitles = [title, "Value", "Units"] 

917 else: 

918 raise ValueError("Unexpected latex option, %s." % latex) 

919 if rawlines: 

920 return lines 

921 if not latex: 

922 if lines: 

923 maxlens = np.max([list(map(len, line)) for line in lines 

924 if line[0] != ("newmodelline",)], axis=0) 

925 dirs = [">", "<", "<", "<"] 

926 # check lengths before using zip 

927 assert len(list(dirs)) == len(list(maxlens)) 

928 fmts = ["{0:%s%s}" % (direc, L) for direc, L in zip(dirs, maxlens)] 

929 for i, line in enumerate(lines): 

930 if line[0] == ("newmodelline",): 

931 line = [fmts[0].format(" | "), line[1]] 

932 else: 

933 line = [fmt.format(s) for fmt, s in zip(fmts, line)] 

934 lines[i] = "".join(line).rstrip() 

935 lines = [title] + ["-"*len(title)] + lines + [""] 

936 else: 

937 colfmt = {1: "llcl", 2: "lcl", 3: "llc"} 

938 lines = (["\n".join(["{\\footnotesize", 

939 "\\begin{longtable}{%s}" % colfmt[latex], 

940 "\\toprule", 

941 " & ".join(coltitles) + " \\\\ \\midrule"])] + 

942 [" & ".join(l) + " \\\\" for l in lines] + 

943 ["\n".join(["\\bottomrule", "\\end{longtable}}", ""])]) 

944 return lines