"""Weather generator module for analog-based weather simulation.
This module provides the core WeatherGenerator class that implements analog-based
weather generation using similarity measures and probability models to sample
realistic weather trajectories from historical data.
"""
import datetime
import os
import shutil
import uuid
from typing import Any
import cartopy.crs as ccrs
import dask
import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import yaml
from icecream import ic
from matplotlib.gridspec import GridSpec
from tqdm.auto import tqdm
import unseen_awg.similarity_measures
from unseen_awg.data_classes import InitTimeLeadTimeMemberState
from unseen_awg.plotting_utils import add_contours, map_plot_without_frame_with_bounds
from unseen_awg.probability_models import ProbabilityModel
from unseen_awg.snakemake_utils import snakemake_handler
from unseen_awg.time_steppers import TimeStepper
from unseen_awg.timestep_utils import (
dayofyear_year_to_datetime64_naive,
is_in_window_from_time,
is_in_window_from_year_fraction,
)
from unseen_awg.utils import (
apply_similarity_metric,
get_k_random_indices,
get_k_smallest_indices,
get_map_valid_n_day_transitions,
)
[docs]
def setup_lazy_similarity_dataset(
ds_year_dayofyear_format: xr.Dataset,
window_size: int,
ref_time: np.datetime64 = np.datetime64("2000-01-01", "ns"),
) -> xr.Dataset:
"""Set up a lazy dataset to store similarities computed in weather generator in.
Creates a dataset structure for computing similarities between weather states
within a specified time window. The dataset includes coordinates for reference
states and candidate states with time shifts.
The dataset has dimensions that identify the base sample
(dayofyear, year, sample, ensemble member) and additional dimensions that identify
the candidate (d_shift, c_year, c_sample, c_ensemble_member). The valid_time of
the candidate state can be computed from c_year, dayofyear, and d_shift.
Parameters
----------
ds_year_dayofyear_format : xr.Dataset
Input dataset in year-dayofyear format containing weather data.
window_size : int
Size of the time window (in days) for similarity computations.
ref_time : np.datetime64, optional
Reference time for temporal calculations, by default
np.datetime64("2000-01-01", "ns").
Returns
-------
xr.Dataset
Lazy dataset with similarity computation structure including coordinates
for reference and candidate states.
"""
ds_similarities = xr.Dataset(
coords={
"dayofyear": ds_year_dayofyear_format.dayofyear.data,
"year": ds_year_dayofyear_format.year.data,
"sample": ds_year_dayofyear_format.sample.data,
"ensemble_member": ds_year_dayofyear_format.ensemble_member.data,
"d_shift": np.arange(-(window_size + 1), (window_size + 1) + 1),
"c_year": ds_year_dayofyear_format.year.data,
"c_sample": ds_year_dayofyear_format.sample.data,
"c_ensemble_member": ds_year_dayofyear_format.ensemble_member.data,
}
)
l_dims = [len(ds_similarities[d]) for d in ds_similarities.dims]
l_chunks = []
for d in ds_similarities.dims:
if d in ["dayofyear", "year"]:
l_chunks.append(1)
else:
l_chunks.append(len(ds_similarities[d]))
ds_similarities["similarities"] = (
ds_similarities.dims,
dask.array.full(l_dims, fill_value=np.nan, chunks=l_chunks),
)
ds_similarities = ds_similarities.assign_coords(
{
"init_time": ds_year_dayofyear_format.init_time.load(),
"c_init_time": ds_year_dayofyear_format.init_time.load()
.rename({"sample": "c_sample", "year": "c_year"})
.sel(
dayofyear=(
((ds_similarities.dayofyear + ds_similarities.d_shift) - 1) % 366
)
+ 1
),
"c_dayofyear": (
((ds_similarities.dayofyear + ds_similarities.d_shift) - 1) % 366
)
+ 1,
}
)
valid_time_reference = xr.apply_ufunc(
np.vectorize(dayofyear_year_to_datetime64_naive),
ds_similarities.dayofyear,
ds_similarities.year,
)
valid_time_candidates = xr.apply_ufunc(
np.vectorize(dayofyear_year_to_datetime64_naive),
(((ds_similarities.dayofyear + ds_similarities.d_shift) - 1) % 366) + 1,
ds_similarities.c_year,
)
ds_similarities = ds_similarities.assign_coords(
m_is_near=is_in_window_from_time(
valid_time_reference,
valid_time_candidates,
window_size=window_size,
ref_time=ref_time,
)
)
ds_similarities = ds_similarities.assign_coords(
{
"valid_time": xr.apply_ufunc(
np.vectorize(dayofyear_year_to_datetime64_naive),
ds_similarities.dayofyear,
ds_similarities.year,
),
"c_valid_time": xr.apply_ufunc(
np.vectorize(dayofyear_year_to_datetime64_naive),
(((ds_similarities.dayofyear + ds_similarities.d_shift) - 1) % 366) + 1,
ds_similarities.c_year,
),
}
)
ds_similarities = ds_similarities.assign_coords(
{
"lead_time": ds_similarities.valid_time - ds_similarities.init_time,
"c_lead_time": ds_similarities.c_valid_time - ds_similarities.c_init_time,
}
)
# cf-conventions:
ds_similarities.attrs["Conventions"] = "CF-1.7"
ds_similarities.attrs["Title"] = (
"Similarities between pairs of states in the underlying dataset"
)
ds_similarities.attrs["Source"] = (
"Computed according to the selected similarity measure from underlying dataset."
)
ds_similarities["dayofyear"].attrs.update(
long_name="day of year",
units="1", # dimensionless integer index
)
ds_similarities["year"].attrs.update(
long_name="year",
units="1",
)
ds_similarities["sample"].attrs.update(
long_name="index for data on same valid_time.",
units="1",
)
ds_similarities["ensemble_member"].attrs.update(
long_name="reforecast ensemble member",
units="1",
)
ds_similarities["d_shift"].attrs.update(
long_name="shift between day of year of base sample "
"and day of year of candidate",
units="1",
)
ds_similarities["c_year"].attrs.update(
long_name="year of candidate",
units="1",
)
ds_similarities["c_sample"].attrs.update(
long_name="index for data on same valid_time of candidate",
units="1",
)
ds_similarities["c_ensemble_member"].attrs.update(
long_name="reforecas ensemble member of candidate",
units="1",
)
ds_similarities["valid_time"].attrs.update(
long_name="valid time of forecast",
standard_name="time",
)
ds_similarities["lead_time"].attrs.update(
standard_name="forecast_period",
long_name="reforecast lead time",
)
ds_similarities["init_time"].attrs.update(
long_name="reforecast initialisation time",
standard_name="forecast_reference_time",
)
ds_similarities["c_valid_time"].attrs.update(
long_name="valid time of candidate",
standard_name="time",
)
ds_similarities["c_init_time"].attrs.update(
long_name="reforecast initialisation time of candidate",
standard_name="forecast_reference_time",
)
ds_similarities["c_lead_time"].attrs.update(
long_name="reforecast lead time of candidate",
)
ds_similarities["c_dayofyear"].attrs.update(
long_name="day of year of candidate",
units="1", # dimensionless integer index
)
ds_similarities["m_is_near"].attrs.update(
long_name="mask: candidate date is near target date",
units="1",
)
ds_similarities["similarities"].attrs.update(
long_name="similarity score between forecast and candidate analog",
units="1",
)
for var in ["init_time", "c_init_time", "valid_time", "c_valid_time"]:
ds_similarities[var].encoding.update(
units="seconds since 1970-01-01",
dtype="int64",
)
# timedelta variables: encode as seconds (a plain number + units)
for var in ["lead_time", "c_lead_time"]:
ds_similarities[var].encoding.update(
units="seconds",
dtype="int64",
)
return ds_similarities
[docs]
class WeatherGenerator:
"""Analog-based weather generator for creating synthetic weather trajectories.
This class implements an analog-based approach to weather generation, where
weather states are sampled from historical data while assuring that successive
states either follow each other in the historical dataset or are analogs of the true
successor. The generator uses configurable similarity measures and probability
models to create realistic weather sequences, additional parameters can be specified
when sampling time series.
Parameters
----------
params : dict[str, Any]
Configuration parameters containing:
- weather_generator.window_size : int
Half-window size within which states are considered as potential analogs.
- weather_generator.var : str
Name of variable to use for similarity calculations.
- weather_generator.similarity : str
Name of similarity function to use.
- weather_generator.use_precomputed_similarities : bool
Whether to use precomputed similarities or not. If True, similarities
are precomputed when a WeatherGenerator instance is initialized.
Otherwise, a lazy array is set up and similarities are computed on-the-fly
during the sampling process. This slows down the sampling process.
- weather_generator.n_samples : int
Number of samples to use from dataset. By "sample" we denote
datapoints that possess the same valid_time (but a different init_time).
Providing a low n_samples allows restricting the number of included
states.
- dir_wg : str
Directory path for weather generator outputs.
- zarr_year_dayofyear : str
Path to zarr store containing the preprocessed input dataset.
Attributes
----------
window_size : int
Half-window size within which states are considered as potential analogs.
var : str
Name of variable to use for similarity calculations.
similarity_function : callable
Function used to compute similarities between states.
path_wg : str
Path to weather generator working directory.
path_dataset : str
Path to input dataset.
use_precomputed_similarities : bool
Flag indicating whether to use precomputed similarities.
ds_similarities : xr.Dataset
Dataset containing results of similarity computations. Is initialized as
lazy dataset and computed during initialization if
use_precomputed_similarities is True.
"""
def __init__(self, params: dict[str, Any]) -> None:
self.window_size = params["weather_generator.window_size"]
self.var = params["weather_generator.var"]
self.similarity_function = getattr(
unseen_awg.similarity_measures, params["weather_generator.similarity"]
)
self.path_wg = os.path.join(params["dir_wg"])
self.path_dataset = params["zarr_year_dayofyear"]
self.use_precomputed_similarities = params[
"weather_generator.use_precomputed_similarities"
]
# load the reshaped year dayofyear dataset (chunks along dayofyear).
ic("Load circulation dataset")
ds = xr.open_zarr(self.path_dataset, decode_timedelta=True).isel(
sample=slice(0, params["weather_generator.n_samples"])
)
# set up store for similarities
path_similarities = os.path.join(self.path_wg, "similarities.zarr")
ic("Set up array for similarities")
self.ds_similarities = setup_lazy_similarity_dataset(
ds_year_dayofyear_format=ds,
window_size=self.window_size,
)
d_shifts = xr.DataArray(
np.arange(-(self.window_size + 1), (self.window_size + 1) + 1),
dims="d_shift",
)
ds = ds.drop_vars("init_time")
ds_candidate = (
ds.rename(
{
"year": "c_year",
"ensemble_member": "c_ensemble_member",
"sample": "c_sample",
}
)
.sel(dayofyear=((((ds.dayofyear + d_shifts) - 1) % 366) + 1))
.assign_coords({"d_shifts": d_shifts})
)
ic("Set up similarity computation lazily")
similarities = apply_similarity_metric(
ds_reference=ds,
ds_candidate=ds_candidate,
similarity_func=self.similarity_function,
variable_name=self.var,
).rename("similarities")
if self.use_precomputed_similarities:
self.ds_similarities.to_zarr(
path_similarities,
compute=False,
)
for i_doy, _ in enumerate(
tqdm(
self.ds_similarities.similarities.dayofyear,
desc="Compute similarities",
)
):
similarities.isel(
dayofyear=slice(i_doy, i_doy + 1),
).drop_vars(
[
"ensemble_member",
"c_sample",
"c_ensemble_member",
"c_year",
"d_shifts",
"sample",
"year",
]
).load().to_zarr(
path_similarities,
region={
"dayofyear": slice(i_doy, i_doy + 1),
},
)
else:
os.makedirs(self.path_wg, exist_ok=False)
self.ds_similarities["similarities"] = similarities
[docs]
def sample_trajectory(
self,
blocksize: int,
probability_model: ProbabilityModel,
stepper_class: type[TimeStepper],
n_steps: int,
rng: np.random.Generator,
initialization: InitTimeLeadTimeMemberState | None = None,
start_by_taking_analog: bool = False,
show_progressbar: bool = False,
) -> xr.Dataset:
"""Sample a synthetic weather trajectory using the analog method.
Generates a weather trajectory by iteratively sampling analog states
from historical data. The sampling alternates between following a historical
trajectory and sampling analogs of the true successor states - so that in effect
blocks of size blocksize are sampled while for the transition between blocks
close analogs of the "true" state that would follow each block are chosen.
Parameters
----------
blocksize : int
Number of days to sample contiguously from the same historical trajectory.
probability_model : ProbabilityModel
Model defining sampling probabilities given similarities between base states
and candidate states.
stepper_class : type[TimeStepper]
Class for managing the output time assigned to each sample in the resulting
trajectory. This is used as a means for supporting different calendars in
sampled datasets.
n_steps : int
Number of sampling steps to perform, not necessarily equal to the length of
the sampled series in days.
rng : np.random.Generator
Random number generator for sampling.
initialization : InitTimeLeadTimeMemberState | None, optional
Initial state specification, by default None (random initialization).
start_by_taking_analog : bool, optional
Whether to start by taking an analog of the initial state, by default False.
show_progressbar : bool, optional
Whether to display progress bar, by default False.
Returns
-------
xr.Dataset
Generated weather trajectory with time series of sampled states.
"""
# set up a "map" between states and their successors states
# for blocksize-long transitions:
path_map = os.path.join(self.path_wg, f"map_{blocksize}_steps_transition.nc")
map_n_step_transition = self._load_or_create_map_file(path_map, blocksize)
current_block_start_state = self.get_initial_state(
initialization=initialization,
map_n_step_transition=map_n_step_transition,
blocksize=blocksize,
rng=rng,
)
# initialize stepper with starting condition.
# Stepper is iterator that on each call returns tuple of (time, year_fraction)
stepper = stepper_class(
init_year=current_block_start_state.valid_time.dt.year,
init_month=current_block_start_state.valid_time.dt.month,
init_day=current_block_start_state.valid_time.dt.day,
blocksize=blocksize,
)
current_out_time, current_year_fraction = next(stepper)
current_block_start_state["out_time"] = current_out_time
if start_by_taking_analog:
current_block_start_state = self.sampling_step(
next_state=current_block_start_state,
next_year_fraction=current_year_fraction,
map_n_step_transition=map_n_step_transition,
probability_model=probability_model,
rng=rng,
)
# initialize empty trajectory
trajectory: list[xr.Dataset] = []
for _ in tqdm(
range(n_steps), disable=(not show_progressbar), desc="Sampling trajectory"
):
# alternate between following what actually happend and analog sampling
next_state, next_year_fraction = self.time_evolution_step(
trajectory=trajectory,
current_block_start_state=current_block_start_state,
map_n_step_transition=map_n_step_transition,
stepper=stepper,
blocksize=blocksize,
)
next_state = self.sampling_step(
next_state=next_state,
next_year_fraction=next_year_fraction,
map_n_step_transition=map_n_step_transition,
probability_model=probability_model,
rng=rng,
)
current_block_start_state = next_state
return xr.concat(trajectory, dim="out_time")
[docs]
@classmethod
def load(cls, wg_path: str) -> "WeatherGenerator":
"""Load a WeatherGenerator instance from saved configuration.
Parameters
----------
wg_path : str
Path to directory containing saved weather generator configuration.
Returns
-------
WeatherGenerator
Loaded weather generator instance.
"""
with open(os.path.join(wg_path, "params.yaml"), "r") as file:
params = yaml.safe_load(file)
instance = super().__new__(cls)
instance.window_size = params["weather_generator.window_size"]
instance.var = params["weather_generator.var"]
instance.similarity_function = getattr(
unseen_awg.similarity_measures, params["weather_generator.similarity"]
)
instance.path_wg = params["dir_wg"]
instance.path_dataset = params["zarr_year_dayofyear"]
instance.use_precomputed_similarities = params[
"weather_generator.use_precomputed_similarities"
]
if instance.use_precomputed_similarities:
instance.ds_similarities = xr.open_zarr(
os.path.join(wg_path, "similarities.zarr"), decode_timedelta=True
)
else:
ds = xr.open_zarr(instance.path_dataset, decode_timedelta=True).isel(
sample=slice(0, params["weather_generator.n_samples"])
)
instance.ds_similarities = setup_lazy_similarity_dataset(
ds_year_dayofyear_format=ds,
window_size=instance.window_size,
)
d_shifts = xr.DataArray(
np.arange(-(instance.window_size + 1), (instance.window_size + 1) + 1),
dims="d_shift",
)
ds = ds.drop_vars("init_time")
ds_candidate = (
ds.rename(
{
"year": "c_year",
"ensemble_member": "c_ensemble_member",
"sample": "c_sample",
}
)
.sel(dayofyear=((((ds.dayofyear + d_shifts) - 1) % 366) + 1))
.assign_coords({"d_shifts": d_shifts})
)
similarities = apply_similarity_metric(
ds_reference=ds,
ds_candidate=ds_candidate,
similarity_func=instance.similarity_function,
variable_name=instance.var,
).rename("similarities")
instance.ds_similarities["similarities"] = similarities
return instance
[docs]
def time_evolution_step(
self,
trajectory: list[xr.Dataset],
current_block_start_state: xr.Dataset,
map_n_step_transition: xr.Dataset,
stepper: TimeStepper,
blocksize: int,
) -> tuple[xr.Dataset, float]:
"""Perform one time evolution step in trajectory generation.
Advances the trajectory by one block of time steps, following the
evolution in the underlying historical data set starting from current_state.
Parameters
----------
trajectory : list[xr.Dataset]
List of trajectory states to append new states to.
current_block_start_state : xr.Dataset
State to start current state from.
map_n_step_transition : xr.Dataset
Mapping of allowed n-day transitions between states.
stepper : TimeStepper
Time stepper instance for managing temporal progression.
blocksize : int
Number of days in each time block.
Returns
-------
tuple[xr.Dataset, float]
Next state and corresponding year fraction.
"""
for i in range(0, blocksize):
trajectory.append(
xr.Dataset(
{
"lead_time": current_block_start_state.lead_time
+ i * np.timedelta64(1, "D"),
"init_time": current_block_start_state.init_time,
"ensemble_member": current_block_start_state.ensemble_member,
},
coords={
"out_time": current_block_start_state.out_time
+ datetime.timedelta(days=i)
},
)
)
next_coords = map_n_step_transition.sel(
dayofyear=current_block_start_state.dayofyear,
year=current_block_start_state.year,
sample=current_block_start_state.sample,
)
next_state = xr.Dataset(
self.ds_similarities.init_time.sel(
dayofyear=next_coords.next_dayofyear,
year=next_coords.next_year,
sample=next_coords.next_sample,
).coords
).reset_coords(drop=False)
next_out_time, next_year_fraction = next(stepper)
next_state["out_time"] = next_out_time
next_state["ensemble_member"] = current_block_start_state.ensemble_member
return next_state, next_year_fraction
[docs]
def sampling_step(
self,
next_state: xr.Dataset,
next_year_fraction: float,
map_n_step_transition: xr.Dataset,
probability_model: ProbabilityModel,
rng: np.random.Generator,
) -> xr.Dataset:
"""Perform analog sampling step to select next weather state.
Samples an analog state from historical data based on similarity
to the true next state and according to distribution and constraints
defined by the probability_model.
Parameters
----------
next_state : xr.Dataset
True next state in underlying historic dataset.
next_year_fraction : float
Year fraction of next sample. Used to define temporal similarity
rather than an actual calender date to simplify calendar handling.
map_n_step_transition : xr.Dataset
Mapping of allowed n-day transitions between states.
probability_model : ProbabilityModel
Model defining sampling probabilities given similarities.
rng : np.random.Generator
Random number generator for sampling.
Returns
-------
xr.Dataset
Sampled analog state for the next time step.
"""
s_sims = self.ds_similarities.sel(
year=next_state.year,
dayofyear=next_state.dayofyear,
sample=next_state.sample,
ensemble_member=next_state.ensemble_member,
).load()
for var in [
"c_year",
"c_valid_time",
"c_sample",
"c_init_time",
"c_dayofyear",
"c_lead_time",
"c_ensemble_member",
]:
s_sims[f"sampled_{var}"] = s_sims[var].broadcast_like(s_sims.similarities)
# Mask that indicates whether a sample is a valid sample
# (i.e. the corresponding date is actually contained in the data set).
m_is_valid = ~np.isnan(
map_n_step_transition.sel(
sample=s_sims.c_sample,
year=s_sims.c_year,
dayofyear=((s_sims.dayofyear + s_sims.d_shift) - 1) % 366 + 1,
).next_year
)
# Mask that is true if the states are close to the next assigned output date.
m_is_near_to_year_fraction = is_in_window_from_year_fraction(
base_year_fractions=next_year_fraction,
other_dates=s_sims.c_valid_time.load(),
window_size=self.window_size,
ref_time=np.datetime64("2000-01-01", "ns"),
)
# Combine masks with additional mask that is true if the states
# are close to valid_date of the next sample.
m = (
m_is_valid
& m_is_near_to_year_fraction
& s_sims.m_is_near.expand_dims({"c_sample": s_sims.c_sample})
.expand_dims(
{"ensemble_member": self.ds_similarities.ensemble_member}, axis=-1
)
.load()
)
# take the mask subset:
similarities = s_sims.similarities.data[m]
coords = xr.Dataset(
{
c: ("datapoint", v.data[m])
for c, v in s_sims.data_vars.items()
if c
in [
"sampled_c_valid_time",
"sampled_c_init_time",
"sampled_c_ensemble_member",
]
}
).rename(
{
"sampled_c_valid_time": "valid_time",
"sampled_c_init_time": "init_time",
"sampled_c_ensemble_member": "ensemble_member",
}
)
i = probability_model.sample(
similarities=similarities,
coords_s_next=next_state[["init_time", "valid_time", "ensemble_member"]],
coords_candidates=coords,
rng=rng,
size=1,
)
res = xr.Dataset(
{
var: s_sims[f"sampled_c_{var}"].data[m][i][0]
for var in [
"year",
"valid_time",
"sample",
"init_time",
"dayofyear",
"lead_time",
"ensemble_member",
]
}
)
res["out_time"] = next_state["out_time"]
return res
def _load_or_create_map_file(self, path_map: str, blocksize: int) -> xr.Dataset:
"""Load existing transition map or create new one if not found.
Attempts to load a precomputed transition map file. If the file doesn't
exist or cannot be opened, creates a new transition map and saves it
atomically to prevent race conditions in parallel execution.
The map file is used to identify valid samples and provides a mapping between
coordinates of each state and the coordinates of its corresponding true
successor state.
Parameters
----------
path_map : str
Path to the transition map file.
blocksize : int
Number of days for each block of states.
Returns
-------
xr.Dataset
Transition map dataset containing valid n-day transitions.
Raises
------
ValueError
If no valid transitions exist for the chosen blocksize.
"""
# Try to load the existing file first
try:
return xr.open_dataset(
path_map, decode_timedelta=True, lock=False, mode="r"
)
except (FileNotFoundError, OSError):
# File doesn't exist or can't be opened - we'll need to create it
# Create a unique temporary filename
temp_path = os.path.join(
os.path.dirname(path_map),
f"temp_map_{blocksize}_{uuid.uuid4().hex}.nc",
)
try:
# Create the map
map_n_step_transition = get_map_valid_n_day_transitions(
self.ds_similarities.init_time.load(), n=blocksize
)
if np.isnan(map_n_step_transition["next_sample"]).all():
raise ValueError(
f"No valid transitions for chosen blocksize: {blocksize}"
)
# Save to temporary file first
map_n_step_transition.to_netcdf(temp_path)
# Try to atomically move the temp file to the final location
# This will fail if another process has created the file in the meantime
try:
# Make sure directory exists
os.makedirs(os.path.dirname(path_map), exist_ok=True)
# Try to move the file (atomic on same filesystem)
shutil.move(temp_path, path_map)
return map_n_step_transition
except (OSError, shutil.Error):
# If move fails, another process likely created the file first
# Try to load the existing file
if os.path.exists(path_map):
return xr.open_dataset(
path_map, decode_timedelta=True, lock=False, mode="r"
)
else:
# If the file still doesn't exist, return our computed map
return map_n_step_transition
finally:
# Clean up the temp file if it still exists
if os.path.exists(temp_path):
try:
os.remove(temp_path)
except OSError:
pass
[docs]
def get_initial_state(
self,
initialization: InitTimeLeadTimeMemberState | None,
map_n_step_transition: xr.Dataset,
blocksize: int,
rng: np.random.Generator,
) -> xr.Dataset:
"""Get initial state to start sampling a trajectory from.
Determines the starting state for weather generation, either from
a specified initialization or by random selection from set of valid states.
Parameters
----------
initialization : InitTimeLeadTimeMemberState | None
Specific initialization state, or None for random selection.
map_n_step_transition : xr.Dataset
Mapping of allowed n-day transitions between states.
blocksize : int
Number of days in each time block.
rng : np.random.Generator
Random number generator for random initialization.
Returns
-------
xr.Dataset
Initial state for trajectory generation.
Raises
------
ValueError
If the specified initialization is invalid.
AssertionError
If the initialization state is not found in valid transitions.
"""
if initialization is None:
# if no initialization provided select randomly from possible samples.
stacked_isnotnan_ds = ~np.isnan(map_n_step_transition.next_sample).stack(
datapoint=("dayofyear", "year", "sample")
)
initial_state = xr.Dataset(
stacked_isnotnan_ds[stacked_isnotnan_ds]
.isel(
datapoint=rng.integers(
len(stacked_isnotnan_ds[stacked_isnotnan_ds])
)
)
.drop("datapoint")
.coords
)
return initial_state.assign_coords(
ensemble_member=rng.choice(self.ds_similarities.ensemble_member)
).reset_coords(drop=False)
elif isinstance(initialization, InitTimeLeadTimeMemberState):
stacked_isnotnan_ds = ~np.isnan(
get_map_valid_n_day_transitions(
self.ds_similarities.init_time.load(), n=blocksize
).next_sample
).stack(datapoint=("dayofyear", "year", "sample"))
vsa = stacked_isnotnan_ds[stacked_isnotnan_ds]
assert len(
vsa.where(
(vsa.init_time == initialization.init_time)
& (vsa.lead_time == initialization.lead_time),
drop=True,
).datapoint
== 1
), f"{initialization} seems to be an invalid starting point."
initial_state = xr.Dataset(
vsa.where(
(vsa.init_time == initialization.init_time)
& (vsa.lead_time == initialization.lead_time),
drop=True,
)
.squeeze()
.coords
).drop_vars("datapoint")
return initial_state.assign_coords(
ensemble_member=initialization.ensemble_member
).reset_coords(drop=False)
else:
raise ValueError(f"Invalid initial condition {initialization}")
[docs]
def get_similarities_k_closest_neighbors(
self,
states: xr.DataArray,
k: int,
minimum_timedelta_days: int | None = None,
dim_states: str | None = None,
) -> xr.Dataset:
"""Get the k closest neighbors based on similarity measures.
Finds the k most similar historical states to the given query states
based on (precomputed) similarity measures.
Parameters
----------
states : xr.DataArray
Query states to find neighbors for.
k : int
Number of closest neighbors to return.
minimum_timedelta_days : int | None, optional
Minimum time separation in days between query and candidate states, that
allows excluding analogs that are temporally close to the base state if
this is undesired. By default None, i.e. no restriction.
dim_states : str | None, optional
Dimension name for states, by default None.
Returns
-------
xr.Dataset
Dataset containing the k closest neighbor states and their similarities.
"""
sims = self.ds_similarities.sel(
dayofyear=states.dayofyear,
year=states.year,
sample=states.sample,
ensemble_member=states.ensemble_member,
).load()
sims_flattened = -sims.similarities.stack(
flat_dim=[d for d in sims.dims if d != dim_states]
) # assume that similarity increases the more similar the points are
if minimum_timedelta_days is not None:
keeps_minimum_distance = (
abs(
(sims_flattened.valid_time - sims_flattened.c_valid_time)
/ np.timedelta64(1, "D")
)
>= minimum_timedelta_days
)
else:
keeps_minimum_distance = xr.ones_like(sims_flattened, dtype=bool)
return -sims_flattened.isel(
flat_dim=xr.apply_ufunc(
get_k_smallest_indices,
sims_flattened,
keeps_minimum_distance,
k,
input_core_dims=[["flat_dim"], ["flat_dim"], []],
output_core_dims=[["neighbor"]],
vectorize=True,
)
)
[docs]
def get_similarities_k_random_neighbors(
self,
states: xr.DataArray,
k: int,
rng: np.random.Generator,
minimum_timedelta_days: int | None = None,
dim_states: str | None = None,
) -> xr.Dataset:
"""Get k randomly selected neighbors from valid candidates.
Randomly selects k historical states from valid candidates that meet
the specified temporal constraints.
Parameters
----------
states : xr.DataArray
Query states to find neighbors for.
k : int
Number of random neighbors to return.
rng : np.random.Generator
Random number generator for sampling.
minimum_timedelta_days : int | None, optional
Minimum time separation in days between query and candidate states,
by default None.
dim_states : str | None, optional
Dimension name for states, by default None.
Returns
-------
xr.Dataset
Dataset containing k randomly selected neighbor states.
"""
sims = self.ds_similarities.sel(
dayofyear=states.dayofyear,
year=states.year,
sample=states.sample,
ensemble_member=states.ensemble_member,
).load()
sims_flattened = -sims.similarities.stack(
flat_dim=[d for d in sims.dims if d != dim_states]
) # assume that similarity increases the more similar the points are
if minimum_timedelta_days is not None:
keeps_minimum_distance = (
abs(
(sims_flattened.valid_time - sims_flattened.c_valid_time)
/ np.timedelta64(1, "D")
)
>= minimum_timedelta_days
)
else:
keeps_minimum_distance = xr.ones_like(sims_flattened, dtype=bool)
return -sims_flattened.isel(
flat_dim=xr.apply_ufunc(
get_k_random_indices,
sims_flattened,
keeps_minimum_distance,
k,
rng,
input_core_dims=[["flat_dim"], ["flat_dim"], [], []],
output_core_dims=[["neighbor"]],
vectorize=True,
)
)
[docs]
def get_analog_data(
self, queries: xr.DataArray, use_candidate_coords: bool = False
) -> xr.Dataset:
"""Retrieve analog weather data for specified query coordinates.
Extracts weather data from the dataset at the coordinates specified
in the query array, either using the query coordinates directly or
the candidate coordinates.
Parameters
----------
queries : xr.DataArray
Query array containing coordinate information.
use_candidate_coords : bool, optional
Whether to pick the sample according to provided coordinates of a candidate
state or of a base state.
by default False.
Returns
-------
xr.Dataset
Weather data at the specified coordinates.
"""
da_wg = xr.open_zarr(self.path_dataset, decode_timedelta=True)[self.var]
if use_candidate_coords:
return da_wg.sel(
dayofyear=queries.c_dayofyear,
year=queries.c_year,
sample=queries.c_sample,
ensemble_member=queries.c_ensemble_member,
)
else:
return da_wg.sel(
dayofyear=queries.dayofyear,
year=queries.year,
sample=queries.sample,
ensemble_member=queries.ensemble_member,
)
[docs]
def plot_k_nearest_and_random_neighbors(
self,
state: xr.Dataset,
k: int,
rng: np.random.Generator,
minimum_timedelta_days: int | None = None,
vmin: float = 450,
vmax: float = 600,
minor_spacing_contours: float = 10,
major_spacing_contours: float = 30,
) -> matplotlib.figure.Figure:
"""Create comparison plot of nearest neighbors vs random neighbors.
Generates a visualization comparing the k nearest neighbors and k random
neighbors for a given weather state, showing the base state, and the random and
nearest neighbors (analogs) among the candidates and side by side.
Parameters
----------
state : xr.Dataset
Reference weather state to find neighbors for.
k : int
Number of neighbors to display.
rng : np.random.Generator
Random number generator for random neighbor selection.
minimum_timedelta_days : int | None, optional
Minimum time separation constraint, by default None.
vmin : float, optional
Minimum value for color scale, by default 450.
vmax : float, optional
Maximum value for color scale, by default 600.
minor_spacing_contours : float, optional
Spacing for minor contour lines, by default 10.
major_spacing_contours : float, optional
Spacing for major contour lines, by default 30.
Returns
-------
matplotlib.figure.Figure
Figure containing the comparison plots.
"""
similarities_nbs = self.get_similarities_k_closest_neighbors(
states=state,
k=k,
minimum_timedelta_days=minimum_timedelta_days,
)
similarities_rands = self.get_similarities_k_random_neighbors(
states=state,
k=k,
rng=rng,
minimum_timedelta_days=minimum_timedelta_days,
)
da_nbs = self.get_analog_data(similarities_nbs, use_candidate_coords=True)
da_rands = self.get_analog_data(similarities_rands, use_candidate_coords=True)
da_base = self.get_analog_data(
similarities_nbs, use_candidate_coords=False
).metpy.quantify()
# preps for contour plots
major_levels = np.arange(vmin, vmax, major_spacing_contours)
minor_levels = np.arange(vmin, vmax, minor_spacing_contours)
minor_levels = minor_levels[~np.isin(minor_levels, major_levels)]
fig = plt.figure(figsize=(8, 14))
gs = GridSpec(
k + 3,
2,
figure=fig,
height_ratios=[2, 0.2]
+ [1] * k
+ [
0.1,
],
)
ax_cbar = fig.add_subplot(gs[-1, :])
title_ax1 = fig.add_subplot(gs[1, 0])
title_ax2 = fig.add_subplot(gs[1, 1])
title_ax1.text(
x=0.5,
y=0.5,
s="Nearest neighbors\namong candidates",
horizontalalignment="center",
fontsize="x-large",
)
title_ax2.text(
x=0.5,
y=0.5,
s="Random samples\namong candidates",
horizontalalignment="center",
fontsize="x-large",
)
title_ax1.axis("off")
title_ax2.axis("off")
# actual plotting
# base state:
ax = fig.add_subplot(gs[0, :], projection=ccrs.Robinson())
map_plot_without_frame_with_bounds(
ax=ax,
da=da_base,
vmin=vmin,
vmax=vmax,
cbar_ax=ax_cbar,
cbar_kwargs={
"orientation": "horizontal",
},
)
add_contours(
ax=ax,
da=da_base,
major_levels=major_levels,
minor_levels=minor_levels,
add_labels=True,
)
if len(self.ds_similarities.c_sample) > 1:
ax.set_title(
r"$t_{init}$"
+ f": {np.datetime_as_string((da_base.init_time).squeeze(), unit='D')} "
+ r"$t_{lead}$"
+ f": {int((da_base.lead_time / np.timedelta64(1, 'D')).data)}d "
+ "$m$"
+ f": {da_base.ensemble_member.data} "
)
else:
vt = da_base.init_time + da_base.lead_time
ax.set_title(
r"$t_{valid}$: "
+ f"{np.datetime_as_string((vt).squeeze(), unit='D')} "
+ "$m$"
+ f": {da_base.ensemble_member.data} "
)
for i in range(k):
# nearest neigbors:
ax_nb = fig.add_subplot(gs[2 + i, 0], projection=ccrs.Robinson())
da_nb = da_nbs.isel(neighbor=i)
map_plot_without_frame_with_bounds(
ax=ax_nb, da=da_nb, add_colorbar=False, vmin=vmin, vmax=vmax
)
add_contours(
ax=ax_nb, da=da_nb, major_levels=major_levels, minor_levels=minor_levels
)
if len(self.ds_similarities.c_sample) > 1:
t_init_out = (da_nb.c_init_time).squeeze()
ax_nb.set_title(
r"$t_{init}$: "
+ f"{np.datetime_as_string(t_init_out, unit='D')}"
+ r" $t_{lead}$"
+ f": {int((da_nb.c_lead_time / np.timedelta64(1, 'D')).data)}d "
+ "$m$"
+ f": {da_nb.c_ensemble_member.data} "
)
else:
vt = da_nb.c_init_time + da_nb.c_lead_time
ax_nb.set_title(
r"$t_{valid}$"
+ f": {np.datetime_as_string((vt).squeeze(), unit='D')} "
+ "$m$"
+ f": {da_nb.c_ensemble_member.data} "
)
# random neighbors:
ax_rand = fig.add_subplot(gs[2 + i, 1], projection=ccrs.Robinson())
da_rand = da_rands.isel(neighbor=i)
map_plot_without_frame_with_bounds(
ax=ax_rand, da=da_rand, add_colorbar=False, vmin=vmin, vmax=vmax
)
add_contours(
ax=ax_rand,
da=da_rand,
major_levels=major_levels,
minor_levels=minor_levels,
)
if len(self.ds_similarities.c_sample) > 1:
t_init_out = (da_rand.c_init_time).squeeze()
ax_rand.set_title(
r"$t_{init}$"
+ f": {np.datetime_as_string(t_init_out, unit='D')} "
+ r"$t_{lead}$"
+ f": {int((da_rand.c_lead_time / np.timedelta64(1, 'D')).data)}d "
+ "$m$"
+ f": {da_rand.c_sample.data} "
)
else:
vt = da_rand.c_init_time + da_rand.c_lead_time
ax_rand.set_title(
r"$t_{valid}$"
+ f": {np.datetime_as_string((vt).squeeze(), unit='D')} "
+ "$m$"
+ f": {da_rand.c_sample.data} "
)
return fig
[docs]
@snakemake_handler
def main(snakemake: Any) -> None:
"""Main function for weather generator execution in Snakemake workflow.
Initializes and runs the weather generator with parameters from Snakemake,
handling logging and parameter management for the workflow execution.
Parameters
----------
snakemake : Any
Snakemake object containing input/output paths, parameters, and logging
configuration.
"""
all_params = snakemake.params.all_params.copy()
tracked_params = snakemake.params.tracked_params.copy()
all_params["zarr_year_dayofyear"] = snakemake.input["zarr_year_dayofyear"]
tracked_params["zarr_year_dayofyear"] = snakemake.input["zarr_year_dayofyear"]
all_params["dir_wg"] = snakemake.output["dir_wg"]
tracked_params["dir_wg"] = snakemake.output["dir_wg"]
os.makedirs(snakemake.output["dir_wg"], exist_ok=True)
with open(os.path.join(snakemake.output["dir_wg"], "params.yaml"), "w") as f:
yaml.dump(tracked_params, f, default_flow_style=False, sort_keys=False)
WeatherGenerator(params=all_params)
if __name__ == "__main__":
main(snakemake=snakemake) # noqa: F821