Coverage for gpkit/constraints/set.py: 86%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

276 statements  

1"Implements ConstraintSet" 

2import sys 

3from collections import defaultdict, OrderedDict 

4from itertools import chain 

5import numpy as np 

6from ..keydict import KeySet, KeyDict 

7from ..small_scripts import try_str_without 

8from ..repr_conventions import ReprMixin 

9from .single_equation import SingleEquationConstraint 

10 

11 

12def add_meq_bounds(bounded, meq_bounded): #TODO: collapse with GP version? 

13 "Iterates through meq_bounds until convergence" 

14 still_alive = True 

15 while still_alive: 

16 still_alive = False # if no changes are made, the loop exits 

17 for bound in list(meq_bounded): 

18 if bound in bounded: # bound already exists 

19 del meq_bounded[bound] 

20 continue 

21 for condition in meq_bounded[bound]: 

22 if condition.issubset(bounded): # bound's condition is met 

23 del meq_bounded[bound] 

24 bounded.add(bound) 

25 still_alive = True 

26 break 

27 

28def _sort_by_name_and_idx(var): 

29 "return tuple for Variable sorting" 

30 return (var.key.str_without(["units", "idx"]), var.key.idx or ()) 

31 

32def _sort_constraints(item): 

33 "return tuple for Constraint sorting" 

34 label, constraint = item 

35 return (not isinstance(constraint, SingleEquationConstraint), 

36 bool(getattr(constraint, "lineage", None)), label) 

37 

38def sort_constraints_dict(iterable): 

39 "Sort a dictionary of {k: constraint} and return its keys and values" 

40 if sys.version_info >= (3, 7) or isinstance(iterable, OrderedDict): 

41 return iterable.keys(), iterable.values() 

42 items = sorted(list(iterable.items()), key=_sort_constraints) 

43 return (item[0] for item in items), (item[1] for item in items) 

44 

45def flatiter(iterable, yield_if_hasattr=None): 

46 "Yields contained constraints, optionally including constraintsets." 

47 if isinstance(iterable, dict): 

48 _, iterable = sort_constraints_dict(iterable) 

49 for constraint in iterable: 

50 if (not hasattr(constraint, "__iter__") 

51 or (yield_if_hasattr 

52 and hasattr(constraint, yield_if_hasattr))): 

53 yield constraint 

54 else: 

55 try: # numpy array 

56 yield from constraint.flat 

57 except TypeError: # ConstrainSet 

58 yield from constraint.flat(yield_if_hasattr) 

59 except AttributeError: # probably a list or dict 

60 yield from flatiter(constraint, yield_if_hasattr) 

61 

62 

63class ConstraintSet(list, ReprMixin): 

64 "Recursive container for ConstraintSets and Inequalities" 

65 unique_varkeys, idxlookup = frozenset(), {} 

66 _name_collision_varkeys = None 

67 _varkeys = None 

68 

69 def __init__(self, constraints, substitutions=None, *, bonusvks=None): # pylint: disable=too-many-branches,too-many-statements 

70 if isinstance(constraints, dict): 

71 keys, constraints = sort_constraints_dict(constraints) 

72 self.idxlookup = {k: i for i, k in enumerate(keys)} 

73 elif isinstance(constraints, ConstraintSet): 

74 constraints = [constraints] # put it one level down 

75 list.__init__(self, constraints) 

76 self.vks = set(self.unique_varkeys) 

77 self.substitutions = KeyDict({k: k.value for k in self.unique_varkeys 

78 if "value" in k.descr}) 

79 self.substitutions.vks = self.vks 

80 self.bounded, self.meq_bounded = set(), defaultdict(set) 

81 for i, constraint in enumerate(self): 

82 if hasattr(constraint, "vks"): 

83 self._update(constraint) 

84 elif not (hasattr(constraint, "as_hmapslt1") 

85 or hasattr(constraint, "as_gpconstr")): 

86 try: 

87 for subconstraint in flatiter(constraint, "vks"): 

88 self._update(subconstraint) 

89 except Exception as e: 

