# -*- coding: utf-8 -*-
"""
Created on Wed Mar 16 11:34:25 2022
@author: dboateng
"""
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import os
import matplotlib as mpl
import seaborn as sns
from matplotlib.dates import YearLocator
import matplotlib.dates as mdates
from cycler import cycler
try:
from plot_utils import *
except:
from .plot_utils import *
[docs]def plot_monthly_mean(means, stds, color, ylabel=None, ax=None,
fig_path=None, fig_name=None, lolims=False):
if ax is None:
fig,ax = plt.subplots(1,1, sharex=False, figsize=(20, 15))
plot = means.plot(kind="bar", yerr=stds, rot=0, ax=ax, fontsize=20, capsize=4,
width=0.8, color=color, edgecolor=black,
error_kw=dict(ecolor='black',elinewidth=0.5, lolims=lolims))
for ch in plot.get_children():
if str(ch).startswith("Line2D"):
ch.set_marker("_")
ch.set_markersize(10)
break
if ylabel is not None:
ax.set_ylabel(ylabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_yticklabels([])
plt.tight_layout()
if fig_path is not None:
plt.savefig(os.path.join(fig_path, fig_name), bbox_inches="tight", format= "svg")
[docs]def correlation_heatmap(data, cmap, ax=None, vmax=None, vmin=None, center=0, cbar_ax=None,
add_cbar=True, title=None, label= "Correlation Coefficinet", fig_path=None, fig_name=None,
xlabel=None, ylabel=None, fig=None):
if ax is None:
fig,ax = plt.subplots(1,1, sharex=False, figsize=(15, 13))
if add_cbar == True:
if cbar_ax is None:
cbar_ax = [0.90, 0.4, 0.02, 0.25]
cbar_ax = fig.add_axes(cbar_ax)
cbar_ax.get_xaxis().set_visible(False)
cbar_ax.yaxis.set_ticks_position('right')
cbar_ax.set_yticklabels([])
cbar_ax.tick_params(size=0)
sns.set(font_scale=1.2)
if all(parameter is not None for parameter in [vmin, vmax]):
sns.heatmap(ax=ax, data=data, cmap=cmap, vmax=vmax, vmin=vmin, center=center, cbar=add_cbar,
square=True, cbar_ax = cbar_ax, cbar_kws={"label": label, "shrink":0.5,
"drawedges": False,},
linewidth=0.5, linecolor="black",)
else:
sns.heatmap(ax=ax, data=data, cmap=cmap, robust=True, cbar=add_cbar,
square=True, cbar_ax = cbar_ax, cbar_kws={"label": label, "shrink":0.5,
"drawedges": False},
linewidth=0.5, linecolor="black")
if xlabel is not None:
ax.set_xlabel(xlabel, fontsize=18)
ax.set_ylabel(ylabel, fontsize=18)
plt.tight_layout()
plt.subplots_adjust(left=0.15, right=0.88, top=0.97, bottom=0.05)
if fig_path is not None:
plt.savefig(os.path.join(fig_path, fig_name), bbox_inches="tight", format= "svg")
[docs]def barplot(methods, stationnames, path_to_data, ax=None, xlabel=None, ylabel=None,
varname="test_r2", varname_std="test_r2_std", filename="validation_score_", legend=True,
fig_path=None, fig_name=None, show_error=False, width=0.5, rot=0, use_id=True):
if ax is None:
fig,ax = plt.subplots(1,1, sharex=False, figsize=(18, 15))
df, df_std = barplot_data(methods, stationnames, path_to_data, varname=varname, varname_std=varname_std,
filename=filename, use_id=use_id)
colors = [selector_method_colors[m] for m in methods]
mpl.rcParams["axes.prop_cycle"] = cycler("color", colors)
if show_error == True:
df.plot(kind="bar", yerr=df_std, rot=rot, ax=ax, legend = legend, fontsize=20, capsize=4,
width=width, edgecolor=black)
else:
df.plot(kind="bar", rot=rot, ax=ax, legend = legend, fontsize=20, width=width, edgecolor=black)
if ylabel is not None:
ax.set_ylabel(ylabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_yticklabels([])
if xlabel is not None:
ax.set_xlabel(xlabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_xticklabels([])
if legend ==True:
ax.legend(loc="upper right", bbox_to_anchor=(1.15, 1), borderaxespad=0., frameon=True, fontsize=20)
plt.tight_layout()
plt.subplots_adjust(left=0.05, right=0.95, top=0.97, bottom=0.05)
if fig_path is not None:
plt.savefig(os.path.join(fig_path, fig_name), bbox_inches="tight", format= "svg")
[docs]def boxplot(regressors, stationnames, path_to_data, ax=None, xlabel=None, ylabel=None,
varname="test_r2", filename="validation_score_",
fig_path=None, fig_name=None, colors=None, patch_artist=False, rot=45):
if ax is None:
fig,ax = plt.subplots(1,1, sharex=False, figsize=(20, 15))
scores = boxplot_data(regressors, stationnames, path_to_data, filename=filename,
varname=varname)
color = { "boxes": black,
"whiskers": black,
"medians": red,
"caps": black,
}
boxplot = scores.plot(kind= "box", rot=rot, ax=ax, fontsize=20, color= color, sym="+b", grid=False,
widths=0.9, notch=False, patch_artist=patch_artist, return_type="dict")
if colors is not None:
for patch, color in zip(boxplot["boxes"], colors):
patch.set_facecolor(color)
if ylabel is not None:
ax.set_ylabel(ylabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_yticklabels([])
if xlabel is not None:
ax.set_xlabel(xlabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_xticklabels([])
plt.tight_layout()
plt.subplots_adjust(left=0.05, right=0.95, top=0.97, bottom=0.05)
if fig_path is not None:
plt.savefig(os.path.join(fig_path, fig_name), bbox_inches="tight", format= "svg")
[docs]def heatmaps(data, cmap, label=None, title=None, vmax=None, vmin=None, center=None, ax=None,
cbar=True, cbar_ax=None, xlabel=None):
if ax is None:
fig,ax = plt.subplots(1,1, figsize=(20,15))
plt.subplots_adjust(left=0.02, right=1-0.02, top=0.94, bottom=0.45, hspace=0.25)
if all(parameter is not None for parameter in [vmax, vmin, center]):
if cbar == False:
sns.heatmap(data=data, ax=ax, cmap=cmap, vmax=vmax, vmin=vmin, center=center,
square=True, cbar=cbar, linewidth=0.3, linecolor="black")
else:
if cbar_ax is not None:
sns.heatmap(data=data, ax=ax, cmap=cmap, vmax=vmax, vmin=vmin, center=center,
square=True, cbar=cbar, cbar_kws={"label":label,
"shrink":.80, "extend":"both"},
linewidth=0.3, linecolor="black", cbar_ax=cbar_ax)
else:
sns.heatmap(data=data, ax=ax, cmap=cmap, vmax=vmax, vmin=vmin, center=center,
square=True, cbar=cbar, cbar_kws={"label":label,
"shrink":.80, "extend":"both"},
linewidth=0.3, linecolor="black")
else:
if cbar == False:
sns.heatmap(data=data, ax=ax, cmap=cmap, square=True, cbar=cbar, linewidth=0.3,
linecolor="black")
else:
if cbar_ax is not None:
sns.heatmap(data=data, ax=ax, cmap=cmap, square=True, cbar=cbar,
cbar_kws={"label":label,"shrink":.80,"extend":"both"},
linewidth=0.3, linecolor="black", cbar_ax=cbar_ax)
else:
sns.heatmap(data=data, ax=ax, cmap=cmap, square=True, cbar=cbar,
cbar_kws={"label":label,"shrink":.80,"extend":"both"},
linewidth=0.3, linecolor="black")
if title is not None:
ax.set_title(title, fontsize=20, fontweight="bold", loc="left")
if xlabel is not None:
ax.set_xlabel(xlabel, fontweight="bold", fontsize=20)
else:
ax.set_xticklabels([])
[docs]def scatterplot(station_num, stationnames, path_to_data, filename, ax=None,
obs_train_name="obs 1958-2010",
obs_test_name="obs 2011-2020",
val_predict_name="ERA5 1958-2010",
test_predict_name="ERA5 2011-2020",
obs_full_name="obs anomalies",
method = "Stacking", ylabel=None, xlabel=None,
fig_path=None, fig_name=None, train_marker="*", test_marker="o",
train_color=black, test_color=blue,
):
if ax is None:
fig,ax = plt.subplots(1,1, figsize=(20,15))
plt.subplots_adjust(left=0.02, right=1-0.02, top=0.94, bottom=0.45, hspace=0.25)
station_info = prediction_example_data(station_num, stationnames, path_to_data, filename,
obs_test_name=obs_test_name, obs_train_name=obs_train_name,
val_predict_name=val_predict_name, test_predict_name=test_predict_name,
method=method, obs_full_name=obs_full_name)
obs_train = station_info["obs_train"]
obs_test = station_info["obs_test"]
ypred_train = station_info["ypred_train"]
ypred_test = station_info["ypred_test"]
obs = station_info["obs"]
from scipy import stats
regression_stats = stats.linregress(obs_test, ypred_test)
regression_slope = regression_stats.slope * obs + regression_stats.intercept
r2 = regression_stats.rvalue
ax.scatter(obs_train, ypred_train, alpha=0.3, c=train_color, s=100, label=val_predict_name,
marker=train_marker)
ax.scatter(obs_test, ypred_test, alpha=0.3, c=test_color, s=100, label=test_predict_name,
marker=test_marker)
ax.plot(obs, regression_slope, color=red, label="PCC = {:.2f}".format(r2))
ax.legend(loc= "upper left", fontsize=20)
# Plot design
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.spines["left"].set_position(("outward", 20))
ax.spines["bottom"].set_position(("outward", 20))
if ylabel is not None:
ax.set_ylabel(ylabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_yticklabels([])
if xlabel is not None:
ax.set_xlabel(xlabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_xticklabels([])
plt.tight_layout()
plt.subplots_adjust(left=0.05, right=0.95, top=0.97, bottom=0.05)
if fig_path is not None:
plt.savefig(os.path.join(fig_path, fig_name), bbox_inches="tight", format= "svg")
[docs]def lineplot(station_num, stationnames, path_to_data, filename, ax=None, fig=None,
obs_train_name="obs 1958-2010",
obs_test_name="obs 2011-2020",
val_predict_name="ERA5 1958-2010",
test_predict_name="ERA5 2011-2020",
obs_full_name="obs anomalies",
method = "Stacking", ylabel=None, xlabel=None,
fig_path=None, fig_name=None,
):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize= (20, 15), sharex=True)
plt.subplots_adjust(left=0.12, right=1-0.01, top=0.98, bottom=0.06, hspace=0.01)
station_info = prediction_example_data(station_num, stationnames, path_to_data, filename,
obs_test_name=obs_test_name, obs_train_name=obs_train_name,
val_predict_name=val_predict_name, test_predict_name=test_predict_name,
method=method, obs_full_name=obs_full_name)
ypred_train = station_info["ypred_train"].rolling(3, min_periods=1, win_type="hann",
center=True).mean()
ypred_test = station_info["ypred_test"].rolling(3, min_periods=1, win_type="hann",
center=True).mean()
obs = station_info["obs"].rolling(3, min_periods=1, win_type="hann",
center=True).mean()
ax.plot(obs, linestyle="-", color=green, label="Obs")
ax.plot(ypred_train, linestyle="-.", color= blue, label = val_predict_name)
ax.plot(ypred_test, linestyle="--", color=red, label=test_predict_name)
ax.xaxis.set_major_locator(YearLocator(5))
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
ax.axhline(y=0, linestyle="--", color=grey, linewidth=2)
ax.legend(bbox_to_anchor=(0.01, 1.02, 1., 0.102), loc=3, ncol=3, borderaxespad=0., frameon = True,
fontsize=20)
if ylabel is not None:
ax.set_ylabel(ylabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_yticklabels([])
if xlabel is not None:
ax.set_xlabel(xlabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_xticklabels([])
plt.tight_layout()
if fig_path is not None:
plt.savefig(os.path.join(fig_path, fig_name), bbox_inches="tight", format= "svg")
[docs]def plot_time_series(stationnames, path_to_data, filename, id_name,
daterange, color, label, ymax=None, ymin=None, ax=None, ylabel=None, xlabel=None,
fig_path=None, fig_name=None, method="Stacking", window=12):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize= (20, 15), sharex=True)
df = extract_time_series(stationnames, path_to_data, filename, id_name,
method, daterange,)
df = df.rolling(window, min_periods=1, win_type="hann", center=True).mean()
ax.plot(df["mean"], "--", color=color, label=label)
#try with max and min to notice the difference with 5 years window
ax.fill_between(df.index, df["mean"] - df["std"], df["mean"] + df["std"], color=color,
alpha=0.2,)
ax.xaxis.set_major_locator(YearLocator(10))
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
ax.axhline(y=0, linestyle="--", color=grey, linewidth=2)
ax.legend(frameon=True, fontsize=12, loc="lower left")
if ymax is not None:
ax.set_ylim([ymin, ymax])
if ylabel is not None:
ax.set_ylabel(ylabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_yticklabels([])
# if xlabel is not None:
# ax.set_xlabel(xlabel, fontweight="bold", fontsize=20)
# ax.grid(True)
# else:
# ax.grid(True)
# ax.set_xticklabels([])
plt.tight_layout()
if fig_path is not None:
plt.savefig(os.path.join(fig_path, fig_name), bbox_inches="tight", format= "svg")
[docs]def plot_projection_comparison(stationnames, path_to_data,
filename, id_name, method, stationloc_dir,
daterange, datasets, variable, dataset_varname,
ax=None, xlabel=None, ylabel=None, legend=True,
figpath=None, figname=None, width=0.5, title=None,
vmax=None, vmin=None, use_id=True):
df = extract_comparison_data_means(stationnames, path_to_data, filename, id_name,
method, stationloc_dir, daterange, datasets,
variable, dataset_varname, use_id=use_id)
models_col_names = ["ESD", "MPIESM", "CESM5", "HadGEM2", "CORDEX"]
if ax is None:
fig,ax = plt.subplots(1,1, sharex=False, figsize=(18, 15))
colors = [Models_colors[c] for c in models_col_names]
mpl.rcParams["axes.prop_cycle"] = cycler("color", colors)
if use_id:
df.plot(kind="bar", rot=0, ax=ax, legend=legend, fontsize=20, width=width)
else:
df.plot(kind="bar", rot=45, ax=ax, legend=legend, fontsize=20, width=width)
if vmax is not None:
ax.set_ylim(vmin, vmax)
if ylabel is not None:
ax.set_ylabel(ylabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_yticklabels([])
if xlabel is not None:
ax.set_xlabel(xlabel, fontweight="bold", fontsize=20)
ax.grid(True, linestyle="--", color=gridline_color)
else:
ax.grid(True, linestyle="--", color=gridline_color)
ax.set_xticklabels([])
if title is not None:
ax.set_title(title, fontsize=20, fontweight="bold", loc="left")
if legend ==True:
ax.legend(loc="upper right", bbox_to_anchor=(1.15, 1), borderaxespad=0., frameon=True, fontsize=20)
plt.tight_layout()
plt.subplots_adjust(left=0.05, right=0.95, top=0.97, bottom=0.05)