Coverage for gpkit\constraints\set.py : 0%
![Show keyboard shortcuts](keybd_closed.png)
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Implements ConstraintSet"
2import sys
3from collections import defaultdict, OrderedDict
4from itertools import chain
5import numpy as np
6from ..keydict import KeySet, KeyDict
7from ..small_scripts import try_str_without
8from ..repr_conventions import ReprMixin
9from .single_equation import SingleEquationConstraint
12def add_meq_bounds(bounded, meq_bounded): #TODO: collapse with GP version?
13 "Iterates through meq_bounds until convergence"
14 still_alive = True
15 while still_alive:
16 still_alive = False # if no changes are made, the loop exits
17 for bound in list(meq_bounded):
18 if bound in bounded: # bound already exists
19 del meq_bounded[bound]
20 continue
21 for condition in meq_bounded[bound]:
22 if condition.issubset(bounded): # bound's condition is met
23 del meq_bounded[bound]
24 bounded.add(bound)
25 still_alive = True
26 break
28def _sort_by_name_and_idx(var):
29 "return tuple for Variable sorting"
30 return (var.key.str_without(["units", "idx"]), var.key.idx or ())
32def _sort_constraints(item):
33 "return tuple for Constraint sorting"
34 label, constraint = item
35 return (not isinstance(constraint, SingleEquationConstraint),
36 bool(getattr(constraint, "lineage", None)), label)
38def sort_constraints_dict(iterable):
39 "Sort a dictionary of {k: constraint} and return its keys and values"
40 if sys.version_info >= (3, 7) or isinstance(iterable, OrderedDict):
41 return iterable.keys(), iterable.values()
42 items = sorted(list(iterable.items()), key=_sort_constraints)
43 return (item[0] for item in items), (item[1] for item in items)
45def flatiter(iterable, yield_if_hasattr=None):
46 "Yields contained constraints, optionally including constraintsets."
47 if isinstance(iterable, dict):
48 _, iterable = sort_constraints_dict(iterable)
49 for constraint in iterable:
50 if (not hasattr(constraint, "__iter__")
51 or (yield_if_hasattr
52 and hasattr(constraint, yield_if_hasattr))):
53 yield constraint
54 else:
55 try: # numpy array
56 yield from constraint.flat
57 except TypeError: # ConstrainSet
58 yield from constraint.flat(yield_if_hasattr)
59 except AttributeError: # probably a list or dict
60 yield from flatiter(constraint, yield_if_hasattr)
63class ConstraintSet(list, ReprMixin):
64 "Recursive container for ConstraintSets and Inequalities"
65 unique_varkeys, idxlookup = frozenset(), {}
66 _name_collision_varkeys = None
67 _varkeys = None
69 def __init__(self, constraints, substitutions=None, *, bonusvks=None): # pylint: disable=too-many-branches,too-many-statements
70 if isinstance(constraints, dict):
71 keys, constraints = sort_constraints_dict(constraints)
72 self.idxlookup = {k: i for i, k in enumerate(keys)}
73 elif isinstance(constraints, ConstraintSet):
74 constraints = [constraints] # put it one level down
75 list.__init__(self, constraints)
76 self.vks = set(self.unique_varkeys)
77 self.substitutions = KeyDict({k: k.value for k in self.unique_varkeys
78 if "value" in k.descr})
79 self.substitutions.vks = self.vks
80 self.bounded, self.meq_bounded = set(), defaultdict(set)
81 for i, constraint in enumerate(self):
82 if hasattr(constraint, "vks"):
83 self._update(constraint)
84 elif not (hasattr(constraint, "as_hmapslt1")
85 or hasattr(constraint, "as_gpconstr")):
86 try:
87 for subconstraint in flatiter(constraint, "vks"):
88 self._update(subconstraint)
89 except Exception as e:
90 raise badelement(self, i, constraint) from e
91 elif isinstance(constraint, ConstraintSet):
92 raise badelement(self, i, constraint,
93 " It had not yet been initialized!")
94 if bonusvks:
95 self.vks.update(bonusvks)
96 if substitutions:
97 self.substitutions.update(substitutions)
98 for key in self.vks:
99 if key not in self.substitutions:
100 if key.veckey is None or key.veckey not in self.substitutions:
101 continue
102 if np.isnan(self.substitutions[key.veckey][key.idx]):
103 continue
104 self.bounded.add((key, "upper"))
105 self.bounded.add((key, "lower"))
106 if key.value is not None and not key.constant:
107 del key.descr["value"]
108 if key.veckey and key.veckey.value is not None:
109 del key.veckey.descr["value"]
110 add_meq_bounds(self.bounded, self.meq_bounded)
112 def _update(self, constraint):
113 "Update parameters with a given constraint"
114 self.vks.update(constraint.vks)
115 if hasattr(constraint, "substitutions"):
116 self.substitutions.update(constraint.substitutions)
117 else:
118 self.substitutions.update({k: k.value \
119 for k in constraint.vks if "value" in k.descr})
120 self.bounded.update(constraint.bounded)
121 for bound, solutionset in constraint.meq_bounded.items():
122 self.meq_bounded[bound].update(solutionset)
124 def __getitem__(self, key):
125 if key in self.idxlookup:
126 key = self.idxlookup[key]
127 if isinstance(key, int):
128 return list.__getitem__(self, key)
129 return self._choosevar(key, self.variables_byname(key))
131 def _choosevar(self, key, variables):
132 if not variables:
133 raise KeyError(key)
134 firstvar, *othervars = variables
135 veckey = firstvar.key.veckey
136 if veckey is None or any(v.key.veckey != veckey for v in othervars):
137 if not othervars:
138 return firstvar
139 raise ValueError("multiple variables are called '%s'; show them"
140 " with `.variables_byname('%s')`" % (key, key))
141 from ..nomials import NomialArray # all one vector!
142 arr = NomialArray(np.full(veckey.shape, np.nan, dtype="object"))
143 for v in variables:
144 arr[v.key.idx] = v
145 arr.key = veckey
146 return arr
148 def variables_byname(self, key):
149 "Get all variables with a given name"
150 from ..nomials import Variable
151 return sorted([Variable(k) for k in self.varkeys[key]],
152 key=_sort_by_name_and_idx)
154 @property
155 def varkeys(self):
156 "The NomialData's varkeys, created when necessary for a substitution."
157 if self._varkeys is None:
158 self._varkeys = KeySet(self.vks)
159 return self._varkeys
161 def constrained_varkeys(self):
162 "Return all varkeys in non-ConstraintSet constraints"
163 return self.vks - self.unique_varkeys
165 flat = flatiter
167 def as_hmapslt1(self, subs):
168 "Yields hmaps<=1 from self.flat()"
169 yield from chain(*(c.as_hmapslt1(subs)
170 for c in flatiter(self,
171 yield_if_hasattr="as_hmapslt1")))
173 def process_result(self, result):
174 """Does arbitrary computation / manipulation of a program's result
176 There's no guarantee what order different constraints will process
177 results in, so any changes made to the program's result should be
178 careful not to step on other constraint's toes.
180 Potential Uses
181 --------------
182 - check that an inequality was tight
183 - add values computed from solved variables
185 """
186 for constraint in self.flat(yield_if_hasattr="process_result"):
187 if hasattr(constraint, "process_result"):
188 constraint.process_result(result)
189 evalfn_vars = {v.veckey or v for v in self.unique_varkeys
190 if v.evalfn and v not in result["variables"]}
191 for v in evalfn_vars:
192 val = v.evalfn(result["variables"])
193 result["variables"][v] = result["freevariables"][v] = val
195 def __repr__(self):
196 "Returns namespaced string."
197 if not self:
198 return "<gpkit.%s object>" % self.__class__.__name__
199 return ("<gpkit.%s object containing %i top-level constraint(s)"
200 " and %i variable(s)>" % (self.__class__.__name__,
201 len(self), len(self.varkeys)))
203 def set_necessarylineage(self, clear=False):
204 "Returns the set of contained varkeys whose names are not unique"
205 if self._name_collision_varkeys is None:
206 self._name_collision_varkeys = {}
207 name_collisions = defaultdict(set)
208 for key in self.varkeys:
209 shortname = key.str_without(["lineage", "vec"])
210 if len(self.varkeys[shortname]) > 1:
211 name_collisions[shortname].add(key)
212 for vks in name_collisions.values():
213 min_namespaced = defaultdict(set)
214 for vk in vks:
215 *_, mineage = vk.lineagestr().split(".")
216 min_namespaced[(mineage, 1)].add(vk)
217 while any(len(vks) > 1 for vks in min_namespaced.values()):
218 for key, vks in list(min_namespaced.items()):
219 if len(vks) > 1:
220 del min_namespaced[key]
221 mineage, idx = key
222 idx += 1
223 for vk in vks:
224 lineages = vk.lineagestr().split(".")
225 submineage = lineages[-idx] + "." + mineage
226 min_namespaced[(submineage, idx)].add(vk)
227 for (_, idx), vks in min_namespaced.items():
228 vk, = vks
229 self._name_collision_varkeys[vk] = idx
230 if clear:
231 for vk in self._name_collision_varkeys:
232 del vk.descr["necessarylineage"]
233 else:
234 for vk, idx in self._name_collision_varkeys.items():
235 vk.descr["necessarylineage"] = idx
237 def lines_without(self, excluded):
238 "Lines representation of a ConstraintSet."
239 excluded = frozenset(excluded)
240 root, rootlines = "root" not in excluded, []
241 if root:
242 excluded = excluded.union(["root"])
243 if "unnecessary lineage" in excluded:
244 self.set_necessarylineage()
245 if hasattr(self, "_rootlines"):
246 rootlines = self._rootlines(excluded) # pylint: disable=no-member
247 lines = recursively_line(self, excluded)
248 indent = " " if getattr(self, "lineage", None) else ""
249 if root and "unnecessary lineage" in excluded:
250 indent += " "
251 self.set_necessarylineage(clear=True)
252 return rootlines + [(indent+line).rstrip() for line in lines]
254 def str_without(self, excluded=("unnecessary lineage", "units")):
255 "String representation of a ConstraintSet."
256 return "\n".join(self.lines_without(excluded))
258 def latex(self, excluded=("units",)):
259 "LaTeX representation of a ConstraintSet."
260 lines = []
261 root = "root" not in excluded
262 if root:
263 excluded += ("root",)
264 lines.append("\\begin{array}{ll} \\text{}")
265 if hasattr(self, "_rootlatex"):
266 lines.append(self._rootlatex(excluded)) # pylint: disable=no-member
267 for constraint in self:
268 cstr = try_str_without(constraint, excluded, latex=True)
269 if cstr[:6] != " & ": # require indentation
270 cstr = " & " + cstr + " \\\\"
271 lines.append(cstr)
272 if root:
273 lines.append("\\end{array}")
274 return "\n".join(lines)
276 def as_view(self):
277 "Return a ConstraintSetView of this ConstraintSet."
278 return ConstraintSetView(self)
280def recursively_line(iterable, excluded):
281 "Generates lines in a recursive tree-like fashion, the better to indent."
282 named_constraints = {}
283 if isinstance(iterable, dict):
284 keys, iterable = sort_constraints_dict(iterable)
285 named_constraints = dict(enumerate(keys))
286 elif hasattr(iterable, "idxlookup"):
287 named_constraints = {i: k for k, i in iterable.idxlookup.items()}
288 lines = []
289 for i, constraint in enumerate(iterable):
290 if hasattr(constraint, "lines_without"):
291 clines = constraint.lines_without(excluded)
292 elif not hasattr(constraint, "__iter__"):
293 clines = try_str_without(constraint, excluded).split("\n")
294 elif iterable is constraint:
295 clines = ["(constraint contained itself)"]
296 else:
297 clines = recursively_line(constraint, excluded)
298 if (getattr(constraint, "lineage", None)
299 and isinstance(constraint, ConstraintSet)):
300 name, num = constraint.lineage[-1]
301 if not any(clines):
302 clines = [" " + "(no constraints)"] # named model indent
303 if lines:
304 lines.append("")
305 lines.append(name if not num else name + str(num))
306 elif "constraint names" not in excluded and i in named_constraints:
307 lines.append("\"%s\":" % named_constraints[i])
308 clines = [" " + line for line in clines] # named constraint indent
309 lines.extend(clines)
310 return lines
313class ConstraintSetView:
314 "Class to access particular views on a set's variables"
316 def __init__(self, constraintset, index=()):
317 self.constraintset = constraintset
318 try:
319 self.index = tuple(index)
320 except TypeError: # probably not iterable
321 self.index = (index,)
323 def __getitem__(self, index):
324 "Appends the index to its own and returns a new view."
325 if not isinstance(index, tuple):
326 index = (index,)
327 # indexes are preprended to match Vectorize convention
328 return ConstraintSetView(self.constraintset, index + self.index)
330 def __getattr__(self, attr):
331 """Returns attribute from the base ConstraintSets
333 If it's a another ConstraintSet, return the matching View;
334 if it's an array, return it at the specified index;
335 otherwise, raise an error.
336 """
337 if not hasattr(self.constraintset, attr):
338 raise AttributeError("the underlying object lacks `.%s`." % attr)
340 value = getattr(self.constraintset, attr)
341 if isinstance(value, ConstraintSet):
342 return ConstraintSetView(value, self.index)
343 if not hasattr(value, "shape"):
344 raise ValueError("attribute %s with value %s did not have"
345 " a shape, so ConstraintSetView cannot"
346 " return an indexed view." % (attr, value))
347 index = self.index
348 newdims = len(value.shape) - len(self.index)
349 if newdims > 0: # indexes are put last to match Vectorize
350 index = (slice(None),)*newdims + index
351 return value[index]
355def badelement(cns, i, constraint, cause=""):
356 "Identify the bad element and raise a ValueError"
357 cause = cause if not isinstance(constraint, bool) else (
358 " Did the constraint list contain an accidental equality?")
359 if len(cns) == 1:
360 loc = "the only constraint"
361 elif i == 0:
362 loc = "at the start, before %s" % cns[i+1]
363 elif i == len(cns) - 1:
364 loc = "at the end, after %s" % cns[i-1]
365 else:
366 loc = "between %s and %s" % (cns[i-1], cns[i+1])
367 return ValueError("Invalid ConstraintSet element '%s' %s was %s.%s"
368 % (repr(constraint), type(constraint), loc, cause))