90 raise badelement(self, i, constraint) from e 

91 elif isinstance(constraint, ConstraintSet): 

92 raise badelement(self, i, constraint, 

93 " It had not yet been initialized!") 

94 if bonusvks: 

95 self.vks.update(bonusvks) 

96 if substitutions: 

97 self.substitutions.update(substitutions) 

98 for key in self.vks: 

99 if key not in self.substitutions: 

100 if key.veckey is None or key.veckey not in self.substitutions: 

101 continue 

102 if np.isnan(self.substitutions[key.veckey][key.idx]): 

103 continue 

104 self.bounded.add((key, "upper")) 

105 self.bounded.add((key, "lower")) 

106 if key.value is not None and not key.constant: 

107 del key.descr["value"] 

108 if key.veckey and key.veckey.value is not None: 

109 del key.veckey.descr["value"] 

110 add_meq_bounds(self.bounded, self.meq_bounded) 

111 

112 def _update(self, constraint): 

113 "Update parameters with a given constraint" 

114 self.vks.update(constraint.vks) 

115 if hasattr(constraint, "substitutions"): 

116 self.substitutions.update(constraint.substitutions) 

117 else: 

118 self.substitutions.update({k: k.value \ 

119 for k in constraint.vks if "value" in k.descr}) 

120 self.bounded.update(constraint.bounded) 

121 for bound, solutionset in constraint.meq_bounded.items(): 

122 self.meq_bounded[bound].update(solutionset) 

123 

124 def __getitem__(self, key): 

125 if key in self.idxlookup: 

126 key = self.idxlookup[key] 

127 if isinstance(key, int): 

128 return list.__getitem__(self, key) 

129 return self._choosevar(key, self.variables_byname(key)) 

130 

131 def _choosevar(self, key, variables): 

132 if not variables: 

133 raise KeyError(key) 

134 firstvar, *othervars = variables 

135 veckey = firstvar.key.veckey 

136 if veckey is None or any(v.key.veckey != veckey for v in othervars): 

137 if not othervars: 

138 return firstvar 

139 raise ValueError("multiple variables are called '%s'; show them" 

140 " with `.variables_byname('%s')`" % (key, key)) 

141 from ..nomials import NomialArray # all one vector! 

142 arr = NomialArray(np.full(veckey.shape, np.nan, dtype="object")) 

143 for v in variables: 

144 arr[v.key.idx] = v 

145 arr.key = veckey 

146 return arr 

147 

148 def variables_byname(self, key): 

149 "Get all variables with a given name" 

150 from ..nomials import Variable 

151 return sorted([Variable(k) for k in self.varkeys[key]], 

152 key=_sort_by_name_and_idx) 

153 

154 @property 

155 def varkeys(self): 

156 "The NomialData's varkeys, created when necessary for a substitution." 

157 if self._varkeys is None: 

158 self._varkeys = KeySet(self.vks) 

159 return self._varkeys 

160 

161 def constrained_varkeys(self): 

162 "Return all varkeys in non-ConstraintSet constraints" 

163 return self.vks - self.unique_varkeys 

164 

165 flat = flatiter 

166 

167 def as_hmapslt1(self, subs): 

168 "Yields hmaps<=1 from self.flat()" 

169 yield from chain(*(c.as_hmapslt1(subs) 

170 for c in flatiter(self, 

171 yield_if_hasattr="as_hmapslt1"))) 

172 

173 def process_result(self, result): 

174 """Does arbitrary computation / manipulation of a program's result 

175 

176 There's no guarantee what order different constraints will process 

177 results in, so any changes made to the program's result should be 

178 careful not to step on other constraint's toes. 

179 

180 Potential Uses 

181 -------------- 

182 - check that an inequality was tight 

183 - add values computed from solved variables 

184 

185 """ 

186 for constraint in self.flat(yield_if_hasattr="process_result"): 

187 if hasattr(constraint, "process_result"): 

188 constraint.process_result(result) 

189 evalfn_vars = {v.veckey or v for v in self.unique_varkeys 

190 if v.evalfn and v not in result["variables"]} 

