import cartopy.crs as ccrs
import matplotlib as mpl
import numpy as np
import xarray as xr
from unseen_awg.utils import is_no_jump
[docs]
def transition_init_time_plot(ax_joint, ax_top, ax_right, traj, only_jumps=False):
# Data preparation (same as original)
d_min = 0
d_max = 366
bins = np.arange(d_min - 0.5, d_max + 1.5)
if traj.sizes.get("seed", 0):
x = []
y = []
for s in traj.seed:
c_traj = traj.sel(seed=s)
if only_jumps:
y.append(
(
c_traj.isel(out_time=slice(1, None))
.isel(out_time=~is_no_jump(c_traj).squeeze())
.init_time.dt.dayofyear.load()
.data
)
)
x.append(
(
c_traj.isel(out_time=slice(None, -1))
.isel(out_time=~is_no_jump(c_traj).squeeze())
.init_time.dt.dayofyear.load()
.data
)
)
else:
y.append(
c_traj.isel(out_time=slice(1, None))
.init_time.dt.dayofyear.load()
.data
)
x.append(
c_traj.isel(out_time=slice(None, -1))
.init_time.dt.dayofyear.load()
.data
)
x = np.concatenate(x)
y = np.concatenate(y)
else:
if only_jumps:
y = (
traj.isel(out_time=slice(1, None))
.isel(out_time=~is_no_jump(traj).squeeze())
.init_time.dt.dayofyear.load()
.data
)
x = (
traj.isel(out_time=slice(None, -1))
.isel(out_time=~is_no_jump(traj).squeeze())
.init_time.dt.dayofyear.load()
.data
)
else:
y = traj.isel(out_time=slice(1, None)).init_time.dt.dayofyear.load().data
x = traj.isel(out_time=slice(None, -1)).init_time.dt.dayofyear.load().data
# Create figure with gridspec for joint plot layout
# Main scatter plot
ax_joint.scatter(x, y, s=1, color="C0")
ax_joint.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
ax_joint.set_xlabel(r"day of year $t_{init}$, step i")
ax_joint.set_ylabel(r"day of year $t_{init}$, step i+1")
ax_joint.set_xlim(d_min, d_max)
ax_joint.set_ylim(d_min, d_max)
ax_joint.set_aspect("equal")
ax_joint.set_xticks(np.arange(d_min, d_max, 100))
ax_joint.set_yticks(np.arange(d_min, d_max, 100))
ax_joint.figure.canvas.draw()
# Get the actual position of the joint plot after aspect ratio adjustment
joint_pos = ax_joint.get_position()
# Align top marginal with joint plot x-axis
top_pos = ax_top.get_position()
ax_top.set_position([joint_pos.x0, top_pos.y0, joint_pos.width, top_pos.height])
# Align right marginal with joint plot y-axis
right_pos = ax_right.get_position()
ax_right.set_position(
[right_pos.x0, joint_pos.y0, right_pos.width, joint_pos.height]
)
# Top marginal histogram (filled)
ax_top.hist(x, bins=bins, color="C0", edgecolor="C0", histtype="stepfilled")
ax_top.set_xlim(d_min, d_max)
ax_top.tick_params(labelbottom=False)
ax_top.spines["top"].set_visible(False)
ax_top.spines["right"].set_visible(False)
# Right marginal histogram (filled)
ax_right.hist(
y,
bins=bins,
color="C0",
edgecolor="C0",
orientation="horizontal",
histtype="stepfilled",
)
ax_right.set_ylim(d_min, d_max)
ax_right.tick_params(labelleft=False)
ax_right.spines["top"].set_visible(False)
ax_right.spines["right"].set_visible(False)
# Style adjustments to match seaborn
for ax in [ax_joint, ax_top, ax_right]:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(direction="out")
ax_right.sharey(ax_joint)
ax_top.sharex(ax_joint)
[docs]
def transition_lead_time_plot(
ax_joint,
ax_top,
ax_right,
traj,
lt_max=44,
only_jumps=False,
use_log_cnorm_in_joint_plot=True,
):
# c0_color = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
# Create custom colormap from transparent to C0
# custom_cmap = mpl.colors.LinearSegmentedColormap.from_list(
# "transparent_to_c0",
# [
# (0, 0, 0, 0), # Transparent (RGBA)
# mpl.colors.to_rgba(c0_color),
# ], # C0 color with full opacity
# )
# Data preparation (same as original)
if traj.sizes.get("seed", 0):
x = []
y = []
for s in traj.seed:
c_traj = traj.sel(seed=s)
if only_jumps:
y.append(
(
c_traj.isel(out_time=slice(1, None))
.isel(out_time=~is_no_jump(c_traj).squeeze())
.lead_time.load()
.data
/ np.timedelta64(1, "D")
).astype("int")
)
x.append(
(
c_traj.isel(out_time=slice(None, -1))
.isel(out_time=~is_no_jump(c_traj).squeeze())
.lead_time.load()
.data
/ np.timedelta64(1, "D")
).astype("int")
)
else:
y.append(
(
c_traj.isel(out_time=slice(1, None)).lead_time.load().data
/ np.timedelta64(1, "D")
).astype("int")
)
x.append(
(
c_traj.isel(out_time=slice(None, -1)).lead_time.load().data
/ np.timedelta64(1, "D")
).astype("int")
)
x = np.concatenate(x)
y = np.concatenate(y)
else:
if only_jumps:
y = (
traj.isel(out_time=slice(1, None))
.isel(out_time=~is_no_jump(traj).squeeze())
.lead_time.load()
.data
/ np.timedelta64(1, "D")
).astype("int")
x = (
traj.isel(out_time=slice(None, -1))
.isel(out_time=~is_no_jump(traj).squeeze())
.lead_time.load()
.data
/ np.timedelta64(1, "D")
).astype("int")
else:
y = (
traj.isel(out_time=slice(1, None)).lead_time.load().data
/ np.timedelta64(1, "D")
).astype("int")
x = (
traj.isel(out_time=slice(None, -1)).lead_time.load().data
/ np.timedelta64(1, "D")
).astype("int")
bins = np.arange(-0.5, lt_max + 1.5, 1)
# Create figure with gridspec for joint plot layout
# Main 2D histogram
norm = mpl.colors.LogNorm() if use_log_cnorm_in_joint_plot else None
counts, xedges, yedges, im = ax_joint.hist2d(
x, y, bins=[bins, bins], cmap="Blues", alpha=0.8, norm=norm
)
# ax_joint.grid(None)
# Add grid lines at each bin edge
for edge in bins:
ax_joint.axvline(
edge, color="k", linestyle="-", linewidth=0.5, alpha=0.5, zorder=10
)
ax_joint.axhline(
edge, color="k", linestyle="-", linewidth=0.5, alpha=0.5, zorder=10
)
ax_joint.set_xlabel(r"$t_{lead}$, step i")
ax_joint.set_ylabel(r"$t_{lead}$, step i+1")
ax_joint.set_xlim(-0.5, lt_max + 0.5)
ax_joint.set_ylim(-0.5, lt_max + 0.5)
ax_joint.set_aspect("equal")
ax_joint.grid(None)
ax_joint.figure.canvas.draw()
# Get the actual position of the joint plot after aspect ratio adjustment
joint_pos = ax_joint.get_position()
# Align top marginal with joint plot x-axis
top_pos = ax_top.get_position()
ax_top.set_position([joint_pos.x0, top_pos.y0, joint_pos.width, top_pos.height])
# Align right marginal with joint plot y-axis
right_pos = ax_right.get_position()
ax_right.set_position(
[right_pos.x0, joint_pos.y0, right_pos.width, joint_pos.height]
)
# Top marginal histogram (stepfilled style)
ax_top.hist(x, bins=bins, histtype="stepfilled", alpha=0.7, color="C0")
# Add vertical grid lines to top marginal
for edge in bins:
ax_top.axvline(
edge, color="k", linestyle="-", linewidth=0.5, alpha=0.5, zorder=10
)
ax_top.set_xlim(-0.5, lt_max + 0.5)
ax_top.tick_params(labelbottom=False)
ax_top.spines["top"].set_visible(False)
ax_top.spines["right"].set_visible(False)
ax_top.set_xticks([])
ax_top.grid(None)
# Right marginal histogram (stepfilled style)
ax_right.hist(
y,
bins=bins,
histtype="stepfilled",
alpha=0.7,
color="C0",
orientation="horizontal",
)
# Add horizontal grid lines to right marginal
for edge in bins:
ax_right.axhline(
edge, color="k", linestyle="-", linewidth=0.5, alpha=0.5, zorder=10
)
ax_right.set_ylim(-0.5, lt_max + 0.5)
ax_right.tick_params(labelleft=False)
ax_right.spines["top"].set_visible(False)
ax_right.spines["right"].set_visible(False)
ax_right.set_yticks([])
ax_right.grid(None)
# Style adjustments
for ax in [ax_joint, ax_top, ax_right]:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(direction="out")
[docs]
def transition_valid_time_plot(ax_joint, ax_top, ax_right, traj, only_jumps=False):
# Data preparation (same as original)
d_min = 0
d_max = 366
bins = np.arange(d_min - 0.5, d_max + 1.5)
if traj.sizes.get("seed", 0):
x = []
y = []
for s in traj.seed:
c_traj = traj.sel(seed=s)
if only_jumps:
y.append(
(
c_traj.isel(out_time=slice(1, None))
.isel(out_time=~is_no_jump(c_traj).squeeze())
.init_time
+ c_traj.isel(out_time=slice(1, None))
.isel(out_time=~is_no_jump(c_traj).squeeze())
.lead_time
).dt.dayofyear.load()
)
x.append(
(
c_traj.isel(out_time=slice(None, -1))
.isel(out_time=~is_no_jump(c_traj).squeeze())
.init_time
+ c_traj.isel(out_time=slice(None, -1))
.isel(out_time=~is_no_jump(c_traj).squeeze())
.lead_time
).dt.dayofyear.load()
)
else:
y.append(
(
c_traj.isel(out_time=slice(1, None)).init_time
+ c_traj.isel(out_time=slice(1, None)).lead_time
).dt.dayofyear.load()
)
x.append(
(
c_traj.isel(out_time=slice(None, -1)).init_time
+ c_traj.isel(out_time=slice(None, -1)).lead_time
).dt.dayofyear.load()
)
x = np.concatenate(x)
y = np.concatenate(y)
else:
if only_jumps:
y = (
traj.isel(out_time=slice(1, None))
.isel(out_time=~is_no_jump(traj).squeeze())
.init_time.dt.dayofyear.load()
.data
)
x = (
traj.isel(out_time=slice(None, -1))
.isel(out_time=~is_no_jump(traj).squeeze())
.init_time.dt.dayofyear.load()
.data
)
else:
y = traj.isel(out_time=slice(1, None)).init_time.dt.dayofyear.load().data
x = traj.isel(out_time=slice(None, -1)).init_time.dt.dayofyear.load().data
# Create figure with gridspec for joint plot layout
# Main scatter plot
ax_joint.scatter(x, y, s=1, color="C0")
ax_joint.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
ax_joint.set_xlabel(r"day of year $t_{valid}$, step i")
ax_joint.set_ylabel(r"day of year $t_{valid}$, step i+1")
ax_joint.set_xlim(d_min, d_max)
ax_joint.set_ylim(d_min, d_max)
ax_joint.set_aspect("equal")
ax_joint.figure.canvas.draw()
ax_joint.set_xticks(np.arange(d_min, d_max, 100))
ax_joint.set_yticks(np.arange(d_min, d_max, 100))
# Get the actual position of the joint plot after aspect ratio adjustment
joint_pos = ax_joint.get_position()
# Align top marginal with joint plot x-axis
top_pos = ax_top.get_position()
ax_top.set_position([joint_pos.x0, top_pos.y0, joint_pos.width, top_pos.height])
# Align right marginal with joint plot y-axis
right_pos = ax_right.get_position()
ax_right.set_position(
[right_pos.x0, joint_pos.y0, right_pos.width, joint_pos.height]
)
# Top marginal histogram (filled)
ax_top.hist(x, bins=bins, color="C0", edgecolor="C0", histtype="stepfilled")
ax_top.set_xlim(d_min, d_max)
ax_top.tick_params(labelbottom=False)
ax_top.spines["top"].set_visible(False)
ax_top.spines["right"].set_visible(False)
# Right marginal histogram (filled)
ax_right.hist(
y,
bins=bins,
color="C0",
edgecolor="C0",
orientation="horizontal",
histtype="stepfilled",
)
ax_right.set_ylim(d_min, d_max)
ax_right.tick_params(labelleft=False)
ax_right.spines["top"].set_visible(False)
ax_right.spines["right"].set_visible(False)
# Style adjustments to match seaborn
for ax in [ax_joint, ax_top, ax_right]:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(direction="out")
# https://stackoverflow.com/questions/25812255/row-and-column-headers-in-matplotlibs-subplots
[docs]
def map_plot_without_frame_with_bounds(ax, da: xr.DataArray, **plot_kwargs):
dlat = abs(da.latitude[1] - da.latitude[0])
dlon = abs(da.longitude[1] - da.longitude[0])
lat_max = da.latitude.max() + dlat / 2
lat_min = da.latitude.min() - dlat / 2
lon_max = da.longitude.max() + dlon / 2
lon_min = da.longitude.min() - dlon / 2
# black frame around region with data
ax.plot(
[lon_min, lon_max, lon_max, lon_min, lon_min],
[lat_min, lat_min, lat_max, lat_max, lat_min],
color="black",
linewidth=1,
transform=ccrs.PlateCarree(), # remove this line to get straight lines
)
for spine in ax.spines.values():
spine.set_visible(False)
m = da.plot(ax=ax, transform=ccrs.PlateCarree(), **plot_kwargs)
ax.coastlines(resolution="50m", linewidth=1, color="k")
return m
[docs]
def contourf_plot_without_frame_with_bounds(ax, da: xr.DataArray, **plot_kwargs):
# dlat = abs(da.latitude[1] - da.latitude[0])
# dlon = abs(da.longitude[1] - da.longitude[0])
lat_max = da.latitude.max()
lat_min = da.latitude.min()
lon_max = da.longitude.max()
lon_min = da.longitude.min()
# black frame around region with data
ax.plot(
[lon_min, lon_max, lon_max, lon_min, lon_min],
[lat_min, lat_min, lat_max, lat_max, lat_min],
color="black",
transform=ccrs.PlateCarree(), # remove this line to get straight lines
)
for spine in ax.spines.values():
spine.set_visible(False)
da.plot.contourf(ax=ax, transform=ccrs.PlateCarree(), **plot_kwargs)
ax.coastlines()
[docs]
def add_contours(
ax,
da: xr.DataArray,
major_levels,
minor_levels,
use_contour_labels=True,
linewidth_major=1,
linewidth_minor=0.5,
**plot_kwargs,
):
da.squeeze().plot.contour(
ax=ax,
levels=minor_levels,
linewidths=linewidth_minor,
transform=ccrs.PlateCarree(),
colors="k",
**plot_kwargs,
)
contours = da.squeeze().plot.contour(
ax=ax,
levels=major_levels,
linewidths=linewidth_major,
transform=ccrs.PlateCarree(),
colors="k",
**plot_kwargs,
)
if use_contour_labels:
ax.clabel(contours, contours.levels)
[docs]
def add_label_to_axes(
ax,
label,
ax_xpos=0.05,
ax_ypos=0.95,
ha="left",
va="top",
edgecolor="white",
**font_kwargs,
):
ax.text(
ax_xpos,
ax_ypos,
label,
ha=ha,
va=va,
transform=ax.transAxes,
bbox=dict(facecolor="white", edgecolor=edgecolor, boxstyle="round4"),
**font_kwargs,
)