Coverage for gpkit/constraints/prog_factories.py: 84%

138 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 22:13 -0500

1"Scripts for generating, solving and sweeping programs" 

2from time import time 

3import warnings as pywarnings 

4import numpy as np 

5from adce import adnumber 

6from ..nomials import parse_subs 

7from ..solution_array import SolutionArray 

8from ..keydict import KeyDict 

9from ..small_scripts import maybe_flatten 

10from ..small_classes import FixedScalar 

11from ..exceptions import Infeasible 

12from ..globals import SignomialsEnabled 

13 

14 

15def evaluate_linked(constants, linked): 

16 # pylint: disable=too-many-branches 

17 "Evaluates the values and gradients of linked variables." 

18 kdc = KeyDict({k: adnumber(maybe_flatten(v), k) 

19 for k, v in constants.items()}) 

20 kdc_plain = None 

21 array_calulated = {} 

22 for key in constants: # remove gradients from constants 

23 key.descr.pop("gradients", None) 

24 for v, f in linked.items(): 

25 try: 

26 if v.veckey and v.veckey.vecfn: 

27 if v.veckey not in array_calulated: 

28 with SignomialsEnabled(): # to allow use of gpkit.units 

29 vecout = v.veckey.vecfn(kdc) 

30 if not hasattr(vecout, "shape"): 

31 vecout = np.array(vecout) 

32 array_calulated[v.veckey] = vecout 

33 out = array_calulated[v.veckey][v.idx] 

34 else: 

35 with SignomialsEnabled(): # to allow use of gpkit.units 

36 out = f(kdc) 

37 if isinstance(out, FixedScalar): # to allow use of gpkit.units 

38 out = out.value 

39 if hasattr(out, "units"): 

40 out = out.to(v.units or "dimensionless").magnitude 

41 elif out != 0 and v.units: 

42 pywarnings.warn( 

43 f"Linked function for {v} did not return a united value." 

44 " Modifying it to do so (e.g. by using `()` instead of `[]`" 

45 " to access variables) will reduce errors.") 

46 out = maybe_flatten(out) 

47 if not hasattr(out, "x"): 

48 constants[v] = out 

49 continue # a new fixed variable, not a calculated one 

50 constants[v] = out.x 

51 v.descr["gradients"] = {adn.tag: grad 

52 for adn, grad in out.d().items() 

53 if adn.tag} 

54 except Exception as exception: # pylint: disable=broad-except 

55 from .. import settings # pylint: disable=import-outside-toplevel 

56 if settings.get("ad_errors_raise", None): 

57 raise 

58 if kdc_plain is None: 

59 kdc_plain = KeyDict(constants) 

60 constants[v] = f(kdc_plain) 

61 v.descr.pop("gradients", None) 

62 print("Warning: skipped auto-differentiation of linked variable" 

63 f" {v} because {exception!r} was raised. Set `gpkit.settings" 

64 "[\"ad_errors_raise\"] = True` to raise such Exceptions" 

65 " directly.\n") 

66 if ("Automatic differentiation not yet supported for <class " 

67 "'gpkit.nomials.math.Monomial'> objects") in str(exception): 

68 print("This particular warning may have come from using" 

69 f" gpkit.units.* in the function for {v}; try using" 

70 " gpkit.ureg.* or gpkit.units.*.units instead.") 

71 

72 

73def progify(program, return_attr=None): 

74 """Generates function that returns a program() and optionally an attribute. 

75 

76 Arguments 

77 --------- 

78 program: NomialData 

79 Class to return, e.g. GeometricProgram or SequentialGeometricProgram 

80 return_attr: string 

81 attribute to return in addition to the program 

82 """ 

83 def programfn(self, constants=None, **initargs): 

84 "Return program version of self" 

85 if not constants: 

86 constants, _, linked = parse_subs(self.varkeys, self.substitutions) 

87 if linked: 

88 evaluate_linked(constants, linked) 

89 prog = program(self.cost, self, constants, **initargs) 

90 prog.model = self # NOTE SIDE EFFECTS 

91 if return_attr: 

92 return prog, getattr(prog, return_attr) 

93 return prog 

94 return programfn 

95 

96 

97def solvify(genfunction): 

98 "Returns function for making/solving/sweeping a program." 

99 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False, 

100 **kwargs): 

