Coverage for gpkit/constraints/prog_factories.py: 84%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

138 statements  

1"Scripts for generating, solving and sweeping programs" 

2from time import time 

3import warnings as pywarnings 

4import numpy as np 

5from adce import adnumber 

6from ..nomials import parse_subs 

7from ..solution_array import SolutionArray 

8from ..keydict import KeyDict 

9from ..small_scripts import maybe_flatten 

10from ..small_classes import FixedScalar 

11from ..exceptions import Infeasible 

12from ..globals import SignomialsEnabled 

13 

14 

15def evaluate_linked(constants, linked): 

16 "Evaluates the values and gradients of linked variables." 

17 kdc = KeyDict({k: adnumber(maybe_flatten(v), k) 

18 for k, v in constants.items()}) 

19 kdc_plain = None 

20 array_calulated = {} 

21 for key in constants: # remove gradients from constants 

22 key.descr.pop("gradients", None) 

23 for v, f in linked.items(): 

24 try: 

25 if v.veckey and v.veckey.vecfn: 

26 if v.veckey not in array_calulated: 

27 with SignomialsEnabled(): # to allow use of gpkit.units 

28 vecout = v.veckey.vecfn(kdc) 

29 if not hasattr(vecout, "shape"): 

30 vecout = np.array(vecout) 

31 array_calulated[v.veckey] = vecout 

32 out = array_calulated[v.veckey][v.idx] 

33 else: 

34 with SignomialsEnabled(): # to allow use of gpkit.units 

35 out = f(kdc) 

36 if isinstance(out, FixedScalar): # to allow use of gpkit.units 

37 out = out.value 

38 if hasattr(out, "units"): 

39 out = out.to(v.units or "dimensionless").magnitude 

40 elif out != 0 and v.units: 

41 pywarnings.warn( 

42 "Linked function for %s did not return a united value." 

43 " Modifying it to do so (e.g. by using `()` instead of `[]`" 

44 " to access variables) will reduce errors." % v) 

45 out = maybe_flatten(out) 

46 if not hasattr(out, "x"): 

47 constants[v] = out 

48 continue # a new fixed variable, not a calculated one 

49 constants[v] = out.x 

50 v.descr["gradients"] = {adn.tag: grad 

51 for adn, grad in out.d().items() 

52 if adn.tag} 

53 except Exception as exception: # pylint: disable=broad-except 

54 from .. import settings 

55 if settings.get("ad_errors_raise", None): 

56 raise 

57 if kdc_plain is None: 

58 kdc_plain = KeyDict(constants) 

59 constants[v] = f(kdc_plain) 

60 v.descr.pop("gradients", None) 

61 print("Warning: skipped auto-differentiation of linked variable" 

62 " %s because %s was raised. Set `gpkit.settings" 

63 "[\"ad_errors_raise\"] = True` to raise such Exceptions" 

64 " directly.\n" % (v, repr(exception))) 

65 if ("Automatic differentiation not yet supported for <class " 

66 "'gpkit.nomials.math.Monomial'> objects") in str(exception): 

67 print("This particular warning may have come from using" 

68 " gpkit.units.* in the function for %s; try using" 

69 " gpkit.ureg.* or gpkit.units.*.units instead." % v) 

70 

71 

72def progify(program, return_attr=None): 

73 """Generates function that returns a program() and optionally an attribute. 

74 

75 Arguments 

76 --------- 

77 program: NomialData 

78 Class to return, e.g. GeometricProgram or SequentialGeometricProgram 

79 return_attr: string 

80 attribute to return in addition to the program 

81 """ 

82 def programfn(self, constants=None, **initargs): 

83 "Return program version of self" 

84 if not constants: 

85 constants, _, linked = parse_subs(self.varkeys, self.substitutions) 

86 if linked: 

87 evaluate_linked(constants, linked) 

88 prog = program(self.cost, self, constants, **initargs) 

89 prog.model = self # NOTE SIDE EFFECTS 

90 if return_attr: 

91 return prog, getattr(prog, return_attr) 

92 return prog 

93 return programfn 

94 

95 

96def solvify(genfunction): 

97 "Returns function for making/solving/sweeping a program." 

98 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False, 

99 **kwargs): 

