Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Scripts for generating, solving and sweeping programs"
2from time import time
3import numpy as np
4from ad import adnumber
5from ..nomials import parse_subs
6from ..solution_array import SolutionArray
7from ..keydict import KeyDict
8from ..small_scripts import maybe_flatten
9from ..exceptions import Infeasible
12def evaluate_linked(constants, linked):
13 "Evaluates the values and gradients of linked variables."
14 kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
15 for k, v in constants.items()})
16 kdc_plain = None
17 array_calulated = {}
18 for v, f in linked.items():
19 try:
20 if v.veckey and v.veckey.original_fn:
21 if v.veckey not in array_calulated:
22 ofn = v.veckey.original_fn
23 array_calulated[v.veckey] = np.array(ofn(kdc))
24 out = array_calulated[v.veckey][v.idx]
25 else:
26 out = f(kdc)
27 constants[v] = out.x
28 v.descr["gradients"] = {adn.tag: grad
29 for adn, grad in out.d().items()
30 if adn.tag} # else it's user-created
31 except Exception as exception: # pylint: disable=broad-except
32 from .. import settings
33 if settings.get("ad_errors_raise", None):
34 raise
35 print("Warning: skipped auto-differentiation of linked variable"
36 " %s because %s was raised. Set `gpkit.settings"
37 "[\"ad_errors_raise\"] = True` to raise such Exceptions"
38 " directly.\n" % (v, repr(exception)))
39 if kdc_plain is None:
40 kdc_plain = KeyDict(constants)
41 constants[v] = f(kdc_plain)
42 v.descr.pop("gradients", None)
45def progify(program, return_attr=None):
46 """Generates function that returns a program() and optionally an attribute.
48 Arguments
49 ---------
50 program: NomialData
51 Class to return, e.g. GeometricProgram or SequentialGeometricProgram
52 return_attr: string
53 attribute to return in addition to the program
54 """
55 def programfn(self, constants=None, **initargs):
56 "Return program version of self"
57 if not constants:
58 constants, _, linked = parse_subs(self.varkeys, self.substitutions)
59 if linked:
60 evaluate_linked(constants, linked)
61 prog = program(self.cost, self, constants, **initargs)
62 prog.model = self # NOTE SIDE EFFECTS
63 if return_attr:
64 return prog, getattr(prog, return_attr)
65 return prog
66 return programfn
69def solvify(genfunction):
70 "Returns function for making/solving/sweeping a program."
71 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False,
72 **solveargs):
73 """Forms a mathematical program and attempts to solve it.
75 Arguments
76 ---------
77 solver : string or function (default None)
78 If None, uses the default solver found in installation.
79 verbosity : int (default 1)
80 If greater than 0 prints runtime messages.
81 Is decremented by one and then passed to programs.
82 skipsweepfailures : bool (default False)
83 If True, when a solve errors during a sweep, skip it.
84 **solveargs : Passed to solve() call
86 Returns
87 -------
88 sol : SolutionArray
89 See the SolutionArray documentation for details.
91 Raises
92 ------
93 ValueError if the program is invalid.
94 RuntimeWarning if an error occurs in solving or parsing the solution.
95 """
96 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions)
97 solution = SolutionArray()
98 solution.modelstr = str(self)
100 # NOTE SIDE EFFECTS: self.program and self.solution set below
101 if sweep:
102 run_sweep(genfunction, self, solution, skipsweepfailures,
103 constants, sweep, linked, solver, verbosity, **solveargs)
104 else:
105 self.program, progsolve = genfunction(self)
106 result = progsolve(solver, verbosity=verbosity, **solveargs)
107 if solveargs.get("process_result", True):
108 self.process_result(result)
109 solution.append(result)
110 solution.to_arrays()
111 self.solution = solution
112 return solution
113 return solvefn
116# pylint: disable=too-many-locals,too-many-arguments,too-many-branches
117def run_sweep(genfunction, self, solution, skipsweepfailures,
118 constants, sweep, linked, solver, verbosity, **solveargs):
119 "Runs through a sweep."
120 # sort sweeps by the eqstr of their varkey
121 sweepvars, sweepvals = zip(*sorted(list(sweep.items()),
122 key=lambda vkval: vkval[0].eqstr))
123 if len(sweep) == 1:
124 sweep_grids = np.array(list(sweepvals))
125 else:
126 sweep_grids = np.meshgrid(*list(sweepvals))
128 N_passes = sweep_grids[0].size
129 sweep_vects = {var: grid.reshape(N_passes)
130 for (var, grid) in zip(sweepvars, sweep_grids)}
132 if verbosity > 0:
133 print("Sweeping over %i solves." % N_passes)
134 tic = time()
136 self.program = []
137 last_error = None
138 for i in range(N_passes):
139 constants.update({var: sweep_vect[i]
140 for (var, sweep_vect) in sweep_vects.items()})
141 if linked:
142 evaluate_linked(constants, linked)
143 program, solvefn = genfunction(self, constants)
144 self.program.append(program) # NOTE: SIDE EFFECTS
145 try:
146 result = solvefn(solver, verbosity=verbosity-1, **solveargs)
147 if solveargs.get("process_result", True):
148 self.process_result(result)
149 solution.append(result)
150 except Infeasible as e:
151 last_error = e
152 if not skipsweepfailures:
153 raise RuntimeWarning(
154 "Sweep halted! Progress saved to m.program. To skip over"
155 " such failures, solve with skipsweepfailures=True.") from e
156 if not solution:
157 raise RuntimeWarning("No sweeps solved successfully.") from last_error
159 solution["sweepvariables"] = KeyDict()
160 ksweep = KeyDict(sweep)
161 for var, val in list(solution["constants"].items()):
162 if var in ksweep:
163 solution["sweepvariables"][var] = val
164 del solution["constants"][var]
165 elif linked: # if any variables are linked, we check all of them
166 if hasattr(val[0], "shape"):
167 differences = ((l != val[0]).any() for l in val[1:])
168 else:
169 differences = (l != val[0] for l in val[1:])
170 if not any(differences):
171 solution["constants"][var] = [val[0]]
172 else:
173 solution["constants"][var] = [val[0]]
175 if verbosity > 0:
176 soltime = time() - tic
177 print("Sweeping took %.3g seconds." % (soltime,))