# -*- coding: utf-8 -*-
"""
Created on Thu Aug  5 10:45:44 2021

Python script to plot UAH monthly grid layers
Expects grid files for tlt, tmt, ttp, tls in Downloads directory
@author: Mark B
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Data from https://www.nsstc.uah.edu/data/msu/v6.0/tlt/
files = [r"C:/Users/mark/Downloads/tltmonamg.2022_6.0",
         r"C:/Users/mark/Downloads/tmtmonamg.2022_6.0",
         r"C:/Users/mark/Downloads/ttpmonamg.2022_6.0",
         r"C:/Users/mark/Downloads/tlsmonamg.2022_6.0",]

lats = np.linspace(np.deg2rad(-88.75), np.deg2rad(88.75), 72)       # Center of 2.5 degree latitudes grid
lons = np.linspace(np.deg2rad(-178.75), np.deg2rad(178.75), 144)    # Center of 2.5 degree longitude grid
Lon,Lat = np.meshgrid(lons, lats)
months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 
          'August', 'September', 'October', 'November', 'December']

# Monthly file contains 12 grids 2.5 degree grids of 144 longitude 
# & 72 latitudes, 16 per line
rowsPerMonth = int(144*72/16)

# Loop through list of months
for month in [2]:       # Months to plot (zero indexed)
    fig = plt.figure(figsize=(15,8))
    for n,file in enumerate(files):
        with open(file, "r") as fid:
            for k in range(month + month*rowsPerMonth):
                line = fid.readline()
            [iSat, iYear, iMonth, aChan] = fid.readline().split()
        
        df =  pd.read_fwf(file, widths=[5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,], 
                          skiprows = 1 + month + month*rowsPerMonth, 
                          header=None, 
                          na_values = ["-9999"], nrows=rowsPerMonth)
        
        if np.any(~df.isna()):
            ax = fig.add_subplot(2,2,n+1, projection='mollweide')
            monGrid = np.array(df.values) / 100
            monGrid = monGrid.reshape((72,144))
            
            # Calculate global weighted average
            latWeights = np.cos(np.deg2rad(lats))
            latMeans = monGrid.mean(axis=1)
            mask = latMeans == latMeans     # Mask nan values
            globeMean = np.mean(latMeans[mask] * latWeights[mask]) / latWeights[mask].mean()
            print("Month {0:d} : {1:8.3f}".format(month+1, globeMean))
            
            # Plot global grid
            cs = ax.pcolormesh(Lon, Lat, monGrid, cmap='seismic', vmin=-9.5, vmax=9.5)
            cbar = plt.colorbar(cs)
            plt.title(f"{months[month]} {file.split('/')[-1]}") 
            plt.grid(axis='both')
    fig.suptitle(f"UAH {months[month]} {iYear}")
    plt.show()
    plt.pause(1.0)