101 """Forms a mathematical program and attempts to solve it. 

102 

103 Arguments 

104 --------- 

105 solver : string or function (default None) 

106 If None, uses the default solver found in installation. 

107 verbosity : int (default 1) 

108 If greater than 0 prints runtime messages. 

109 Is decremented by one and then passed to programs. 

110 skipsweepfailures : bool (default False) 

111 If True, when a solve errors during a sweep, skip it. 

112 **kwargs : Passed to solve and program init calls 

113 

114 Returns 

115 ------- 

116 sol : SolutionArray 

117 See the SolutionArray documentation for details. 

118 

119 Raises 

120 ------ 

121 ValueError if the program is invalid. 

122 RuntimeWarning if an error occurs in solving or parsing the solution. 

123 """ 

124 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions) 

125 solution = SolutionArray() 

126 solution.modelstr = str(self) 

127 

128 # NOTE SIDE EFFECTS: self.program and self.solution set below 

129 if sweep: 

130 run_sweep(genfunction, self, solution, skipsweepfailures, 

131 constants, sweep, linked, solver, verbosity, **kwargs) 

132 else: 

133 self.program, progsolve = genfunction(self, **kwargs) 

134 result = progsolve(solver, verbosity=verbosity, **kwargs) 

135 if kwargs.get("process_result", True): 

136 self.process_result(result) 

137 solution.append(result) 

138 solution.to_arrays() 

139 self.solution = solution 

140 return solution 

141 return solvefn 

142 

143 

144# pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements 

145def run_sweep(genfunction, self, solution, skipsweepfailures, 

146 constants, sweep, linked, solver, verbosity, **kwargs): 

147 "Runs through a sweep." 

148 # sort sweeps by the eqstr of their varkey 

149 sweepvars, sweepvals = zip(*sorted(list(sweep.items()), 

150 key=lambda vkval: vkval[0].eqstr)) 

151 if len(sweep) == 1: 

152 sweep_grids = np.array(list(sweepvals)) 

153 else: 

154 sweep_grids = np.meshgrid(*list(sweepvals)) 

155 

156 n_passes = sweep_grids[0].size 

157 sweep_vects = {var: grid.reshape(n_passes) 

158 for (var, grid) in zip(sweepvars, sweep_grids)} 

159 

160 if verbosity > 0: 

161 tic = time() 

162 

163 self.program = [] 

164 last_error = None 

165 for i in range(n_passes): 

166 constants.update({var: sweep_vect[i] 

167 for (var, sweep_vect) in sweep_vects.items()}) 

168 if linked: 

169 evaluate_linked(constants, linked) 

170 program, solvefn = genfunction(self, constants, **kwargs) 

171 program.model = None # so it doesn't try to debug 

172 self.program.append(program) # NOTE: SIDE EFFECTS 

173 if i == 0 and verbosity > 0: # wait for successful program gen 

174 # TODO: use full string when minimum lineage is set automatically 

175 sweepvarsstr = ", ".join([str(var) 

176 for var, val in zip(sweepvars, sweepvals) 

177 if not np.isnan(val).all()]) 

178 print(f"Sweeping {sweepvarsstr} with {n_passes} solves:") 

179 try: 

180 if verbosity > 1: 

181 print(f"\nSolve {i}:") 

182 result = solvefn(solver, verbosity=verbosity-1, **kwargs) 

183 if kwargs.get("process_result", True): 

184 self.process_result(result) 

185 solution.append(result) 

186 if verbosity == 1: 

187 print(".", end="", flush=True) 

188 except Infeasible as e: 

189 last_error = e 

190 if not skipsweepfailures: 

191 raise RuntimeWarning( 

192 f"Solve {i} was infeasible; progress saved to m.program." 

193 " To continue sweeping after failures, solve with" 

194 " skipsweepfailures=True.") from e 

195 if verbosity > 1: 

196 print(f"Solve {i} was {e.__class__.__name__}.") 

197 if verbosity == 1: 

198 print("!", end="", flush=True) 

199 if not solution: 

200 raise RuntimeWarning("All solves were infeasible.") from last_error 

201 if verbosity == 1: 

202 print() 

203 

204 solution["sweepvariables"] = KeyDict() 

205 ksweep = KeyDict(sweep) 

206 for var, val in list(solution["constants"].items()): 

207 if var in ksweep: 

208 solution["sweepvariables"][var] = val 

209 del solution["constants"][var] 

210 elif linked: # if any variables are linked, we check all of them 

211 if hasattr(val[0], "shape"): 

212 differences = ((l != val[0]).any() for l in val[1:]) 

213 else: 

214 differences = (l != val[0] for l in val[1:]) 

215 if not any(differences): 

216 solution["constants"][var] = [val[0]] 

217 else: 

218 solution["constants"][var] = [val[0]] 

219 

220 if verbosity > 0: 

221 soltime = time() - tic 

222 print(f"Sweeping took {soltime:.3g} seconds.")