Getting started
In the following, all the code blocks directly related to pastax are opened by default,
while the closed ones are not required to learn about the library.
Show/hide code
%load_ext autoreload
%autoreload 2
Show/hide code
import jax
jax.config.update("jax_enable_x64", True)
Defining a velocity field¶
In the following, we will use an idealized configuration (inspired by this one from Parcels documentation) where the velocity field \(\mathbf{u}\) presents two eddies displaced over time.
We consider a spatial domain spanning from 29° to 31° latitude and -1° to 1° longitude, and a time period of two days.
Show/hide code
import jax.numpy as jnp
import numpy as np
ny = nx = 101
nt = 49
lat = jnp.linspace(29, 31, ny)
lon = jnp.linspace(-1, 1, nx)
dt = 1 * 60 * 60 # 1 hour in seconds
ts = np.linspace(0, dt * (nt - 1), nt) # discretize interval of 2 days in seconds
Below we defined the \(\mathbf{u}_u\) and \(\mathbf{u}_v\) fields, considering circular and geostrophic eddies.
Show/hide code
from pastax.utils import degrees_to_meters
dy, dx = degrees_to_meters(jnp.asarray([lat[1] - lat[0], lon[1] - lon[0]]), lat.mean())
# SSH field parameters
eta0 = 0.4
sig = 0.1
eta_fn = lambda _dx, _dy: eta0 * jnp.exp(-((_dx / (sig * nx)) ** 2) - (_dy / (sig * ny)) ** 2)
xi, yi = jnp.meshgrid(jnp.arange(lon.size), jnp.arange(lat.size), indexing="ij")
# eddies displacement parameters
eddy_speed = 0.05
dxi = eddy_speed * dt / dx
dyi = eddy_speed * dt / dy
# SSH field
eta = jnp.zeros((nt, ny, nx))
for t in range(nt):
# eddy 1
x1 = 0.5 * lon.size - (8 + dxi * t)
y1 = 0.5 * lat.size - (8 + dyi * t)
eta = eta.at[t, :, :].add(eta_fn((x1 - xi), (y1 - yi)))
# eddy 2
x2 = 0.5 * lon.size + (8 + dxi * t)
y2 = 0.5 * lat.size + (8 + dyi * t)
eta = eta.at[t, :, :].add(-eta_fn((x2 - xi), (y2 - yi)))
# geostrophic velocities
coriolis_factor = 2 * 7.292115e-5 * jnp.sin(jnp.radians(lat.mean()))
g = 9.81
u = jnp.zeros_like(eta)
u = u.at[:, :-1, :].set(jnp.diff(eta[:, :, :], axis=1) / dy / coriolis_factor * g)
u = u.at[:, -1, :].set(u[:, -2, :])
v = jnp.zeros_like(eta)
v = v.at[:, :, :-1].set(-jnp.diff(eta[:, :, :], axis=-1) / dx / coriolis_factor * g)
v = v.at[:, :, -1].set(v[:, :, -2])
Looking at \(\| \mathbf{u} \|\) over time we can see how the eddies separate.
Show/hide code
import cmocean.cm as cmo
import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib import animation
plt.ioff()
fig, ax = plt.subplots()
im = ax.pcolormesh(lon, lat, jnp.sqrt(u[0] ** 2 + v[0] ** 2), vmax=2.75, vmin=0, cmap=cmo.speed)
clb = fig.colorbar(im, ax=ax, orientation="vertical", label="m s$^{-1}$")
_ = clb.ax.set_title("$\| \mathbf{u} \|$")
quiver = ax.quiver(lon[::5], lat[::5], u[0][::5, ::5], v[0][::5, ::5], scale=50)
ax.set_aspect("equal")
title = ax.set_title(f"$t$ = {ts[0] / 60 / 60:.0f}h")
fig.tight_layout()
def draw_frame(i):
im.set_array(jnp.sqrt(u[i] ** 2 + v[i] ** 2).ravel())
quiver.set_UVC(u[i][::5, ::5], v[i][::5, ::5])
title.set_text(f"$t$ = {ts[i] / 60 / 60:.0f}h")
return im, quiver, title
def init_func():
return draw_frame(0)
anim = animation.FuncAnimation(fig, draw_frame, init_func=init_func, frames=nt, interval=100, blit=True)
HTML(anim.to_html5_video())
In order to interpolate this velocity field during simulations of trajectories, we define a pastax.gridded.Gridded (a xarray.Dataset-likish object compatible with JAX just-in-time compilation mechanism).
Show/hide code
from pastax.gridded import Gridded
two_eddies_grid = Gridded.from_array(
fields={"u": u, "v": v}, # the fields
time=ts,
latitude=lat,
longitude=lon, # the coordinates
)
Simulating trajectories¶
The initial position of the simulated trajectories is chosen to be right in between the two eddies: 30° latitude and 0° longitude.
We use a pastax.trajectory.Location to represent this location.
Show/hide code
from pastax.trajectory import Location
x0 = Location([30, 0])
We also define the time range of the simulated trajectory, and the initial integration timestep.
Show/hide code
ts_sim = ts[1:-1]
dt0 = 60 * 60 # integration timestep: 60 minutes in seconds
Generating a single trajectory¶
Let's start with the simplest use-case: advecting a single particle inside a given velocity field \(\mathbf{u}\).
For this, we use a pastax.simulator.DeterministicSimulator simulator, and a pastax.dynamics.linear_uv dynamics.
Formulation
This combination of simulator and dynamics formulates the displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:
by interpolating the velocity field \(\mathbf{u}\) in space and time.
Show/hide code
from pastax.dynamics import linear_uv
from pastax.simulator import DeterministicSimulator
traj_sim = DeterministicSimulator()
uv_dynamics = linear_uv
pastax.simulator.BaseSimulator simulators are callables returning a simulated pastax.trajectory.Trajectory (or a pastax.trajectory.TrajectoryEnsemble as we will see later).
Show/hide code
traj = traj_sim(dynamics=uv_dynamics, args=two_eddies_grid, x0=x0, ts=ts_sim, dt0=dt0)
We can plot trajectories using their pastax.trajectory.Trajectory.plot method.
Show/hide code
plt.close()
fig, ax = plt.subplots()
# plot the field
im = ax.pcolormesh(lon, lat, jnp.sqrt((u**2 + v**2)).max(axis=0), cmap=cmo.speed, vmin=0)
clb = fig.colorbar(im, ax=ax, orientation="vertical", label="m s$^{-1}$")
_ = clb.ax.set_title("$\| \mathbf{u} \|$")
# add the trajectory
ax = traj.plot(ax, label="Trajectory", color="blue")
ax.set_aspect("equal")
ax.legend()
fig.tight_layout()
plt.show()
Since the particle is released right between the eddies, its trajectory is perpendicular (up to numerical approximations) to the direction of the eddies. Because of smaller scales processes (not represented by our geostrophic velocity fields) this very unlikely to happen in practice, even if two eddies were to separate in exactly opposite ways (which is also very unlikely of course).
To get an idea of the different possible "realistic" trajectories we will next generate a stochastic ensemble.
Sampling an ensemble of trajectories¶
Instead of a single trajectory, we will now simulate an ensemble of 100 trajectories by sampling solutions from a Stochastic Differential Equation with Smagorinsky diffusion using the pastax.simulator.StochasticSimulator simulator and a pastax.dynamics.StochasticSmagorinskyDiffusion.
Formulation
This combination of simulator and dynamics formulates a displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:
where \(V = \sqrt{2 K}\) and \(K\) is the Smagorinsky diffusion:
where \(C_s\) is the tunable Smagorinsky constant, \(\Delta x \Delta y\) a spatial scaling factor, and the rest of the expression represents the horizontal diffusion.
Show/hide code
from pastax.dynamics import StochasticSmagorinskyDiffusion
from pastax.simulator import StochasticSimulator
n_samples = 100
ens_sim = StochasticSimulator()
smag_dynamics = StochasticSmagorinskyDiffusion.from_cs(cs=1e-1)
traj_ens = ens_sim(
dynamics=smag_dynamics,
args=two_eddies_grid,
x0=x0,
ts=ts_sim,
dt0=dt0,
n_samples=n_samples,
)
This time, the simulator returns a pastax.trajectory.TrajectoryEnsemble.
Again, we can plot it using the pastax.trajectory.TrajectoryEnsemble.plot method.
Show/hide code
fig, ax = plt.subplots()
# plot the field
im = ax.pcolormesh(lon, lat, jnp.sqrt((u**2 + v**2)).max(axis=0), cmap=cmo.speed, vmin=0)
clb = fig.colorbar(im, ax=ax, orientation="vertical", label="m s$^{-1}$")
_ = clb.ax.set_title("$\| \mathbf{u} \|$")
# add the ensemble
ax = traj_ens.plot(ax, label="Ensemble", color="black", alpha_factor=0.3)
# add the trajectory on top
ax = traj.plot(ax, label="Trajectory", color="blue")
ax.set_aspect("equal")
ax.legend()
fig.tight_layout()
plt.show()
We see that the ensemble much better represent the presence of the two eddies.
In a controlled deterministic case¶
We stick to our idealized velocity field \(\mathbf{u}\) and consider a reference (pair of simulator and) dynamics (\(\mathcal{S}_\text{ref}\)) whose output will be considered as the truth and a tunable dynamics (\(\mathcal{S}_\text{tun}\)) that we will want to calibrate according to the truth.
Here, the reference dynamics knows and uses \(\mathbf{u}\) to simulate the true trajectory. The tunable dynamics uses as input a modified version of the velocity field: \(\widehat{\mathbf{u}} = (\mathbf{u} - b) / a\), and apply the linear transformation \(x \mapsto \beta_0 + x \beta_1\) to \(\widehat{\mathbf{u}}\) to simulate an estimated trajectory.
Starting from randomly sampled \(\beta_0\) and \(\beta_1\), we will (hopefully) learn their expected values \(b\) and \(a\) by minimizing (using a gradient-based optimizer) a metric \(\mathcal{L}\) quantifying the distance between the true and estimated trajectories:
Show/hide code
true_grid = two_eddies_grid # same as before
# velocity field transofrmation parameters
a = 2
b = 0.1
# transformed velocity field
trans_u = (u - b) / a
trans_v = (v - b) / a
trans_grid = Gridded.from_array(
fields={"u": trans_u, "v": trans_v}, # the fields
time=ts,
latitude=lat,
longitude=lon, # the coordinates
)
As before, we use a pastax.dynamics.linear_uv dynamics for the reference \(\mathcal{S}_\text{ref}\).
For the tunable one \(\mathcal{S}_\text{tun}\), we use a pastax.dynamics.LinearUV, whose intercept and slope attributes can be optimized.
Show/hide code
from pastax.dynamics import LinearUV
ref_dynamics = linear_uv
tun_dynamics_init = LinearUV(intercept=0., slope=0.)
print(f"Initial parameters: (intercept={tun_dynamics_init.intercept.item()}, slope={tun_dynamics_init.slope.item()})")
To make the example more interesting, we are going to use a different initial position and start inside the leftward eddy.
Show/hide code
x0 = Location([29.95, -0.05])
true_traj = traj_sim(dynamics=ref_dynamics, args=true_grid, x0=x0, ts=ts_sim, dt0=dt0)
est_traj_init = traj_sim(dynamics=tun_dynamics_init, args=trans_grid, x0=x0, ts=ts_sim, dt0=dt0)
Show/hide code
fig, ax = plt.subplots()
# plot the field
im = ax.pcolormesh(lon, lat, jnp.sqrt((u**2 + v**2)).max(axis=0), cmap=cmo.speed, vmin=0)
clb = fig.colorbar(im, ax=ax, orientation="vertical", label="m s$^{-1}$")
_ = clb.ax.set_title("$\| \mathbf{u} \|$")
# add the reference
ax = true_traj.plot(ax, label="Truth", color="blue")
# add the estimation
ax = est_traj_init.plot(ax, label="Estimate", color="black")
ax.set_xlim(-1, 1)
ax.set_ylim(29, 31)
ax.set_aspect("equal")
ax.legend()
fig.tight_layout()
plt.show()
We can see that the true trajectory is trapped and follows the eddy. The estimated one is not visible as we used 0 for the initial values of \(\beta_0\) and \(\beta_1\).
To obtain an estimate matching the truth by adjusting the parameters \(\beta_0\) and \(\beta_1\) using automatic differentiation, we need a disimilarity metric and a minimizer.
There is plenty of options for the choice of the disimilarity metric, including:
- the Liu Index,
- the Separation distance (distance on Earth in meters),
- the Euclidean distance (distance in degrees),
- ...
In this case, we will use the the Liu Index calculated every hour between the true and the estimated trajectories.
Show/hide code
import diffrax as dfx
import equinox as eqx
def sim_fn(dynamics):
return traj_sim(dynamics, args=trans_grid, x0=x0, ts=ts_sim, dt0=dt0, adjoint=dfx.ForwardMode())
@eqx.filter_jit
def residual_fn(dynamics, ref_traj):
est_traj = sim_fn(dynamics)
return ref_traj.liu_index(est_traj).value
Important
As our simulator only has two parameters, we want to use forward mode AD rather than reverse mode.
For this we need to tell Diffrax not to use diffrax.RecursiveCheckpointAdjoint adjoint as it does not support forward mode AD but to employ diffrax.ForwardMode.
We then use the eqx.filter_jacfwd wrapper rather than eqx.filter_grad.
Note
pastax.trajectory.Trajectory objects define several methods to compute metrics or distances between trajectories, such as liu_index here.
We use here Optimistix to solve our optimisation problem using a Gauss-Newton solver.
Show/hide code
import optimistix as optx
gn_sol = optx.least_squares(
residual_fn,
solver=optx.BestSoFarLeastSquares(optx.GaussNewton(rtol=1e-8, atol=1e-8)),
y0=tun_dynamics_init,
args=true_traj
)
tun_dynamics_final = gn_sol.value
Note
Optax is also a good option for solving this problem.
After optimization, the fitted parameters of the calibrated simulator correspond to their expected values:
Show/hide code
print(f"Estimated parameters after {gn_sol.stats['num_steps'].item()} iterations: (intercept={tun_dynamics_final.intercept.item():.4f}, slope={tun_dynamics_final.slope.item():.4f})")
Therefore, we should now find a perfect agreement between the two trajectories.
Show/hide code
est_traj_final = sim_fn(tun_dynamics_final)
Show/hide code
fig, ax = plt.subplots()
# plot the field
im = ax.pcolormesh(
lon,
lat,
jnp.sqrt((u**2 + v**2)).max(axis=0),
cmap=cmo.speed,
vmin=0,
vmax=np.sqrt((u**2 + v**2)).max(),
)
clb = fig.colorbar(im, ax=ax, extend="max", orientation="vertical", label="m s$^{-1}$")
_ = clb.ax.set_title("$\| \mathbf{u} \|$")
# add the reference
ax = true_traj.plot(ax, label="Truth", color="blue")
# add the estimation
ax = est_traj_final.plot(ax, label="Estimate", color="black")
ax.set_aspect("equal")
ax.legend()
fig.tight_layout()
plt.show()
And indeed, in this toy use-case, we were able to perfectly simulate the reference trajectory.
Bonus
As we are optimizing only 2 parameters, we can easily inspect the "landscape" of the loss function and its gradient.
Show/hide code
def least_square_fn(dynamics):
residuals = residual_fn(dynamics, true_traj)
loss = jnp.sum(residuals ** 2)
return loss, loss
grad_aux_fn = eqx.filter_jacfwd(least_square_fn, has_aux=True)
# parameters domain
betas0 = jnp.arange(-0.75, 0.75, 0.01)
betas1 = jnp.arange(-1, 3, 0.05)
betas0, betas1 = jnp.meshgrid(betas0, betas1, indexing="ij")
# evaluate the loss and gradient landscapes
grad_landscape, loss_landscape = jax.vmap(
lambda intercept, slope: grad_aux_fn(LinearUV(intercept=intercept, slope=slope)),
in_axes=(0, 0),
)(betas0.ravel(), betas1.ravel())
loss_landscape = loss_landscape.reshape(betas0.shape)
grad_intercept_landscape = grad_landscape.intercept.reshape(betas0.shape)
grad_slope_landscape = grad_landscape.slope.reshape(betas0.shape)
Note
See how using jax.vmap makes it easy (and fast) to simulate a trajectory and evaluate the loss and its gradient for every parameters pair of our domain.
Show/hide code
def residual_fn_wrapper(dynamics, ref_traj):
return residual_fn(dynamics, ref_traj), None
solver = optx.GaussNewton(rtol=1e-8, atol=1e-8)
args = true_traj
options = {}
f_struct = jax.ShapeDtypeStruct((true_traj.length,), jnp.float64)
aux_struct = None
tags = frozenset()
step = eqx.filter_jit(
eqx.Partial(solver.step, fn=residual_fn_wrapper, args=args, options=options, tags=tags)
)
terminate = eqx.filter_jit(
eqx.Partial(solver.terminate, fn=residual_fn_wrapper, args=args, options=options, tags=tags)
)
state = solver.init(residual_fn_wrapper, tun_dynamics_init, args, options, f_struct, aux_struct, tags)
done, result = terminate(y=tun_dynamics_init, state=state)
current_dynamics = tun_dynamics_init
all_dynamics = [tun_dynamics_init]
while not done:
current_dynamics, state, _ = step(y=current_dynamics, state=state)
done, result = terminate(y=current_dynamics, state=state)
all_dynamics.append(current_dynamics)
Show/hide code
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerBase
class ArrowHandler(HandlerBase):
def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
arrow = mpatches.FancyArrow(2, 2, 10, 0, color=orig_handle.get_edgecolor(), width=1)
return [arrow]
intercepts = jnp.asarray([dyn.intercept for dyn in all_dynamics]).ravel()
slopes = jnp.asarray([dyn.slope for dyn in all_dynamics]).ravel()
fig, ax = plt.subplots()
ax.set_xlabel("Intercept")
ax.set_ylabel("Slope")
vmax = jnp.quantile(loss_landscape.ravel(), 0.975)
vmin = jnp.quantile(loss_landscape.ravel(), 0.025)
im = ax.pcolormesh(
betas0,
betas1,
loss_landscape,
cmap=cmo.amp,
norm="log",
vmin=vmin,
vmax=vmax,
alpha=0.8,
)
clb = fig.colorbar(im, ax=ax, orientation="vertical", extend="both")
_ = clb.ax.set_title("$\\mathcal{L}$")
ax.quiver(
betas0[::5, ::5],
betas1[::5, ::5],
-grad_intercept_landscape[::5, ::5],
-grad_slope_landscape[::5, ::5],
color="black",
alpha=0.5,
)
ax.scatter(b, a, color="blue", label="True parameters", s=50)
ax.scatter(0., 0., color="red", label="Initial guess", s=20)
ax.scatter(intercepts[-1], slopes[-1], color="orange", label="Final estimate", s=15)
ax.plot(intercepts, slopes, color="black", linewidth=1)
handles, _ = ax.get_legend_handles_labels()
ax.legend(
handles=handles + [mpatches.FancyArrow(0, 0, 0, 0, color="black", label="$-\\nabla_\Theta \\mathcal{L}$")],
handler_map={mpatches.FancyArrow: ArrowHandler()},
)
fig.tight_layout()
plt.show()
Experiment for the reader
This idealize scenario is actually quite interesting to adress some questions related to the choice of the objective function:
- which metric or distance should I use?
- should I consider the error along the whole trajectory or only at the final position?
Reproducing this section by playing with the objective function will show that the choice of using the squared sum of the Liu Index evaluated every hour is an important one.
In a real-world stochastic context¶
For this example we will finally use real data!
As illustrated by the pastax logo, a significant number of real drifter trajectories are available to oceanographers and compiled in the Global Drifter Program (GDP) dataset.
We also observe at a coarse resolution one of the main drivers of the drifters trajectories: the geostrophic surface currents derived from altimetry, and distributed by CNES and CLS as part of the DUACS gridded product.
DUACS surface currents come with inherent uncertainties, which must be considered when using them for trajectory reconstruction. One effective approach is to sample ensembles of trajectories rather than simulate deterministic ones. This approach allows us to account for uncertainty by exploring a range of possible outcomes.
In this section, we will revisit the Smagorinsky diffusion model introduced in Sampling an ensemble of trajectories. We will tune its \(C_S\) parameter by minimizing the loss between a simulated ensemble and an observed trajectory.
The libraries clouddrift and copernicusmarine provide lazy-access to the GDP and DUACS data.
Below we load them as xarray.Dataset...
Show/hide code
import clouddrift as cd
import copernicusmarine as cm
start_time = np.datetime64("2023-03-22T00:00:00")
n_days = 7
gdp6h_ds = cd.datasets.gdp6h() # full GDP 6h dataset
traj_ds = cd.ragged.subset( # select a trajectory of interest
gdp6h_ds,
{
"WMO": 4402885,
"time": (start_time, start_time + np.timedelta64(n_days, "D")),
},
row_dim_name="traj",
)
full_duacs_ds = cm.open_dataset( # full DUACS dataset subset for plotting
dataset_id="cmems_obs-sl_glo_phy-ssh_my_allsat-l4-duacs-0.125deg_P1D",
service="arco-geo-series",
variables=["ugos", "vgos"],
start_datetime=(traj_ds.time.compute().min().values - np.timedelta64(1, "D")).astype("M8[s]").astype("O"),
end_datetime=(traj_ds.time.compute().max().values + np.timedelta64(1, "D")).astype("M8[s]").astype("O"),
minimum_longitude=-77,
maximum_longitude=-60,
minimum_latitude=30,
maximum_latitude=40,
)
duacs_ds = full_duacs_ds.sel( # corresponding DUACS dataset subset
latitude=slice(traj_ds.lat.min() - 2, traj_ds.lat.max() + 2),
longitude=slice(traj_ds.lon.min() - 2, traj_ds.lon.max() + 2),
)
... from which we create pastax.gridded.Gridded and pastax.trajectory.Trajectory objects:
Show/hide code
from pastax.gridded import Gridded
from pastax.trajectory import Trajectory
drifter_traj = Trajectory.from_xarray(traj_ds)
duacs_grid = Gridded.from_xarray(
duacs_ds,
fields={"u": "ugos", "v": "vgos"},
coordinates={"time": "time", "latitude": "latitude", "longitude": "longitude"},
)
The drifter we consider was trapped for a few days in an eddy near the Gulf Stream.
Show/hide code
import cartopy.crs as ccrs
from matplotlib.collections import LineCollection
from matplotlib.patches import Rectangle
drifter_ts = traj_ds.time.values.astype("M8[s]").astype("O").tolist()
ds = full_duacs_ds.sel(time=drifter_ts[0], method="nearest")
fig = plt.figure()
ax = fig.add_subplot(projection=ccrs.PlateCarree())
fig.subplots_adjust(left=0.02, bottom=0, right=1.02, top=1, wspace=None, hspace=None)
im = ax.pcolormesh(
ds.longitude,
ds.latitude,
np.sqrt(ds.ugos**2 + ds.vgos**2),
vmax=1.5,
vmin=0,
cmap=cmo.speed
)
clb = fig.colorbar(im, ax=ax, orientation="vertical", extend="max", shrink=0.63, label="m s$^{-1}$")
_ = clb.ax.set_title("$\| \mathbf{u} \|$")
locations = drifter_traj.locations.value[:1, None, ::-1]
segments = np.concatenate([locations[:-1], locations[1:]], axis=1)
lc = LineCollection(segments, color="blue")
ax.add_collection(lc)
area = Rectangle(
(duacs_ds.longitude.min() + 1, duacs_ds.latitude.min() + 1),
duacs_ds.longitude.max() - duacs_ds.longitude.min() - 2,
duacs_ds.latitude.max() - duacs_ds.latitude.min() - 2,
edgecolor="red",
facecolor="none",
)
ax.add_patch(area)
ax.coastlines()
title = ax.set_title(str(ds.time.values)[:-10])
def draw_frame(i):
ds = full_duacs_ds.sel(time=drifter_ts[i], method="nearest")
im.set_array(np.sqrt(ds.ugos**2 + ds.vgos**2).values.ravel())
locations = drifter_traj.locations.value[:i, None, ::-1]
segments = np.concatenate((locations[:-1], locations[1:]), axis=1)
lc.set_segments(segments)
title.set_text(str(drifter_ts[i]))
return im, title
def init_func():
return draw_frame(0)
anim = animation.FuncAnimation(
fig,
draw_frame,
init_func=init_func,
frames=len(drifter_ts),
interval=200,
blit=True,
)
HTML(anim.to_html5_video())
We initialize \(C_S\) with an arbitrary value of 0.1 and define two types of dynamics: stochastic dynamics, which include a stochastic term, and deterministic dynamics, where the stochastic term is omitted.
Show/hide code
x0 = drifter_traj.origin
ts = drifter_traj.times.value
initial_cs = .1 # default / "classical" Smagorinsky constant
stoch_smag_init = StochasticSmagorinskyDiffusion.from_cs(cs=initial_cs)
stoch_smag_too_large = StochasticSmagorinskyDiffusion.from_cs(cs=1)
stoch_smag_too_small = StochasticSmagorinskyDiffusion.from_cs(cs=.001)
deter_smag_init = LinearUV()
The adequation metric we use to compare one reference trajectory to an ensemble of simulated trajectories is inspired by the Continuous Ranked Probability Score (CRPS), a widely used measure for evaluating probabilistic forecasts. The objective is twofold: reduce ensemble bias relative to a single ground-truth trajectory, while encouraging sufficient dispersion within the ensemble. To quantify the difference between pairs of trajectories, we use the Liu Index, as In a controlled deterministic case.
Show/hide code
dt0_, saveat, stepsize_controller, adjoint, n_steps, brownian_motion = ens_sim.get_diffeqsolve_best_args(
ts, dt0, n_steps=int((ts[-1] - ts[0]) // dt0), constant_step_size=True, save_at_steps=False, ad_mode="forward"
)
def sim_fn(dynamics):
return ens_sim(
dynamics=dynamics, args=duacs_grid, x0=x0, ts=ts,
dt0=dt0_, saveat=saveat, stepsize_controller=stepsize_controller,
adjoint=adjoint, max_steps=n_steps,
n_samples=n_samples, key=jax.random.key(0),
brownian_motion=brownian_motion
)
@eqx.filter_jit
def residual_fn(dynamics, ref_traj):
def pair_pair_residual_fn(traj1: Trajectory, traj2: Trajectory) -> jax.Array:
residuals = traj1.liu_index(traj2).value
residuals = jnp.where(jnp.isnan(residuals), 1, residuals)
return residuals
traj_ens = sim_fn(dynamics)
return traj_ens.crps(ref_traj, pair_pair_residual_fn, is_metric_symmetric=False).value
Bonus
Significant speedups can be achieved by carefully selecting the arguments passed to diffrax.diffeqsolve, which is called internally.
The class method DiffraxSimulator.get_diffeqsolve_best_args applies general heuristics to determine optimal argument values based on a high-level description of the problem, derived from its own input arguments.
The calibration is performed using a Levenberg-Marquardt solver.
Show/hide code
lm_sol = optx.least_squares(
residual_fn,
solver=optx.BestSoFarLeastSquares(optx.LevenbergMarquardt(rtol=1e-8, atol=1e-8)),
y0=stoch_smag_init,
args=drifter_traj,
throw=False
)
stoch_smag_final = lm_sol.value
It results in convergence to a smaller \(C_S\) value, indicating a tighter ensemble. However, the value does not converge to 0, thereby avoiding ensemble collapse (a situation where all ensemble members become identical, losing the intended variability).
Show/hide code
print(f"Estimated parameter after {lm_sol.stats['num_steps'].item()} iterations: Cs={stoch_smag_final.cs.item():.4f}")
Using the initial and calibrated final values of \(C_S\), we generate large ensembles of 1000 trajectories. This allows us to observe the differences between these ensembles and compare them against the deterministic prediction, which is represented by dashed lines below.
Show/hide code
traj_est_init = traj_sim(dynamics=deter_smag_init, args=duacs_grid, x0=x0, ts=ts, dt0=dt0)
traj_ens_too_large = ens_sim(
dynamics=stoch_smag_too_large,
args=duacs_grid,
x0=x0,
ts=ts,
dt0=dt0,
n_samples=1000
)
traj_ens_too_small = ens_sim(
dynamics=stoch_smag_too_small,
args=duacs_grid,
x0=x0,
ts=ts,
dt0=dt0,
n_samples=1000
)
traj_ens_final = ens_sim(
dynamics=stoch_smag_final,
args=duacs_grid,
x0=x0,
ts=ts,
dt0=dt0,
n_samples=1000
)
Show/hide code
from matplotlib.lines import Line2D
ds = duacs_ds.sel(time=drifter_ts[len(drifter_ts) // 2], method="nearest")
fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(10, 6))
im = ax1.pcolormesh(ds.longitude, ds.latitude, np.sqrt(ds.ugos**2 + ds.vgos**2), vmin=0, cmap=cmo.speed, alpha=.75)
ax1 = traj_ens_too_small.plot(ax1, label="Ensemble", color="black", alpha_factor=0.15)
ax1 = traj_est_init.plot(ax1, label="Deterministic", color="black", linestyle=":")
ax1 = drifter_traj.plot(ax1, label="Drifter", color="blue")
ax1.set_extent((ds.longitude.min(), ds.longitude.max(), ds.latitude.min(), ds.latitude.max()))
ax1.coastlines()
ax1.set_title("$C_S = 0.001$")
handles, labels = ax1.get_legend_handles_labels()
handles = [
Line2D([0], [0], color="black", linestyle=":", label="Deterministic") if label == "Deterministic" else h
for h, label in zip(handles, labels)
]
handles = handles[-1:] + handles[:-1]
labels = labels[-1:] + labels[:-1]
ax1.legend(handles=handles, labels=labels, loc="lower right")
im = ax2.pcolormesh(ds.longitude, ds.latitude, np.sqrt(ds.ugos**2 + ds.vgos**2), vmin=0, cmap=cmo.speed, alpha=.75)
ax2 = traj_ens_final.plot(ax2, label="Ensemble", color="black", alpha_factor=0.15)
ax2 = traj_est_init.plot(ax2, label="Deterministic", color="black", linestyle=":")
ax2 = drifter_traj.plot(ax2, label="Drifter", color="blue")
ax2.set_extent((ds.longitude.min(), ds.longitude.max(), ds.latitude.min(), ds.latitude.max()))
ax2.coastlines()
ax2.set_title(f"$C_S = C_S^\\star \\approx {stoch_smag_final.cs.item():.3f}$")
im = ax3.pcolormesh(ds.longitude, ds.latitude, np.sqrt(ds.ugos**2 + ds.vgos**2), vmin=0, cmap=cmo.speed, alpha=.75)
ax3 = traj_ens_too_large.plot(ax3, label="Ensemble", color="black", alpha_factor=0.15)
ax3 = traj_est_init.plot(ax3, label="Deterministic", color="black", linestyle=":")
ax3 = drifter_traj.plot(ax3, label="Drifter", color="blue")
ax3.set_extent((ds.longitude.min(), ds.longitude.max(), ds.latitude.min(), ds.latitude.max()))
ax3.coastlines()
ax3.set_title("$C_S = 1$")
fig.subplots_adjust(right=0.915)
cbar_ax = fig.add_axes([0.96, 0.28, 0.02, 0.43])
cbar_ax.set_title("$\| \mathbf{u} \|$")
fig.colorbar(im, cax=cbar_ax, label="m s$^{-1}$")
plt.show()
With large ensembles of sampled trajectories, we can further visualize the results using 2D histograms that better highlight the differences in spatial distribution and variability of the samplers.
Show/hide code
def compute_hist2d(traj_ens):
def get_bin_edges(bins_center):
dim_hs = jnp.pad(jnp.diff(bins_center) / 2, (1, 1), mode="edge")
return jnp.concatenate(
(bins_center[0:1] - dim_hs[0:1], bins_center + dim_hs[1:]),
axis=0
)
bin_edges = [get_bin_edges(bins_center) for bins_center in (duacs_ds.longitude.values, duacs_ds.latitude.values)]
hist2d = jax.vmap(
lambda x, y: jnp.histogram2d(x, y, bins=bin_edges, density=True)[0],
in_axes=(1, 1),
)(traj_ens.longitudes.value, traj_ens.latitudes.value)
return hist2d.swapaxes(1, 2)
hist2d_too_small = compute_hist2d(traj_ens_too_small)
hist2d_too_large = compute_hist2d(traj_ens_too_large)
hist2d_final = compute_hist2d(traj_ens_final)
Show/hide code
ds = duacs_ds.sel(time=drifter_ts[0], method="nearest")
vmax = np.mean([hist2d_too_small[-1].max(), hist2d_final[-1].max(), hist2d_too_large[-1].max()])
fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(10, 6))
fig.subplots_adjust(left=0.02, bottom=0, right=0.98, top=1, wspace=None, hspace=None)
im1 = ax1.pcolormesh(ds.longitude, ds.latitude, hist2d_too_small[0], cmap=cmo.gray_r, vmax=vmax)
ax1 = drifter_traj.plot(ax1, label="Drifter", color="blue")
ax1.set_extent((ds.longitude.min(), ds.longitude.max(), ds.latitude.min(), ds.latitude.max()))
ax1.coastlines()
ax1.set_title("$C_S = 0.001$")
im2 = ax2.pcolormesh(ds.longitude, ds.latitude, hist2d_final[0], cmap=cmo.gray_r, vmax=vmax)
ax2 = drifter_traj.plot(ax2, label="Drifter", color="blue")
ax2.set_extent((ds.longitude.min(), ds.longitude.max(), ds.latitude.min(), ds.latitude.max()))
ax2.coastlines()
ax2.set_title(f"$C_S = C_S^\\star \\approx {stoch_smag_final.cs.item():.3f}$")
im3 = ax3.pcolormesh(ds.longitude, ds.latitude, hist2d_too_large[0], cmap=cmo.gray_r, vmax=vmax)
ax3 = drifter_traj.plot(ax3, label="Drifter", color="blue")
ax3.set_extent((ds.longitude.min(), ds.longitude.max(), ds.latitude.min(), ds.latitude.max()))
ax3.coastlines()
ax3.set_title("$C_S = 1$")
fig.subplots_adjust(bottom=0.1)
cbar_ax = fig.add_axes([0.2, 0.2, 0.65, 0.04])
fig.colorbar(im3, cax=cbar_ax, extend="max", orientation="horizontal", label="density")
title = ax2.text(0.03, 1.2, str(drifter_ts[0]), fontsize="x-large", transform=ax2.transAxes)
def draw_frame(i):
im1.set_array(hist2d_too_small[i].ravel())
im2.set_array(hist2d_final[i].ravel())
im3.set_array(hist2d_too_large[i].ravel())
title.set_text(str(drifter_ts[i]))
return im1, im2, im3, title
def init_func():
return draw_frame(0)
anim = animation.FuncAnimation(
fig,
draw_frame,
init_func=init_func,
frames=len(drifter_ts),
interval=200,
blit=True
)
HTML(anim.to_html5_video())
With this very simple model we see that we can calibrate both the deterministic and the stochastic terms of the drift dynamics, provided that we employ an appropriate loss function.