Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"Implements ConstraintSet" 

2import sys 

3from collections import defaultdict, OrderedDict 

4from itertools import chain 

5import numpy as np 

6from ..keydict import KeySet, KeyDict 

7from ..small_scripts import try_str_without 

8from ..repr_conventions import ReprMixin 

9from .single_equation import SingleEquationConstraint 

10 

11 

12def add_meq_bounds(bounded, meq_bounded): #TODO: collapse with GP version? 

13 "Iterates through meq_bounds until convergence" 

14 still_alive = True 

15 while still_alive: 

16 still_alive = False # if no changes are made, the loop exits 

17 for bound in list(meq_bounded): 

18 if bound in bounded: # bound already exists 

19 del meq_bounded[bound] 

20 continue 

21 for condition in meq_bounded[bound]: 

22 if condition.issubset(bounded): # bound's condition is met 

23 del meq_bounded[bound] 

24 bounded.add(bound) 

25 still_alive = True 

26 break 

27 

28def _sort_by_name_and_idx(var): 

29 "return tuple for Variable sorting" 

30 return (var.key.str_without(["units", "idx"]), var.key.idx or ()) 

31 

32def _sort_constraints(item): 

33 "return tuple for Constraint sorting" 

34 label, constraint = item 

35 return (not isinstance(constraint, SingleEquationConstraint), 

36 bool(getattr(constraint, "lineage", None)), label) 

37 

38def sort_constraints_dict(iterable): 

39 "Sort a dictionary of {k: constraint} and return its keys and values" 

40 if sys.version_info >= (3, 7) or isinstance(iterable, OrderedDict): 

41 return iterable.keys(), iterable.values() 

42 items = sorted(list(iterable.items()), key=_sort_constraints) 

43 return (item[0] for item in items), (item[1] for item in items) 

44 

45def flatiter(iterable, yield_if_hasattr=None): 

46 "Yields contained constraints, optionally including constraintsets." 

47 if isinstance(iterable, dict): 

48 _, iterable = sort_constraints_dict(iterable) 

49 for constraint in iterable: 

50 if (not hasattr(constraint, "__iter__") 

51 or (yield_if_hasattr 

52 and hasattr(constraint, yield_if_hasattr))): 

53 yield constraint 

54 else: 

55 try: # numpy array 

56 yield from constraint.flat 

57 except TypeError: # ConstrainSet 

58 yield from constraint.flat(yield_if_hasattr) 

59 except AttributeError: # probably a list or dict 

60 yield from flatiter(constraint, yield_if_hasattr) 

61 

62 

63class ConstraintSet(list, ReprMixin): 

64 "Recursive container for ConstraintSets and Inequalities" 

65 unique_varkeys, idxlookup = frozenset(), {} 

66 _name_collision_varkeys = None 

67 

68 def __init__(self, constraints, substitutions=None): # pylint: disable=too-many-branches,too-many-statements 

69 if isinstance(constraints, dict): 

70 keys, constraints = sort_constraints_dict(constraints) 

71 self.idxlookup = {k: i for i, k in enumerate(keys)} 

72 elif isinstance(constraints, ConstraintSet): 

73 constraints = [constraints] # put it one level down 

74 list.__init__(self, constraints) 

75 self.varkeys = KeySet(self.unique_varkeys) 

76 self.substitutions = KeyDict({k: k.value for k in self.unique_varkeys 

77 if "value" in k.descr}) 

78 self.substitutions.varkeys = self.varkeys 

79 self.bounded, self.meq_bounded = set(), defaultdict(set) 

80 for i, constraint in enumerate(self): 

81 if hasattr(constraint, "varkeys"): 

82 self._update(constraint) 

83 elif not (hasattr(constraint, "as_hmapslt1") 

84 or hasattr(constraint, "as_gpconstr")): 

85 try: 

86 for subconstraint in flatiter(constraint, "varkeys"): 

87 self._update(subconstraint) 

88 except Exception as e: 

89 raise badelement(self, i, constraint) from e 

90 elif isinstance(constraint, ConstraintSet): 

91 raise badelement(self, i, constraint, 

92 " It had not yet been initialized!") 

93 if substitutions: 

94 self.substitutions.update(substitutions) 

95 for subkey in self.substitutions: 

96 if subkey.shape and not subkey.idx: # vector sub found 

97 for key in self.varkeys: 

98 if key.veckey: 

99 self.varkeys.keymap[key.veckey].add(key) 

100 break # vectorkeys need to be mapped only once 

101 for subkey in self.substitutions: 

102 for key in self.varkeys[subkey]: 

103 self.bounded.add((key, "upper")) 

104 self.bounded.add((key, "lower")) 

105 if key.value is not None and not key.constant: 

106 del key.descr["value"] 

107 if key.veckey and key.veckey.value is not None: 

108 del key.veckey.descr["value"] 

109 add_meq_bounds(self.bounded, self.meq_bounded) 

