Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"Implements ConstraintSet" 

2import sys 

3from collections import defaultdict, OrderedDict 

4from itertools import chain 

5import numpy as np 

6from ..keydict import KeySet, KeyDict 

7from ..small_scripts import try_str_without 

8from ..repr_conventions import ReprMixin 

9from .single_equation import SingleEquationConstraint 

10 

11 

12def add_meq_bounds(bounded, meq_bounded): #TODO: collapse with GP version? 

13 "Iterates through meq_bounds until convergence" 

14 still_alive = True 

15 while still_alive: 

16 still_alive = False # if no changes are made, the loop exits 

17 for bound in list(meq_bounded): 

18 if bound in bounded: # bound already exists 

19 del meq_bounded[bound] 

20 continue 

21 for condition in meq_bounded[bound]: 

22 if condition.issubset(bounded): # bound's condition is met 

23 del meq_bounded[bound] 

24 bounded.add(bound) 

25 still_alive = True 

26 break 

27 

28def _sort_by_name_and_idx(var): 

29 "return tuple for Variable sorting" 

30 return (var.key.str_without(["units", "idx"]), var.key.idx or ()) 

31 

32def _sort_constraints(item): 

33 "return tuple for Constraint sorting" 

34 label, constraint = item 

35 return (not isinstance(constraint, SingleEquationConstraint), 

36 bool(getattr(constraint, "lineage", None)), label) 

37 

38def sort_constraints_dict(iterable): 

39 "Sort a dictionary of {k: constraint} and return its keys and values" 

40 if sys.version_info >= (3, 7) or isinstance(iterable, OrderedDict): 

41 return iterable.keys(), iterable.values() 

42 items = sorted(list(iterable.items()), key=_sort_constraints) 

43 return (item[0] for item in items), (item[1] for item in items) 

44 

45def flatiter(iterable, yield_if_hasattr=None): 

46 "Yields contained constraints, optionally including constraintsets." 

47 if isinstance(iterable, dict): 

48 _, iterable = sort_constraints_dict(iterable) 

49 for constraint in iterable: 

50 if (not hasattr(constraint, "__iter__") 

51 or (yield_if_hasattr 

52 and hasattr(constraint, yield_if_hasattr))): 

53 yield constraint 

54 else: 

55 try: # numpy array 

56 yield from constraint.flat 

57 except TypeError: # ConstrainSet 

58 yield from constraint.flat(yield_if_hasattr) 

59 except AttributeError: # probably a list or dict 

60 yield from flatiter(constraint, yield_if_hasattr) 

61 

62 

63class ConstraintSet(list, ReprMixin): 

64 "Recursive container for ConstraintSets and Inequalities" 

65 unique_varkeys, idxlookup = frozenset(), {} 

66 _name_collision_varkeys = None 

67 _varkeys = None 

68 

69 def __init__(self, constraints, substitutions=None, *, bonusvks=None): # pylint: disable=too-many-branches,too-many-statements 

70 if isinstance(constraints, dict): 

71 keys, constraints = sort_constraints_dict(constraints) 

72 self.idxlookup = {k: i for i, k in enumerate(keys)} 

73 elif isinstance(constraints, ConstraintSet): 

74 constraints = [constraints] # put it one level down 

75 list.__init__(self, constraints) 

76 self.vks = set(self.unique_varkeys) 

77 self.substitutions = KeyDict({k: k.value for k in self.unique_varkeys 

78 if "value" in k.descr}) 

79 self.substitutions.cset = self 

80 self.bounded, self.meq_bounded = set(), defaultdict(set) 

81 for i, constraint in enumerate(self): 

82 if hasattr(constraint, "vks"): 

83 self._update(constraint) 

84 elif not (hasattr(constraint, "as_hmapslt1") 

85 or hasattr(constraint, "as_gpconstr")): 

86 try: 

87 for subconstraint in flatiter(constraint, "vks"): 

88 self._update(subconstraint) 

89 except Exception as e: 

90 raise badelement(self, i, constraint) from e 

91 elif isinstance(constraint, ConstraintSet): 

92 raise badelement(self, i, constraint, 

93 " It had not yet been initialized!") 

94 if bonusvks: 

95 self.vks.update(bonusvks) 

96 if substitutions: 

97 self.substitutions.update(substitutions) 

98 self._varkeys = None 

99 for key in self.vks: 

100 if key not in self.substitutions: 

101 if key.veckey is None or key.veckey not in self.substitutions: 

102 continue 

103 if np.isnan(self.substitutions[key.veckey][key.idx]): 

104 continue 

105 self.bounded.add((key, "upper")) 

106 self.bounded.add((key, "lower")) 

107 if key.value is not None and not key.constant: 

108 del key.descr["value"] 

109 if key.veckey and key.veckey.value is not None: 

110 del key.veckey.descr["value"] 

111 add_meq_bounds(self.bounded, self.meq_bounded) 

112 