100 """Forms a mathematical program and attempts to solve it. 

101 

102 Arguments 

103 --------- 

104 solver : string or function (default None) 

105 If None, uses the default solver found in installation. 

106 verbosity : int (default 1) 

107 If greater than 0 prints runtime messages. 

108 Is decremented by one and then passed to programs. 

109 skipsweepfailures : bool (default False) 

110 If True, when a solve errors during a sweep, skip it. 

111 **kwargs : Passed to solve and program init calls 

112 

113 Returns 

114 ------- 

115 sol : SolutionArray 

116 See the SolutionArray documentation for details. 

117 

118 Raises 

119 ------ 

120 ValueError if the program is invalid. 

121 RuntimeWarning if an error occurs in solving or parsing the solution. 

122 """ 

123 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions) 

124 solution = SolutionArray() 

125 solution.modelstr = str(self) 

126 

127 # NOTE SIDE EFFECTS: self.program and self.solution set below 

128 if sweep: 

129 run_sweep(genfunction, self, solution, skipsweepfailures, 

130 constants, sweep, linked, solver, verbosity, **kwargs) 

131 else: 

132 self.program, progsolve = genfunction(self, **kwargs) 

133 result = progsolve(solver, verbosity=verbosity, **kwargs) 

134 if kwargs.get("process_result", True): 

135 self.process_result(result) 

136 solution.append(result) 

137 solution.to_arrays() 

138 self.solution = solution 

139 return solution 

140 return solvefn 

141 

142 

143# pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements 

144def run_sweep(genfunction, self, solution, skipsweepfailures, 

145 constants, sweep, linked, solver, verbosity, **kwargs): 

146 "Runs through a sweep." 

147 # sort sweeps by the eqstr of their varkey 

148 sweepvars, sweepvals = zip(*sorted(list(sweep.items()), 

149 key=lambda vkval: vkval[0].eqstr)) 

150 if len(sweep) == 1: 

151 sweep_grids = np.array(list(sweepvals)) 

152 else: 

153 sweep_grids = np.meshgrid(*list(sweepvals)) 

154 

155 N_passes = sweep_grids[0].size 

156 sweep_vects = {var: grid.reshape(N_passes) 

157 for (var, grid) in zip(sweepvars, sweep_grids)} 

158 

159 if verbosity > 0: 

160 tic = time() 

161 

162 self.program = [] 

163 last_error = None 

164 for i in range(N_passes): 

165 constants.update({var: sweep_vect[i] 

166 for (var, sweep_vect) in sweep_vects.items()}) 

167 if linked: 

168 evaluate_linked(constants, linked) 

169 program, solvefn = genfunction(self, constants, **kwargs) 

170 program.model = None # so it doesn't try to debug 

171 self.program.append(program) # NOTE: SIDE EFFECTS 

172 if i == 0 and verbosity > 0: # wait for successful program gen 

173 # TODO: use full string when minimum lineage is set automatically 

174 sweepvarsstr = ", ".join([str(var) 

175 for var, val in zip(sweepvars, sweepvals) 

176 if not np.isnan(val).all()]) 

177 print("Sweeping %s with %i solves:" % (sweepvarsstr, N_passes)) 

178 try: 

179 if verbosity > 1: 

180 print("\nSolve %i:" % i) 

181 result = solvefn(solver, verbosity=verbosity-1, **kwargs) 

182 if kwargs.get("process_result", True): 

183 self.process_result(result) 

184 solution.append(result) 

185 if verbosity == 1: 

186 print(".", end="", flush=True) 

187 except Infeasible as e: 

188 last_error = e 

189 if not skipsweepfailures: 

190 raise RuntimeWarning( 

191 "Solve %i was infeasible; progress saved to m.program." 

192 " To continue sweeping after failures, solve with" 

193 " skipsweepfailures=True." % i) from e 

194 if verbosity > 1: 

195 print("Solve %i was %s." % (i, e.__class__.__name__)) 

196 if verbosity == 1: 

197 print("!", end="", flush=True) 

198 if not solution: 

199 raise RuntimeWarning("All solves were infeasible.") from last_error 

200 if verbosity == 1: 

201 print() 

202 

203 solution["sweepvariables"] = KeyDict() 

204 ksweep = KeyDict(sweep) 

205 for var, val in list(solution["constants"].items()): 

206 if var in ksweep: 

207 solution["sweepvariables"][var] = val 

208 del solution["constants"][var] 

209 elif linked: # if any variables are linked, we check all of them 

210 if hasattr(val[0], "shape"): 

211 differences = ((l != val[0]).any() for l in val[1:]) 

212 else: 

213 differences = (l != val[0] for l in val[1:]) 

214 if not any(differences): 

215 solution["constants"][var] = [val[0]] 

216 else: 

217 solution["constants"][var] = [val[0]] 

218 

219 if verbosity > 0: 

220 soltime = time() - tic 

221 print("Sweeping took %.3g seconds." % (soltime,))