Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Scripts for generating, solving and sweeping programs"
2from time import time
3import numpy as np
4from ad import adnumber
5from ..nomials import parse_subs
6from ..solution_array import SolutionArray
7from ..keydict import KeyDict
8from ..small_scripts import maybe_flatten
9from ..small_classes import FixedScalar
10from ..exceptions import Infeasible
11from ..globals import SignomialsEnabled
14def evaluate_linked(constants, linked):
15 "Evaluates the values and gradients of linked variables."
16 kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
17 for k, v in constants.items()})
18 kdc_plain = None
19 array_calulated = {}
20 for key in constants: # remove gradients from constants
21 if key.gradients:
22 del key.descr["gradients"]
23 for v, f in linked.items():
24 try:
25 if v.veckey and v.veckey.vecfn:
26 if v.veckey not in array_calulated:
27 with SignomialsEnabled(): # to allow use of gpkit.units
28 vecout = v.veckey.vecfn(kdc)
29 if not hasattr(vecout, "shape"):
30 vecout = np.array(vecout)
31 array_calulated[v.veckey] = vecout
32 if (any(vecout != 0) and v.veckey.units
33 and not hasattr(vecout, "units")):
34 print("Warning: linked function for %s did not return"
35 " a united value. Modifying it to do so (e.g. by"
36 " using `()` instead of `[]` to access variables)"
37 " would reduce the risk of errors." % v.veckey)
38 out = array_calulated[v.veckey][v.idx]
39 else:
40 with SignomialsEnabled(): # to allow use of gpkit.units
41 out = f(kdc)
42 if isinstance(out, FixedScalar): # to allow use of gpkit.units
43 out = out.value
44 if hasattr(out, "units"):
45 out = out.to(v.units or "dimensionless").magnitude
46 elif out != 0 and v.units and not v.veckey:
47 print("Warning: linked function for %s did not return"
48 " a united value. Modifying it to do so (e.g. by"
49 " using `()` instead of `[]` to access variables)"
50 " would reduce the risk of errors." % v)
51 if not hasattr(out, "x"):
52 constants[v] = out
53 continue # a new fixed variable, not a calculated one
54 constants[v] = out.x
55 v.descr["gradients"] = {adn.tag: grad
56 for adn, grad in out.d().items()}
57 except Exception as exception: # pylint: disable=broad-except
58 from .. import settings
59 if settings.get("ad_errors_raise", None):
60 raise
61 print("Warning: skipped auto-differentiation of linked variable"
62 " %s because %s was raised. Set `gpkit.settings"
63 "[\"ad_errors_raise\"] = True` to raise such Exceptions"
64 " directly.\n" % (v, repr(exception)))
65 if kdc_plain is None:
66 kdc_plain = KeyDict(constants)
67 constants[v] = f(kdc_plain)
68 v.descr.pop("gradients", None)
71def progify(program, return_attr=None):
72 """Generates function that returns a program() and optionally an attribute.
74 Arguments
75 ---------
76 program: NomialData
77 Class to return, e.g. GeometricProgram or SequentialGeometricProgram
78 return_attr: string
79 attribute to return in addition to the program
80 """
81 def programfn(self, constants=None, **initargs):
82 "Return program version of self"
83 if not constants:
84 constants, _, linked = parse_subs(self.varkeys, self.substitutions)
85 if linked:
86 evaluate_linked(constants, linked)
87 prog = program(self.cost, self, constants, **initargs)
88 prog.model = self # NOTE SIDE EFFECTS
89 if return_attr:
90 return prog, getattr(prog, return_attr)
91 return prog
92 return programfn
95def solvify(genfunction):
96 "Returns function for making/solving/sweeping a program."
97 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False,
98 **solveargs):
99 """Forms a mathematical program and attempts to solve it.
101 Arguments
102 ---------
103 solver : string or function (default None)
104 If None, uses the default solver found in installation.
105 verbosity : int (default 1)
106 If greater than 0 prints runtime messages.
107 Is decremented by one and then passed to programs.
108 skipsweepfailures : bool (default False)
109 If True, when a solve errors during a sweep, skip it.
110 **solveargs : Passed to solve() call
112 Returns
113 -------
114 sol : SolutionArray
115 See the SolutionArray documentation for details.
117 Raises
118 ------
119 ValueError if the program is invalid.
120 RuntimeWarning if an error occurs in solving or parsing the solution.
121 """
122 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions)
123 solution = SolutionArray()
124 solution.modelstr = str(self)
126 # NOTE SIDE EFFECTS: self.program and self.solution set below
127 if sweep:
128 run_sweep(genfunction, self, solution, skipsweepfailures,
129 constants, sweep, linked, solver, verbosity, **solveargs)
130 else:
131 self.program, progsolve = genfunction(self)
132 result = progsolve(solver, verbosity=verbosity, **solveargs)
133 if solveargs.get("process_result", True):
134 self.process_result(result)
135 solution.append(result)
136 solution.to_arrays()
137 self.solution = solution
138 return solution
139 return solvefn
142# pylint: disable=too-many-locals,too-many-arguments,too-many-branches
143def run_sweep(genfunction, self, solution, skipsweepfailures,
144 constants, sweep, linked, solver, verbosity, **solveargs):
145 "Runs through a sweep."
146 # sort sweeps by the eqstr of their varkey
147 sweepvars, sweepvals = zip(*sorted(list(sweep.items()),
148 key=lambda vkval: vkval[0].eqstr))
149 if len(sweep) == 1:
150 sweep_grids = np.array(list(sweepvals))
151 else:
152 sweep_grids = np.meshgrid(*list(sweepvals))
154 N_passes = sweep_grids[0].size
155 sweep_vects = {var: grid.reshape(N_passes)
156 for (var, grid) in zip(sweepvars, sweep_grids)}
158 if verbosity > 0:
159 print("Sweeping with %i solves:" % N_passes)
160 tic = time()
162 self.program = []
163 last_error = None
164 for i in range(N_passes):
165 constants.update({var: sweep_vect[i]
166 for (var, sweep_vect) in sweep_vects.items()})
167 if linked:
168 evaluate_linked(constants, linked)
169 program, solvefn = genfunction(self, constants)
170 self.program.append(program) # NOTE: SIDE EFFECTS
171 try:
172 if verbosity > 1:
173 print("\nSolve %i:" % i)
174 result = solvefn(solver, verbosity=verbosity-1, **solveargs)
175 if solveargs.get("process_result", True):
176 self.process_result(result)
177 solution.append(result)
178 except Infeasible as e:
179 last_error = e
180 if not skipsweepfailures:
181 raise RuntimeWarning(
182 "Solve %i was infeasible; progress saved to m.program."
183 " To continue sweeping after failures, solve with"
184 " skipsweepfailures=True." % i) from e
185 if verbosity > 0:
186 print("Solve %i was %s." % (i, e.__class__.__name__))
187 if not solution:
188 raise RuntimeWarning("All solves were infeasible.") from last_error
190 solution["sweepvariables"] = KeyDict()
191 ksweep = KeyDict(sweep)
192 for var, val in list(solution["constants"].items()):
193 if var in ksweep:
194 solution["sweepvariables"][var] = val
195 del solution["constants"][var]
196 elif linked: # if any variables are linked, we check all of them
197 if hasattr(val[0], "shape"):
198 differences = ((l != val[0]).any() for l in val[1:])
199 else:
200 differences = (l != val[0] for l in val[1:])
201 if not any(differences):
202 solution["constants"][var] = [val[0]]
203 else:
204 solution["constants"][var] = [val[0]]
206 if verbosity > 0:
207 soltime = time() - tic
208 print("Sweeping took %.3g seconds." % (soltime,))