113 def _update(self, constraint): 

114 "Update parameters with a given constraint" 

115 self.vks.update(constraint.vks) 

116 if hasattr(constraint, "substitutions"): 

117 self.substitutions.update(constraint.substitutions) 

118 else: 

119 self.substitutions.update({k: k.value \ 

120 for k in constraint.vks if "value" in k.descr}) 

121 self.bounded.update(constraint.bounded) 

122 for bound, solutionset in constraint.meq_bounded.items(): 

123 self.meq_bounded[bound].update(solutionset) 

124 

125 def __getitem__(self, key): 

126 if key in self.idxlookup: 

127 key = self.idxlookup[key] 

128 if isinstance(key, int): 

129 return list.__getitem__(self, key) 

130 return self._choosevar(key, self.variables_byname(key)) 

131 

132 def _choosevar(self, key, variables): 

133 if not variables: 

134 raise KeyError(key) 

135 firstvar, *othervars = variables 

136 veckey = firstvar.key.veckey 

137 if veckey is None or any(v.key.veckey != veckey for v in othervars): 

138 if not othervars: 

139 return firstvar 

140 raise ValueError("multiple variables are called '%s'; show them" 

141 " with `.variables_byname('%s')`" % (key, key)) 

142 from ..nomials import NomialArray # all one vector! 

143 arr = NomialArray(np.full(veckey.shape, np.nan, dtype="object")) 

144 for v in variables: 

145 arr[v.key.idx] = v 

146 arr.key = veckey 

147 return arr 

148 

149 def variables_byname(self, key): 

150 "Get all variables with a given name" 

151 from ..nomials import Variable 

152 return sorted([Variable(k) for k in self.varkeys[key]], 

153 key=_sort_by_name_and_idx) 

154 

155 @property 

156 def varkeys(self): 

157 "The NomialData's varkeys, created when necessary for a substitution." 

158 if self._varkeys is None: 

159 self._varkeys = KeySet(self.vks) 

160 return self._varkeys 

161 

162 def constrained_varkeys(self): 

163 "Return all varkeys in non-ConstraintSet constraints" 

164 return self.vks - self.unique_varkeys 

165 

166 flat = flatiter 

167 

168 def as_hmapslt1(self, subs): 

169 "Yields hmaps<=1 from self.flat()" 

170 yield from chain(*(c.as_hmapslt1(subs) 

171 for c in flatiter(self, 

172 yield_if_hasattr="as_hmapslt1"))) 

173 

174 def process_result(self, result): 

175 """Does arbitrary computation / manipulation of a program's result 

176 

177 There's no guarantee what order different constraints will process 

178 results in, so any changes made to the program's result should be 

179 careful not to step on other constraint's toes. 

180 

181 Potential Uses 

182 -------------- 

183 - check that an inequality was tight 

184 - add values computed from solved variables 

185 

186 """ 

187 for constraint in self.flat(yield_if_hasattr="process_result"): 

188 if hasattr(constraint, "process_result"): 

189 constraint.process_result(result) 

190 evalfn_vars = {v.veckey or v for v in self.unique_varkeys 

191 if v.evalfn and v not in result["variables"]} 

192 for v in evalfn_vars: 

193 val = v.evalfn(result["variables"]) 

194 result["variables"][v] = result["freevariables"][v] = val 

195 

196 def __repr__(self): 

197 "Returns namespaced string." 

198 if not self: 

199 return "<gpkit.%s object>" % self.__class__.__name__ 

200 return ("<gpkit.%s object containing %i top-level constraint(s)" 

201 " and %i variable(s)>" % (self.__class__.__name__, 

202 len(self), len(self.varkeys))) 

203 

204 def name_collision_varkeys(self): 

205 "Returns the set of contained varkeys whose names are not unique" 

206 if self._name_collision_varkeys is None: 

207 self._name_collision_varkeys = { 

208 key for key in self.varkeys 

209 if len(self.varkeys[key.str_without(["lineage", "vec"])]) > 1} 

210 return self._name_collision_varkeys 

211 

212 def lines_without(self, excluded): 

213 "Lines representation of a ConstraintSet." 

214 excluded = frozenset(excluded) 

215 root, rootlines = "root" not in excluded, [] 

216 if root: 

217 excluded = excluded.union(["root"]) 

218 if "unnecessary lineage" in excluded: 

219 for key in self.name_collision_varkeys(): 

220 key.descr["necessarylineage"] = True 

221 if hasattr(self, "_rootlines"): 

222 rootlines = self._rootlines(excluded) # pylint: disable=no-member 

223 lines = recursively_line(self, excluded) 

224 indent = " " if getattr(self, "lineage", None) else "" 

225 if root and "unnecessary lineage" in excluded: 

226 indent += " " 

227 for key in self.name_collision_varkeys(): 

228 del key.descr["necessarylineage"] 

229 return rootlines + [(indent+line).rstrip() for line in lines] 

