Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Scripts for generating, solving and sweeping programs"
2from time import time
3import numpy as np
4from ad import adnumber
5from ..nomials import parse_subs
6from ..solution_array import SolutionArray
7from ..keydict import KeyDict
8from ..small_scripts import maybe_flatten
9from ..exceptions import Infeasible
12def evaluate_linked(constants, linked):
13 "Evaluates the values and gradients of linked variables."
14 kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
15 for k, v in constants.items()})
16 kdc_plain = None
17 array_calulated = {}
18 for key in constants: # remove gradients from constants
19 if key.gradients:
20 del key.descr["gradients"]
21 for v, f in linked.items():
22 try:
23 if v.veckey and v.veckey.original_fn:
24 if v.veckey not in array_calulated:
25 ofn = v.veckey.original_fn
26 array_calulated[v.veckey] = np.array(ofn(kdc))
27 out = array_calulated[v.veckey][v.idx]
28 else:
29 out = f(kdc)
30 constants[v] = out.x
31 v.descr["gradients"] = {adn.tag: grad
32 for adn, grad in out.d().items()
33 if adn.tag} # else it's user-created
34 except Exception as exception: # pylint: disable=broad-except
35 from .. import settings
36 if settings.get("ad_errors_raise", None):
37 raise
38 print("Warning: skipped auto-differentiation of linked variable"
39 " %s because %s was raised. Set `gpkit.settings"
40 "[\"ad_errors_raise\"] = True` to raise such Exceptions"
41 " directly.\n" % (v, repr(exception)))
42 if kdc_plain is None:
43 kdc_plain = KeyDict(constants)
44 constants[v] = f(kdc_plain)
45 v.descr.pop("gradients", None)
48def progify(program, return_attr=None):
49 """Generates function that returns a program() and optionally an attribute.
51 Arguments
52 ---------
53 program: NomialData
54 Class to return, e.g. GeometricProgram or SequentialGeometricProgram
55 return_attr: string
56 attribute to return in addition to the program
57 """
58 def programfn(self, constants=None, **initargs):
59 "Return program version of self"
60 if not constants:
61 constants, _, linked = parse_subs(self.varkeys, self.substitutions)
62 if linked:
63 evaluate_linked(constants, linked)
64 prog = program(self.cost, self, constants, **initargs)
65 prog.model = self # NOTE SIDE EFFECTS
66 if return_attr:
67 return prog, getattr(prog, return_attr)
68 return prog
69 return programfn
72def solvify(genfunction):
73 "Returns function for making/solving/sweeping a program."
74 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False,
75 **solveargs):
76 """Forms a mathematical program and attempts to solve it.
78 Arguments
79 ---------
80 solver : string or function (default None)
81 If None, uses the default solver found in installation.
82 verbosity : int (default 1)
83 If greater than 0 prints runtime messages.
84 Is decremented by one and then passed to programs.
85 skipsweepfailures : bool (default False)
86 If True, when a solve errors during a sweep, skip it.
87 **solveargs : Passed to solve() call
89 Returns
90 -------
91 sol : SolutionArray
92 See the SolutionArray documentation for details.
94 Raises
95 ------
96 ValueError if the program is invalid.
97 RuntimeWarning if an error occurs in solving or parsing the solution.
98 """
99 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions)
100 solution = SolutionArray()
101 solution.modelstr = str(self)
103 # NOTE SIDE EFFECTS: self.program and self.solution set below
104 if sweep:
105 run_sweep(genfunction, self, solution, skipsweepfailures,
106 constants, sweep, linked, solver, verbosity, **solveargs)
107 else:
108 self.program, progsolve = genfunction(self)
109 result = progsolve(solver, verbosity=verbosity, **solveargs)
110 if solveargs.get("process_result", True):
111 self.process_result(result)
112 solution.append(result)
113 solution.to_arrays()
114 self.solution = solution
115 return solution
116 return solvefn
119# pylint: disable=too-many-locals,too-many-arguments,too-many-branches
120def run_sweep(genfunction, self, solution, skipsweepfailures,
121 constants, sweep, linked, solver, verbosity, **solveargs):
122 "Runs through a sweep."
123 # sort sweeps by the eqstr of their varkey
124 sweepvars, sweepvals = zip(*sorted(list(sweep.items()),
125 key=lambda vkval: vkval[0].eqstr))
126 if len(sweep) == 1:
127 sweep_grids = np.array(list(sweepvals))
128 else:
129 sweep_grids = np.meshgrid(*list(sweepvals))
131 N_passes = sweep_grids[0].size
132 sweep_vects = {var: grid.reshape(N_passes)
133 for (var, grid) in zip(sweepvars, sweep_grids)}
135 if verbosity > 0:
136 print("Sweeping with %i solves:" % N_passes)
137 tic = time()
139 self.program = []
140 last_error = None
141 for i in range(N_passes):
142 constants.update({var: sweep_vect[i]
143 for (var, sweep_vect) in sweep_vects.items()})
144 if linked:
145 evaluate_linked(constants, linked)
146 program, solvefn = genfunction(self, constants)
147 self.program.append(program) # NOTE: SIDE EFFECTS
148 try:
149 if verbosity > 1:
150 print("\nSolve %i:" % i)
151 result = solvefn(solver, verbosity=verbosity-1, **solveargs)
152 if solveargs.get("process_result", True):
153 self.process_result(result)
154 solution.append(result)
155 except Infeasible as e:
156 last_error = e
157 if not skipsweepfailures:
158 raise RuntimeWarning(
159 "Solve %i was infeasible; progress saved to m.program."
160 " To continue sweeping after failures, solve with"
161 " skipsweepfailures=True." % i) from e
162 if verbosity > 0:
163 print("Solve %i was %s." % (i, e.__class__.__name__))
164 if not solution:
165 raise RuntimeWarning("All solves were infeasible.") from last_error
167 solution["sweepvariables"] = KeyDict()
168 ksweep = KeyDict(sweep)
169 for var, val in list(solution["constants"].items()):
170 if var in ksweep:
171 solution["sweepvariables"][var] = val
172 del solution["constants"][var]
173 elif linked: # if any variables are linked, we check all of them
174 if hasattr(val[0], "shape"):
175 differences = ((l != val[0]).any() for l in val[1:])
176 else:
177 differences = (l != val[0] for l in val[1:])
178 if not any(differences):
179 solution["constants"][var] = [val[0]]
180 else:
181 solution["constants"][var] = [val[0]]
183 if verbosity > 0:
184 soltime = time() - tic
185 print("Sweeping took %.3g seconds." % (soltime,))