# -*- coding: utf-8 -*-
from netCDF4 import Dataset
from netCDF4 import MFDataset
from staircase import staircase_from_line
import numpy as np
# ---------------
# User settings
# --------------

def comp_flux(temp_field,salt_field,volmask,iflag,i_file,u_file,v_file,t_file,fltm,flsm,fltp,flsp,i0,j0,i1,j1,minjk,flag):
    """Compute flux accross a cross-section"""
    Isec, Jsec = staircase_from_line(i0, i1, j0, j1)

    dum=np.copy(volmask)
    dum[:,:,:,:]=0.
    dum[0,:,Jsec,Isec]=np.copy(volmask[0,:,Jsec,Isec])
    lvolmask=dum

    Isec_u=np.copy(Isec)
    Jsec_u=np.copy(Jsec)

    Isec_v=np.copy(Isec)
    Jsec_v=np.copy(Jsec)

    
    sign_dj=np.sign(j1-j0)
    if sign_dj == 1:
       Isec_u=Isec_u-1
    Isecp_u=np.copy(Isec_u)
    Jsecp_u=np.copy(Jsec_u)
    Isecp_u[:]=0
    Jsecp_u[:]=0
    Isecp_u=np.copy(Isec_u-sign_dj)
    Jsecp_u=np.copy(Jsec_u)


    Isecp_v=np.copy(Isec_v)
    Jsecp_v=np.copy(Jsec_v)
    Isecp_v[:]=0
    Jsecp_v[:]=0
    Isecp_v=np.copy(Isec_v)
    Jsecp_v=np.copy(Jsec_v+1)

    
    if flag == 'vfl': 
       fu=Dataset(u_file)
       uocetr_eff=fu.variables['uocetr_eff'][:,:,:,:]
       fu.close()
       fv=Dataset(v_file)
       vocetr_eff=fv.variables['vocetr_eff'][:,:,:,:]
       fv.close()
       if iflag == 1:
          fv=Dataset(i_file)
          uocetr_ice=(fv.variables['xmtrpice'][:,:,:]+fv.variables['xmtrpsnw'][:,:,:])/1026.
          vocetr_ice=(fv.variables['ymtrpice'][:,:,:]+fv.variables['ymtrpsnw'][:,:,:])/1026.
          fv.close()

       
    if flag == 'hfl': 
       fu=Dataset(u_file)
       uocetr_eff=fu.variables['u_heattr'][:,:,:,:]
       fu.close()
       fv=Dataset(v_file)
       vocetr_eff=fv.variables['v_heattr'][:,:,:,:]
       fv.close()
       if iflag == 1:
          fv=Dataset(i_file)
          uocetr_ice=fv.variables['xmtrpice'][:,:,:]*0.
          vocetr_ice=fv.variables['ymtrpice'][:,:,:]*0.
          fv.close()


    if flag == 'sfl': 
       fu=Dataset(u_file)
       uocetr_eff=fu.variables['u_salttr'][:,:,:,:]
       fu.close()
       fv=Dataset(v_file)
       vocetr_eff=fv.variables['v_salttr'][:,:,:,:]
       fv.close()
       if iflag == 1:
          fv=Dataset(i_file)
          uocetr_ice=fv.variables['xstrpice'][:,:,:]/1000.
          vocetr_ice=fv.variables['ystrpice'][:,:,:]/1000.
          fv.close()
 

    
    ft=Dataset(t_file)
    temp=ft.variables[temp_field][:,:,:,:]
    salt=ft.variables[salt_field][:,:,:,:]
    ft.close()

    trmask=np.copy(temp[:,:,:,:])

    trmask[:,:,:,:]=0.
               
    ntr=np.where((temp >= fltm)&(temp <= fltp)&(salt >= flsm)&(salt <=flsp))
    trmask[ntr]=1.

    temp=lvolmask*temp*trmask
    salt=lvolmask*salt*trmask

    mtemp1=np.sum(temp,axis=3)
    mtemp2=np.sum(mtemp1,axis=2)

    msalt1=np.sum(salt,axis=3)
    msalt2=np.sum(msalt1,axis=2)

    mvol1=np.sum(lvolmask*trmask,axis=3)
    mvol2=np.sum(mvol1,axis=2)

    noval=np.where(mvol2 == 0.)
    mvol2[noval]=1.
    mtemp=mtemp2/mvol2
    msalt=msalt2/mvol2
  
    umask=np.copy(uocetr_eff[0,:,:,:])
    vmask=np.copy(vocetr_eff[0,:,:,:])

    umask[:,:,:]=0.
    vmask[:,:,:]=0.

    umask2=np.copy(umask)
    vmask2=np.copy(vmask)

    umask[:,Jsec_u,Isec_u]=1.
    vmask[:,Jsec_v,Isec_v]=1.

    umask2[:,Jsec_u,Isec_u]=np.copy(umask[:,Jsec_u,Isec_u])*np.copy(1.-umask[:,Jsecp_u,Isecp_u])
    vmask2[:,Jsec_v,Isec_v]=np.copy(vmask[:,Jsec_v,Isec_v])*np.copy(1.-vmask[:,Jsecp_v,Isecp_v])

    umask=np.copy(umask2)
    vmask=np.copy(vmask2)

    if iflag == 1:
       uocetr_eff[:,0,:,:]=uocetr_eff[:,0,:,:]+iflag*uocetr_ice
       vocetr_eff[:,0,:,:]=vocetr_eff[:,0,:,:]+iflag*vocetr_ice
    if iflag == 0:
       uocetr_eff[:,0,:,:]=uocetr_eff[:,0,:,:]
       vocetr_eff[:,0,:,:]=vocetr_eff[:,0,:,:]
       
    uocetr_eff=uocetr_eff*np.copy(umask) *np.copy(trmask)
    vocetr_eff=vocetr_eff*np.copy(vmask) *np.copy(trmask)

    flux_total=-sign_dj*uocetr_eff+vocetr_eff


    return flux_total, mtemp, msalt