230 

231 def str_without(self, excluded=("unnecessary lineage", "units")): 

232 "String representation of a ConstraintSet." 

233 return "\n".join(self.lines_without(excluded)) 

234 

235 def latex(self, excluded=("units",)): 

236 "LaTeX representation of a ConstraintSet." 

237 lines = [] 

238 root = "root" not in excluded 

239 if root: 

240 excluded += ("root",) 

241 lines.append("\\begin{array}{ll} \\text{}") 

242 if hasattr(self, "_rootlatex"): 

243 lines.append(self._rootlatex(excluded)) # pylint: disable=no-member 

244 for constraint in self: 

245 cstr = try_str_without(constraint, excluded, latex=True) 

246 if cstr[:6] != " & ": # require indentation 

247 cstr = " & " + cstr + " \\\\" 

248 lines.append(cstr) 

249 if root: 

250 lines.append("\\end{array}") 

251 return "\n".join(lines) 

252 

253 def as_view(self): 

254 "Return a ConstraintSetView of this ConstraintSet." 

255 return ConstraintSetView(self) 

256 

257def recursively_line(iterable, excluded): 

258 "Generates lines in a recursive tree-like fashion, the better to indent." 

259 named_constraints = {} 

260 if isinstance(iterable, dict): 

261 keys, iterable = sort_constraints_dict(iterable) 

262 named_constraints = dict(enumerate(keys)) 

263 elif hasattr(iterable, "idxlookup"): 

264 named_constraints = {i: k for k, i in iterable.idxlookup.items()} 

265 lines = [] 

266 for i, constraint in enumerate(iterable): 

267 if hasattr(constraint, "lines_without"): 

268 clines = constraint.lines_without(excluded) 

269 elif not hasattr(constraint, "__iter__"): 

270 clines = try_str_without(constraint, excluded).split("\n") 

271 elif iterable is constraint: 

272 clines = ["(constraint contained itself)"] 

273 else: 

274 clines = recursively_line(constraint, excluded) 

275 if (getattr(constraint, "lineage", None) 

276 and isinstance(constraint, ConstraintSet)): 

277 name, num = constraint.lineage[-1] 

278 if not any(clines): 

279 clines = [" " + "(no constraints)"] # named model indent 

280 if lines: 

281 lines.append("") 

282 lines.append(name if not num else name + str(num)) 

283 elif "constraint names" not in excluded and i in named_constraints: 

284 lines.append("\"%s\":" % named_constraints[i]) 

285 clines = [" " + line for line in clines] # named constraint indent 

286 lines.extend(clines) 

287 return lines 

288 

289 

290class ConstraintSetView: 

291 "Class to access particular views on a set's variables" 

292 

293 def __init__(self, constraintset, index=()): 

294 self.constraintset = constraintset 

295 try: 

296 self.index = tuple(index) 

297 except TypeError: # probably not iterable 

298 self.index = (index,) 

299 

300 def __getitem__(self, index): 

301 "Appends the index to its own and returns a new view." 

302 if not isinstance(index, tuple): 

303 index = (index,) 

304 # indexes are preprended to match Vectorize convention 

305 return ConstraintSetView(self.constraintset, index + self.index) 

306 

307 def __getattr__(self, attr): 

308 """Returns attribute from the base ConstraintSets 

309 

310 If it's a another ConstraintSet, return the matching View; 

311 if it's an array, return it at the specified index; 

312 otherwise, raise an error. 

313 """ 

314 if not hasattr(self.constraintset, attr): 

315 raise AttributeError("the underlying object lacks `.%s`." % attr) 

316 

317 value = getattr(self.constraintset, attr) 

318 if isinstance(value, ConstraintSet): 

319 return ConstraintSetView(value, self.index) 

320 if not hasattr(value, "shape"): 

321 raise ValueError("attribute %s with value %s did not have" 

322 " a shape, so ConstraintSetView cannot" 

323 " return an indexed view." % (attr, value)) 

324 index = self.index 

325 newdims = len(value.shape) - len(self.index) 

326 if newdims > 0: # indexes are put last to match Vectorize 

327 index = (slice(None),)*newdims + index 

328 return value[index] 

329 

330 

331 

332def badelement(cns, i, constraint, cause=""): 

333 "Identify the bad element and raise a ValueError" 

334 cause = cause if not isinstance(constraint, bool) else ( 

335 " Did the constraint list contain an accidental equality?") 

336 if len(cns) == 1: 

337 loc = "the only constraint" 

338 elif i == 0: 

339 loc = "at the start, before %s" % cns[i+1] 

340 elif i == len(cns) - 1: 

341 loc = "at the end, after %s" % cns[i-1] 

342 else: 

343 loc = "between %s and %s" % (cns[i-1], cns[i+1]) 

344 return ValueError("Invalid ConstraintSet element '%s' %s was %s.%s" 

345 % (repr(constraint), type(constraint), loc, cause))