Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"Scripts for generating, solving and sweeping programs" 

2from time import time 

3import numpy as np 

4from ad import adnumber 

5from ..nomials import parse_subs 

6from ..solution_array import SolutionArray 

7from ..keydict import KeyDict 

8from ..small_scripts import maybe_flatten 

9from ..exceptions import Infeasible 

10 

11 

12def evaluate_linked(constants, linked): 

13 "Evaluates the values and gradients of linked variables." 

14 kdc = KeyDict({k: adnumber(maybe_flatten(v), k) 

15 for k, v in constants.items()}) 

16 kdc_plain = None 

17 array_calulated = {} 

18 for v, f in linked.items(): 

19 try: 

20 if v.veckey and v.veckey.original_fn: 

21 if v.veckey not in array_calulated: 

22 ofn = v.veckey.original_fn 

23 array_calulated[v.veckey] = np.array(ofn(kdc)) 

24 out = array_calulated[v.veckey][v.idx] 

25 else: 

26 out = f(kdc) 

27 constants[v] = out.x 

28 v.descr["gradients"] = {adn.tag: grad 

29 for adn, grad in out.d().items() 

30 if adn.tag} # else it's user-created 

31 except Exception as exception: # pylint: disable=broad-except 

32 from .. import settings 

33 if settings.get("ad_errors_raise", None): 

34 raise 

35 print("Warning: skipped auto-differentiation of linked variable" 

36 " %s because %s was raised. Set `gpkit.settings" 

37 "[\"ad_errors_raise\"] = True` to raise such Exceptions" 

38 " directly.\n" % (v, repr(exception))) 

39 if kdc_plain is None: 

40 kdc_plain = KeyDict(constants) 

41 constants[v] = f(kdc_plain) 

42 v.descr.pop("gradients", None) 

43 

44 

45def progify(program, return_attr=None): 

46 """Generates function that returns a program() and optionally an attribute. 

47 

48 Arguments 

49 --------- 

50 program: NomialData 

51 Class to return, e.g. GeometricProgram or SequentialGeometricProgram 

52 return_attr: string 

53 attribute to return in addition to the program 

54 """ 

55 def programfn(self, constants=None, **initargs): 

56 "Return program version of self" 

57 if not constants: 

58 constants, _, linked = parse_subs(self.varkeys, self.substitutions) 

59 if linked: 

60 evaluate_linked(constants, linked) 

61 prog = program(self.cost, self, constants, **initargs) 

62 prog.model = self # NOTE SIDE EFFECTS 

63 if return_attr: 

64 return prog, getattr(prog, return_attr) 

65 return prog 

66 return programfn 

67 

68 

69def solvify(genfunction): 

70 "Returns function for making/solving/sweeping a program." 

71 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False, 

72 **solveargs): 

73 """Forms a mathematical program and attempts to solve it. 

74 

75 Arguments 

76 --------- 

77 solver : string or function (default None) 

78 If None, uses the default solver found in installation. 

79 verbosity : int (default 1) 

80 If greater than 0 prints runtime messages. 

81 Is decremented by one and then passed to programs. 

82 skipsweepfailures : bool (default False) 

83 If True, when a solve errors during a sweep, skip it. 

84 **solveargs : Passed to solve() call 

85 

86 Returns 

87 ------- 

88 sol : SolutionArray 

89 See the SolutionArray documentation for details. 

90 

91 Raises 

92 ------ 

93 ValueError if the program is invalid. 

94 RuntimeWarning if an error occurs in solving or parsing the solution. 

95 """ 

96 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions) 

97 solution = SolutionArray() 

98 solution.modelstr = str(self) 

99 

100 # NOTE SIDE EFFECTS: self.program and self.solution set below 

101 if sweep: 

102 run_sweep(genfunction, self, solution, skipsweepfailures, 

103 constants, sweep, linked, solver, verbosity, **solveargs) 

104 else: 

105 self.program, progsolve = genfunction(self) 

106 result = progsolve(solver, verbosity=verbosity, **solveargs) 

107 if solveargs.get("process_result", True): 

108 self.process_result(result) 

109 solution.append(result) 

110 solution.to_arrays() 

111 self.solution = solution 

112 return solution 

113 return solvefn 

114 

115 

116# pylint: disable=too-many-locals,too-many-arguments,too-many-branches 

117def run_sweep(genfunction, self, solution, skipsweepfailures, 

118 constants, sweep, linked, solver, verbosity, **solveargs): 

119 "Runs through a sweep." 

120 # sort sweeps by the eqstr of their varkey 

121 sweepvars, sweepvals = zip(*sorted(list(sweep.items()), 

122 key=lambda vkval: vkval[0].eqstr)) 

123 if len(sweep) == 1: 

124 sweep_grids = np.array(list(sweepvals)) 

125 else: 

126 sweep_grids = np.meshgrid(*list(sweepvals)) 

127 

128 N_passes = sweep_grids[0].size 

129 sweep_vects = {var: grid.reshape(N_passes) 

130 for (var, grid) in zip(sweepvars, sweep_grids)} 

131 

132 if verbosity > 0: 

133 print("Sweeping over %i solves." % N_passes) 

134 tic = time() 

135 

136 self.program = [] 

137 last_error = None 

138 for i in range(N_passes): 

139 constants.update({var: sweep_vect[i] 

140 for (var, sweep_vect) in sweep_vects.items()}) 

141 if linked: 

142 evaluate_linked(constants, linked) 

143 program, solvefn = genfunction(self, constants) 

144 self.program.append(program) # NOTE: SIDE EFFECTS 

145 try: 

146 result = solvefn(solver, verbosity=verbosity-1, **solveargs) 

147 if solveargs.get("process_result", True): 

148 self.process_result(result) 

149 solution.append(result) 

150 except Infeasible as e: 

151 last_error = e 

152 if not skipsweepfailures: 

153 raise RuntimeWarning( 

154 "Sweep halted! Progress saved to m.program. To skip over" 

155 " such failures, solve with skipsweepfailures=True.") from e 

156 if not solution: 

157 raise RuntimeWarning("No sweeps solved successfully.") from last_error 

158 

159 solution["sweepvariables"] = KeyDict() 

160 ksweep = KeyDict(sweep) 

161 for var, val in list(solution["constants"].items()): 

162 if var in ksweep: 

163 solution["sweepvariables"][var] = val 

164 del solution["constants"][var] 

165 elif linked: # if any variables are linked, we check all of them 

166 if hasattr(val[0], "shape"): 

167 differences = ((l != val[0]).any() for l in val[1:]) 

168 else: 

169 differences = (l != val[0] for l in val[1:]) 

170 if not any(differences): 

171 solution["constants"][var] = [val[0]] 

172 else: 

173 solution["constants"][var] = [val[0]] 

174 

175 if verbosity > 0: 

176 soltime = time() - tic 

177 print("Sweeping took %.3g seconds." % (soltime,))