Coverage for gpkit/small_classes.py: 91%

137 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 22:13 -0500

1"""Miscellaneous small classes""" 

2from operator import xor 

3from functools import reduce 

4import numpy as np 

5from scipy.sparse import csr_matrix 

6from .units import Quantity 

7 

8Strings = (str,) 

9Numbers = (int, float, np.number, Quantity) 

10 

11 

12class FixedScalarMeta(type): 

13 "Metaclass to implement instance checking for fixed scalars" 

14 def __instancecheck__(cls, obj): 

15 return getattr(obj, "hmap", None) and len(obj.hmap) == 1 and not obj.vks 

16 

17 

18class FixedScalar(metaclass=FixedScalarMeta): 

19 # pylint: disable=too-few-public-methods 

20 "Instances of this class are scalar Nomials with no variables" 

21 

22 

23class Count: # pylint: disable=too-few-public-methods 

24 "Like python 2's itertools.count, for Python 3 compatibility." 

25 def __init__(self): 

26 self.count = -1 

27 

28 def next(self): 

29 "Increment self.count and return it" 

30 self.count += 1 

31 return self.count 

32 

33 

34def matrix_converter(name): 

35 "Generates conversion function." 

36 def to_(self): # used in tocoo, tocsc, etc below 

37 "Converts to another type of matrix." 

38 return getattr(self.tocsr(), "to"+name)() 

39 return to_ 

40 

41 

42class CootMatrix: 

43 "A very simple sparse matrix representation." 

44 def __init__(self, row, col, data): 

45 self.row, self.col, self.data = row, col, data 

46 self.shape = [(max(self.row) + 1) if self.row else 0, 

47 (max(self.col) + 1) if self.col else 0] 

48 

49 def __eq__(self, other): 

50 return (self.row == other.row and self.col == other.col 

51 and self.data == other.data and self.shape == other.shape) 

52 

53 tocoo = matrix_converter("coo") 

54 tocsc = matrix_converter("csc") 

55 todia = matrix_converter("dia") 

56 todok = matrix_converter("dok") 

57 todense = matrix_converter("dense") 

58 

59 def tocsr(self): 

60 "Converts to a Scipy sparse csr_matrix" 

61 return csr_matrix((self.data, (self.row, self.col))) 

62 

63 def dot(self, arg): 

64 "Returns dot product with arg." 

65 return self.tocsr().dot(arg) 

66 

67 

68class SolverLog: 

69 "Adds a `write` method to list so it's file-like and can replace stdout." 

70 def __init__(self, output=None, *, verbosity=0): 

71 self.written = "" 

72 self.verbosity = verbosity 

73 self.output = output 

74 

75 def write(self, writ): 

76 "Append and potentially write the new line." 

77 if writ[:2] == "b'": 

78 writ = writ[2:-1] 

79 self.written += writ 

80 if self.verbosity > 0: # pragma: no cover 

81 self.output.write(writ) 

82 

83 def lines(self): 

84 "Returns the lines presently written." 

85 return self.written.split("\n") 

86 

87 def flush(self): 

88 "Dummy function for I/O api compatibility" 

89 

90 

91class DictOfLists(dict): 

92 "A hierarchy of dicionaries, with lists at the bottom." 

93 

94 def append(self, sol): 

95 "Appends a dict (of dicts) of lists to all held lists." 

96 if not hasattr(self, "initialized"): 

97 _enlist_dict(sol, self) 

98 self.initialized = True # pylint: disable=attribute-defined-outside-init 

99 else: 

100 _append_dict(sol, self) 

101 

102 def atindex(self, i): 

103 "Indexes into each list independently." 

104 return self.__class__(_index_dict(i, self, self.__class__())) 

105 

106 def to_arrays(self): 

107 "Converts all lists into array." 

108 _enray(self, self) 

109 

110 

111def _enlist_dict(d_in, d_out): 

112 "Recursively copies d_in into d_out, placing non-dict items into lists." 

113 for k, v in d_in.items(): 

114 if isinstance(v, dict): 

115 d_out[k] = _enlist_dict(v, v.__class__()) 

116 else: 

117 d_out[k] = [v] 

118 assert set(d_in.keys()) == set(d_out.keys()) 

119 return d_out 

120 

121 

