Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Implements KeyDict and KeySet classes"
2from collections import defaultdict
3from collections.abc import Hashable
4import numpy as np
5from .small_classes import Numbers, Quantity, FixedScalar
6from .small_scripts import is_sweepvar, isnan
8DIMLESS_QUANTITY = Quantity(1, "dimensionless")
9INT_DTYPE = np.dtype(int)
11def clean_value(key, value):
12 """Gets the value of variable-less monomials, so that
13 `x.sub({x: gpkit.units.m})` and `x.sub({x: gpkit.ureg.m})` are equivalent.
15 Also converts any quantities to the key's units, because quantities
16 can't/shouldn't be stored as elements of numpy arrays.
17 """
18 if isinstance(value, FixedScalar):
19 value = value.value
20 if isinstance(value, Quantity):
21 value = value.to(key.units or "dimensionless").magnitude
22 return value
25class KeyMap:
26 """Helper class to provide KeyMapping to interfaces.
28 Mapping keys
29 ------------
30 A KeyMap keeps an internal list of VarKeys as
31 canonical keys, and their values can be accessed with any object whose
32 `key` attribute matches one of those VarKeys, or with strings matching
33 any of the multiple possible string interpretations of each key:
35 For example, after creating the KeyDict kd and setting kd[x] = v (where x
36 is a Variable or VarKey), v can be accessed with by the following keys:
37 - x
38 - x.key
39 - x.name (a string)
40 - "x_modelname" (x's name including modelname)
42 Note that if a item is set using a key that does not have a `.key`
43 attribute, that key can be set and accessed normally.
44 """
45 collapse_arrays = False
46 keymap = []
47 log_gets = False
48 varkeys = None
50 def __init__(self, *args, **kwargs):
51 "Passes through to super().__init__ via the `update()` method"
52 self.keymap = defaultdict(set)
53 self._unmapped_keys = set()
54 self.owned = set()
55 self.update(*args, **kwargs) # pylint: disable=no-member
57 def parse_and_index(self, key):
58 "Returns key if key had one, and veckey/idx for indexed veckeys."
59 try:
60 key = key.key
61 if self.collapse_arrays and key.idx:
62 return key.veckey, key.idx
63 return key, None
64 except AttributeError:
65 if not self.varkeys:
66 return key, self.update_keymap()
67 # looks like we're in a substitutions dictionary
68 if key not in self.varkeys: # pylint:disable=unsupported-membership-test
69 raise KeyError(key)
70 newkey, *otherkeys = self.varkeys[key] # pylint:disable=unsubscriptable-object
71 if otherkeys:
72 if all(k.veckey == newkey.veckey for k in otherkeys):
73 return newkey.veckey, None
74 raise ValueError("%s refers to multiple keys in this substitutions"
75 " KeyDict. Use `.variables_byname(%s)` to see all"
76 " of them." % (key, key))
77 if self.collapse_arrays and newkey.idx:
78 return newkey.veckey, newkey.idx
79 return newkey, None
81 def __contains__(self, key): # pylint:disable=too-many-return-statements
82 "In a winding way, figures out if a key is in the KeyDict"
83 try:
84 key, idx = self.parse_and_index(key)
85 except KeyError:
86 return False
87 except ValueError: # multiple keys correspond
88 return True
89 if not isinstance(key, Hashable):
90 return False
91 if super().__contains__(key): # pylint: disable=no-member
92 if idx:
93 try:
94 value = super().__getitem__(key)[idx] # pylint: disable=no-member
95 return True if is_sweepvar(value) else not isnan(value)
96 except TypeError:
97 raise TypeError("%s has an idx, but its value in this"
98 " KeyDict is the scalar %s."
99 % (key, super().__getitem__(key))) # pylint: disable=no-member
100 except IndexError:
101 raise IndexError("key %s with idx %s is out of bounds"
102 " for value %s" %
103 (key, idx, super().__getitem__(key))) # pylint: disable=no-member
104 return True
105 return key in self.keymap
107 def update_keymap(self):
108 "Updates the keymap with the keys in _unmapped_keys"
109 copied = set() # have to copy bc update leaves duplicate sets
110 for key in self._unmapped_keys:
111 for mapkey in key.keys:
112 if mapkey not in copied and mapkey in self.keymap:
113 self.keymap[mapkey] = set(self.keymap[mapkey])
114 copied.add(mapkey)
115 self.keymap[mapkey].add(key)
116 self._unmapped_keys = set()
119class KeyDict(KeyMap, dict):
120 """KeyDicts do two things over a dict: map keys and collapse arrays.
122 >>>> kd = gpkit.keydict.KeyDict()
124 For mapping keys, see KeyMapper.__doc__
126 Collapsing arrays
127 -----------------
128 If ``.collapse_arrays`` is True then VarKeys which have a `shape`
129 parameter (indicating they are part of an array) are stored as numpy
130 arrays, and automatically de-indexed when a matching VarKey with a
131 particular `idx` parameter is used as a key.
133 See also: gpkit/tests/t_keydict.py.
134 """
135 collapse_arrays = True
137 def get(self, key, *alternative):
138 return alternative[0] if alternative and key not in self else self[key]
140 def _copyonwrite(self, key):
141 "Copys arrays before they are written to"
142 if not hasattr(self, "owned"): # backwards pickle compatibility
143 self.owned = set()
144 if key not in self.owned:
145 super().__setitem__(key, super().__getitem__(key).copy())
146 self.owned.add(key)
148 def update(self, *args, **kwargs):
149 "Iterates through the dictionary created by args and kwargs"
150 if not self and len(args) == 1 and isinstance(args[0], KeyDict):
151 super().update(args[0])
152 self.keymap.update(args[0].keymap)
153 self._unmapped_keys.update(args[0]._unmapped_keys) # pylint:disable=protected-access
154 else:
155 for k, v in dict(*args, **kwargs).items():
156 self[k] = v
158 def __call__(self, key): # if uniting is ever a speed hit, cache it
159 got = self[key]
160 if isinstance(got, dict):
161 for k, v in got.items():
162 got[k] = v*(k.units or DIMLESS_QUANTITY)
163 return got
164 if not hasattr(key, "units"):
165 key, = self.keymap[self.parse_and_index(key)[0]]
166 return Quantity(got, key.units or "dimensionless")
168 def __getitem__(self, key):
169 "Overloads __getitem__ and [] access to work with all keys"
170 key, idx = self.parse_and_index(key)
171 keys = self.keymap[key]
172 if not keys:
173 del self.keymap[key] # remove blank entry added by defaultdict
174 raise KeyError(key)
175 got = {}
176 for k in keys:
177 if not idx and k.shape:
178 self._copyonwrite(k)
179 val = dict.__getitem__(self, k)
180 if idx:
181 val = val[idx]
182 if len(keys) == 1:
183 return val
184 got[k] = val
185 return got
187 def __setitem__(self, key, value):
188 "Overloads __setitem__ and []= to work with all keys"
189 # pylint: disable=too-many-boolean-expressions,too-many-branches
190 try:
191 key, idx = self.parse_and_index(key)
192 except KeyError as e: # may be indexed VectorVariable
193 # NOTE: this try/except takes ~4% of the time in this loop
194 if not hasattr(key, "shape"):
195 raise e
196 if not hasattr(value, "shape"):
197 value = np.full(key.shape, value)
198 elif key.shape != value.shape:
199 raise KeyError("Key and value have different shapes") from e
200 for subkey, subval in zip(key.flat, value.flat):
201 self[subkey] = subval
202 return
203 value = clean_value(key, value)
204 if key not in self.keymap:
205 if not hasattr(self, "_unmapped_keys"):
206 self.__init__() # py3's pickle sets items before init... :(
207 self.keymap[key].add(key)
208 self._unmapped_keys.add(key)
209 if idx:
210 dty = {} if isinstance(value, Numbers) else {"dtype": "object"}
211 if getattr(value, "shape", None) and value.dtype != INT_DTYPE:
212 dty["dtype"] = value.dtype
213 dict.__setitem__(self, key, np.full(key.shape, np.nan, **dty))
214 self.owned.add(key)
215 if idx:
216 if is_sweepvar(value):
217 old = super().__getitem__(key)
218 super().__setitem__(key, np.array(old, "object"))
219 self.owned.add(key)
220 self._copyonwrite(key)
221 if hasattr(value, "__call__"): # a linked function
222 old = super().__getitem__(key)
223 super().__setitem__(key, np.array(old, dtype="object"))
224 super().__getitem__(key)[idx] = value
225 return # successfully set a single index!
226 if key.shape: # now if we're setting an array...
227 if getattr(value, "shape", None): # is the value an array?
228 if value.dtype == INT_DTYPE:
229 value = np.array(value, "f") # convert to float
230 if dict.__contains__(self, key):
231 old = super().__getitem__(key)
232 if old.dtype != value.dtype:
233 # e.g. replacing a number with a linked function
234 newly_typed_array = np.array(old, dtype=value.dtype)
235 super().__setitem__(key, newly_typed_array)
236 self.owned.add(key)
237 self._copyonwrite(key)
238 goodvals = ~isnan(value)
239 super().__getitem__(key)[goodvals] = value[goodvals]
240 return # successfully set only some indexes!
241 elif not is_sweepvar(value): # or needs to be made one?
242 if not hasattr(value, "__len__"):
243 value = np.full(key.shape, value, "f")
244 elif not isinstance(value[0], np.ndarray):
245 value = np.array([clean_value(key, v) for v in value])
246 super().__setitem__(key, value)
247 self.owned.add(key)
249 def __delitem__(self, key):
250 "Overloads del [] to work with all keys"
251 if not hasattr(key, "key"): # not a keyed object
252 self.update_keymap()
253 keys = self.keymap[key]
254 if not keys:
255 raise KeyError(key)
256 for k in keys:
257 del self[k]
258 else:
259 key = key.key
260 veckey, idx = self.parse_and_index(key)
261 if idx is None:
262 super().__delitem__(key)
263 else:
264 super().__getitem__(veckey)[idx] = np.nan
265 if isnan(super().__getitem__(veckey)).all():
266 super().__delitem__(veckey)
267 copiedonwrite = set() # to save time, .update() does not copy
268 mapkeys = set([key])
269 if key.keys:
270 mapkeys.update(key.keys)
271 for mapkey in mapkeys:
272 if mapkey in self.keymap:
273 if len(self.keymap[mapkey]) == 1:
274 del self.keymap[mapkey]
275 continue
276 if mapkey not in copiedonwrite:
277 self.keymap[mapkey] = set(self.keymap[mapkey])
278 copiedonwrite.add(mapkey)
279 self.keymap[mapkey].remove(key)
282class KeySet(KeyMap, set):
283 "KeyMaps that don't collapse arrays or store values."
284 collapse_arrays = False
286 def update(self, keys):
287 "Iterates through the dictionary created by args and kwargs"
288 if isinstance(keys, KeySet):
289 set.update(self, keys)
290 for key, value in keys.keymap.items():
291 self.keymap[key].update(value)
292 self._unmapped_keys.update(keys._unmapped_keys) # pylint: disable=protected-access
293 else: # set-like interface
294 for key in keys:
295 self.keymap[key].add(key)
296 self._unmapped_keys.update(keys)
297 super().update(keys)
299 def __getitem__(self, key):
300 "Gets the keys corresponding to a particular key."
301 key, _ = self.parse_and_index(key)
302 return self.keymap[key]