Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"Scripts for generating, solving and sweeping programs"
2from time import time
3import warnings as pywarnings
4import numpy as np
5from ad import adnumber
6from ..nomials import parse_subs
7from ..solution_array import SolutionArray
8from ..keydict import KeyDict
9from ..small_scripts import maybe_flatten
10from ..small_classes import FixedScalar
11from ..exceptions import Infeasible
12from ..globals import SignomialsEnabled
15def evaluate_linked(constants, linked):
16 "Evaluates the values and gradients of linked variables."
17 kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
18 for k, v in constants.items()})
19 kdc_plain = None
20 array_calulated = {}
21 for key in constants: # remove gradients from constants
22 if key.gradients:
23 del key.descr["gradients"]
24 for v, f in linked.items():
25 try:
26 if v.veckey and v.veckey.vecfn:
27 if v.veckey not in array_calulated:
28 with SignomialsEnabled(): # to allow use of gpkit.units
29 vecout = v.veckey.vecfn(kdc)
30 if not hasattr(vecout, "shape"):
31 vecout = np.array(vecout)
32 array_calulated[v.veckey] = vecout
33 out = array_calulated[v.veckey][v.idx]
34 else:
35 with SignomialsEnabled(): # to allow use of gpkit.units
36 out = f(kdc)
37 if isinstance(out, FixedScalar): # to allow use of gpkit.units
38 out = out.value
39 if hasattr(out, "units"):
40 out = out.to(v.units or "dimensionless").magnitude
41 elif out != 0 and v.units:
42 pywarnings.warn(
43 "Linked function for %s did not return a united value."
44 "Modifying it to do so (e.g. by using `()` instead of `[]`"
45 " to access variables) will reduce errors." % v)
46 if hasattr(out, "__len__"):
47 out = out.item() # break out of 0-dimensional arrays
48 if not hasattr(out, "x"):
49 constants[v] = out
50 continue # a new fixed variable, not a calculated one
51 constants[v] = out.x
52 gradients = {adn.tag:
53 grad for adn, grad in out.d().items() if adn.tag}
54 if gradients:
55 v.descr["gradients"] = gradients
56 except Exception as exception: # pylint: disable=broad-except
57 from .. import settings
58 if settings.get("ad_errors_raise", None):
59 raise
60 print("Warning: skipped auto-differentiation of linked variable"
61 " %s because %s was raised. Set `gpkit.settings"
62 "[\"ad_errors_raise\"] = True` to raise such Exceptions"
63 " directly.\n" % (v, repr(exception)))
64 if kdc_plain is None:
65 kdc_plain = KeyDict(constants)
66 constants[v] = f(kdc_plain)
67 v.descr.pop("gradients", None)
70def progify(program, return_attr=None):
71 """Generates function that returns a program() and optionally an attribute.
73 Arguments
74 ---------
75 program: NomialData
76 Class to return, e.g. GeometricProgram or SequentialGeometricProgram
77 return_attr: string
78 attribute to return in addition to the program
79 """
80 def programfn(self, constants=None, **initargs):
81 "Return program version of self"
82 if not constants:
83 constants, _, linked = parse_subs(self.varkeys, self.substitutions)
84 if linked:
85 evaluate_linked(constants, linked)
86 prog = program(self.cost, self, constants, **initargs)
87 prog.model = self # NOTE SIDE EFFECTS
88 if return_attr:
89 return prog, getattr(prog, return_attr)
90 return prog
91 return programfn
94def solvify(genfunction):
95 "Returns function for making/solving/sweeping a program."
96 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False,
97 **kwargs):
98 """Forms a mathematical program and attempts to solve it.
100 Arguments
101 ---------
102 solver : string or function (default None)
103 If None, uses the default solver found in installation.
104 verbosity : int (default 1)
105 If greater than 0 prints runtime messages.
106 Is decremented by one and then passed to programs.
107 skipsweepfailures : bool (default False)
108 If True, when a solve errors during a sweep, skip it.
109 **kwargs : Passed to solve and program init calls
111 Returns
112 -------
113 sol : SolutionArray
114 See the SolutionArray documentation for details.
116 Raises
117 ------
118 ValueError if the program is invalid.
119 RuntimeWarning if an error occurs in solving or parsing the solution.
120 """
121 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions)
122 solution = SolutionArray()
123 solution.modelstr = str(self)
125 # NOTE SIDE EFFECTS: self.program and self.solution set below
126 if sweep:
127 run_sweep(genfunction, self, solution, skipsweepfailures,
128 constants, sweep, linked, solver, verbosity, **kwargs)
129 else:
130 self.program, progsolve = genfunction(self, **kwargs)
131 result = progsolve(solver, verbosity=verbosity, **kwargs)
132 if kwargs.get("process_result", True):
133 self.process_result(result)
134 solution.append(result)
135 solution.to_arrays()
136 self.solution = solution
137 return solution
138 return solvefn
141# pylint: disable=too-many-locals,too-many-arguments,too-many-branches
142def run_sweep(genfunction, self, solution, skipsweepfailures,
143 constants, sweep, linked, solver, verbosity, **kwargs):
144 "Runs through a sweep."
145 # sort sweeps by the eqstr of their varkey
146 sweepvars, sweepvals = zip(*sorted(list(sweep.items()),
147 key=lambda vkval: vkval[0].eqstr))
148 if len(sweep) == 1:
149 sweep_grids = np.array(list(sweepvals))
150 else:
151 sweep_grids = np.meshgrid(*list(sweepvals))
153 N_passes = sweep_grids[0].size
154 sweep_vects = {var: grid.reshape(N_passes)
155 for (var, grid) in zip(sweepvars, sweep_grids)}
157 if verbosity > 0:
158 print("Sweeping with %i solves:" % N_passes)
159 tic = time()
161 self.program = []
162 last_error = None
163 for i in range(N_passes):
164 constants.update({var: sweep_vect[i]
165 for (var, sweep_vect) in sweep_vects.items()})
166 if linked:
167 evaluate_linked(constants, linked)
168 program, solvefn = genfunction(self, constants, **kwargs)
169 self.program.append(program) # NOTE: SIDE EFFECTS
170 try:
171 if verbosity > 1:
172 print("\nSolve %i:" % i)
173 result = solvefn(solver, verbosity=verbosity-1, **kwargs)
174 if kwargs.get("process_result", True):
175 self.process_result(result)
176 solution.append(result)
177 except Infeasible as e:
178 last_error = e
179 if not skipsweepfailures:
180 raise RuntimeWarning(
181 "Solve %i was infeasible; progress saved to m.program."
182 " To continue sweeping after failures, solve with"
183 " skipsweepfailures=True." % i) from e
184 if verbosity > 0:
185 print("Solve %i was %s." % (i, e.__class__.__name__))
186 if not solution:
187 raise RuntimeWarning("All solves were infeasible.") from last_error
189 solution["sweepvariables"] = KeyDict()
190 ksweep = KeyDict(sweep)
191 for var, val in list(solution["constants"].items()):
192 if var in ksweep:
193 solution["sweepvariables"][var] = val
194 del solution["constants"][var]
195 elif linked: # if any variables are linked, we check all of them
196 if hasattr(val[0], "shape"):
197 differences = ((l != val[0]).any() for l in val[1:])
198 else:
199 differences = (l != val[0] for l in val[1:])
200 if not any(differences):
201 solution["constants"][var] = [val[0]]
202 else:
203 solution["constants"][var] = [val[0]]
205 if verbosity > 0:
206 soltime = time() - tic
207 print("Sweeping took %.3g seconds." % (soltime,))