122def _append_dict(d_in, d_out): 

123 "Recursively travels dict d_out and appends items found in d_in." 

124 for k, v in d_in.items(): 

125 if isinstance(v, dict): 

126 d_out[k] = _append_dict(v, d_out[k]) 

127 else: 

128 try: 

129 d_out[k].append(v) 

130 except KeyError as e: 

131 msg = f"Key `{k}` was added after the first sweep." 

132 raise RuntimeWarning(msg) from e 

133 return d_out 

134 

135 

136def _index_dict(idx, d_in, d_out): 

137 "Recursively travels dict d_in, placing items at idx into dict d_out." 

138 for k, v in d_in.items(): 

139 if isinstance(v, dict): 

140 d_out[k] = _index_dict(idx, v, v.__class__()) 

141 else: 

142 try: 

143 d_out[k] = v[idx] 

144 except (IndexError, TypeError): # if not an array, return as is 

145 d_out[k] = v 

146 return d_out 

147 

148 

149def _enray(d_in, d_out): 

150 "Recursively turns lists into numpy arrays." 

151 for k, v in d_in.items(): 

152 if isinstance(v, dict): 

153 d_out[k] = _enray(v, v.__class__()) 

154 else: 

155 if len(v) == 1: 

156 v, = v 

157 else: 

158 if isinstance(v[0], list): 

159 v = np.array(v, dtype="object") 

160 else: 

161 v = np.array(v) 

162 d_out[k] = v 

163 return d_out 

164 

165 

166class HashVector(dict): 

167 """A simple, sparse, string-indexed vector. Inherits from dict. 

168 

169 The HashVector class supports element-wise arithmetic: 

170 any undeclared variables are assumed to have a value of zero. 

171 

172 Arguments 

173 --------- 

174 arg : iterable 

175 

176 Example 

177 ------- 

178 >>> x = gpkit.nomials.Monomial("x") 

179 >>> exp = gpkit.small_classes.HashVector({x: 2}) 

180 """ 

181 hashvalue = None 

182 

183 def __hash__(self): 

184 "Allows HashVectors to be used as dictionary keys." 

185 if self.hashvalue is None: 

186 self.hashvalue = reduce(xor, map(hash, self.items()), 0) 

187 return self.hashvalue 

188 

189 def copy(self): 

190 "Return a copy of this" 

191 hv = self.__class__(self) 

192 hv.hashvalue = self.hashvalue 

193 return hv 

194 

195 def __pow__(self, other): 

196 "Accepts scalars. Return Hashvector with each value put to a power." 

197 if isinstance(other, Numbers): 

198 return self.__class__({k: v**other for (k, v) in self.items()}) 

199 return NotImplemented 

200 

201 def __mul__(self, other): 

202 """Accepts scalars and dicts. Returns with each value multiplied. 

203 

204 If the other object inherits from dict, multiplication is element-wise 

205 and their key's intersection will form the new keys.""" 

206 try: 

207 return self.__class__({k: v*other for (k, v) in self.items()}) 

208 except: # pylint: disable=bare-except 

209 return NotImplemented 

210 

211 def __add__(self, other): 

212 """Accepts scalars and dicts. Returns with each value added. 

213 

214 If the other object inherits from dict, addition is element-wise 

215 and their key's union will form the new keys.""" 

216 if isinstance(other, Numbers): 

217 return self.__class__({k: v + other for (k, v) in self.items()}) 

218 if isinstance(other, dict): 

219 sums = self.copy() 

220 for key, value in other.items(): 

221 if key in sums: 

222 svalue = sums[key] 

223 if value == -svalue: 

224 del sums[key] # remove zeros created by addition 

225 else: 

226 sums[key] = value + svalue 

227 else: 

228 sums[key] = value 

229 sums.hashvalue = None 

230 return sums 

231 return NotImplemented 

232 

233 # pylint: disable=multiple-statements 

234 def __neg__(self): return -1*self 

235 def __sub__(self, other): return self + -other 

236 def __rsub__(self, other): return other + -self 

237 def __radd__(self, other): return self + other 

238 def __truediv__(self, other): return self * other**-1 

239 def __rtruediv__(self, other): return other * self**-1 

240 def __rmul__(self, other): return self * other 

241 

242 

243EMPTY_HV = HashVector()