Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"Scripts for generating, solving and sweeping programs" 

2from time import time 

3import numpy as np 

4from ad import adnumber 

5from ..nomials import parse_subs 

6from ..solution_array import SolutionArray 

7from ..keydict import KeyDict 

8from ..small_scripts import maybe_flatten 

9from ..small_classes import FixedScalar 

10from ..exceptions import Infeasible 

11from ..globals import SignomialsEnabled 

12 

13 

14def evaluate_linked(constants, linked): 

15 "Evaluates the values and gradients of linked variables." 

16 kdc = KeyDict({k: adnumber(maybe_flatten(v), k) 

17 for k, v in constants.items()}) 

18 kdc_plain = None 

19 array_calulated = {} 

20 for key in constants: # remove gradients from constants 

21 if key.gradients: 

22 del key.descr["gradients"] 

23 for v, f in linked.items(): 

24 try: 

25 if v.veckey and v.veckey.vecfn: 

26 if v.veckey not in array_calulated: 

27 with SignomialsEnabled(): # to allow use of gpkit.units 

28 vecout = v.veckey.vecfn(kdc) 

29 if not hasattr(vecout, "shape"): 

30 vecout = np.array(vecout) 

31 array_calulated[v.veckey] = vecout 

32 if (any(vecout != 0) and v.veckey.units 

33 and not hasattr(vecout, "units")): 

34 print("Warning: linked function for %s did not return" 

35 " a united value. Modifying it to do so (e.g. by" 

36 " using `()` instead of `[]` to access variables)" 

37 " would reduce the risk of errors." % v.veckey) 

38 out = array_calulated[v.veckey][v.idx] 

39 else: 

40 with SignomialsEnabled(): # to allow use of gpkit.units 

41 out = f(kdc) 

42 if isinstance(out, FixedScalar): # to allow use of gpkit.units 

43 out = out.value 

44 if hasattr(out, "units"): 

45 out = out.to(v.units or "dimensionless").magnitude 

46 elif out != 0 and v.units and not v.veckey: 

47 print("Warning: linked function for %s did not return" 

48 " a united value. Modifying it to do so (e.g. by" 

49 " using `()` instead of `[]` to access variables)" 

50 " would reduce the risk of errors." % v) 

51 if not hasattr(out, "x"): 

52 constants[v] = out 

53 continue # a new fixed variable, not a calculated one 

54 constants[v] = out.x 

55 v.descr["gradients"] = {adn.tag: grad 

56 for adn, grad in out.d().items()} 

57 except Exception as exception: # pylint: disable=broad-except 

58 from .. import settings 

59 if settings.get("ad_errors_raise", None): 

60 raise 

61 print("Warning: skipped auto-differentiation of linked variable" 

62 " %s because %s was raised. Set `gpkit.settings" 

63 "[\"ad_errors_raise\"] = True` to raise such Exceptions" 

64 " directly.\n" % (v, repr(exception))) 

65 if kdc_plain is None: 

66 kdc_plain = KeyDict(constants) 

67 constants[v] = f(kdc_plain) 

68 v.descr.pop("gradients", None) 

69 

70 

71def progify(program, return_attr=None): 

72 """Generates function that returns a program() and optionally an attribute. 

73 

74 Arguments 

75 --------- 

76 program: NomialData 

77 Class to return, e.g. GeometricProgram or SequentialGeometricProgram 

78 return_attr: string 

79 attribute to return in addition to the program 

80 """ 

81 def programfn(self, constants=None, **initargs): 

82 "Return program version of self" 

83 if not constants: 

84 constants, _, linked = parse_subs(self.varkeys, self.substitutions) 

85 if linked: 

86 evaluate_linked(constants, linked) 

87 prog = program(self.cost, self, constants, **initargs) 

88 prog.model = self # NOTE SIDE EFFECTS 

89 if return_attr: 

90 return prog, getattr(prog, return_attr) 

91 return prog 

92 return programfn 

93 

94 

95def solvify(genfunction): 

96 "Returns function for making/solving/sweeping a program." 

97 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False, 

98 **solveargs): 

