Coverage for gpkit/constraints/set.py: 86%

277 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 22:15 -0500

1"Implements ConstraintSet" 

2import sys 

3from collections import defaultdict, OrderedDict 

4from itertools import chain 

5import numpy as np 

6from ..keydict import KeySet, KeyDict 

7from ..nomials import NomialArray 

8from ..small_scripts import try_str_without 

9from ..repr_conventions import ReprMixin 

10from .single_equation import SingleEquationConstraint 

11from ..nomials import Variable 

12 

13 

14def add_meq_bounds(bounded, meq_bounded): #TODO: collapse with GP version? 

15 "Iterates through meq_bounds until convergence" 

16 still_alive = True 

17 while still_alive: 

18 still_alive = False # if no changes are made, the loop exits 

19 for bound in list(meq_bounded): 

20 if bound in bounded: # bound already exists 

21 del meq_bounded[bound] 

22 continue 

23 for condition in meq_bounded[bound]: 

24 if condition.issubset(bounded): # bound's condition is met 

25 del meq_bounded[bound] 

26 bounded.add(bound) 

27 still_alive = True 

28 break 

29 

30def _sort_by_name_and_idx(var): 

31 "return tuple for Variable sorting" 

32 return (var.key.str_without(["units", "idx"]), var.key.idx or ()) 

33 

34def _sort_constraints(item): 

35 "return tuple for Constraint sorting" 

36 label, constraint = item 

37 return (not isinstance(constraint, SingleEquationConstraint), 

38 bool(getattr(constraint, "lineage", None)), label) 

39 

40def sort_constraints_dict(iterable): 

41 "Sort a dictionary of {k: constraint} and return its keys and values" 

42 if sys.version_info >= (3, 7) or isinstance(iterable, OrderedDict): 

43 return iterable.keys(), iterable.values() 

44 items = sorted(list(iterable.items()), key=_sort_constraints) 

45 return (item[0] for item in items), (item[1] for item in items) 

46 

47def flatiter(iterable, yield_if_hasattr=None): 

48 "Yields contained constraints, optionally including constraintsets." 

49 if isinstance(iterable, dict): 

50 _, iterable = sort_constraints_dict(iterable) 

51 for constraint in iterable: 

52 if (not hasattr(constraint, "__iter__") 

53 or (yield_if_hasattr 

54 and hasattr(constraint, yield_if_hasattr))): 

55 yield constraint 

56 else: 

57 try: # numpy array 

58 yield from constraint.flat 

59 except TypeError: # ConstrainSet 

60 yield from constraint.flat(yield_if_hasattr) 

61 except AttributeError: # probably a list or dict 

62 yield from flatiter(constraint, yield_if_hasattr) 

63 

64 

65class ConstraintSet(list, ReprMixin): # pylint: disable=too-many-instance-attributes 

66 "Recursive container for ConstraintSets and Inequalities" 

67 unique_varkeys, idxlookup = frozenset(), {} 

68 _name_collision_varkeys = None 

69 _varkeys = None 

70 _lineageset = False 

71 

72 def __init__(self, constraints, substitutions=None, *, bonusvks=None): # pylint: disable=too-many-branches,too-many-statements 

73 if isinstance(constraints, dict): 

74 keys, constraints = sort_constraints_dict(constraints) 

75 self.idxlookup = {k: i for i, k in enumerate(keys)} 

76 elif isinstance(constraints, ConstraintSet): 

77 constraints = [constraints] # put it one level down 

78 list.__init__(self, constraints) 

79 self.vks = set(self.unique_varkeys) 

80 self.substitutions = KeyDict({k: k.value for k in self.unique_varkeys 

81 if "value" in k.descr}) 

82 self.substitutions.vks = self.vks 

83 self.bounded, self.meq_bounded = set(), defaultdict(set) 

84 for i, constraint in enumerate(self): 

85 if hasattr(constraint, "vks"): 

86 self._update(constraint) 

87 elif not (hasattr(constraint, "as_hmapslt1") 

88 or hasattr(constraint, "as_gpconstr")): 

89 try: 

