# Import pyCAPS class file
import pyCAPS

# Import os module
import os
import argparse

# Setup and read command line options. Please note that this isn't required for pyCAPS
parser = argparse.ArgumentParser(description = 'himach X43 Pytest Example',
                                 prog = 'himach_Cone_PyTest',
                                 formatter_class = argparse.ArgumentDefaultsHelpFormatter)

#Setup the available commandline options
parser.add_argument('-workDir', default = "./", nargs=1, type=str, help = 'Set working/run directory')
parser.add_argument("-outLevel", default = 1, type=int, choices=[0, 1, 2], help="Set output verbosity")
args = parser.parse_args()

# Create working directory variable
projectName = "HIMachConeTest"
workDir = os.path.join(str(args.workDir[0]), projectName)

# Load CSM file
geometryScript = os.path.join("cone.csm")
capsProblem = pyCAPS.Problem(problemName=workDir,
                             capsFile=geometryScript,
                             outLevel=args.outLevel)

# Load egadsTess aim to create surface mesh
egadsTess = capsProblem.analysis.create(aim = "egadsTessAIM",
                                        name = "egadsTess" )

# Set meshing parameters
egadsTess.input.Tess_Params = [0.025, 0.01, 10.0]

# Mixed quad/tri grid
egadsTess.input.Mesh_Elements = "Tri"

# Mesh format
egadsTess.input.Mesh_Format = "vtk"

# Load himach aim
himach = capsProblem.analysis.create(aim = "himachAIM",
                                     name = "himach")

# Link mesh
himach.input["Surface_Mesh"].link(egadsTess.output["Surface_Mesh"])

# Specify the possibly full PATH to HI-Mach executable
himach.input.HIMach = "himach.exe"

# Set flow conditions
himach.input.Mach  = 4.0
himach.input.Alpha = 0
himach.input.Beta = 0

himach.input.Mesh_Morph = True

# Set solver settings
himach.input.Windward_Method   = "modified-newtonian"
himach.input.Leeward_Method    = "prandtl-meyer"
himach.input.Base_Pressure     = "gaubeaud"
himach.input.Shielding_Effects = False


# Explicitly run analysis (optional)
himach.runAnalysis()

# Get all ouputs
for i in himach.output.keys():
    print(str(i) + " = " + str(himach.output[i].value))


# Import pyplot module
try:
    import numpy as npy
    from matplotlib import pyplot as plt

    L0 = capsProblem.geometry.despmtr.L
    H0 = capsProblem.geometry.despmtr.H
    a0 = capsProblem.geometry.despmtr.alpha
    b0 = capsProblem.geometry.despmtr.beta

    for despmtr in ['L', 'H', 'alpha', 'beta']:

        capsProblem.geometry.despmtr.L     = L0
        capsProblem.geometry.despmtr.H     = H0
        capsProblem.geometry.despmtr.alpha = a0
        capsProblem.geometry.despmtr.beta  = b0

        # Set design variables
        himach.input.Design_Variable = {despmtr:{}}

        if despmtr == 'alpha' or despmtr == 'beta':
            dmin = -10
            dmax = 10
        else:
            dmin = capsProblem.geometry.despmtr[despmtr].value*0.25
            dmax = capsProblem.geometry.despmtr[despmtr].value*1.75

        desvals = npy.linspace(dmin, dmax, 21)
        CD = []
        CL = []
        dynout = {}
        CD_despmtr = []
        CL_despmtr = []
        dynout_despmtr = {}
        for desval in desvals:
            value = capsProblem.geometry.despmtr[despmtr].value

            capsProblem.geometry.despmtr[despmtr].value = desval

            print("-"*20)
            print(despmtr, desval)
            print("-"*20)

            CD.append(himach.output["CD"].value)
            CD_despmtr.append(himach.output["CD"].deriv(despmtr))

            CL.append(himach.output["CL"].value)
            CL_despmtr.append(himach.output["CL"].deriv(despmtr))

            for key in himach.dynout.keys():
                
                if not key in dynout: 
                    dynout[key] = []
                    dynout_despmtr[key] = []

                dynout[key].append(himach.dynout[key].value)
                dynout_despmtr[key].append(himach.dynout[key].deriv(despmtr))
            
            himach.input["Surface_Mesh"].unlink()

        def plot(C, C_despmtr, c):

            fig, ax1 = plt.subplots()
            ax2 = ax1.twinx()

            # Compute central difference derivatives
            dC_despmtr = []
            for i in range(1,len(desvals)-1):
                dC_despmtr.append( (C[i+1]-C[i-1])/(desvals[i+1]-desvals[i-1]) )

            
            if despmtr == "alpha" or despmtr == "beta":
                xpmtr = "\\"+despmtr+r"_g"
            else:
                xpmtr = r"\mathrm{"+despmtr+r"}"

            # Plot the functional
            lns1 = ax1.plot(desvals, C, 'o-', label = r"$"+c+r"("+xpmtr+r")$", color='blue')

            # Plot plot the derivative
            lns2 = ax2.plot(desvals, C_despmtr, 's--', label = r"$\partial "+c+r"/\partial "+xpmtr+r"$", color='red')

            # Plot plot the FD derivative
            lns3 = ax2.plot(desvals[1:-1], dC_despmtr, '.--', label = r"$\Delta "+c+r"/\Delta "+xpmtr+r"$", color='black')

            #ax2.axhline(0., linewidth=1, linestyle=":", color='r')
            #ax2.axvline(0., linewidth=1, linestyle=":", color='k')

            #plt.title("(window must be closed to continue Python script)\n")

            ax1.set_xlabel(r"$"+xpmtr+r"$")
            ax1.set_ylabel(r"$"+c+r"$", color='blue')
            ax2.set_ylabel(r"$\partial "+c+r"$", color='red')

            fig.tight_layout()  # otherwise the right y-label is slightly clipped

            # Shrink current axis's height by 10% on the bottom
            box = ax1.get_position()
            ax1.set_position([box.x0, box.y0, box.width, box.height * 0.9])
            ax2.set_position([box.x0, box.y0, box.width, box.height * 0.9])

            # add legend
            lns = lns1+lns2+lns3
            labs = [l.get_label() for l in lns]
            #ax1.legend(lns, labs, loc='best', facecolor='white', framealpha=1)
            ax1.legend(lns, labs, loc='lower center', bbox_to_anchor=(0.5, 1.01), fancybox=True, ncol=3)

            plt.savefig("Cone_"+c+"_"+despmtr+".png")

        print(CD, CD_despmtr)
        print(CL, CL_despmtr)
        #for key in dynout.keys():
        #    print(dynout[key], dynout_despmtr[key])

        plot(CD, CD_despmtr, 'C_D')
        plot(CL, CL_despmtr, 'C_L')
        #for key in dynout.keys():
        #    plot(dynout[key], dynout_despmtr[key], key)

    plt.show()

except ImportError:
    print ("Unable to import matplotlib.pyplot module.")
