Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"Scripts for generating, solving and sweeping programs" 

2from time import time 

3import numpy as np 

4from ad import adnumber 

5from ..nomials import parse_subs 

6from ..solution_array import SolutionArray 

7from ..keydict import KeyDict 

8from ..small_scripts import maybe_flatten 

9from ..exceptions import Infeasible 

10 

11 

12def evaluate_linked(constants, linked): 

13 "Evaluates the values and gradients of linked variables." 

14 kdc = KeyDict({k: adnumber(maybe_flatten(v), k) 

15 for k, v in constants.items()}) 

16 kdc_plain = None 

17 array_calulated = {} 

18 for key in constants: # remove gradients from constants 

19 if key.gradients: 

20 del key.descr["gradients"] 

21 for v, f in linked.items(): 

22 try: 

23 if v.veckey and v.veckey.original_fn: 

24 if v.veckey not in array_calulated: 

25 ofn = v.veckey.original_fn 

26 array_calulated[v.veckey] = np.array(ofn(kdc)) 

27 out = array_calulated[v.veckey][v.idx] 

28 else: 

29 out = f(kdc) 

30 if not hasattr(out, "x"): 

31 constants[v] = out 

32 continue # a new fixed variable, not a calculated one 

33 constants[v] = out.x 

34 v.descr["gradients"] = {adn.tag: grad 

35 for adn, grad in out.d().items() 

36 if adn.tag} # else it's user-created 

37 except Exception as exception: # pylint: disable=broad-except 

38 from .. import settings 

39 if settings.get("ad_errors_raise", None): 

40 raise 

41 print("Warning: skipped auto-differentiation of linked variable" 

42 " %s because %s was raised. Set `gpkit.settings" 

43 "[\"ad_errors_raise\"] = True` to raise such Exceptions" 

44 " directly.\n" % (v, repr(exception))) 

45 if kdc_plain is None: 

46 kdc_plain = KeyDict(constants) 

47 constants[v] = f(kdc_plain) 

48 v.descr.pop("gradients", None) 

49 

50 

51def progify(program, return_attr=None): 

52 """Generates function that returns a program() and optionally an attribute. 

53 

54 Arguments 

55 --------- 

56 program: NomialData 

57 Class to return, e.g. GeometricProgram or SequentialGeometricProgram 

58 return_attr: string 

59 attribute to return in addition to the program 

60 """ 

61 def programfn(self, constants=None, **initargs): 

62 "Return program version of self" 

63 if not constants: 

64 constants, _, linked = parse_subs(self.varkeys, self.substitutions) 

65 if linked: 

66 evaluate_linked(constants, linked) 

67 prog = program(self.cost, self, constants, **initargs) 

68 prog.model = self # NOTE SIDE EFFECTS 

69 if return_attr: 

70 return prog, getattr(prog, return_attr) 

71 return prog 

72 return programfn 

73 

74 

75def solvify(genfunction): 

76 "Returns function for making/solving/sweeping a program." 

77 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False, 

78 **solveargs): 

79 """Forms a mathematical program and attempts to solve it. 

80 

81 Arguments 

82 --------- 

83 solver : string or function (default None) 

84 If None, uses the default solver found in installation. 

85 verbosity : int (default 1) 

86 If greater than 0 prints runtime messages. 

87 Is decremented by one and then passed to programs. 

88 skipsweepfailures : bool (default False) 

89 If True, when a solve errors during a sweep, skip it. 

90 **solveargs : Passed to solve() call 

91 

92 Returns 

93 ------- 

94 sol : SolutionArray 

95 See the SolutionArray documentation for details. 

96 

97 Raises 

98 ------ 

99 ValueError if the program is invalid. 

100 RuntimeWarning if an error occurs in solving or parsing the solution. 

101 """ 

102 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions) 

103 solution = SolutionArray() 

104 solution.modelstr = str(self) 

105 

106 # NOTE SIDE EFFECTS: self.program and self.solution set below 

107 if sweep: 

108 run_sweep(genfunction, self, solution, skipsweepfailures, 

109 constants, sweep, linked, solver, verbosity, **solveargs) 

110 else: 

111 self.program, progsolve = genfunction(self) 

112 result = progsolve(solver, verbosity=verbosity, **solveargs) 

113 if solveargs.get("process_result", True): 

114 self.process_result(result) 

115 solution.append(result) 

116 solution.to_arrays() 

117 self.solution = solution 

118 return solution 

119 return solvefn 

120 

121 

122# pylint: disable=too-many-locals,too-many-arguments,too-many-branches 

123def run_sweep(genfunction, self, solution, skipsweepfailures, 

124 constants, sweep, linked, solver, verbosity, **solveargs): 

125 "Runs through a sweep." 

126 # sort sweeps by the eqstr of their varkey 

127 sweepvars, sweepvals = zip(*sorted(list(sweep.items()), 

128 key=lambda vkval: vkval[0].eqstr)) 

129 if len(sweep) == 1: 

130 sweep_grids = np.array(list(sweepvals)) 

131 else: 

132 sweep_grids = np.meshgrid(*list(sweepvals)) 

133 

134 N_passes = sweep_grids[0].size 

135 sweep_vects = {var: grid.reshape(N_passes) 

136 for (var, grid) in zip(sweepvars, sweep_grids)} 

137 

138 if verbosity > 0: 

139 print("Sweeping with %i solves:" % N_passes) 

140 tic = time() 

141 

142 self.program = [] 

143 last_error = None 

144 for i in range(N_passes): 

145 constants.update({var: sweep_vect[i] 

146 for (var, sweep_vect) in sweep_vects.items()}) 

147 if linked: 

148 evaluate_linked(constants, linked) 

149 program, solvefn = genfunction(self, constants) 

150 self.program.append(program) # NOTE: SIDE EFFECTS 

151 try: 

152 if verbosity > 1: 

153 print("\nSolve %i:" % i) 

154 result = solvefn(solver, verbosity=verbosity-1, **solveargs) 

155 if solveargs.get("process_result", True): 

156 self.process_result(result) 

157 solution.append(result) 

158 except Infeasible as e: 

159 last_error = e 

160 if not skipsweepfailures: 

161 raise RuntimeWarning( 

162 "Solve %i was infeasible; progress saved to m.program." 

163 " To continue sweeping after failures, solve with" 

164 " skipsweepfailures=True." % i) from e 

165 if verbosity > 0: 

166 print("Solve %i was %s." % (i, e.__class__.__name__)) 

167 if not solution: 

168 raise RuntimeWarning("All solves were infeasible.") from last_error 

169 

170 solution["sweepvariables"] = KeyDict() 

171 ksweep = KeyDict(sweep) 

172 for var, val in list(solution["constants"].items()): 

173 if var in ksweep: 

174 solution["sweepvariables"][var] = val 

175 del solution["constants"][var] 

176 elif linked: # if any variables are linked, we check all of them 

177 if hasattr(val[0], "shape"): 

178 differences = ((l != val[0]).any() for l in val[1:]) 

179 else: 

180 differences = (l != val[0] for l in val[1:]) 

181 if not any(differences): 

182 solution["constants"][var] = [val[0]] 

183 else: 

184 solution["constants"][var] = [val[0]] 

185 

186 if verbosity > 0: 

187 soltime = time() - tic 

188 print("Sweeping took %.3g seconds." % (soltime,))