90 for subconstraint in flatiter(constraint, "vks"): 

91 self._update(subconstraint) 

92 except Exception as e: 

93 raise badelement(self, i, constraint) from e 

94 elif isinstance(constraint, ConstraintSet): 

95 raise badelement(self, i, constraint, 

96 " It had not yet been initialized!") 

97 if bonusvks: 

98 self.vks.update(bonusvks) 

99 if substitutions: 

100 self.substitutions.update(substitutions) 

101 for key in self.vks: 

102 if key not in self.substitutions: 

103 if key.veckey is None or key.veckey not in self.substitutions: 

104 continue 

105 if np.isnan(self.substitutions[key.veckey][key.idx]): 

106 continue 

107 self.bounded.add((key, "upper")) 

108 self.bounded.add((key, "lower")) 

109 if key.value is not None and not key.constant: 

110 del key.descr["value"] 

111 if key.veckey and key.veckey.value is not None: 

112 del key.veckey.descr["value"] 

113 add_meq_bounds(self.bounded, self.meq_bounded) 

114 

115 def _update(self, constraint): 

116 "Update parameters with a given constraint" 

117 self.vks.update(constraint.vks) 

118 if hasattr(constraint, "substitutions"): 

119 self.substitutions.update(constraint.substitutions) 

120 else: 

121 self.substitutions.update({k: k.value \ 

122 for k in constraint.vks if "value" in k.descr}) 

123 self.bounded.update(constraint.bounded) 

124 for bound, solutionset in constraint.meq_bounded.items(): 

125 self.meq_bounded[bound].update(solutionset) 

126 

127 def __getitem__(self, key): 

128 if key in self.idxlookup: 

129 key = self.idxlookup[key] 

130 if isinstance(key, int): 

131 return list.__getitem__(self, key) 

132 return self._choosevar(key, self.variables_byname(key)) 

133 

134 def _choosevar(self, key, variables): 

135 if not variables: 

136 raise KeyError(key) 

137 firstvar, *othervars = variables 

138 veckey = firstvar.key.veckey 

139 if veckey is None or any(v.key.veckey != veckey for v in othervars): 

140 if not othervars: 

141 return firstvar 

142 raise ValueError(f"multiple variables are called '{key}'; show them" 

143 f" with `.variables_byname('{key}')`") 

144 arr = NomialArray(np.full(veckey.shape, np.nan, dtype="object")) 

145 for v in variables: 

146 arr[v.key.idx] = v 

147 arr.key = veckey 

148 return arr 

149 

150 def variables_byname(self, key): 

151 "Get all variables with a given name" 

152 return sorted([Variable(k) for k in self.varkeys[key]], 

153 key=_sort_by_name_and_idx) 

154 

155 @property 

156 def varkeys(self): 

157 "The NomialData's varkeys, created when necessary for a substitution." 

158 if self._varkeys is None: 

159 self._varkeys = KeySet(self.vks) 

160 return self._varkeys 

161 

162 def constrained_varkeys(self): 

163 "Return all varkeys in non-ConstraintSet constraints" 

164 return self.vks - self.unique_varkeys 

165 

166 flat = flatiter 

167 

168 def as_hmapslt1(self, subs): 

169 "Yields hmaps<=1 from self.flat()" 

170 yield from chain(*(c.as_hmapslt1(subs) 

171 for c in flatiter(self, 

172 yield_if_hasattr="as_hmapslt1"))) 

173 

174 def process_result(self, result): 

175 """Does arbitrary computation / manipulation of a program's result 

176 

177 There's no guarantee what order different constraints will process 

178 results in, so any changes made to the program's result should be 

179 careful not to step on other constraint's toes. 

180 

181 Potential Uses 

182 -------------- 

183 - check that an inequality was tight 

184 - add values computed from solved variables 

185 

186 """ 

187 for constraint in self.flat(yield_if_hasattr="process_result"): 

188 if hasattr(constraint, "process_result"): 

189 constraint.process_result(result) 

190 evalfn_vars = {v.veckey or v for v in self.unique_varkeys 

191 if v.evalfn and v not in result["variables"]} 