110 

111 def _update(self, constraint): 

112 "Update parameters with a given constraint" 

113 self.varkeys.update(constraint.varkeys) 

114 if hasattr(constraint, "substitutions"): 

115 self.substitutions.update(constraint.substitutions) 

116 else: 

117 self.substitutions.update({k: k.value \ 

118 for k in constraint.varkeys if "value" in k.descr}) 

119 self.bounded.update(constraint.bounded) 

120 for bound, solutionset in constraint.meq_bounded.items(): 

121 self.meq_bounded[bound].update(solutionset) 

122 

123 def __getitem__(self, key): 

124 if key in self.idxlookup: 

125 key = self.idxlookup[key] 

126 if isinstance(key, int): 

127 return list.__getitem__(self, key) 

128 return self._choosevar(key, self.variables_byname(key)) 

129 

130 def _choosevar(self, key, variables): 

131 if not variables: 

132 raise KeyError(key) 

133 firstvar, *othervars = variables 

134 veckey = firstvar.key.veckey 

135 if veckey is None or any(v.key.veckey != veckey for v in othervars): 

136 if not othervars: 

137 return firstvar 

138 raise ValueError("multiple variables are called '%s'; show them" 

139 " with `.variables_byname('%s')`" % (key, key)) 

140 from ..nomials import NomialArray # all one vector! 

141 arr = NomialArray(np.full(veckey.shape, np.nan, dtype="object")) 

142 for v in variables: 

143 arr[v.key.idx] = v 

144 arr.key = veckey 

145 return arr 

146 

147 def variables_byname(self, key): 

148 "Get all variables with a given name" 

149 from ..nomials import Variable 

150 return sorted([Variable(k) for k in self.varkeys[key]], 

151 key=_sort_by_name_and_idx) 

152 

153 def constrained_varkeys(self): 

154 "Return all varkeys in non-ConstraintSet constraints" 

155 constrained_varkeys = set() 

156 for constraint in self.flat(yield_if_hasattr="varkeys"): 

157 constrained_varkeys.update(constraint.varkeys) 

158 return constrained_varkeys 

159 

160 flat = flatiter 

161 

162 def as_hmapslt1(self, subs): 

163 "Yields hmaps<=1 from self.flat()" 

164 yield from chain(*(c.as_hmapslt1(subs) 

165 for c in flatiter(self, 

166 yield_if_hasattr="as_hmapslt1"))) 

167 

168 def process_result(self, result): 

169 """Does arbitrary computation / manipulation of a program's result 

170 

171 There's no guarantee what order different constraints will process 

172 results in, so any changes made to the program's result should be 

173 careful not to step on other constraint's toes. 

174 

175 Potential Uses 

176 -------------- 

177 - check that an inequality was tight 

178 - add values computed from solved variables 

179 

180 """ 

181 for constraint in self.flat(yield_if_hasattr="process_result"): 

182 if hasattr(constraint, "process_result"): 

183 constraint.process_result(result) 

184 evalfn_vars = {v.veckey or v for v in self.unique_varkeys 

185 if v.evalfn and v not in result["variables"]} 

186 for v in evalfn_vars: 

187 val = v.evalfn(result["variables"]) 

188 result["variables"][v] = result["freevariables"][v] = val 

189 

190 def __repr__(self): 

191 "Returns namespaced string." 

192 if not self: 

193 return "<gpkit.%s object>" % self.__class__.__name__ 

194 return ("<gpkit.%s object containing %i top-level constraint(s)" 

195 " and %i variable(s)>" % (self.__class__.__name__, 

196 len(self), len(self.varkeys))) 

197 

198 def name_collision_varkeys(self): 

199 "Returns the set of contained varkeys whose names are not unique" 

200 if self._name_collision_varkeys is None: 

201 self._name_collision_varkeys = { 

202 key for key in self.varkeys 

203 if len(self.varkeys[key.str_without(["lineage", "vec"])]) > 1} 

204 return self._name_collision_varkeys 

205 

206 def lines_without(self, excluded): 

207 "Lines representation of a ConstraintSet." 

208 excluded = frozenset(excluded) 

209 root, rootlines = "root" not in excluded, [] 

210 if root: 

211 excluded = excluded.union(["root"]) 

212 if "unnecessary lineage" in excluded: 

213 for key in self.name_collision_varkeys(): 

214 key.descr["necessarylineage"] = True 

215 if hasattr(self, "_rootlines"): 

216 rootlines = self._rootlines(excluded) # pylint: disable=no-member 

217 lines = recursively_line(self, excluded) 

218 indent = " " if getattr(self, "lineage", None) else "" 

219 if root and "unnecessary lineage" in excluded: 

220 indent += " " 

221 for key in self.name_collision_varkeys(): 

222 del key.descr["necessarylineage"] 

223 return rootlines + [(indent+line).rstrip() for line in lines] 

224 

