Coverage for gpkit/constraints/prog_factories.py: 84%

138 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 16:49 -0500

1"Scripts for generating, solving and sweeping programs" 

2from time import time 

3import warnings as pywarnings 

4import numpy as np 

5from adce import adnumber 

6from ..nomials import parse_subs 

7from ..solution_array import SolutionArray 

8from ..keydict import KeyDict 

9from ..small_scripts import maybe_flatten 

10from ..small_classes import FixedScalar 

11from ..exceptions import Infeasible 

12from ..globals import SignomialsEnabled 

13 

14 

15def evaluate_linked(constants, linked): 

16 "Evaluates the values and gradients of linked variables." 

17 kdc = KeyDict({k: adnumber(maybe_flatten(v), k) 

18 for k, v in constants.items()}) 

19 kdc_plain = None 

20 array_calulated = {} 

21 for key in constants: # remove gradients from constants 

22 key.descr.pop("gradients", None) 

23 for v, f in linked.items(): 

24 try: 

25 if v.veckey and v.veckey.vecfn: 

26 if v.veckey not in array_calulated: 

27 with SignomialsEnabled(): # to allow use of gpkit.units 

28 vecout = v.veckey.vecfn(kdc) 

29 if not hasattr(vecout, "shape"): 

30 vecout = np.array(vecout) 

31 array_calulated[v.veckey] = vecout 

32 out = array_calulated[v.veckey][v.idx] 

33 else: 

34 with SignomialsEnabled(): # to allow use of gpkit.units 

35 out = f(kdc) 

36 if isinstance(out, FixedScalar): # to allow use of gpkit.units 

37 out = out.value 

38 if hasattr(out, "units"): 

39 out = out.to(v.units or "dimensionless").magnitude 

40 elif out != 0 and v.units: 

41 pywarnings.warn( 

42 "Linked function for %s did not return a united value." 

43 " Modifying it to do so (e.g. by using `()` instead of `[]`" 

44 " to access variables) will reduce errors." % v) 

45 out = maybe_flatten(out) 

46 if not hasattr(out, "x"): 

47 constants[v] = out 

48 continue # a new fixed variable, not a calculated one 

49 constants[v] = out.x 

50 v.descr["gradients"] = {adn.tag: grad 

51 for adn, grad in out.d().items() 

52 if adn.tag} 

53 except Exception as exception: # pylint: disable=broad-except 

54 from .. import settings 

55 if settings.get("ad_errors_raise", None): 

56 raise 

57 if kdc_plain is None: 

58 kdc_plain = KeyDict(constants) 

59 constants[v] = f(kdc_plain) 

60 v.descr.pop("gradients", None) 

61 print("Warning: skipped auto-differentiation of linked variable" 

62 " %s because %s was raised. Set `gpkit.settings" 

63 "[\"ad_errors_raise\"] = True` to raise such Exceptions" 

64 " directly.\n" % (v, repr(exception))) 

65 if ("Automatic differentiation not yet supported for <class " 

66 "'gpkit.nomials.math.Monomial'> objects") in str(exception): 

67 print("This particular warning may have come from using" 

68 " gpkit.units.* in the function for %s; try using" 

69 " gpkit.ureg.* or gpkit.units.*.units instead." % v) 

70 

71 

72def progify(program, return_attr=None): 

73 """Generates function that returns a program() and optionally an attribute. 

74 

75 Arguments 

76 --------- 

77 program: NomialData 

78 Class to return, e.g. GeometricProgram or SequentialGeometricProgram 

79 return_attr: string 

80 attribute to return in addition to the program 

81 """ 

82 def programfn(self, constants=None, **initargs): 

83 "Return program version of self" 

84 if not constants: 

85 constants, _, linked = parse_subs(self.varkeys, self.substitutions) 

86 if linked: 

87 evaluate_linked(constants, linked) 

88 prog = program(self.cost, self, constants, **initargs) 

89 prog.model = self # NOTE SIDE EFFECTS 

90 if return_attr: 

91 return prog, getattr(prog, return_attr) 

92 return prog 

93 return programfn 

94 

95 

96def solvify(genfunction): 

97 "Returns function for making/solving/sweeping a program." 

98 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False, 

99 **kwargs): 

