import numpy as np

from pyNastran.utils.numpy_utils import integer_types
from pyNastran.op2.tables.oes_stressStrain.real.oes_objects import OES_Object
from pyNastran.op2.result_objects.op2_objects import get_times_dtype
from pyNastran.f06.f06_formatting import write_floats_12e, _eigenvalue_header # write_floats_13e,


class RealNonlinearBushArray(OES_Object): # 226-CBUSHNL
    """
    ::

                  N O N L I N E A R   F O R C E S  A N D  S T R E S S E S  I N   B U S H   E L E M E N T S    ( C B U S H )

                                F O R,C E                               S T R E S S                             S T R A I N
      ELEMENT ID.   FORCE-X      FORCE-Y      FORCE-Z       STRESS-TX    STRESS-TY    STRESS-TZ     STRAIN-TX    STRAIN-TY    STRAIN-TZ
                    MOMENT-X     MOMENT-Y     MOMENT-Z      STRESS-RX    STRESS-RY    STRESS-RZ     STRAIN-RX    STRAIN-RY    STRAIN-RZ
             6      0.0          0.0          0.0           0.0          0.0          0.0           0.0          0.0          0.0
                    0.0          0.0          0.0           0.0          0.0          0.0           0.0          0.0          0.0
    """
    def __init__(self, data_code, is_sort1, isubcase, dt):
        """tested by elements/loadstep_elements.op2"""
        OES_Object.__init__(self, data_code, isubcase, apply_data_code=True)
        #self.code = [self.format_code, self.sort_code, self.s_code]
        self.nelements = 0  # result specific

    @property
    def is_real(self) -> bool:
        return True

    @property
    def is_complex(self) -> bool:
        return False

    @property
    def nnodes_per_element(self) -> int:
        return 1

    def _reset_indices(self) -> None:
        self.itotal = 0
        self.ielement = 0

    def _get_msgs(self):
        raise NotImplementedError()

    def get_headers(self) -> list[str]:
        headers = ['fx', 'fy', 'fz', 'otx', 'oty', 'otz', 'etx', 'ety', 'etz',
                   'mx', 'my', 'mz', 'orx', 'ory', 'orz', 'erx', 'ery', 'erz']
        return headers

    def build(self):
        """sizes the vectorized attributes of the RealNonlinearBushArray"""
        #print('ntimes=%s nelements=%s ntotal=%s' % (self.ntimes, self.nelements, self.ntotal))
        assert self.ntimes > 0, 'ntimes=%s' % self.ntimes
        assert self.nelements > 0, 'nelements=%s' % self.nelements
        assert self.ntotal > 0, 'ntotal=%s' % self.ntotal
        #self.names = []
        self.nelements //= self.ntimes
        self.itime = 0
        self.ielement = 0
        self.itotal = 0
        #self.ntimes = 0
        #self.nelements = 0

        #print("ntimes=%s nelements=%s ntotal=%s" % (self.ntimes, self.nelements, self.ntotal))
        dtype, idtype, fdtype = get_times_dtype(self.nonlinear_factor, self.size, self.analysis_fmt)
        self._times = np.zeros(self.ntimes, dtype=self.analysis_fmt)
        self.element = np.zeros(self.nelements, dtype=idtype)

        #[fx, fy, fz, otx, oty, otz, etx, ety, etz,
        # mx, my, mz, orx, ory, orz, erx, ery, erz]
        self.data = np.zeros((self.ntimes, self.nelements, 18), dtype=fdtype)

    def build_dataframe(self):
        """creates a pandas dataframe"""
        #import pandas as pd
        headers = self.get_headers()
        if self.nonlinear_factor not in (None, np.nan):
            #Time                                           0.02       0.04
            #ElementID Item
            #102       axial_stress                    19.413668  76.139496
            #          equiv_stress                    19.413668  76.139496
            #          total_strain                     0.000194   0.000761
            #          effective_plastic_creep_strain   0.000000   0.000000
            #          effective_creep_strain           0.000000   0.000000
            #          linear_torsional_stress          0.000000   0.000000
            column_names, column_values = self._build_dataframe_transient_header()
            self.data_frame = self._build_pandas_transient_elements(
                column_values, column_names,
                headers, self.element, self.data)
        else:
            raise NotImplementedError('transient pandas nonlinear cbush')
            #bbb
            #df1 = pd.DataFrame(self.element).T
            #df1.columns = ['ElementID']
            #df2 = pd.DataFrame(self.data[0])
            #df2.columns = headers
            #self.data_frame = df1.join([df2])
        #print(self.data_frame)

    def __eq__(self, table):  # pragma: no cover
        self._eq_header(table)
        assert self.is_sort1 == table.is_sort1
        if not np.array_equal(self.data, table.data):
            msg = 'table_name=%r class_name=%s\n' % (self.table_name, self.__class__.__name__)
            msg += '%s\n' % str(self.code_information())
            ntimes = self.data.shape[0]

            i = 0
            if self.is_sort1:
                for itime in range(ntimes):
                    for ieid, eid, in enumerate(self.element):
                        t1 = self.data[itime, ieid, :]
                        t2 = table.data[itime, ieid, :]
                        (fx1, fy1, fz1, otx1, oty1, otz1, etx1, ety1, etz1, mx1, my1, mz1, orx1, ory1, orz1, erx1, ery1, erz1) = t1
                        (fx2, fy2, fz2, otx2, oty2, otz2, etx2, ety2, etz2, mx2, my2, mz2, orx2, ory2, orz2, erx2, ery2, erz2) = t2
                        if not np.allclose(t1, t2):
                        #if not np.array_equal(t1, t2):
                            msg += '%s\n'
                            '  (%s, %s, %s, %s, %s, %s, %s, %s, %s %s, %s, %s, %s, %s, %s, %s, %s, %s)\n'
                            '  (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)\n' % (
                                eid,
                                fx1, fy1, fz1, otx1, oty1, otz1, etx1, ety1, etz1, mx1, my1, mz1, orx1, ory1, orz1, erx1, ery1, erz1,
                                fx2, fy2, fz2, otx2, oty2, otz2, etx2, ety2, etz2, mx2, my2, mz2, orx2, ory2, orz2, erx2, ery2, erz2)
                            i += 1
                        if i > 10:
                            print(msg)
                            raise ValueError(msg)
            else:
                raise NotImplementedError(self.is_sort2)
            if i > 0:
                print(msg)
                raise ValueError(msg)
        return True

    def add_sort1(self, dt, eid, fx, fy, fz, otx, oty, otz, etx, ety, etz,
                  mx, my, mz, orx, ory, orz, erx, ery, erz):
        """unvectorized method for adding SORT1 transient data"""
        assert self.sort_method == 1, self
        assert isinstance(eid, integer_types) and eid > 0, 'dt=%s eid=%s' % (dt, eid)
        self._times[self.itime] = dt
        self.element[self.ielement] = eid
        self.data[self.itime, self.ielement, :] = [
            fx, fy, fz, otx, oty, otz, etx, ety, etz,
            mx, my, mz, orx, ory, orz, erx, ery, erz
        ]
        self.ielement += 1

    def get_stats(self, short: bool=False) -> list[str]:
        if not self.is_built:
            return [
                f'<{self.__class__.__name__}>; table_name={self.table_name!r}\n',
                f'  ntimes: {self.ntimes:d}\n',
                f'  ntotal: {self.ntotal:d}\n',
            ]

        ntimes, nelements, _ = self.data.shape
        assert self.ntimes == ntimes, 'ntimes=%s expected=%s' % (self.ntimes, ntimes)
        assert self.nelements == nelements, 'nelements=%s expected=%s' % (self.nelements, nelements)

        msg = []
        if self.nonlinear_factor not in (None, np.nan):  # transient
            msg.append('  type=%s ntimes=%i nelements=%i\n'
                       % (self.__class__.__name__, ntimes, nelements))
            ntimes_word = 'ntimes'
        else:
            msg.append('  type=%s nelements=%i\n'
                       % (self.__class__.__name__, nelements))
            ntimes_word = '1'
        msg.append('  etype\n')
        headers = self.get_headers()
        n = len(headers)
        msg.append('  data: [%s, nelements, %i] where %i=[%s]\n' % (ntimes_word, n, n, str(', '.join(headers))))
        msg.append(f'  data.shape = {self.data.shape}\n')
        msg.append(f'  element type: {self.element_name}-{self.element_type}\n')
        msg += self.get_data_code()
        return msg

    def write_f06(self, f06_file, header=None, page_stamp='PAGE %s',
                  page_num: int=1, is_mag_phase: bool=False, is_sort1: bool=True):
        if header is None:
            header = []
        if is_sort1:
            msg = [
                '             N O N L I N E A R   F O R C E S  A N D  S T R E S S E S  I N   B U S H   E L E M E N T S    ( C B U S H )\n'
                ' \n'
                '                           F O R,C E                               S T R E S S                             S T R A I N\n'
                ' ELEMENT ID.   FORCE-X      FORCE-Y      FORCE-Z       STRESS-TX    STRESS-TY    STRESS-TZ     STRAIN-TX    STRAIN-TY    STRAIN-TZ\n'
                '               MOMENT-X     MOMENT-Y     MOMENT-Z      STRESS-RX    STRESS-RY    STRESS-RZ     STRAIN-RX    STRAIN-RY    STRAIN-RZ\n'
                #'        6      0.0          0.0          0.0           0.0          0.0          0.0           0.0          0.0          0.0\n'
                #'               0.0          0.0          0.0           0.0          0.0          0.0           0.0          0.0          0.0\n'
            ]
        else:
            raise NotImplementedError('SORT2')

        if self.is_sort1:
            page_num = self._write_sort1_as_sort1(header, page_stamp, page_num, f06_file, msg)
        else:
            raise NotImplementedError('RealNonlinearRodArray')
        return page_num

    def _write_sort1_as_sort1(self, header, page_stamp, page_num, f06_file, msg_temp):
        ntimes = self.data.shape[0]

        eids = self.element
        #is_odd = False
        #nwrite = len(eids)

        for itime in range(ntimes):
            dt = self._times[itime]
            header = _eigenvalue_header(self, header, itime, ntimes, dt)
            f06_file.write(''.join(header + msg_temp))

            #print("self.data.shape=%s itime=%s ieids=%s" % (str(self.data.shape), itime, str(ieids)))
            fx = self.data[itime, :, 0]
            fy = self.data[itime, :, 1]
            fz = self.data[itime, :, 2]
            otx = self.data[itime, :, 3]
            oty = self.data[itime, :, 4]
            otz = self.data[itime, :, 5]
            etx = self.data[itime, :, 6]
            ety = self.data[itime, :, 7]
            etz = self.data[itime, :, 8]

            mx = self.data[itime, :, 9]
            my = self.data[itime, :, 10]
            mz = self.data[itime, :, 11]
            orx = self.data[itime, :, 12]
            ory = self.data[itime, :, 13]
            orz = self.data[itime, :, 14]
            erx = self.data[itime, :, 15]
            ery = self.data[itime, :, 16]
            erz = self.data[itime, :, 17]

            #print "dt=%s axials=%s eqs=%s ts=%s epcs=%s ecs=%s lts=%s" %(dt,axial,eqs,ts,epcs,ecs,lts)
            #msgE[eid] = '      ELEMENT-ID = %8i\n' % (eid)
            #if eid not in msgT:
                #msgT[eid] = []
            #msgT[eid].append('  %9.3E       %13.6E       %13.6E       %13.6E       %13.6E       %13.6E       %13.6E\n' % (dt, axial, eqs, ts, epcs, ecs, lts))

            for (eid, fxi, fyi, fzi, otxi, otyi, otzi, etxi, etyi, etzi, mxi, myi, mzi, orxi, oryi, orzi, erxi, eryi, erzi) in zip(
                eids, fx, fy, fz, otx, oty, otz, etx, ety, etz, mx, my, mz, orx, ory, orz, erx, ery, erz):
                ([sfx, sfy, sfz, sotx, soty, sotz, setx, sety, setz,
                  smx, smy, smz, sorx, sory, sorz, serx, sery, serz]) = write_floats_12e(
                      [fxi, fyi, fzi, otxi, otyi, otzi, etxi, etyi, etzi,
                       mxi, myi, mzi, orxi, oryi, orzi, erxi, eryi, erzi])
                f06_file.write(#fx    fy    fz     otx   oty   otz  etx   ety   etz
                    ' %8i     %-12s %-12s %-12s  %-12s %-12s %-12s  %-12s %-12s %s\n'
                    ' %8s     %-12s %-12s %-12s  %-12s %-12s %-12s  %-12s %-12s %s\n' % (
                        eid, sfx, sfy, sfz, sotx, soty, sotz, setx, sety, setz,
                        '', smx, smy, smz, sorx, sory, sorz, serx, sery, serz))
            f06_file.write(page_stamp % page_num)
            page_num += 1
        return page_num - 1

    def write_op2(self, op2_file, op2_ascii, itable, new_result, date,
                  is_mag_phase=False, endian='>'):
        """writes an OP2"""
        import inspect
        from struct import Struct, pack
        frame = inspect.currentframe()
        call_frame = inspect.getouterframes(frame, 2)
        op2_ascii.write(f'{self.__class__.__name__}.write_op2: {call_frame[1][3]}\n')

        if itable == -1:
            self._write_table_header(op2_file, op2_ascii, date)
            itable = -3

        #if isinstance(self.nonlinear_factor, float):
            #op2_format = '%sif' % (7 * self.ntimes)
            #raise NotImplementedError()
        #else:
            #op2_format = 'i21f'
        #s = Struct(op2_format)

        eids = self.element

        # table 4 info
        #ntimes = self.data.shape[0]
        #nnodes = self.data.shape[1]
        nelements = self.data.shape[1]

        ntotali = self.num_wide
        ntotal = ntotali * nelements

        #print('shape = %s' % str(self.data.shape))
        #assert self.ntimes == 1, self.ntimes

        #device_code = self.device_code
        op2_ascii.write(f'  ntimes = {self.ntimes}\n')

        eids_device = self.element * 10 + self.device_code

        #fmt = '%2i %6f'
        #print('ntotal=%s' % (ntotal))
        #assert ntotal == 193, ntotal

        if self.is_sort1:
            struct1 = Struct(endian + b'i18f')
        else:
            raise NotImplementedError('SORT2')

        op2_ascii.write(f'nelements={nelements:d}\n')

        for itime in range(self.ntimes):
            #print('3, %s' % itable)
            self._write_table_3(op2_file, op2_ascii, new_result, itable, itime)

            # record 4
            #print('stress itable = %s' % itable)
            itable -= 1
            #print('4, %s' % itable)
            header = [4, itable, 4,
                      4, 1, 4,
                      4, 0, 4,
                      4, ntotal, 4,
                      4 * ntotal]
            op2_file.write(pack('%ii' % len(header), *header))
            op2_ascii.write('r4 [4, 0, 4]\n')
            op2_ascii.write(f'r4 [4, {itable:d}, 4]\n')
            op2_ascii.write(f'r4 [4, {4 * ntotal:d}, 4]\n')

            fx = self.data[itime, :, 0]
            fy = self.data[itime, :, 1]
            fz = self.data[itime, :, 2]
            otx = self.data[itime, :, 3]
            oty = self.data[itime, :, 4]
            otz = self.data[itime, :, 5]
            etx = self.data[itime, :, 6]
            ety = self.data[itime, :, 7]
            etz = self.data[itime, :, 8]

            mx = self.data[itime, :, 9]
            my = self.data[itime, :, 10]
            mz = self.data[itime, :, 11]
            orx = self.data[itime, :, 12]
            ory = self.data[itime, :, 13]
            orz = self.data[itime, :, 14]
            erx = self.data[itime, :, 15]
            ery = self.data[itime, :, 16]
            erz = self.data[itime, :, 17]

            for (eid_device, fxi, fyi, fzi, otxi, otyi, otzi, etxi, etyi, etzi, mxi, myi, mzi, orxi, oryi, orzi, erxi, eryi, erzi) in zip(
                eids_device, fx, fy, fz, otx, oty, otz, etx, ety, etz, mx, my, mz, orx, ory, orz, erx, ery, erz):
                data = [eid_device,
                        fxi, fyi, fzi, otxi, otyi, otzi, etxi, etyi, etzi,
                        mxi, myi, mzi, orxi, oryi, orzi, erxi, eryi, erzi]
                op2_ascii.write('  eid=%s data=%s\n' % (eid_device, str(data)))
                op2_file.write(struct1.pack(*data))

            itable -= 1
            header = [4 * ntotal,]
            op2_file.write(pack('i', *header))
            op2_ascii.write('footer = %s\n' % header)
            new_result = False
        return itable
