# -*- coding: utf-8 -*-
"""
Created on Mon Jun  6 12:25:48 2022

Nino regions: https://climatedataguide.ucar.edu/climate-data/nino-sst-indices-nino-12-3-34-4-oni-and-tni

@author: brugger
"""

import netCDF4 as nc
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
from climateUtils import ctrm, runAvg

def genBase(series):
    baseline = [np.mean(series[k::12]) for k in range(12)]
    return baseline

def subBase(series, baseline):
    anomaly = [s - baseline[k % 12] for k,s in enumerate(series)]
    return anomaly

file = r"C:\Users\brugger\Downloads\sst.mnmean.nc"
ds = nc.Dataset(file)

dates = nc.num2date(ds['time'], ds['time'].units)
dtDates = [datetime(d.year, d.month, d.day, d.hour) for d in dates]
decDates = np.array([d.year + (d.month-0.5)/12 for d in dtDates])

lonDeg = np.array(ds['lon'])
latDeg = np.array(ds['lat'])
Lon,Lat = np.meshgrid(lonDeg, ['lat'])

lon = ds['lon']
lat = ds['lat']

sst = ds['sst']

nino4 = [np.mean(sst[k,42:47,80:106]) for k in range(dates.size)]
nino34 = [np.mean(sst[k,42:47,95:121]) for k in range(dates.size)]
nino3 = [np.mean(sst[k,42:47,105:136]) for k in range(dates.size)]
nino12 = [np.mean(sst[k,44:50,135:141]) for k in range(dates.size)]     # Lat not quite right

baseStart = len(dates) - dates[-1].month - 12*(dates[-1].year % 5) - 12*30
anomNino4 = subBase(nino4, genBase(nino4[baseStart:baseStart+12*30]))
anomNino34 = subBase(nino34, genBase(nino34[baseStart:baseStart+12*30]))
anomNino3 = subBase(nino3, genBase(nino3[baseStart:baseStart+12*30]))
anomNino12 = subBase(nino12, genBase(nino12[baseStart:baseStart+12*30]))

N = 5
plt.figure(1) ; plt.clf()
plt.plot(dtDates[2:-2], runAvg(anomNino4,N), '-c', label="Nino 4 Anomaly")
plt.plot(dtDates[2:-2], runAvg(anomNino34,N), '-k', label="Nino 3.4 Anomaly")
plt.plot(dtDates[2:-2], runAvg(anomNino3,N), '-b', label="Nino 3 Anomaly")
plt.plot(dtDates[2:-2], runAvg(anomNino12,N), '-y', label="Nino 1+2 Anomaly")
plt.grid(axis='both')
plt.legend()

plt.figure(2) ; plt.clf()
for offset in range(0,100,10):
    base = baseStart-12*offset
    plt.plot(range(1,13), genBase(nino34[base : base + 12*30]), label=f"{dates[base].year}-{dates[base+12*30].year}")
plt.legend()

trendNino34 = np.polyfit(decDates, anomNino34, 1)[0]
print(f"Nino 3.4 trend: {10*trendNino34:0.2f} C per decade")