Coverage for gpkit/constraints/prog_factories.py: 84%
138 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 22:33 -0500
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 22:33 -0500
1"Scripts for generating, solving and sweeping programs"
2from time import time
3import warnings as pywarnings
4import numpy as np
5from adce import adnumber
6from ..nomials import parse_subs
7from ..solution_array import SolutionArray
8from ..keydict import KeyDict
9from ..small_scripts import maybe_flatten
10from ..small_classes import FixedScalar
11from ..exceptions import Infeasible
12from ..globals import SignomialsEnabled
15def evaluate_linked(constants, linked):
16 # pylint: disable=too-many-branches
17 "Evaluates the values and gradients of linked variables."
18 kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
19 for k, v in constants.items()})
20 kdc_plain = None
21 array_calulated = {}
22 for key in constants: # remove gradients from constants
23 key.descr.pop("gradients", None)
24 for v, f in linked.items():
25 try:
26 if v.veckey and v.veckey.vecfn:
27 if v.veckey not in array_calulated:
28 with SignomialsEnabled(): # to allow use of gpkit.units
29 vecout = v.veckey.vecfn(kdc)
30 if not hasattr(vecout, "shape"):
31 vecout = np.array(vecout)
32 array_calulated[v.veckey] = vecout
33 out = array_calulated[v.veckey][v.idx]
34 else:
35 with SignomialsEnabled(): # to allow use of gpkit.units
36 out = f(kdc)
37 if isinstance(out, FixedScalar): # to allow use of gpkit.units
38 out = out.value
39 if hasattr(out, "units"):
40 out = out.to(v.units or "dimensionless").magnitude
41 elif out != 0 and v.units:
42 pywarnings.warn(
43 f"Linked function for {v} did not return a united value."
44 " Modifying it to do so (e.g. by using `()` instead of `[]`"
45 " to access variables) will reduce errors.")
46 out = maybe_flatten(out)
47 if not hasattr(out, "x"):
48 constants[v] = out
49 continue # a new fixed variable, not a calculated one
50 constants[v] = out.x
51 v.descr["gradients"] = {adn.tag: grad
52 for adn, grad in out.d().items()
53 if adn.tag}
54 except Exception as exception: # pylint: disable=broad-except
55 from .. import settings # pylint: disable=import-outside-toplevel
56 if settings.get("ad_errors_raise", None):
57 raise
58 if kdc_plain is None:
59 kdc_plain = KeyDict(constants)
60 constants[v] = f(kdc_plain)
61 v.descr.pop("gradients", None)
62 print("Warning: skipped auto-differentiation of linked variable"
63 f" {v} because {exception!r} was raised. Set `gpkit.settings"
64 "[\"ad_errors_raise\"] = True` to raise such Exceptions"
65 " directly.\n")
66 if ("Automatic differentiation not yet supported for <class "
67 "'gpkit.nomials.math.Monomial'> objects") in str(exception):
68 print("This particular warning may have come from using"
69 f" gpkit.units.* in the function for {v}; try using"
70 " gpkit.ureg.* or gpkit.units.*.units instead.")
73def progify(program, return_attr=None):
74 """Generates function that returns a program() and optionally an attribute.
76 Arguments
77 ---------
78 program: NomialData
79 Class to return, e.g. GeometricProgram or SequentialGeometricProgram
80 return_attr: string
81 attribute to return in addition to the program
82 """
83 def programfn(self, constants=None, **initargs):
84 "Return program version of self"
85 if not constants:
86 constants, _, linked = parse_subs(self.varkeys, self.substitutions)
87 if linked:
88 evaluate_linked(constants, linked)
89 prog = program(self.cost, self, constants, **initargs)
90 prog.model = self # NOTE SIDE EFFECTS
91 if return_attr:
92 return prog, getattr(prog, return_attr)
93 return prog
94 return programfn
97def solvify(genfunction):
98 "Returns function for making/solving/sweeping a program."
99 def solvefn(self, solver=None, *, verbosity=1, skipsweepfailures=False,
100 **kwargs):
101 """Forms a mathematical program and attempts to solve it.
103 Arguments
104 ---------
105 solver : string or function (default None)
106 If None, uses the default solver found in installation.
107 verbosity : int (default 1)
108 If greater than 0 prints runtime messages.
109 Is decremented by one and then passed to programs.
110 skipsweepfailures : bool (default False)
111 If True, when a solve errors during a sweep, skip it.
112 **kwargs : Passed to solve and program init calls
114 Returns
115 -------
116 sol : SolutionArray
117 See the SolutionArray documentation for details.
119 Raises
120 ------
121 ValueError if the program is invalid.
122 RuntimeWarning if an error occurs in solving or parsing the solution.
123 """
124 constants, sweep, linked = parse_subs(self.varkeys, self.substitutions)
125 solution = SolutionArray()
126 solution.modelstr = str(self)
128 # NOTE SIDE EFFECTS: self.program and self.solution set below
129 if sweep:
130 run_sweep(genfunction, self, solution, skipsweepfailures,
131 constants, sweep, linked, solver, verbosity, **kwargs)
132 else:
133 self.program, progsolve = genfunction(self, **kwargs)
134 result = progsolve(solver, verbosity=verbosity, **kwargs)
135 if kwargs.get("process_result", True):
136 self.process_result(result)
137 solution.append(result)
138 solution.to_arrays()
139 self.solution = solution
140 return solution
141 return solvefn
144# pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements
145def run_sweep(genfunction, self, solution, skipsweepfailures,
146 constants, sweep, linked, solver, verbosity, **kwargs):
147 "Runs through a sweep."
148 # sort sweeps by the eqstr of their varkey
149 sweepvars, sweepvals = zip(*sorted(list(sweep.items()),
150 key=lambda vkval: vkval[0].eqstr))
151 if len(sweep) == 1:
152 sweep_grids = np.array(list(sweepvals))
153 else:
154 sweep_grids = np.meshgrid(*list(sweepvals))
156 n_passes = sweep_grids[0].size
157 sweep_vects = {var: grid.reshape(n_passes)
158 for (var, grid) in zip(sweepvars, sweep_grids)}
160 if verbosity > 0:
161 tic = time()
163 self.program = []
164 last_error = None
165 for i in range(n_passes):
166 constants.update({var: sweep_vect[i]
167 for (var, sweep_vect) in sweep_vects.items()})
168 if linked:
169 evaluate_linked(constants, linked)
170 program, solvefn = genfunction(self, constants, **kwargs)
171 program.model = None # so it doesn't try to debug
172 self.program.append(program) # NOTE: SIDE EFFECTS
173 if i == 0 and verbosity > 0: # wait for successful program gen
174 # TODO: use full string when minimum lineage is set automatically
175 sweepvarsstr = ", ".join([str(var)
176 for var, val in zip(sweepvars, sweepvals)
177 if not np.isnan(val).all()])
178 print(f"Sweeping {sweepvarsstr} with {n_passes} solves:")
179 try:
180 if verbosity > 1:
181 print(f"\nSolve {i}:")
182 result = solvefn(solver, verbosity=verbosity-1, **kwargs)
183 if kwargs.get("process_result", True):
184 self.process_result(result)
185 solution.append(result)
186 if verbosity == 1:
187 print(".", end="", flush=True)
188 except Infeasible as e:
189 last_error = e
190 if not skipsweepfailures:
191 raise RuntimeWarning(
192 f"Solve {i} was infeasible; progress saved to m.program."
193 " To continue sweeping after failures, solve with"
194 " skipsweepfailures=True.") from e
195 if verbosity > 1:
196 print(f"Solve {i} was {e.__class__.__name__}.")
197 if verbosity == 1:
198 print("!", end="", flush=True)
199 if not solution:
200 raise RuntimeWarning("All solves were infeasible.") from last_error
201 if verbosity == 1:
202 print()
204 solution["sweepvariables"] = KeyDict()
205 ksweep = KeyDict(sweep)
206 for var, val in list(solution["constants"].items()):
207 if var in ksweep:
208 solution["sweepvariables"][var] = val
209 del solution["constants"][var]
210 elif linked: # if any variables are linked, we check all of them
211 if hasattr(val[0], "shape"):
212 differences = ((l != val[0]).any() for l in val[1:])
213 else:
214 differences = (l != val[0] for l in val[1:])
215 if not any(differences):
216 solution["constants"][var] = [val[0]]
217 else:
218 solution["constants"][var] = [val[0]]
220 if verbosity > 0:
221 soltime = time() - tic
222 print(f"Sweeping took {soltime:.3g} seconds.")