Coverage for gpkit/constraints/prog_factories.py: 86%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Scripts for generating, solving and sweeping programs"
2from time import time
3import warnings as pywarnings
4import numpy as np
5from adce import adnumber
6from ..nomials import parse_subs
7from ..solution_array import SolutionArray
8from ..keydict import KeyDict
9from ..small_scripts import maybe_flatten
10from ..small_classes import FixedScalar
11from ..exceptions import Infeasible
12from ..globals import SignomialsEnabled
15def evaluate_linked(constants, linked):
16 "Evaluates the values and gradients of linked variables."
17 kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
18 for k, v in constants.items()})
19 kdc_plain = None
20 array_calulated = {}
21 for key in constants: # remove gradients from constants
22 key.descr.pop("gradients", None)
23 for v, f in linked.items():
24 try:
25 if v.veckey and v.veckey.vecfn:
26 if v.veckey not in array_calulated:
27 with SignomialsEnabled(): # to allow use of gpkit.units
28 vecout = v.veckey.vecfn(kdc)
29 if not hasattr(vecout, "shape"):
30 vecout = np.array(vecout)
31 array_calulated[v.veckey] = vecout
32 out = array_calulated[v.veckey][v.idx]
33 else:
34 with SignomialsEnabled(): # to allow use of gpkit.units
35 out = f(kdc)
36 if isinstance(out, FixedScalar): # to allow use of gpkit.units
37 out = out.value
38 if hasattr(out, "units"):
39 out = out.to(v.units or "dimensionless").magnitude
40 elif out != 0 and v.units:
41 pywarnings.warn(
42 "Linked function for %s did not return a united value."
43 " Modifying it to do so (e.g. by using `()` instead of `[]`"
44 " to access variables) will reduce errors." % v)
45 out = maybe_flatten(out)
46 if not hasattr(out, "x"):
47 constants[v] = out
48 continue # a new fixed variable, not a calculated one
49 constants[v] = out.x
50 v.descr["gradients"] = {adn.tag: grad
51 for adn, grad in out.d().items()
52 if adn.tag}
53 except Exception as exception: # pylint: disable=broad-except
54 from .. import settings
55 if settings.get("ad_errors_raise", None):
56 raise
57 if kdc_plain is None:
58 kdc_plain = KeyDict(constants)
59 constants[v] = f(kdc_plain)
60 v.descr.pop("gradients", None)
61 print("Warning: skipped auto-differentiation of linked variable"
62 " %s because %s was raised. Set `gpkit.settings"
63 "[\"ad_errors_raise\"] = True` to raise such Exceptions"
64 " directly.\n" % (v, repr(exception)))
65 if ("Automatic differentiation not yet supported for <class "
66 "'gpkit.nomials.math.Monomial'> objects") in str(exception):
67 print("This particular warning may have come from using"
68 " gpkit.units.* in the function for %s; try using"
69 " gpkit.ureg.* or gpkit.units.*.units instead." % v)
72def progify(program, return_attr=None):
73 """Generates function that returns a program() and optionally an attribute.
75 Arguments
76 ---------
77 program: NomialData
78 Class to return, e.g. GeometricProgram or SequentialGeometricProgram
79 return_attr: string
80 attribute to return in addition to the program
81 """
82 def programfn(self, constants=None, **initargs):
83 "Return program version of self"
84 if not constants:
85 constants, _, linked = parse_subs(self.varkeys, self.substitutions)
86 if linked:
87 evaluate_linked(constants, linked)
88 prog = program(self.cost, self, constants, **initargs)
89 prog.model = self # NOTE SIDE EFFECTS
90 if return_attr:
91 return prog, getattr(prog, return_attr)
92 return prog
93 return programfn
96def solvify(genfunction):
97 "Returns function for making/solving/sweeping a program."
98 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False,
99 **kwargs):
100 """Forms a mathematical program and attempts to solve it.
102 Arguments
103 ---------
104 solver : string or function (default None)
105 If None, uses the default solver found in installation.
106 verbosity : int (default 1)
107 If greater than 0 prints runtime messages.
108 Is decremented by one and then passed to programs.
109 skipsweepfailures : bool (default False)
110 If True, when a solve errors during a sweep, skip it.
111 **kwargs : Passed to solve and program init calls
113 Returns
114 -------
115 sol : SolutionArray
116 See the SolutionArray documentation for details.
118 Raises
119 ------
120 ValueError if the program is invalid.
121 RuntimeWarning if an error occurs in solving or parsing the solution.
122 """
123 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions)
124 solution = SolutionArray()
125 solution.modelstr = str(self)
127 # NOTE SIDE EFFECTS: self.program and self.solution set below
128 if sweep:
129 run_sweep(genfunction, self, solution, skipsweepfailures,
130 constants, sweep, linked, solver, verbosity, **kwargs)
131 else:
132 self.program, progsolve = genfunction(self, **kwargs)
133 result = progsolve(solver, verbosity=verbosity, **kwargs)
134 if kwargs.get("process_result", True):
135 self.process_result(result)
136 solution.append(result)
137 solution.to_arrays()
138 self.solution = solution
139 return solution
140 return solvefn
143# pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements
144def run_sweep(genfunction, self, solution, skipsweepfailures,
145 constants, sweep, linked, solver, verbosity, **kwargs):
146 "Runs through a sweep."
147 # sort sweeps by the eqstr of their varkey
148 sweepvars, sweepvals = zip(*sorted(list(sweep.items()),
149 key=lambda vkval: vkval[0].eqstr))
150 if len(sweep) == 1:
151 sweep_grids = np.array(list(sweepvals))
152 else:
153 sweep_grids = np.meshgrid(*list(sweepvals))
155 N_passes = sweep_grids[0].size
156 sweep_vects = {var: grid.reshape(N_passes)
157 for (var, grid) in zip(sweepvars, sweep_grids)}
159 if verbosity > 0:
160 print("Sweeping with %i solves:" % N_passes)
161 tic = time()
163 self.program = []
164 last_error = None
165 for i in range(N_passes):
166 constants.update({var: sweep_vect[i]
167 for (var, sweep_vect) in sweep_vects.items()})
168 if linked:
169 evaluate_linked(constants, linked)
170 program, solvefn = genfunction(self, constants, **kwargs)
171 program.model = None # so it doesn't try to debug
172 self.program.append(program) # NOTE: SIDE EFFECTS
173 try:
174 if verbosity > 1:
175 print("\nSolve %i:" % i)
176 result = solvefn(solver, verbosity=verbosity-1, **kwargs)
177 if kwargs.get("process_result", True):
178 self.process_result(result)
179 solution.append(result)
180 except Infeasible as e:
181 last_error = e
182 if not skipsweepfailures:
183 raise RuntimeWarning(
184 "Solve %i was infeasible; progress saved to m.program."
185 " To continue sweeping after failures, solve with"
186 " skipsweepfailures=True." % i) from e
187 if verbosity > 0:
188 print("Solve %i was %s." % (i, e.__class__.__name__))
189 if not solution:
190 raise RuntimeWarning("All solves were infeasible.") from last_error
192 solution["sweepvariables"] = KeyDict()
193 ksweep = KeyDict(sweep)
194 for var, val in list(solution["constants"].items()):
195 if var in ksweep:
196 solution["sweepvariables"][var] = val
197 del solution["constants"][var]
198 elif linked: # if any variables are linked, we check all of them
199 if hasattr(val[0], "shape"):
200 differences = ((l != val[0]).any() for l in val[1:])
201 else:
202 differences = (l != val[0] for l in val[1:])
203 if not any(differences):
204 solution["constants"][var] = [val[0]]
205 else:
206 solution["constants"][var] = [val[0]]
208 if verbosity > 0:
209 soltime = time() - tic
210 print("Sweeping took %.3g seconds." % (soltime,))