99 """Forms a mathematical program and attempts to solve it. 

100 

101 Arguments 

102 --------- 

103 solver : string or function (default None) 

104 If None, uses the default solver found in installation. 

105 verbosity : int (default 1) 

106 If greater than 0 prints runtime messages. 

107 Is decremented by one and then passed to programs. 

108 skipsweepfailures : bool (default False) 

109 If True, when a solve errors during a sweep, skip it. 

110 **solveargs : Passed to solve() call 

111 

112 Returns 

113 ------- 

114 sol : SolutionArray 

115 See the SolutionArray documentation for details. 

116 

117 Raises 

118 ------ 

119 ValueError if the program is invalid. 

120 RuntimeWarning if an error occurs in solving or parsing the solution. 

121 """ 

122 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions) 

123 solution = SolutionArray() 

124 solution.modelstr = str(self) 

125 

126 # NOTE SIDE EFFECTS: self.program and self.solution set below 

127 if sweep: 

128 run_sweep(genfunction, self, solution, skipsweepfailures, 

129 constants, sweep, linked, solver, verbosity, **solveargs) 

130 else: 

131 self.program, progsolve = genfunction(self) 

132 result = progsolve(solver, verbosity=verbosity, **solveargs) 

133 if solveargs.get("process_result", True): 

134 self.process_result(result) 

135 solution.append(result) 

136 solution.to_arrays() 

137 self.solution = solution 

138 return solution 

139 return solvefn 

140 

141 

142# pylint: disable=too-many-locals,too-many-arguments,too-many-branches 

143def run_sweep(genfunction, self, solution, skipsweepfailures, 

144 constants, sweep, linked, solver, verbosity, **solveargs): 

145 "Runs through a sweep." 

146 # sort sweeps by the eqstr of their varkey 

147 sweepvars, sweepvals = zip(*sorted(list(sweep.items()), 

148 key=lambda vkval: vkval[0].eqstr)) 

149 if len(sweep) == 1: 

150 sweep_grids = np.array(list(sweepvals)) 

151 else: 

152 sweep_grids = np.meshgrid(*list(sweepvals)) 

153 

154 N_passes = sweep_grids[0].size 

155 sweep_vects = {var: grid.reshape(N_passes) 

156 for (var, grid) in zip(sweepvars, sweep_grids)} 

157 

158 if verbosity > 0: 

159 print("Sweeping with %i solves:" % N_passes) 

160 tic = time() 

161 

162 self.program = [] 

163 last_error = None 

164 for i in range(N_passes): 

165 constants.update({var: sweep_vect[i] 

166 for (var, sweep_vect) in sweep_vects.items()}) 

167 if linked: 

168 evaluate_linked(constants, linked) 

169 program, solvefn = genfunction(self, constants) 

170 self.program.append(program) # NOTE: SIDE EFFECTS 

171 try: 

172 if verbosity > 1: 

173 print("\nSolve %i:" % i) 

174 result = solvefn(solver, verbosity=verbosity-1, **solveargs) 

175 if solveargs.get("process_result", True): 

176 self.process_result(result) 

177 solution.append(result) 

178 except Infeasible as e: 

179 last_error = e 

180 if not skipsweepfailures: 

181 raise RuntimeWarning( 

182 "Solve %i was infeasible; progress saved to m.program." 

183 " To continue sweeping after failures, solve with" 

184 " skipsweepfailures=True." % i) from e 

185 if verbosity > 0: 

186 print("Solve %i was %s." % (i, e.__class__.__name__)) 

187 if not solution: 

188 raise RuntimeWarning("All solves were infeasible.") from last_error 

189 

190 solution["sweepvariables"] = KeyDict() 

191 ksweep = KeyDict(sweep) 

192 for var, val in list(solution["constants"].items()): 

193 if var in ksweep: 

194 solution["sweepvariables"][var] = val 

195 del solution["constants"][var] 

196 elif linked: # if any variables are linked, we check all of them 

197 if hasattr(val[0], "shape"): 

198 differences = ((l != val[0]).any() for l in val[1:]) 

199 else: 

200 differences = (l != val[0] for l in val[1:]) 

201 if not any(differences): 

202 solution["constants"][var] = [val[0]] 

203 else: 

204 solution["constants"][var] = [val[0]] 

205 

206 if verbosity > 0: 

207 soltime = time() - tic 

208 print("Sweeping took %.3g seconds." % (soltime,))