192 for v in evalfn_vars: 

193 val = v.evalfn(result["variables"]) 

194 result["variables"][v] = result["freevariables"][v] = val 

195 

196 def __repr__(self): 

197 "Returns namespaced string." 

198 if not self: 

199 return f"<gpkit.{self.__class__.__name__} object>" 

200 return (f"<gpkit.{self.__class__.__name__} object containing " 

201 f"{len(self)} top-level constraint(s) and " 

202 f"{len(self.varkeys)} variable(s)>") 

203 

204 def set_necessarylineage(self, clear=False): # pylint: disable=too-many-branches 

205 "Returns the set of contained varkeys whose names are not unique" 

206 if self._name_collision_varkeys is None: 

207 self._name_collision_varkeys = {} 

208 name_collisions = defaultdict(set) 

209 for key in self.varkeys: 

210 if hasattr(key, "key"): 

211 if key.veckey and all(k.veckey == key.veckey 

212 for k in self.varkeys[key.name]): 

213 self._name_collision_varkeys[key] = 0 

214 self._name_collision_varkeys[key.veckey] = 0 

215 elif len(self.varkeys[key.name]) == 1: 

216 self._name_collision_varkeys[key] = 0 

217 else: 

218 shortname = key.str_without(["lineage", "vec"]) 

219 if len(self.varkeys[shortname]) > 1: 

220 name_collisions[shortname].add(key) 

221 for varkeys in name_collisions.values(): 

222 min_namespaced = defaultdict(set) 

223 for vk in varkeys: 

224 *_, mineage = vk.lineagestr().split(".") 

225 min_namespaced[(mineage, 1)].add(vk) 

226 while any(len(vks) > 1 for vks in min_namespaced.values()): 

227 for key, vks in list(min_namespaced.items()): 

228 if len(vks) <= 1: 

229 continue 

230 del min_namespaced[key] 

231 mineage, idx = key 

232 idx += 1 

233 for vk in vks: 

234 lineages = vk.lineagestr().split(".") 

235 submineage = lineages[-idx] + "." + mineage 

236 min_namespaced[(submineage, idx)].add(vk) 

237 for (_, idx), vks in min_namespaced.items(): 

238 vk, = vks 

239 self._name_collision_varkeys[vk] = idx 

240 if clear: 

241 self._lineageset = False 

242 for vk in self._name_collision_varkeys: 

243 del vk.descr["necessarylineage"] 

244 else: 

245 self._lineageset = True 

246 for vk, idx in self._name_collision_varkeys.items(): 

247 vk.descr["necessarylineage"] = idx 

248 

249 def lines_without(self, excluded): 

250 "Lines representation of a ConstraintSet." 

251 excluded = frozenset(excluded) 

252 root, rootlines = "root" not in excluded, [] 

253 if root: 

254 excluded = {"root"}.union(excluded) 

255 self.set_necessarylineage() 

256 if hasattr(self, "_rootlines"): 

257 rootlines = self._rootlines(excluded) # pylint: disable=no-member 

258 lines = recursively_line(self, excluded) 

259 indent = " " if root or getattr(self, "lineage", None) else "" 

260 if root: 

261 self.set_necessarylineage(clear=True) 

262 return rootlines + [(indent+line).rstrip() for line in lines] 

263 

264 def str_without(self, excluded=("units",)): 

265 "String representation of a ConstraintSet." 

266 return "\n".join(self.lines_without(excluded)) 

267 

268 def latex(self, excluded=("units",)): 

269 "LaTeX representation of a ConstraintSet." 

270 lines = [] 

271 root = "root" not in excluded 

272 if root: 

273 excluded += ("root",) 

274 lines.append("\\begin{array}{ll} \\text{}") 

275 if hasattr(self, "_rootlatex"): 

276 lines.append(self._rootlatex(excluded)) # pylint: disable=no-member 

277 for constraint in self: 

278 cstr = try_str_without(constraint, excluded, latex=True) 

279 if cstr[:6] != " & ": # require indentation 

280 cstr = " & " + cstr + " \\\\" 

