Coverage for gpkit/constraints/prog_factories.py: 84%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

138 statements  

1"Scripts for generating, solving and sweeping programs" 

2from time import time 

3import warnings as pywarnings 

4import numpy as np 

5from adce import adnumber 

6from ..nomials import parse_subs 

7from ..solution_array import SolutionArray 

8from ..keydict import KeyDict 

9from ..small_scripts import maybe_flatten 

10from ..small_classes import FixedScalar 

11from ..exceptions import Infeasible 

12from ..globals import SignomialsEnabled 

13 

14 

15def evaluate_linked(constants, linked): 

16 "Evaluates the values and gradients of linked variables." 

17 kdc = KeyDict({k: adnumber(maybe_flatten(v), k) 

18 for k, v in constants.items()}) 

19 kdc_plain = None 

20 array_calulated = {} 

21 for key in constants: # remove gradients from constants 

22 key.descr.pop("gradients", None) 

23 for v, f in linked.items(): 

24 try: 

25 if v.veckey and v.veckey.vecfn: 

26 if v.veckey not in array_calulated: 

27 with SignomialsEnabled(): # to allow use of gpkit.units 

28 vecout = v.veckey.vecfn(kdc) 

29 if not hasattr(vecout, "shape"): 

30 vecout = np.array(vecout) 

31 array_calulated[v.veckey] = vecout 

32 out = array_calulated[v.veckey][v.idx] 

33 else: 

34 with SignomialsEnabled(): # to allow use of gpkit.units 

35 out = f(kdc) 

36 if isinstance(out, FixedScalar): # to allow use of gpkit.units 

37 out = out.value 

38 if hasattr(out, "units"): 

39 out = out.to(v.units or "dimensionless").magnitude 

40 elif out != 0 and v.units: 

41 pywarnings.warn( 

42 "Linked function for %s did not return a united value." 

43 " Modifying it to do so (e.g. by using `()` instead of `[]`" 

44 " to access variables) will reduce errors." % v) 

45 out = maybe_flatten(out) 

46 if not hasattr(out, "x"): 

47 constants[v] = out 

48 continue # a new fixed variable, not a calculated one 

49 constants[v] = out.x 

50 v.descr["gradients"] = {adn.tag: grad 

51 for adn, grad in out.d().items() 

52 if adn.tag} 

53 except Exception as exception: # pylint: disable=broad-except 

54 from .. import settings 

55 if settings.get("ad_errors_raise", None): 

56 raise 

57 if kdc_plain is None: 

58 kdc_plain = KeyDict(constants) 

59 constants[v] = f(kdc_plain) 

60 v.descr.pop("gradients", None) 

61 print("Warning: skipped auto-differentiation of linked variable" 

62 " %s because %s was raised. Set `gpkit.settings" 

63 "[\"ad_errors_raise\"] = True` to raise such Exceptions" 

64 " directly.\n" % (v, repr(exception))) 

65 if ("Automatic differentiation not yet supported for <class " 

66 "'gpkit.nomials.math.Monomial'> objects") in str(exception): 

67 print("This particular warning may have come from using" 

68 " gpkit.units.* in the function for %s; try using" 

69 " gpkit.ureg.* or gpkit.units.*.units instead." % v) 

70 

71 

72def progify(program, return_attr=None): 

73 """Generates function that returns a program() and optionally an attribute. 

74 

75 Arguments 

76 --------- 

77 program: NomialData 

78 Class to return, e.g. GeometricProgram or SequentialGeometricProgram 

79 return_attr: string 

80 attribute to return in addition to the program 

81 """ 

82 def programfn(self, constants=None, **initargs): 

83 "Return program version of self" 

84 if not constants: 

85 constants, _, linked = parse_subs(self.varkeys, self.substitutions) 

86 if linked: 

87 evaluate_linked(constants, linked) 

88 prog = program(self.cost, self, constants, **initargs) 

89 prog.model = self # NOTE SIDE EFFECTS 

90 if return_attr: 

91 return prog, getattr(prog, return_attr) 

92 return prog 

93 return programfn 

94 

95 

96def solvify(genfunction): 

97 "Returns function for making/solving/sweeping a program." 

98 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False, 

99 **kwargs): 

