###################################################################
#                                                                 #
# optAlpha --- find alpha that minimized CD using OpenMDAO        #
#                                                                 #
#              Written by John Dannenhoffer @ Syracuse University #
#                     and Marshall Galbraith @ MIT                #
#                                                                 #
###################################################################

# import pyCAPS module
import pyCAPS
from   pyOCSM import esp

# import OpenMDAO v3 module
import openmdao
import openmdao.api as om

#------------------------------------------------------------------

# make a semi-colon-separated string from a list
def makeString(array):
    out = ""
    for i in array:
        out += str(i) + ";"
    return out

#------------------------------------------------------------------

# global storage for convergence history
Alpha_call = []

# class definitions to be used by OpenMDAO
class msesAnalysis(om.ExplicitComponent):
    
    def initialize(self):
        # setup arguments for the component.
        
        # CAPS Problem input
        self.options.declare('capsProblem', types=object)

    def setup(self):
        # assign initial values to the variables.
 
        capsProblem = self.options['capsProblem']
        mses = capsProblem.analysis['mses']

        # attach parameters to OpenMDAO object
        self.add_input('Alpha', val=mses.input.Alpha)
        
        # add output metric
        self.add_output('CD')
        
        # declare and attach partials to self
        self.declare_partials('CD','Alpha')

    def compute(self, inputs, outputs):
        # compute functionals

        capsProblem = self.options['capsProblem']
        mses = capsProblem.analysis['mses']

        print("\n--> Alpha   ", inputs['Alpha'])

        # update input values if changed
        if mses.input.Alpha != inputs['Alpha']:
            mses.input.Alpha = inputs['Alpha']

        # grab objective and attach as an output
        outputs['CD'] = mses.output.CD
        print("--> CD      ", outputs['CD'])

        Alpha_call.append(mses.input.Alpha)

    def compute_partials(self,inputs,partials):
        # compute functional partial derivatives
        
        capsProblem = self.options['capsProblem']
        mses = capsProblem.analysis['mses']

        print("\n--> Alpha   ", inputs['Alpha'])

        # update input values if changed
        if mses.input.Alpha != inputs['Alpha']:
            mses.input.Alpha = inputs['Alpha']

        # get derivatives and set partials
        partials['CD', 'Alpha'] = mses.output["CD"].deriv("Alpha")
        print("--> CD_alpha", partials['CD', 'Alpha'])

#------------------------------------------------------------------

# initiate CAPS problem
capsProblem = pyCAPS.Problem(problemName = "optAlpha",
                             capsFile    = "naca.csm",
                             outLevel    = 0)

# setup AIM for MSES
mses = capsProblem.analysis.create(aim  = "msesAIM",
                                   name = "mses")

# set flow condition
mses.input.Alpha = 3.0   # Initial guess away from solution Alpha == 0
mses.input.Mach  = 0.5
mses.input.Re    = 5e6

# set meshing parameters
mses.input.GridAlpha = 0
mses.input.Airfoil_Points = 201

# trip the flow near the leading edge to get smooth gradient
mses.input.xTransition_Upper = 0.1
mses.input.xTransition_Lower = 0.1

#------------------------------------------------------------------

# setup OpenMDAO problem

# setup the openmdao problem object    
omProblem = om.Problem()

# create the OpenMDAO component
msesSystem = msesAnalysis(capsProblem = capsProblem)

# add subsystem to model
omProblem.model.add_subsystem('msesSystem', msesSystem)

# add design variables to model
omProblem.model.add_design_var('msesSystem.Alpha', lower=-5, upper=5)

# add objective to minimize CD
omProblem.model.add_objective('msesSystem.CD')

# setup the optimization
omProblem.driver = om.ScipyOptimizeDriver()
omProblem.driver.options['optimizer'] = "L-BFGS-B"
omProblem.driver.options['tol'] = 1.e-9
omProblem.driver.options['disp'] = True

# execute optimization
print ("\n==> Starting Optimization...")
omProblem.setup()
omProblem.run_driver()
omProblem.cleanup()

print("--> Optimized alpha:", omProblem.get_val("msesSystem.Alpha"))

# load the plotter
esp.TimLoad("plotter", esp.GetEsp("pyscript"), "")

Iters = range(len(Alpha_call))
Zeros = len(Alpha_call) * [0]

# plot the convergence history
esp.TimMesg("plotter", "new|Optimization convergence|function calls|alpha (deg)|")
esp.TimMesg("plotter", "add|"+makeString(Iters)+"|"+makeString(Alpha_call)+"|k-+|")
esp.TimMesg("plotter", "add|"+makeString(Iters)+"|"+makeString(Zeros    )+"|k:|")
esp.TimMesg("plotter", "show")

# exit the plotter
esp.TimQuit("plotter")

# close the capsProblem (required if you want to run another pyscript)
capsProblem.close()