191 for v in evalfn_vars: 

192 val = v.evalfn(result["variables"]) 

193 result["variables"][v] = result["freevariables"][v] = val 

194 

195 def __repr__(self): 

196 "Returns namespaced string." 

197 if not self: 

198 return "<gpkit.%s object>" % self.__class__.__name__ 

199 return ("<gpkit.%s object containing %i top-level constraint(s)" 

200 " and %i variable(s)>" % (self.__class__.__name__, 

201 len(self), len(self.varkeys))) 

202 

203 def set_necessarylineage(self, clear=False): 

204 "Returns the set of contained varkeys whose names are not unique" 

205 if self._name_collision_varkeys is None: 

206 self._name_collision_varkeys = {} 

207 name_collisions = defaultdict(set) 

208 for key in self.varkeys: 

209 if hasattr(key, "key"): 

210 if key.veckey and all(k.veckey == key.veckey 

211 for k in self.varkeys[key.name]): 

212 self._name_collision_varkeys[key] = 0 

213 self._name_collision_varkeys[key.veckey] = 0 

214 elif len(self.varkeys[key.name]) == 1: 

215 self._name_collision_varkeys[key] = 0 

216 else: 

217 shortname = key.str_without(["lineage", "vec"]) 

218 if len(self.varkeys[shortname]) > 1: 

219 name_collisions[shortname].add(key) 

220 for varkeys in name_collisions.values(): 

221 min_namespaced = defaultdict(set) 

222 for vk in varkeys: 

223 *_, mineage = vk.lineagestr().split(".") 

224 min_namespaced[(mineage, 1)].add(vk) 

225 while any(len(vks) > 1 for vks in min_namespaced.values()): 

226 for key, vks in list(min_namespaced.items()): 

227 if len(vks) <= 1: 

228 continue 

229 del min_namespaced[key] 

230 mineage, idx = key 

231 idx += 1 

232 for vk in vks: 

233 lineages = vk.lineagestr().split(".") 

234 submineage = lineages[-idx] + "." + mineage 

235 min_namespaced[(submineage, idx)].add(vk) 

236 for (_, idx), vks in min_namespaced.items(): 

237 vk, = vks 

238 self._name_collision_varkeys[vk] = idx 

239 if clear: 

240 self._lineageset = False 

241 for vk in self._name_collision_varkeys: 

242 del vk.descr["necessarylineage"] 

243 else: 

244 self._lineageset = True 

245 for vk, idx in self._name_collision_varkeys.items(): 

246 vk.descr["necessarylineage"] = idx 

247 

248 def lines_without(self, excluded): 

249 "Lines representation of a ConstraintSet." 

250 excluded = frozenset(excluded) 

251 root, rootlines = "root" not in excluded, [] 

252 if root: 

253 excluded = {"root"}.union(excluded) 

254 self.set_necessarylineage() 

255 if hasattr(self, "_rootlines"): 

256 rootlines = self._rootlines(excluded) # pylint: disable=no-member 

257 lines = recursively_line(self, excluded) 

258 indent = " " if root or getattr(self, "lineage", None) else "" 

259 if root: 

260 self.set_necessarylineage(clear=True) 

261 return rootlines + [(indent+line).rstrip() for line in lines] 

262 

263 def str_without(self, excluded=("units",)): 

264 "String representation of a ConstraintSet." 

265 return "\n".join(self.lines_without(excluded)) 

266 

267 def latex(self, excluded=("units",)): 

268 "LaTeX representation of a ConstraintSet." 

269 lines = [] 

270 root = "root" not in excluded 

271 if root: 

272 excluded += ("root",) 

273 lines.append("\\begin{array}{ll} \\text{}") 

274 if hasattr(self, "_rootlatex"): 

275 lines.append(self._rootlatex(excluded)) # pylint: disable=no-member 

276 for constraint in self: 

277 cstr = try_str_without(constraint, excluded, latex=True) 

278 if cstr[:6] != " & ": # require indentation 

279 cstr = " & " + cstr + " \\\\" 

280 lines.append(cstr) 

281 if root: 