100 """Forms a mathematical program and attempts to solve it. 

101 

102 Arguments 

103 --------- 

104 solver : string or function (default None) 

105 If None, uses the default solver found in installation. 

106 verbosity : int (default 1) 

107 If greater than 0 prints runtime messages. 

108 Is decremented by one and then passed to programs. 

109 skipsweepfailures : bool (default False) 

110 If True, when a solve errors during a sweep, skip it. 

111 **kwargs : Passed to solve and program init calls 

112 

113 Returns 

114 ------- 

115 sol : SolutionArray 

116 See the SolutionArray documentation for details. 

117 

118 Raises 

119 ------ 

120 ValueError if the program is invalid. 

121 RuntimeWarning if an error occurs in solving or parsing the solution. 

122 """ 

123 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions) 

124 solution = SolutionArray() 

125 solution.modelstr = str(self) 

126 

127 # NOTE SIDE EFFECTS: self.program and self.solution set below 

128 if sweep: 

129 run_sweep(genfunction, self, solution, skipsweepfailures, 

130 constants, sweep, linked, solver, verbosity, **kwargs) 

131 else: 

132 self.program, progsolve = genfunction(self, **kwargs) 

133 result = progsolve(solver, verbosity=verbosity, **kwargs) 

134 if kwargs.get("process_result", True): 

135 self.process_result(result) 

136 solution.append(result) 

137 solution.to_arrays() 

138 self.solution = solution 

139 return solution 

140 return solvefn 

141 

142 

143# pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements 

144def run_sweep(genfunction, self, solution, skipsweepfailures, 

145 constants, sweep, linked, solver, verbosity, **kwargs): 

146 "Runs through a sweep." 

147 # sort sweeps by the eqstr of their varkey 

148 sweepvars, sweepvals = zip(*sorted(list(sweep.items()), 

149 key=lambda vkval: vkval[0].eqstr)) 

150 if len(sweep) == 1: 

151 sweep_grids = np.array(list(sweepvals)) 

152 else: 

153 sweep_grids = np.meshgrid(*list(sweepvals)) 

154 

155 N_passes = sweep_grids[0].size 

156 sweep_vects = {var: grid.reshape(N_passes) 

157 for (var, grid) in zip(sweepvars, sweep_grids)} 

158 

159 if verbosity > 0: 

160 tic = time() 

161 

162 self.program = [] 

163 last_error = None 

164 for i in range(N_passes): 

165 constants.update({var: sweep_vect[i] 

166 for (var, sweep_vect) in sweep_vects.items()}) 

167 if linked: 

168 evaluate_linked(constants, linked) 

169 program, solvefn = genfunction(self, constants, **kwargs) 

170 program.model = None # so it doesn't try to debug 

171 self.program.append(program) # NOTE: SIDE EFFECTS 

172 if i == 0 and verbosity > 0: # wait for successful program gen 

173 # TODO: use full string when minimum lineage is set automatically 

174 sweepvarsstr = ", ".join([str(var) 

175 for var, val in zip(sweepvars, sweepvals) 

176 if not np.isnan(val).all()]) 

177 print("Sweeping %s with %i solves:" % (sweepvarsstr, N_passes)) 

178 try: 

179 if verbosity > 1: 

180 print("\nSolve %i:" % i) 

181 result = solvefn(solver, verbosity=verbosity-1, **kwargs) 

182 if kwargs.get("process_result", True): 

183 self.process_result(result) 

184 solution.append(result) 

185 if verbosity == 1: 

186 print(".", end="", flush=True) 

187 except Infeasible as e: 

188 last_error = e 

189 if not skipsweepfailures: 

190 raise RuntimeWarning( 

191 "Solve %i was infeasible; progress saved to m.program." 

192 " To continue sweeping after failures, solve with" 

193 " skipsweepfailures=True." % i) from e 

194 if verbosity > 1: 

195 print("Solve %i was %s." % (i, e.__class__.__name__)) 

196 if verbosity == 1: 

197 print("!", end="", flush=True) 

198 if not solution: 

199 raise RuntimeWarning("All solves were infeasible.") from last_error 

200 if verbosity == 1: 

201 print() 

202 

203 solution["sweepvariables"] = KeyDict() 

204 ksweep = KeyDict(sweep) 

205 for var, val in list(solution["constants"].items()): 

206 if var in ksweep: 

207 solution["sweepvariables"][var] = val 

208 del solution["constants"][var] 

209 elif linked: # if any variables are linked, we check all of them 

210 if hasattr(val[0], "shape"): 

211 differences = ((l != val[0]).any() for l in val[1:]) 

212 else: 

213 differences = (l != val[0] for l in val[1:]) 

214 if not any(differences): 

215 solution["constants"][var] = [val[0]] 

216 else: 

217 solution["constants"][var] = [val[0]] 

218 

219 if verbosity > 0: 

220 soltime = time() - tic 

221 print("Sweeping took %.3g seconds." % (soltime,))