225 def str_without(self, excluded=("unnecessary lineage", "units")): 

226 "String representation of a ConstraintSet." 

227 return "\n".join(self.lines_without(excluded)) 

228 

229 def latex(self, excluded=("units",)): 

230 "LaTeX representation of a ConstraintSet." 

231 lines = [] 

232 root = "root" not in excluded 

233 if root: 

234 excluded += ("root",) 

235 lines.append("\\begin{array}{ll} \\text{}") 

236 if hasattr(self, "_rootlatex"): 

237 lines.append(self._rootlatex(excluded)) # pylint: disable=no-member 

238 for constraint in self: 

239 cstr = try_str_without(constraint, excluded, latex=True) 

240 if cstr[:6] != " & ": # require indentation 

241 cstr = " & " + cstr + " \\\\" 

242 lines.append(cstr) 

243 if root: 

244 lines.append("\\end{array}") 

245 return "\n".join(lines) 

246 

247 def as_view(self): 

248 "Return a ConstraintSetView of this ConstraintSet." 

249 return ConstraintSetView(self) 

250 

251def recursively_line(iterable, excluded): 

252 "Generates lines in a recursive tree-like fashion, the better to indent." 

253 named_constraints = {} 

254 if isinstance(iterable, dict): 

255 keys, iterable = sort_constraints_dict(iterable) 

256 named_constraints = dict(enumerate(keys)) 

257 elif hasattr(iterable, "idxlookup"): 

258 named_constraints = {i: k for k, i in iterable.idxlookup.items()} 

259 lines = [] 

260 for i, constraint in enumerate(iterable): 

261 if hasattr(constraint, "lines_without"): 

262 clines = constraint.lines_without(excluded) 

263 elif not hasattr(constraint, "__iter__"): 

264 clines = try_str_without(constraint, excluded).split("\n") 

265 elif iterable is constraint: 

266 clines = ["(constraint contained itself)"] 

267 else: 

268 clines = recursively_line(constraint, excluded) 

269 if (getattr(constraint, "lineage", None) 

270 and isinstance(constraint, ConstraintSet)): 

271 name, num = constraint.lineage[-1] 

272 if not any(clines): 

273 clines = [" " + "(no constraints)"] # named model indent 

274 if lines: 

275 lines.append("") 

276 lines.append(name if not num else name + str(num)) 

277 elif "constraint names" not in excluded and i in named_constraints: 

278 lines.append("\"%s\":" % named_constraints[i]) 

279 clines = [" " + line for line in clines] # named constraint indent 

280 lines.extend(clines) 

281 return lines 

282 

283 

284class ConstraintSetView: 

285 "Class to access particular views on a set's variables" 

286 

287 def __init__(self, constraintset, index=()): 

288 self.constraintset = constraintset 

289 try: 

290 self.index = tuple(index) 

291 except TypeError: # probably not iterable 

292 self.index = (index,) 

293 

294 def __getitem__(self, index): 

295 "Appends the index to its own and returns a new view." 

296 if not isinstance(index, tuple): 

297 index = (index,) 

298 # indexes are preprended to match Vectorize convention 

299 return ConstraintSetView(self.constraintset, index + self.index) 

300 

301 def __getattr__(self, attr): 

302 """Returns attribute from the base ConstraintSets 

303 

304 If it's a another ConstraintSet, return the matching View; 

305 if it's an array, return it at the specified index; 

306 otherwise, raise an error. 

307 """ 

308 if not hasattr(self.constraintset, attr): 

309 raise AttributeError("the underlying object lacks `.%s`." % attr) 

310 

311 value = getattr(self.constraintset, attr) 

312 if isinstance(value, ConstraintSet): 

313 return ConstraintSetView(value, self.index) 

314 if not hasattr(value, "shape"): 

315 raise ValueError("attribute %s with value %s did not have" 

316 " a shape, so ConstraintSetView cannot" 

317 " return an indexed view." % (attr, value)) 

318 index = self.index 

319 newdims = len(value.shape) - len(self.index) 

320 if newdims > 0: # indexes are put last to match Vectorize 

321 index = (slice(None),)*newdims + index 

322 return value[index] 

323 

324 

325 

326def badelement(cns, i, constraint, cause=""): 

327 "Identify the bad element and raise a ValueError" 

328 cause = cause if not isinstance(constraint, bool) else ( 

329 " Did the constraint list contain an accidental equality?") 

330 if len(cns) == 1: 

331 loc = "the only constraint" 

332 elif i == 0: 

333 loc = "at the start, before %s" % cns[i+1] 

334 elif i == len(cns) - 1: 

335 loc = "at the end, after %s" % cns[i-1] 

336 else: 

337 loc = "between %s and %s" % (cns[i-1], cns[i+1]) 

338 return ValueError("Invalid ConstraintSet element '%s' %s was %s.%s" 

339 % (repr(constraint), type(constraint), loc, cause))