Coverage for gpkit\constraints\prog_factories.py : 0%
![Show keyboard shortcuts](keybd_closed.png)
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Scripts for generating, solving and sweeping programs"
2from time import time
3import numpy as np
4from ad import adnumber
5from ..nomials import parse_subs
6from ..solution_array import SolutionArray
7from ..keydict import KeyDict
8from ..small_scripts import maybe_flatten
9from ..small_classes import FixedScalar
10from ..exceptions import Infeasible
11from ..globals import SignomialsEnabled
14def evaluate_linked(constants, linked):
15 "Evaluates the values and gradients of linked variables."
16 kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
17 for k, v in constants.items()})
18 kdc_plain = None
19 array_calulated = {}
20 for key in constants: # remove gradients from constants
21 if key.gradients:
22 del key.descr["gradients"]
23 for v, f in linked.items():
24 try:
25 if v.veckey and v.veckey.vecfn:
26 if v.veckey not in array_calulated:
27 with SignomialsEnabled(): # to allow use of gpkit.units
28 vecout = v.veckey.vecfn(kdc)
29 if not hasattr(vecout, "shape"):
30 vecout = np.array(vecout)
31 array_calulated[v.veckey] = vecout
32 if (any(vecout != 0) and v.veckey.units
33 and not hasattr(vecout, "units")):
34 print("Warning: linked function for %s did not return"
35 " a united value. Modifying it to do so (e.g. by"
36 " using `()` instead of `[]` to access variables)"
37 " would reduce the risk of errors." % v.veckey)
38 out = array_calulated[v.veckey][v.idx]
39 else:
40 with SignomialsEnabled(): # to allow use of gpkit.units
41 out = f(kdc)
42 if isinstance(out, FixedScalar): # to allow use of gpkit.units
43 out = out.value
44 if hasattr(out, "units"):
45 out = out.to(v.units or "dimensionless").magnitude
46 elif out != 0 and v.units and not v.veckey:
47 print("Warning: linked function for %s did not return"
48 " a united value. Modifying it to do so (e.g. by"
49 " using `()` instead of `[]` to access variables)"
50 " would reduce the risk of errors." % v)
51 if not hasattr(out, "x"):
52 constants[v] = out
53 continue # a new fixed variable, not a calculated one
54 constants[v] = out.x
55 gradients = {adn.tag:
56 grad for adn, grad in out.d().items() if adn.tag}
57 if gradients:
58 v.descr["gradients"] = gradients
59 except Exception as exception: # pylint: disable=broad-except
60 from .. import settings
61 if settings.get("ad_errors_raise", None):
62 raise
63 print("Warning: skipped auto-differentiation of linked variable"
64 " %s because %s was raised. Set `gpkit.settings"
65 "[\"ad_errors_raise\"] = True` to raise such Exceptions"
66 " directly.\n" % (v, repr(exception)))
67 if kdc_plain is None:
68 kdc_plain = KeyDict(constants)
69 constants[v] = f(kdc_plain)
70 v.descr.pop("gradients", None)
73def progify(program, return_attr=None):
74 """Generates function that returns a program() and optionally an attribute.
76 Arguments
77 ---------
78 program: NomialData
79 Class to return, e.g. GeometricProgram or SequentialGeometricProgram
80 return_attr: string
81 attribute to return in addition to the program
82 """
83 def programfn(self, constants=None, **initargs):
84 "Return program version of self"
85 if not constants:
86 constants, _, linked = parse_subs(self.varkeys, self.substitutions)
87 if linked:
88 evaluate_linked(constants, linked)
89 prog = program(self.cost, self, constants, **initargs)
90 prog.model = self # NOTE SIDE EFFECTS
91 if return_attr:
92 return prog, getattr(prog, return_attr)
93 return prog
94 return programfn
97def solvify(genfunction):
98 "Returns function for making/solving/sweeping a program."
99 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False,
100 **solveargs):
101 """Forms a mathematical program and attempts to solve it.
103 Arguments
104 ---------
105 solver : string or function (default None)
106 If None, uses the default solver found in installation.
107 verbosity : int (default 1)
108 If greater than 0 prints runtime messages.
109 Is decremented by one and then passed to programs.
110 skipsweepfailures : bool (default False)
111 If True, when a solve errors during a sweep, skip it.
112 **solveargs : Passed to solve() call
114 Returns
115 -------
116 sol : SolutionArray
117 See the SolutionArray documentation for details.
119 Raises
120 ------
121 ValueError if the program is invalid.
122 RuntimeWarning if an error occurs in solving or parsing the solution.
123 """
124 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions)
125 solution = SolutionArray()
126 solution.modelstr = str(self)
128 # NOTE SIDE EFFECTS: self.program and self.solution set below
129 if sweep:
130 run_sweep(genfunction, self, solution, skipsweepfailures,
131 constants, sweep, linked, solver, verbosity, **solveargs)
132 else:
133 self.program, progsolve = genfunction(self)
134 result = progsolve(solver, verbosity=verbosity, **solveargs)
135 if solveargs.get("process_result", True):
136 self.process_result(result)
137 solution.append(result)
138 solution.to_arrays()
139 self.solution = solution
140 return solution
141 return solvefn
144# pylint: disable=too-many-locals,too-many-arguments,too-many-branches
145def run_sweep(genfunction, self, solution, skipsweepfailures,
146 constants, sweep, linked, solver, verbosity, **solveargs):
147 "Runs through a sweep."
148 # sort sweeps by the eqstr of their varkey
149 sweepvars, sweepvals = zip(*sorted(list(sweep.items()),
150 key=lambda vkval: vkval[0].eqstr))
151 if len(sweep) == 1:
152 sweep_grids = np.array(list(sweepvals))
153 else:
154 sweep_grids = np.meshgrid(*list(sweepvals))
156 N_passes = sweep_grids[0].size
157 sweep_vects = {var: grid.reshape(N_passes)
158 for (var, grid) in zip(sweepvars, sweep_grids)}
160 if verbosity > 0:
161 print("Sweeping with %i solves:" % N_passes)
162 tic = time()
164 self.program = []
165 last_error = None
166 for i in range(N_passes):
167 constants.update({var: sweep_vect[i]
168 for (var, sweep_vect) in sweep_vects.items()})
169 if linked:
170 evaluate_linked(constants, linked)
171 program, solvefn = genfunction(self, constants)
172 self.program.append(program) # NOTE: SIDE EFFECTS
173 try:
174 if verbosity > 1:
175 print("\nSolve %i:" % i)
176 result = solvefn(solver, verbosity=verbosity-1, **solveargs)
177 if solveargs.get("process_result", True):
178 self.process_result(result)
179 solution.append(result)
180 except Infeasible as e:
181 last_error = e
182 if not skipsweepfailures:
183 raise RuntimeWarning(
184 "Solve %i was infeasible; progress saved to m.program."
185 " To continue sweeping after failures, solve with"
186 " skipsweepfailures=True." % i) from e
187 if verbosity > 0:
188 print("Solve %i was %s." % (i, e.__class__.__name__))
189 if not solution:
190 raise RuntimeWarning("All solves were infeasible.") from last_error
192 solution["sweepvariables"] = KeyDict()
193 ksweep = KeyDict(sweep)
194 for var, val in list(solution["constants"].items()):
195 if var in ksweep:
196 solution["sweepvariables"][var] = val
197 del solution["constants"][var]
198 elif linked: # if any variables are linked, we check all of them
199 if hasattr(val[0], "shape"):
200 differences = ((l != val[0]).any() for l in val[1:])
201 else:
202 differences = (l != val[0] for l in val[1:])
203 if not any(differences):
204 solution["constants"][var] = [val[0]]
205 else:
206 solution["constants"][var] = [val[0]]
208 if verbosity > 0:
209 soltime = time() - tic
210 print("Sweeping took %.3g seconds." % (soltime,))