Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Scripts for generating, solving and sweeping programs"
2from time import time
3import numpy as np
4from ad import adnumber
5from ..nomials import parse_subs
6from ..solution_array import SolutionArray
7from ..keydict import KeyDict
8from ..small_scripts import maybe_flatten
9from ..exceptions import Infeasible
12def evaluate_linked(constants, linked):
13 "Evaluates the values and gradients of linked variables."
14 kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
15 for k, v in constants.items()})
16 kdc_plain = None
17 array_calulated = {}
18 for key in constants: # remove gradients from constants
19 if key.gradients:
20 del key.descr["gradients"]
21 for v, f in linked.items():
22 try:
23 if v.veckey and v.veckey.original_fn:
24 if v.veckey not in array_calulated:
25 ofn = v.veckey.original_fn
26 array_calulated[v.veckey] = np.array(ofn(kdc))
27 out = array_calulated[v.veckey][v.idx]
28 else:
29 out = f(kdc)
30 if not hasattr(out, "x"):
31 constants[v] = out
32 continue # a new fixed variable, not a calculated one
33 constants[v] = out.x
34 v.descr["gradients"] = {adn.tag: grad
35 for adn, grad in out.d().items()
36 if adn.tag} # else it's user-created
37 except Exception as exception: # pylint: disable=broad-except
38 from .. import settings
39 if settings.get("ad_errors_raise", None):
40 raise
41 print("Warning: skipped auto-differentiation of linked variable"
42 " %s because %s was raised. Set `gpkit.settings"
43 "[\"ad_errors_raise\"] = True` to raise such Exceptions"
44 " directly.\n" % (v, repr(exception)))
45 if kdc_plain is None:
46 kdc_plain = KeyDict(constants)
47 constants[v] = f(kdc_plain)
48 v.descr.pop("gradients", None)
51def progify(program, return_attr=None):
52 """Generates function that returns a program() and optionally an attribute.
54 Arguments
55 ---------
56 program: NomialData
57 Class to return, e.g. GeometricProgram or SequentialGeometricProgram
58 return_attr: string
59 attribute to return in addition to the program
60 """
61 def programfn(self, constants=None, **initargs):
62 "Return program version of self"
63 if not constants:
64 constants, _, linked = parse_subs(self.varkeys, self.substitutions)
65 if linked:
66 evaluate_linked(constants, linked)
67 prog = program(self.cost, self, constants, **initargs)
68 prog.model = self # NOTE SIDE EFFECTS
69 if return_attr:
70 return prog, getattr(prog, return_attr)
71 return prog
72 return programfn
75def solvify(genfunction):
76 "Returns function for making/solving/sweeping a program."
77 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False,
78 **solveargs):
79 """Forms a mathematical program and attempts to solve it.
81 Arguments
82 ---------
83 solver : string or function (default None)
84 If None, uses the default solver found in installation.
85 verbosity : int (default 1)
86 If greater than 0 prints runtime messages.
87 Is decremented by one and then passed to programs.
88 skipsweepfailures : bool (default False)
89 If True, when a solve errors during a sweep, skip it.
90 **solveargs : Passed to solve() call
92 Returns
93 -------
94 sol : SolutionArray
95 See the SolutionArray documentation for details.
97 Raises
98 ------
99 ValueError if the program is invalid.
100 RuntimeWarning if an error occurs in solving or parsing the solution.
101 """
102 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions)
103 solution = SolutionArray()
104 solution.modelstr = str(self)
106 # NOTE SIDE EFFECTS: self.program and self.solution set below
107 if sweep:
108 run_sweep(genfunction, self, solution, skipsweepfailures,
109 constants, sweep, linked, solver, verbosity, **solveargs)
110 else:
111 self.program, progsolve = genfunction(self)
112 result = progsolve(solver, verbosity=verbosity, **solveargs)
113 if solveargs.get("process_result", True):
114 self.process_result(result)
115 solution.append(result)
116 solution.to_arrays()
117 self.solution = solution
118 return solution
119 return solvefn
122# pylint: disable=too-many-locals,too-many-arguments,too-many-branches
123def run_sweep(genfunction, self, solution, skipsweepfailures,
124 constants, sweep, linked, solver, verbosity, **solveargs):
125 "Runs through a sweep."
126 # sort sweeps by the eqstr of their varkey
127 sweepvars, sweepvals = zip(*sorted(list(sweep.items()),
128 key=lambda vkval: vkval[0].eqstr))
129 if len(sweep) == 1:
130 sweep_grids = np.array(list(sweepvals))
131 else:
132 sweep_grids = np.meshgrid(*list(sweepvals))
134 N_passes = sweep_grids[0].size
135 sweep_vects = {var: grid.reshape(N_passes)
136 for (var, grid) in zip(sweepvars, sweep_grids)}
138 if verbosity > 0:
139 print("Sweeping with %i solves:" % N_passes)
140 tic = time()
142 self.program = []
143 last_error = None
144 for i in range(N_passes):
145 constants.update({var: sweep_vect[i]
146 for (var, sweep_vect) in sweep_vects.items()})
147 if linked:
148 evaluate_linked(constants, linked)
149 program, solvefn = genfunction(self, constants)
150 self.program.append(program) # NOTE: SIDE EFFECTS
151 try:
152 if verbosity > 1:
153 print("\nSolve %i:" % i)
154 result = solvefn(solver, verbosity=verbosity-1, **solveargs)
155 if solveargs.get("process_result", True):
156 self.process_result(result)
157 solution.append(result)
158 except Infeasible as e:
159 last_error = e
160 if not skipsweepfailures:
161 raise RuntimeWarning(
162 "Solve %i was infeasible; progress saved to m.program."
163 " To continue sweeping after failures, solve with"
164 " skipsweepfailures=True." % i) from e
165 if verbosity > 0:
166 print("Solve %i was %s." % (i, e.__class__.__name__))
167 if not solution:
168 raise RuntimeWarning("All solves were infeasible.") from last_error
170 solution["sweepvariables"] = KeyDict()
171 ksweep = KeyDict(sweep)
172 for var, val in list(solution["constants"].items()):
173 if var in ksweep:
174 solution["sweepvariables"][var] = val
175 del solution["constants"][var]
176 elif linked: # if any variables are linked, we check all of them
177 if hasattr(val[0], "shape"):
178 differences = ((l != val[0]).any() for l in val[1:])
179 else:
180 differences = (l != val[0] for l in val[1:])
181 if not any(differences):
182 solution["constants"][var] = [val[0]]
183 else:
184 solution["constants"][var] = [val[0]]
186 if verbosity > 0:
187 soltime = time() - tic
188 print("Sweeping took %.3g seconds." % (soltime,))