Coverage for gpkit/keydict.py: 84%
226 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 22:33 -0500
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 22:33 -0500
1"Implements KeyDict and KeySet classes"
2from collections import defaultdict
3from collections.abc import Hashable
4import numpy as np
5from .small_classes import Numbers, Quantity, FixedScalar
6from .small_scripts import is_sweepvar, isnan, veclinkedfn
8DIMLESS_QUANTITY = Quantity(1, "dimensionless")
9INT_DTYPE = np.dtype(int)
11def clean_value(key, value):
12 """Gets the value of variable-less monomials, so that
13 `x.sub({x: gpkit.units.m})` and `x.sub({x: gpkit.ureg.m})` are equivalent.
15 Also converts any quantities to the key's units, because quantities
16 can't/shouldn't be stored as elements of numpy arrays.
17 """
18 if isinstance(value, FixedScalar):
19 value = value.value
20 if isinstance(value, Quantity):
21 value = value.to(key.units or "dimensionless").magnitude
22 return value
25class KeyMap:
26 """Helper class to provide KeyMapping to interfaces.
28 Mapping keys
29 ------------
30 A KeyMap keeps an internal list of VarKeys as
31 canonical keys, and their values can be accessed with any object whose
32 `key` attribute matches one of those VarKeys, or with strings matching
33 any of the multiple possible string interpretations of each key:
35 For example, after creating the KeyDict kd and setting kd[x] = v (where x
36 is a Variable or VarKey), v can be accessed with by the following keys:
37 - x
38 - x.key
39 - x.name (a string)
40 - "x_modelname" (x's name including modelname)
42 Note that if a item is set using a key that does not have a `.key`
43 attribute, that key can be set and accessed normally.
44 """
45 collapse_arrays = False
46 keymap = []
47 log_gets = False
48 vks = varkeys = None
50 def __init__(self, *args, **kwargs):
51 "Passes through to super().__init__ via the `update()` method"
52 self.keymap = defaultdict(set)
53 self._unmapped_keys = set()
54 self.owned = set()
55 self.update(*args, **kwargs) # pylint: disable=no-member
57 def parse_and_index(self, key):
58 "Returns key if key had one, and veckey/idx for indexed veckeys."
59 try:
60 key = key.key
61 if self.collapse_arrays and key.idx:
62 return key.veckey, key.idx
63 return key, None
64 except AttributeError:
65 if self.vks is None and self.varkeys is None:
66 return key, self.update_keymap()
67 # looks like we're in a substitutions dictionary
68 if self.varkeys is None:
69 self.varkeys = KeySet(self.vks)
70 if key not in self.varkeys:
71 raise KeyError(key)
72 newkey, *otherkeys = self.varkeys[key]
73 if otherkeys:
74 if all(k.veckey == newkey.veckey for k in otherkeys):
75 return newkey.veckey, None
76 raise ValueError(f"{key} refers to multiple keys in this "
77 "substitutions KeyDict. Use "
78 f"`.variables_byname({key})` to see all of them.")
79 if self.collapse_arrays and newkey.idx:
80 return newkey.veckey, newkey.idx
81 return newkey, None
83 def __contains__(self, key):
84 "In a winding way, figures out if a key is in the KeyDict"
85 # pylint: disable=no-member
86 try:
87 key, idx = self.parse_and_index(key)
88 except KeyError:
89 return False
90 except ValueError: # multiple keys correspond
91 return True
92 if not isinstance(key, Hashable):
93 return False
94 if super().__contains__(key):
95 if idx:
96 try:
97 val = super().__getitem__(key)[idx]
98 return True if is_sweepvar(val) else not isnan(val).any()
99 except TypeError as err:
100 val = super().__getitem__(key)
101 raise TypeError(f"{key} has an idx, but its value in this"
102 " KeyDict is the scalar {val}.") from err
103 except IndexError as err:
104 val = super().__getitem__(key)
105 raise IndexError(f"key {key} with idx {idx} is out of "
106 " bounds for value {val}") from err
107 return key in self.keymap
109 def update_keymap(self):
110 "Updates the keymap with the keys in _unmapped_keys"
111 copied = set() # have to copy bc update leaves duplicate sets
112 for key in self._unmapped_keys:
113 for mapkey in key.keys:
114 if mapkey not in copied and mapkey in self.keymap:
115 self.keymap[mapkey] = set(self.keymap[mapkey])
116 copied.add(mapkey)
117 self.keymap[mapkey].add(key)
118 self._unmapped_keys = set()
121class KeyDict(KeyMap, dict):
122 """KeyDicts do two things over a dict: map keys and collapse arrays.
124 >>>> kd = gpkit.keydict.KeyDict()
126 For mapping keys, see KeyMapper.__doc__
128 Collapsing arrays
129 -----------------
130 If ``.collapse_arrays`` is True then VarKeys which have a `shape`
131 parameter (indicating they are part of an array) are stored as numpy
132 arrays, and automatically de-indexed when a matching VarKey with a
133 particular `idx` parameter is used as a key.
135 See also: gpkit/tests/t_keydict.py.
136 """
137 collapse_arrays = True
139 def get(self, key, *alternative):
140 return alternative[0] if alternative and key not in self else self[key]
142 def _copyonwrite(self, key):
143 "Copys arrays before they are written to"
144 if not hasattr(self, "owned"): # backwards pickle compatibility
145 self.owned = set()
146 if key not in self.owned:
147 super().__setitem__(key, super().__getitem__(key).copy())
148 self.owned.add(key)
150 def update(self, *args, **kwargs):
151 "Iterates through the dictionary created by args and kwargs"
152 if not self and len(args) == 1 and isinstance(args[0], KeyDict):
153 super().update(args[0])
154 self.keymap.update(args[0].keymap)
155 self._unmapped_keys.update(args[0]._unmapped_keys) # pylint:disable=protected-access
156 else:
157 for k, v in dict(*args, **kwargs).items():
158 self[k] = v
160 def __call__(self, key): # if uniting is ever a speed hit, cache it
161 got = self[key]
162 if isinstance(got, dict):
163 for k, v in got.items():
164 got[k] = v*(k.units or DIMLESS_QUANTITY)
165 return got
166 if not hasattr(key, "units"):
167 key, = self.keymap[self.parse_and_index(key)[0]]
168 return Quantity(got, key.units or "dimensionless")
170 def __getitem__(self, key):
171 "Overloads __getitem__ and [] access to work with all keys"
172 key, idx = self.parse_and_index(key)
173 keys = self.keymap[key]
174 if not keys:
175 del self.keymap[key] # remove blank entry added by defaultdict
176 raise KeyError(key)
177 got = {}
178 for k in keys:
179 if not idx and k.shape:
180 self._copyonwrite(k)
181 val = dict.__getitem__(self, k)
182 if idx:
183 if len(idx) < len(val.shape):
184 idx = (...,) + idx # idx from the right, in case of sweep
185 val = val[idx]
186 if len(keys) == 1:
187 return val
188 got[k] = val
189 return got
191 def __setitem__(self, key, value):
192 "Overloads __setitem__ and []= to work with all keys"
193 # pylint: disable=too-many-boolean-expressions,too-many-branches,too-many-statements
194 try:
195 key, idx = self.parse_and_index(key)
196 except KeyError as e: # may be indexed VectorVariable
197 # NOTE: this try/except takes ~4% of the time in this loop
198 if not hasattr(key, "shape"):
199 raise e
200 if not hasattr(value, "shape"):
201 value = np.full(key.shape, value)
202 elif key.shape != value.shape:
203 raise KeyError("Key and value have different shapes") from e
204 for subkey, subval in zip(key.flat, value.flat):
205 self[subkey] = subval
206 return
207 value = clean_value(key, value)
208 if key not in self.keymap:
209 if not hasattr(self, "_unmapped_keys"):
210 self.__init__() # py3's pickle sets items before init... :(
211 self.keymap[key].add(key)
212 self._unmapped_keys.add(key)
213 if idx:
214 dty = {} if isinstance(value, Numbers) else {"dtype": "object"}
215 if getattr(value, "shape", None) and value.dtype != INT_DTYPE:
216 dty["dtype"] = value.dtype
217 dict.__setitem__(self, key, np.full(key.shape, np.nan, **dty))
218 self.owned.add(key)
219 if idx:
220 if is_sweepvar(value):
221 old = super().__getitem__(key)
222 super().__setitem__(key, np.array(old, "object"))
223 self.owned.add(key)
224 self._copyonwrite(key)
225 if hasattr(value, "__call__"): # a linked function
226 old = super().__getitem__(key)
227 super().__setitem__(key, np.array(old, dtype="object"))
228 super().__getitem__(key)[idx] = value
229 return # successfully set a single index!
230 if key.shape: # now if we're setting an array...
231 if hasattr(value, "__call__"): # a linked vector-function
232 key.vecfn = value
233 value = np.empty(key.shape, dtype="object")
234 it = np.nditer(value, flags=['multi_index', 'refs_ok'])
235 while not it.finished:
236 i = it.multi_index
237 it.iternext()
238 value[i] = veclinkedfn(key.vecfn, i)
239 if getattr(value, "shape", None): # is the value an array?
240 if value.dtype == INT_DTYPE:
241 value = np.array(value, "f") # convert to float
242 if dict.__contains__(self, key):
243 old = super().__getitem__(key)
244 if old.dtype != value.dtype:
245 # e.g. replacing a number with a linked function
246 newly_typed_array = np.array(old, dtype=value.dtype)
247 super().__setitem__(key, newly_typed_array)
248 self.owned.add(key)
249 self._copyonwrite(key)
250 goodvals = ~isnan(value)
251 super().__getitem__(key)[goodvals] = value[goodvals]
252 return # successfully set only some indexes!
253 elif not is_sweepvar(value): # or needs to be made one?
254 if not hasattr(value, "__len__"):
255 value = np.full(key.shape, value, "f")
256 elif not isinstance(value[0], np.ndarray):
257 clean_values = [clean_value(key, v) for v in value]
258 dtype = None
259 if any(is_sweepvar(cv) for cv in clean_values):
260 dtype = "object"
261 value = np.array(clean_values, dtype=dtype)
262 super().__setitem__(key, value)
263 self.owned.add(key)
265 def __delitem__(self, key):
266 "Overloads del [] to work with all keys"
267 if not hasattr(key, "key"): # not a keyed object
268 self.update_keymap()
269 keys = self.keymap[key]
270 if not keys:
271 raise KeyError(key)
272 for k in keys:
273 del self[k]
274 else:
275 key = key.key
276 veckey, idx = self.parse_and_index(key)
277 if idx is None:
278 super().__delitem__(key)
279 else:
280 super().__getitem__(veckey)[idx] = np.nan
281 if isnan(super().__getitem__(veckey)).all():
282 super().__delitem__(veckey)
283 copiedonwrite = set() # to save time, .update() does not copy
284 mapkeys = set([key])
285 if key.keys:
286 mapkeys.update(key.keys)
287 for mapkey in mapkeys:
288 if mapkey in self.keymap:
289 if len(self.keymap[mapkey]) == 1:
290 del self.keymap[mapkey]
291 continue
292 if mapkey not in copiedonwrite:
293 self.keymap[mapkey] = set(self.keymap[mapkey])
294 copiedonwrite.add(mapkey)
295 self.keymap[mapkey].remove(key)
298class KeySet(KeyMap, set):
299 "KeyMaps that don't collapse arrays or store values."
300 collapse_arrays = False
302 def update(self, keys):
303 "Iterates through the dictionary created by args and kwargs"
304 for key in keys:
305 self.keymap[key].add(key)
306 self._unmapped_keys.update(keys)
307 super().update(keys)
309 def __getitem__(self, key):
310 "Gets the keys corresponding to a particular key."
311 key, _ = self.parse_and_index(key)
312 return self.keymap[key]