282 lines.append("\\end{array}") 

283 return "\n".join(lines) 

284 

285 def as_view(self): 

286 "Return a ConstraintSetView of this ConstraintSet." 

287 return ConstraintSetView(self) 

288 

289def recursively_line(iterable, excluded): 

290 "Generates lines in a recursive tree-like fashion, the better to indent." 

291 named_constraints = {} 

292 if isinstance(iterable, dict): 

293 keys, iterable = sort_constraints_dict(iterable) 

294 named_constraints = dict(enumerate(keys)) 

295 elif hasattr(iterable, "idxlookup"): 

296 named_constraints = {i: k for k, i in iterable.idxlookup.items()} 

297 lines = [] 

298 for i, constraint in enumerate(iterable): 

299 if hasattr(constraint, "lines_without"): 

300 clines = constraint.lines_without(excluded) 

301 elif not hasattr(constraint, "__iter__"): 

302 clines = try_str_without(constraint, excluded).split("\n") 

303 elif iterable is constraint: 

304 clines = ["(constraint contained itself)"] 

305 else: 

306 clines = recursively_line(constraint, excluded) 

307 if (getattr(constraint, "lineage", None) 

308 and isinstance(constraint, ConstraintSet)): 

309 name, num = constraint.lineage[-1] 

310 if not any(clines): 

311 clines = [" " + "(no constraints)"] # named model indent 

312 if lines: 

313 lines.append("") 

314 lines.append(name if not num else name + str(num)) 

315 elif "constraint names" not in excluded and i in named_constraints: 

316 lines.append("\"%s\":" % named_constraints[i]) 

317 clines = [" " + line for line in clines] # named constraint indent 

318 lines.extend(clines) 

319 return lines 

320 

321 

322class ConstraintSetView: 

323 "Class to access particular views on a set's variables" 

324 

325 def __init__(self, constraintset, index=()): 

326 self.constraintset = constraintset 

327 try: 

328 self.index = tuple(index) 

329 except TypeError: # probably not iterable 

330 self.index = (index,) 

331 

332 def __getitem__(self, index): 

333 "Appends the index to its own and returns a new view." 

334 if not isinstance(index, tuple): 

335 index = (index,) 

336 # indexes are preprended to match Vectorize convention 

337 return ConstraintSetView(self.constraintset, index + self.index) 

338 

339 def __getattr__(self, attr): 

340 """Returns attribute from the base ConstraintSets 

341 

342 If it's a another ConstraintSet, return the matching View; 

343 if it's an array, return it at the specified index; 

344 otherwise, raise an error. 

345 """ 

346 if not hasattr(self.constraintset, attr): 

347 raise AttributeError("the underlying object lacks `.%s`." % attr) 

348 

349 value = getattr(self.constraintset, attr) 

350 if isinstance(value, ConstraintSet): 

351 return ConstraintSetView(value, self.index) 

352 if not hasattr(value, "shape"): 

353 raise ValueError("attribute %s with value %s did not have" 

354 " a shape, so ConstraintSetView cannot" 

355 " return an indexed view." % (attr, value)) 

356 index = self.index 

357 newdims = len(value.shape) - len(self.index) 

358 if newdims > 0: # indexes are put last to match Vectorize 

359 index = (slice(None),)*newdims + index 

360 return value[index] 

361 

362 

363 

364def badelement(cns, i, constraint, cause=""): 

365 "Identify the bad element and raise a ValueError" 

366 cause = cause if not isinstance(constraint, bool) else ( 

367 " Did the constraint list contain an accidental equality?") 

368 if len(cns) == 1: 

369 loc = "the only constraint" 

370 elif i == 0: 

371 loc = "at the start, before %s" % cns[i+1] 

372 elif i == len(cns) - 1: 

373 loc = "at the end, after %s" % cns[i-1] 

374 else: 

375 loc = "between %s and %s" % (cns[i-1], cns[i+1]) 

376 return ValueError("Invalid ConstraintSet element '%s' %s was %s.%s" 

377 % (repr(constraint), type(constraint), loc, cause))