Coverage for gpkit/keydict.py: 87%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

224 statements  

1"Implements KeyDict and KeySet classes" 

2from collections import defaultdict 

3from collections.abc import Hashable 

4import numpy as np 

5from .small_classes import Numbers, Quantity, FixedScalar 

6from .small_scripts import is_sweepvar, isnan, veclinkedfn 

7 

8DIMLESS_QUANTITY = Quantity(1, "dimensionless") 

9INT_DTYPE = np.dtype(int) 

10 

11def clean_value(key, value): 

12 """Gets the value of variable-less monomials, so that 

13 `x.sub({x: gpkit.units.m})` and `x.sub({x: gpkit.ureg.m})` are equivalent. 

14 

15 Also converts any quantities to the key's units, because quantities 

16 can't/shouldn't be stored as elements of numpy arrays. 

17 """ 

18 if isinstance(value, FixedScalar): 

19 value = value.value 

20 if isinstance(value, Quantity): 

21 value = value.to(key.units or "dimensionless").magnitude 

22 return value 

23 

24 

25class KeyMap: 

26 """Helper class to provide KeyMapping to interfaces. 

27 

28 Mapping keys 

29 ------------ 

30 A KeyMap keeps an internal list of VarKeys as 

31 canonical keys, and their values can be accessed with any object whose 

32 `key` attribute matches one of those VarKeys, or with strings matching 

33 any of the multiple possible string interpretations of each key: 

34 

35 For example, after creating the KeyDict kd and setting kd[x] = v (where x 

36 is a Variable or VarKey), v can be accessed with by the following keys: 

37 - x 

38 - x.key 

39 - x.name (a string) 

40 - "x_modelname" (x's name including modelname) 

41 

42 Note that if a item is set using a key that does not have a `.key` 

43 attribute, that key can be set and accessed normally. 

44 """ 

45 collapse_arrays = False 

46 keymap = [] 

47 log_gets = False 

48 vks = varkeys = None 

49 

50 def __init__(self, *args, **kwargs): 

51 "Passes through to super().__init__ via the `update()` method" 

52 self.keymap = defaultdict(set) 

53 self._unmapped_keys = set() 

54 self.owned = set() 

55 self.update(*args, **kwargs) # pylint: disable=no-member 

56 

57 def parse_and_index(self, key): 

58 "Returns key if key had one, and veckey/idx for indexed veckeys." 

59 try: 

60 key = key.key 

61 if self.collapse_arrays and key.idx: 

62 return key.veckey, key.idx 

63 return key, None 

64 except AttributeError: 

65 if self.vks is None and self.varkeys is None: 

66 return key, self.update_keymap() 

67 # looks like we're in a substitutions dictionary 

68 if self.varkeys is None: 

69 self.varkeys = KeySet(self.vks) 

70 if key not in self.varkeys: 

71 raise KeyError(key) 

72 newkey, *otherkeys = self.varkeys[key] 

73 if otherkeys: 

74 if all(k.veckey == newkey.veckey for k in otherkeys): 

75 return newkey.veckey, None 

76 raise ValueError("%s refers to multiple keys in this substitutions" 

77 " KeyDict. Use `.variables_byname(%s)` to see all" 

78 " of them." % (key, key)) 

79 if self.collapse_arrays and newkey.idx: 

80 return newkey.veckey, newkey.idx 

81 return newkey, None 

82 

83 def __contains__(self, key): # pylint:disable=too-many-return-statements 

84 "In a winding way, figures out if a key is in the KeyDict" 

85 try: 

86 key, idx = self.parse_and_index(key) 

87 except KeyError: 

88 return False 

89 except ValueError: # multiple keys correspond 

90 return True 

91 if not isinstance(key, Hashable): 

92 return False 

93 if super().__contains__(key): # pylint: disable=no-member 

94 if idx: 

95 try: 

96 val = super().__getitem__(key)[idx] # pylint: disable=no-member 

97 return True if is_sweepvar(val) else not isnan(val).any() 

98 except TypeError: 

99 raise TypeError("%s has an idx, but its value in this" 

100 " KeyDict is the scalar %s." 

101 % (key, super().__getitem__(key))) # pylint: disable=no-member 

102 except IndexError: 

103 raise IndexError("key %s with idx %s is out of bounds" 

104 " for value %s" % 

105 (key, idx, super().__getitem__(key))) # pylint: disable=no-member 