100 """Forms a mathematical program and attempts to solve it. 

101 

102 Arguments 

103 --------- 

104 solver : string or function (default None) 

105 If None, uses the default solver found in installation. 

106 verbosity : int (default 1) 

107 If greater than 0 prints runtime messages. 

108 Is decremented by one and then passed to programs. 

109 skipsweepfailures : bool (default False) 

110 If True, when a solve errors during a sweep, skip it. 

111 **kwargs : Passed to solve and program init calls 

112 

113 Returns 

114 ------- 

115 sol : SolutionArray 

116 See the SolutionArray documentation for details. 

117 

118 Raises 

119 ------ 

120 ValueError if the program is invalid. 

121 RuntimeWarning if an error occurs in solving or parsing the solution. 

122 """ 

123 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions) 

124 solution = SolutionArray() 

125 solution.modelstr = str(self) 

126 

127 # NOTE SIDE EFFECTS: self.program and self.solution set below 

128 if sweep: 

129 run_sweep(genfunction, self, solution, skipsweepfailures, 

130 constants, sweep, linked, solver, verbosity, **kwargs) 

131 else: 

132 self.program, progsolve = genfunction(self, **kwargs) 

133 result = progsolve(solver, verbosity=verbosity, **kwargs) 

134 if kwargs.get("process_result", True): 

135 self.process_result(result) 

136 solution.append(result) 

137 solution.to_arrays() 

138 self.solution = solution 

139 return solution 

140 return solvefn 

141 

142 

143# pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements 

144def run_sweep(genfunction, self, solution, skipsweepfailures, 

145 constants, sweep, linked, solver, verbosity, **kwargs): 

146 "Runs through a sweep." 

147 # sort sweeps by the eqstr of their varkey 

148 sweepvars, sweepvals = zip(*sorted(list(sweep.items()), 

149 key=lambda vkval: vkval[0].eqstr)) 

150 if len(sweep) == 1: 

151 sweep_grids = np.array(list(sweepvals)) 

152 else: 

153 sweep_grids = np.meshgrid(*list(sweepvals)) 

154 

155 N_passes = sweep_grids[0].size 

156 sweep_vects = {var: grid.reshape(N_passes) 

157 for (var, grid) in zip(sweepvars, sweep_grids)} 

158 

159 if verbosity > 0: 

160 tic = time() 

161 

162 self.program = [] 

163 last_error = None 

164 for i in range(N_passes): 

165 constants.update({var: sweep_vect[i] 

166 for (var, sweep_vect) in sweep_vects.items()}) 

167 if linked: 

168 evaluate_linked(constants, linked) 

169 program, solvefn = genfunction(self, constants, **kwargs) 

170 program.model = None # so it doesn't try to debug 

171 self.program.append(program) # NOTE: SIDE EFFECTS 

172 if i == 0 and verbosity > 0: # wait for successful program gen 

173 # TODO: use full string when minimum lineage is set automatically 

174 sweepvarsstr = ", ".join([sv.name for sv in sweepvars]) 

175 print("Sweeping %s with %i solves:" % (sweepvarsstr, N_passes)) 

176 try: 

177 if verbosity > 1: 

178 print("\nSolve %i:" % i) 

179 result = solvefn(solver, verbosity=verbosity-1, **kwargs) 

180 if kwargs.get("process_result", True): 

181 self.process_result(result) 

182 solution.append(result) 

183 if verbosity == 1: 

184 print(".", end="", flush=True) 

185 except Infeasible as e: 

186 last_error = e 

187 if not skipsweepfailures: 

188 raise RuntimeWarning( 

189 "Solve %i was infeasible; progress saved to m.program." 

190 " To continue sweeping after failures, solve with" 

191 " skipsweepfailures=True." % i) from e 

192 if verbosity > 1: 

193 print("Solve %i was %s." % (i, e.__class__.__name__)) 

194 if verbosity == 1: 

195 print("!", end="", flush=True) 

196 if not solution: 

197 raise RuntimeWarning("All solves were infeasible.") from last_error 

198 if verbosity == 1: 

199 print() 

200 

201 solution["sweepvariables"] = KeyDict() 

202 ksweep = KeyDict(sweep) 

203 for var, val in list(solution["constants"].items()): 

204 if var in ksweep: 

205 solution["sweepvariables"][var] = val 

206 del solution["constants"][var] 

207 elif linked: # if any variables are linked, we check all of them 

208 if hasattr(val[0], "shape"): 

209 differences = ((l != val[0]).any() for l in val[1:]) 

210 else: 

211 differences = (l != val[0] for l in val[1:]) 

212 if not any(differences): 

213 solution["constants"][var] = [val[0]] 

214 else: 

215 solution["constants"][var] = [val[0]] 

216 

217 if verbosity > 0: 

218 soltime = time() - tic 

219 print("Sweeping took %.3g seconds." % (soltime,))