281 lines.append(cstr) 

282 if root: 

283 lines.append("\\end{array}") 

284 return "\n".join(lines) 

285 

286 def as_view(self): 

287 "Return a ConstraintSetView of this ConstraintSet." 

288 return ConstraintSetView(self) 

289 

290def recursively_line(iterable, excluded): 

291 "Generates lines in a recursive tree-like fashion, the better to indent." 

292 named_constraints = {} 

293 if isinstance(iterable, dict): 

294 keys, iterable = sort_constraints_dict(iterable) 

295 named_constraints = dict(enumerate(keys)) 

296 elif hasattr(iterable, "idxlookup"): 

297 named_constraints = {i: k for k, i in iterable.idxlookup.items()} 

298 lines = [] 

299 for i, constraint in enumerate(iterable): 

300 if hasattr(constraint, "lines_without"): 

301 clines = constraint.lines_without(excluded) 

302 elif not hasattr(constraint, "__iter__"): 

303 clines = try_str_without(constraint, excluded).split("\n") 

304 elif iterable is constraint: 

305 clines = ["(constraint contained itself)"] 

306 else: 

307 clines = recursively_line(constraint, excluded) 

308 if (getattr(constraint, "lineage", None) 

309 and isinstance(constraint, ConstraintSet)): 

310 name, num = constraint.lineage[-1] 

311 if not any(clines): 

312 clines = [" " + "(no constraints)"] # named model indent 

313 if lines: 

314 lines.append("") 

315 lines.append(name if not num else name + str(num)) 

316 elif "constraint names" not in excluded and i in named_constraints: 

317 lines.append(f"\"{named_constraints[i]}\":") 

318 clines = [" " + line for line in clines] # named constraint indent 

319 lines.extend(clines) 

320 return lines 

321 

322 

323class ConstraintSetView: 

324 "Class to access particular views on a set's variables" 

325 

326 def __init__(self, constraintset, index=()): 

327 self.constraintset = constraintset 

328 try: 

329 self.index = tuple(index) 

330 except TypeError: # probably not iterable 

331 self.index = (index,) 

332 

333 def __getitem__(self, index): 

334 "Appends the index to its own and returns a new view." 

335 if not isinstance(index, tuple): 

336 index = (index,) 

337 # indexes are preprended to match Vectorize convention 

338 return ConstraintSetView(self.constraintset, index + self.index) 

339 

340 def __getattr__(self, attr): 

341 """Returns attribute from the base ConstraintSets 

342 

343 If it's a another ConstraintSet, return the matching View; 

344 if it's an array, return it at the specified index; 

345 otherwise, raise an error. 

346 """ 

347 if not hasattr(self.constraintset, attr): 

348 raise AttributeError(f"the underlying object lacks `.{attr}`.") 

349 

350 value = getattr(self.constraintset, attr) 

351 if isinstance(value, ConstraintSet): 

352 return ConstraintSetView(value, self.index) 

353 if not hasattr(value, "shape"): 

354 raise ValueError( 

355 f"attribute {attr} with value {value} did not have a shape, " 

356 "so ConstraintSetView cannot return an indexed view.") 

357 index = self.index 

358 newdims = len(value.shape) - len(self.index) 

359 if newdims > 0: # indexes are put last to match Vectorize 

360 index = (slice(None),)*newdims + index 

361 return value[index] 

362 

363 

364 

365def badelement(cns, i, constraint, cause=""): 

366 "Identify the bad element and raise a ValueError" 

367 cause = cause if not isinstance(constraint, bool) else ( 

368 " Did the constraint list contain an accidental equality?") 

369 if len(cns) == 1: 

370 loc = "the only constraint" 

371 elif i == 0: 

372 loc = f"at the start, before {cns[i+1]}" 

373 elif i == len(cns) - 1: 

374 loc = f"at the end, after {cns[i-1]}" 

375 else: 

376 loc = f"between {cns[i-1]} and {cns[i+1]}" 

377 return ValueError(f"Invalid ConstraintSet element '{constraint!r}' " 

378 f"{type(constraint)} was {loc}.{cause}")