106 return key in self.keymap 

107 

108 def update_keymap(self): 

109 "Updates the keymap with the keys in _unmapped_keys" 

110 copied = set() # have to copy bc update leaves duplicate sets 

111 for key in self._unmapped_keys: 

112 for mapkey in key.keys: 

113 if mapkey not in copied and mapkey in self.keymap: 

114 self.keymap[mapkey] = set(self.keymap[mapkey]) 

115 copied.add(mapkey) 

116 self.keymap[mapkey].add(key) 

117 self._unmapped_keys = set() 

118 

119 

120class KeyDict(KeyMap, dict): 

121 """KeyDicts do two things over a dict: map keys and collapse arrays. 

122 

123 >>>> kd = gpkit.keydict.KeyDict() 

124 

125 For mapping keys, see KeyMapper.__doc__ 

126 

127 Collapsing arrays 

128 ----------------- 

129 If ``.collapse_arrays`` is True then VarKeys which have a `shape` 

130 parameter (indicating they are part of an array) are stored as numpy 

131 arrays, and automatically de-indexed when a matching VarKey with a 

132 particular `idx` parameter is used as a key. 

133 

134 See also: gpkit/tests/t_keydict.py. 

135 """ 

136 collapse_arrays = True 

137 

138 def get(self, key, *alternative): 

139 return alternative[0] if alternative and key not in self else self[key] 

140 

141 def _copyonwrite(self, key): 

142 "Copys arrays before they are written to" 

143 if not hasattr(self, "owned"): # backwards pickle compatibility 

144 self.owned = set() 

145 if key not in self.owned: 

146 super().__setitem__(key, super().__getitem__(key).copy()) 

147 self.owned.add(key) 

148 

149 def update(self, *args, **kwargs): 

150 "Iterates through the dictionary created by args and kwargs" 

151 if not self and len(args) == 1 and isinstance(args[0], KeyDict): 

152 super().update(args[0]) 

153 self.keymap.update(args[0].keymap) 

154 self._unmapped_keys.update(args[0]._unmapped_keys) # pylint:disable=protected-access 

155 else: 

156 for k, v in dict(*args, **kwargs).items(): 

157 self[k] = v 

158 

159 def __call__(self, key): # if uniting is ever a speed hit, cache it 

160 got = self[key] 

161 if isinstance(got, dict): 

162 for k, v in got.items(): 

163 got[k] = v*(k.units or DIMLESS_QUANTITY) 

164 return got 

165 if not hasattr(key, "units"): 

166 key, = self.keymap[self.parse_and_index(key)[0]] 

167 return Quantity(got, key.units or "dimensionless") 

168 

169 def __getitem__(self, key): 

170 "Overloads __getitem__ and [] access to work with all keys" 

171 key, idx = self.parse_and_index(key) 

172 keys = self.keymap[key] 

173 if not keys: 

174 del self.keymap[key] # remove blank entry added by defaultdict 

175 raise KeyError(key) 

176 got = {} 

177 for k in keys: 

178 if not idx and k.shape: 

179 self._copyonwrite(k) 

180 val = dict.__getitem__(self, k) 

181 if idx: 

182 if len(idx) < len(val.shape): 

183 idx = (...,) + idx # idx from the right, in case of sweep 

184 val = val[idx] 

185 if len(keys) == 1: 

186 return val 

187 got[k] = val 

188 return got 

189 

190 def __setitem__(self, key, value): 

191 "Overloads __setitem__ and []= to work with all keys" 

192 # pylint: disable=too-many-boolean-expressions,too-many-branches,too-many-statements 

193 try: 

194 key, idx = self.parse_and_index(key) 

195 except KeyError as e: # may be indexed VectorVariable 

196 # NOTE: this try/except takes ~4% of the time in this loop 

197 if not hasattr(key, "shape"): 

198 raise e 

199 if not hasattr(value, "shape"): 

200 value = np.full(key.shape, value) 

201 elif key.shape != value.shape: 

202 raise KeyError("Key and value have different shapes") from e 

203 for subkey, subval in zip(key.flat, value.flat): 

204 self[subkey] = subval 

205 return 

206 value = clean_value(key, value) 

207 if key not in self.keymap: 

208 if not hasattr(self, "_unmapped_keys"): 

209 self.__init__() # py3's pickle sets items before init... :( 

210 self.keymap[key].add(key) 

