Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Implements ConstraintSet"
2import sys
3from collections import defaultdict, OrderedDict
4from itertools import chain
5import numpy as np
6from ..keydict import KeySet, KeyDict
7from ..small_scripts import try_str_without
8from ..repr_conventions import ReprMixin
9from .single_equation import SingleEquationConstraint
12def add_meq_bounds(bounded, meq_bounded): #TODO: collapse with GP version?
13 "Iterates through meq_bounds until convergence"
14 still_alive = True
15 while still_alive:
16 still_alive = False # if no changes are made, the loop exits
17 for bound in list(meq_bounded):
18 if bound in bounded: # bound already exists
19 del meq_bounded[bound]
20 continue
21 for condition in meq_bounded[bound]:
22 if condition.issubset(bounded): # bound's condition is met
23 del meq_bounded[bound]
24 bounded.add(bound)
25 still_alive = True
26 break
28def _sort_by_name_and_idx(var):
29 "return tuple for Variable sorting"
30 return (var.key.str_without(["units", "idx"]), var.key.idx or ())
32def _sort_constraints(item):
33 "return tuple for Constraint sorting"
34 label, constraint = item
35 return (not isinstance(constraint, SingleEquationConstraint),
36 bool(getattr(constraint, "lineage", None)), label)
38def sort_constraints_dict(iterable):
39 "Sort a dictionary of {k: constraint} and return its keys and values"
40 if sys.version_info >= (3, 7) or isinstance(iterable, OrderedDict):
41 return iterable.keys(), iterable.values()
42 items = sorted(list(iterable.items()), key=_sort_constraints)
43 return (item[0] for item in items), (item[1] for item in items)
45def flatiter(iterable, yield_if_hasattr=None):
46 "Yields contained constraints, optionally including constraintsets."
47 if isinstance(iterable, dict):
48 _, iterable = sort_constraints_dict(iterable)
49 for constraint in iterable:
50 if (not hasattr(constraint, "__iter__")
51 or (yield_if_hasattr
52 and hasattr(constraint, yield_if_hasattr))):
53 yield constraint
54 else:
55 try: # numpy array
56 yield from constraint.flat
57 except TypeError: # ConstrainSet
58 yield from constraint.flat(yield_if_hasattr)
59 except AttributeError: # probably a list or dict
60 yield from flatiter(constraint, yield_if_hasattr)
63class ConstraintSet(list, ReprMixin):
64 "Recursive container for ConstraintSets and Inequalities"
65 unique_varkeys, idxlookup = frozenset(), {}
66 _name_collision_varkeys = None
68 def __init__(self, constraints, substitutions=None): # pylint: disable=too-many-branches,too-many-statements
69 if isinstance(constraints, dict):
70 keys, constraints = sort_constraints_dict(constraints)
71 self.idxlookup = {k: i for i, k in enumerate(keys)}
72 elif isinstance(constraints, ConstraintSet):
73 constraints = [constraints] # put it one level down
74 list.__init__(self, constraints)
75 self.varkeys = KeySet(self.unique_varkeys)
76 self.substitutions = KeyDict({k: k.value for k in self.unique_varkeys
77 if "value" in k.descr})
78 self.substitutions.varkeys = self.varkeys
79 self.bounded, self.meq_bounded = set(), defaultdict(set)
80 for i, constraint in enumerate(self):
81 if hasattr(constraint, "varkeys"):
82 self._update(constraint)
83 elif not (hasattr(constraint, "as_hmapslt1")
84 or hasattr(constraint, "as_gpconstr")):
85 try:
86 for subconstraint in flatiter(constraint, "varkeys"):
87 self._update(subconstraint)
88 except Exception as e:
89 raise badelement(self, i, constraint) from e
90 elif isinstance(constraint, ConstraintSet):
91 raise badelement(self, i, constraint,
92 " It had not yet been initialized!")
93 if substitutions:
94 self.substitutions.update(substitutions)
95 for subkey in self.substitutions:
96 if subkey.shape and not subkey.idx: # vector sub found
97 for key in self.varkeys:
98 if key.veckey:
99 self.varkeys.keymap[key.veckey].add(key)
100 break # vectorkeys need to be mapped only once
101 for subkey in self.substitutions:
102 for key in self.varkeys[subkey]:
103 self.bounded.add((key, "upper"))
104 self.bounded.add((key, "lower"))
105 if key.value is not None and not key.constant:
106 del key.descr["value"]
107 if key.veckey and key.veckey.value is not None:
108 del key.veckey.descr["value"]
109 add_meq_bounds(self.bounded, self.meq_bounded)
111 def _update(self, constraint):
112 "Update parameters with a given constraint"
113 self.varkeys.update(constraint.varkeys)
114 if hasattr(constraint, "substitutions"):
115 self.substitutions.update(constraint.substitutions)
116 else:
117 self.substitutions.update({k: k.value \
118 for k in constraint.varkeys if "value" in k.descr})
119 self.bounded.update(constraint.bounded)
120 for bound, solutionset in constraint.meq_bounded.items():
121 self.meq_bounded[bound].update(solutionset)
123 def __getitem__(self, key):
124 if key in self.idxlookup:
125 key = self.idxlookup[key]
126 if isinstance(key, int):
127 return list.__getitem__(self, key)
128 return self._choosevar(key, self.variables_byname(key))
130 def _choosevar(self, key, variables):
131 if not variables:
132 raise KeyError(key)
133 firstvar, *othervars = variables
134 veckey = firstvar.key.veckey
135 if veckey is None or any(v.key.veckey != veckey for v in othervars):
136 if not othervars:
137 return firstvar
138 raise ValueError("multiple variables are called '%s'; show them"
139 " with `.variables_byname('%s')`" % (key, key))
140 from ..nomials import NomialArray # all one vector!
141 arr = NomialArray(np.full(veckey.shape, np.nan, dtype="object"))
142 for v in variables:
143 arr[v.key.idx] = v
144 arr.key = veckey
145 return arr
147 def variables_byname(self, key):
148 "Get all variables with a given name"
149 from ..nomials import Variable
150 return sorted([Variable(k) for k in self.varkeys[key]],
151 key=_sort_by_name_and_idx)
153 def constrained_varkeys(self):
154 "Return all varkeys in non-ConstraintSet constraints"
155 constrained_varkeys = set()
156 for constraint in self.flat(yield_if_hasattr="varkeys"):
157 constrained_varkeys.update(constraint.varkeys)
158 return constrained_varkeys
160 flat = flatiter
162 def as_hmapslt1(self, subs):
163 "Yields hmaps<=1 from self.flat()"
164 yield from chain(*(c.as_hmapslt1(subs)
165 for c in flatiter(self,
166 yield_if_hasattr="as_hmapslt1")))
168 def process_result(self, result):
169 """Does arbitrary computation / manipulation of a program's result
171 There's no guarantee what order different constraints will process
172 results in, so any changes made to the program's result should be
173 careful not to step on other constraint's toes.
175 Potential Uses
176 --------------
177 - check that an inequality was tight
178 - add values computed from solved variables
180 """
181 for constraint in self.flat(yield_if_hasattr="process_result"):
182 if hasattr(constraint, "process_result"):
183 constraint.process_result(result)
184 evalfn_vars = {v.veckey or v for v in self.unique_varkeys
185 if v.evalfn and v not in result["variables"]}
186 for v in evalfn_vars:
187 val = v.evalfn(result["variables"])
188 result["variables"][v] = result["freevariables"][v] = val
190 def __repr__(self):
191 "Returns namespaced string."
192 if not self:
193 return "<gpkit.%s object>" % self.__class__.__name__
194 return ("<gpkit.%s object containing %i top-level constraint(s)"
195 " and %i variable(s)>" % (self.__class__.__name__,
196 len(self), len(self.varkeys)))
198 def name_collision_varkeys(self):
199 "Returns the set of contained varkeys whose names are not unique"
200 if self._name_collision_varkeys is None:
201 self._name_collision_varkeys = {
202 key for key in self.varkeys
203 if len(self.varkeys[key.str_without(["lineage", "vec"])]) > 1}
204 return self._name_collision_varkeys
206 def lines_without(self, excluded):
207 "Lines representation of a ConstraintSet."
208 excluded = frozenset(excluded)
209 root, rootlines = "root" not in excluded, []
210 if root:
211 excluded = excluded.union(["root"])
212 if "unnecessary lineage" in excluded:
213 for key in self.name_collision_varkeys():
214 key.descr["necessarylineage"] = True
215 if hasattr(self, "_rootlines"):
216 rootlines = self._rootlines(excluded) # pylint: disable=no-member
217 lines = recursively_line(self, excluded)
218 indent = " " if getattr(self, "lineage", None) else ""
219 if root and "unnecessary lineage" in excluded:
220 indent += " "
221 for key in self.name_collision_varkeys():
222 del key.descr["necessarylineage"]
223 return rootlines + [(indent+line).rstrip() for line in lines]
225 def str_without(self, excluded=("unnecessary lineage", "units")):
226 "String representation of a ConstraintSet."
227 return "\n".join(self.lines_without(excluded))
229 def latex(self, excluded=("units",)):
230 "LaTeX representation of a ConstraintSet."
231 lines = []
232 root = "root" not in excluded
233 if root:
234 excluded += ("root",)
235 lines.append("\\begin{array}{ll} \\text{}")
236 if hasattr(self, "_rootlatex"):
237 lines.append(self._rootlatex(excluded)) # pylint: disable=no-member
238 for constraint in self:
239 cstr = try_str_without(constraint, excluded, latex=True)
240 if cstr[:6] != " & ": # require indentation
241 cstr = " & " + cstr + " \\\\"
242 lines.append(cstr)
243 if root:
244 lines.append("\\end{array}")
245 return "\n".join(lines)
247 def as_view(self):
248 "Return a ConstraintSetView of this ConstraintSet."
249 return ConstraintSetView(self)
251def recursively_line(iterable, excluded):
252 "Generates lines in a recursive tree-like fashion, the better to indent."
253 named_constraints = {}
254 if isinstance(iterable, dict):
255 keys, iterable = sort_constraints_dict(iterable)
256 named_constraints = dict(enumerate(keys))
257 elif hasattr(iterable, "idxlookup"):
258 named_constraints = {i: k for k, i in iterable.idxlookup.items()}
259 lines = []
260 for i, constraint in enumerate(iterable):
261 if hasattr(constraint, "lines_without"):
262 clines = constraint.lines_without(excluded)
263 elif not hasattr(constraint, "__iter__"):
264 clines = try_str_without(constraint, excluded).split("\n")
265 elif iterable is constraint:
266 clines = ["(constraint contained itself)"]
267 else:
268 clines = recursively_line(constraint, excluded)
269 if (getattr(constraint, "lineage", None)
270 and isinstance(constraint, ConstraintSet)):
271 name, num = constraint.lineage[-1]
272 if not any(clines):
273 clines = [" " + "(no constraints)"] # named model indent
274 if lines:
275 lines.append("")
276 lines.append(name if not num else name + str(num))
277 elif "constraint names" not in excluded and i in named_constraints:
278 lines.append("\"%s\":" % named_constraints[i])
279 clines = [" " + line for line in clines] # named constraint indent
280 lines.extend(clines)
281 return lines
284class ConstraintSetView:
285 "Class to access particular views on a set's variables"
287 def __init__(self, constraintset, index=()):
288 self.constraintset = constraintset
289 try:
290 self.index = tuple(index)
291 except TypeError: # probably not iterable
292 self.index = (index,)
294 def __getitem__(self, index):
295 "Appends the index to its own and returns a new view."
296 if not isinstance(index, tuple):
297 index = (index,)
298 # indexes are preprended to match Vectorize convention
299 return ConstraintSetView(self.constraintset, index + self.index)
301 def __getattr__(self, attr):
302 """Returns attribute from the base ConstraintSets
304 If it's a another ConstraintSet, return the matching View;
305 if it's an array, return it at the specified index;
306 otherwise, raise an error.
307 """
308 if not hasattr(self.constraintset, attr):
309 raise AttributeError("the underlying object lacks `.%s`." % attr)
311 value = getattr(self.constraintset, attr)
312 if isinstance(value, ConstraintSet):
313 return ConstraintSetView(value, self.index)
314 if not hasattr(value, "shape"):
315 raise ValueError("attribute %s with value %s did not have"
316 " a shape, so ConstraintSetView cannot"
317 " return an indexed view." % (attr, value))
318 index = self.index
319 newdims = len(value.shape) - len(self.index)
320 if newdims > 0: # indexes are put last to match Vectorize
321 index = (slice(None),)*newdims + index
322 return value[index]
326def badelement(cns, i, constraint, cause=""):
327 "Identify the bad element and raise a ValueError"
328 cause = cause if not isinstance(constraint, bool) else (
329 " Did the constraint list contain an accidental equality?")
330 if len(cns) == 1:
331 loc = "the only constraint"
332 elif i == 0:
333 loc = "at the start, before %s" % cns[i+1]
334 elif i == len(cns) - 1:
335 loc = "at the end, after %s" % cns[i-1]
336 else:
337 loc = "between %s and %s" % (cns[i-1], cns[i+1])
338 return ValueError("Invalid ConstraintSet element '%s' %s was %s.%s"
339 % (repr(constraint), type(constraint), loc, cause))