211 self._unmapped_keys.add(key) 

212 if idx: 

213 dty = {} if isinstance(value, Numbers) else {"dtype": "object"} 

214 if getattr(value, "shape", None) and value.dtype != INT_DTYPE: 

215 dty["dtype"] = value.dtype 

216 dict.__setitem__(self, key, np.full(key.shape, np.nan, **dty)) 

217 self.owned.add(key) 

218 if idx: 

219 if is_sweepvar(value): 

220 old = super().__getitem__(key) 

221 super().__setitem__(key, np.array(old, "object")) 

222 self.owned.add(key) 

223 self._copyonwrite(key) 

224 if hasattr(value, "__call__"): # a linked function 

225 old = super().__getitem__(key) 

226 super().__setitem__(key, np.array(old, dtype="object")) 

227 super().__getitem__(key)[idx] = value 

228 return # successfully set a single index! 

229 if key.shape: # now if we're setting an array... 

230 if hasattr(value, "__call__"): # a linked vector-function 

231 key.vecfn = value 

232 value = np.empty(key.shape, dtype="object") 

233 it = np.nditer(value, flags=['multi_index', 'refs_ok']) 

234 while not it.finished: 

235 i = it.multi_index 

236 it.iternext() 

237 value[i] = veclinkedfn(key.vecfn, i) 

238 if getattr(value, "shape", None): # is the value an array? 

239 if value.dtype == INT_DTYPE: 

240 value = np.array(value, "f") # convert to float 

241 if dict.__contains__(self, key): 

242 old = super().__getitem__(key) 

243 if old.dtype != value.dtype: 

244 # e.g. replacing a number with a linked function 

245 newly_typed_array = np.array(old, dtype=value.dtype) 

246 super().__setitem__(key, newly_typed_array) 

247 self.owned.add(key) 

248 self._copyonwrite(key) 

249 goodvals = ~isnan(value) 

250 super().__getitem__(key)[goodvals] = value[goodvals] 

251 return # successfully set only some indexes! 

252 elif not is_sweepvar(value): # or needs to be made one? 

253 if not hasattr(value, "__len__"): 

254 value = np.full(key.shape, value, "f") 

255 elif not isinstance(value[0], np.ndarray): 

256 clean_values = [clean_value(key, v) for v in value] 

257 dtype = None 

258 if any(is_sweepvar(cv) for cv in clean_values): 

259 dtype = "object" 

260 value = np.array(clean_values, dtype=dtype) 

261 super().__setitem__(key, value) 

262 self.owned.add(key) 

263 

264 def __delitem__(self, key): 

265 "Overloads del [] to work with all keys" 

266 if not hasattr(key, "key"): # not a keyed object 

267 self.update_keymap() 

268 keys = self.keymap[key] 

269 if not keys: 

270 raise KeyError(key) 

271 for k in keys: 

272 del self[k] 

273 else: 

274 key = key.key 

275 veckey, idx = self.parse_and_index(key) 

276 if idx is None: 

277 super().__delitem__(key) 

278 else: 

279 super().__getitem__(veckey)[idx] = np.nan 

280 if isnan(super().__getitem__(veckey)).all(): 

281 super().__delitem__(veckey) 

282 copiedonwrite = set() # to save time, .update() does not copy 

283 mapkeys = set([key]) 

284 if key.keys: 

285 mapkeys.update(key.keys) 

286 for mapkey in mapkeys: 

287 if mapkey in self.keymap: 

288 if len(self.keymap[mapkey]) == 1: 

289 del self.keymap[mapkey] 

290 continue 

291 if mapkey not in copiedonwrite: 

292 self.keymap[mapkey] = set(self.keymap[mapkey]) 

293 copiedonwrite.add(mapkey) 

294 self.keymap[mapkey].remove(key) 

295 

296 

297class KeySet(KeyMap, set): 

298 "KeyMaps that don't collapse arrays or store values." 

299 collapse_arrays = False 

300 

301 def update(self, keys): 

302 "Iterates through the dictionary created by args and kwargs" 

303 for key in keys: 

304 self.keymap[key].add(key) 

305 self._unmapped_keys.update(keys) 

306 super().update(keys) 

307 

308 def __getitem__(self, key): 

309 "Gets the keys corresponding to a particular key." 

310 key, _ = self.parse_and_index(key) 

311 return self.keymap[key]