Skip to content

API

pastax.simulator

This module provides base pastax.simulator.BaseSimulator classes for pastax.trajectory.Trajectory and pastax.trajectory.TrajectoryEnsemble simulation in JAX.

BaseSimulator

Bases: Module

Base class for defining differentiable pastax.trajectory.Trajectory or pastax.trajectory.TrajectoryEnsemble simulators.

Attributes:

Name Type Description
id str | None

The identifier for the simulator.

Methods:

Name Description
__call__

Simulates a pastax.trajectory.Trajectory or pastax.trajectory.TrajectoryEnsemble following the prescribe drift dynamics and physical field(s) args, from the initial pastax.trajectory.Location x0 at time steps (including t0) ts, using a given solver.

Source code in pastax/simulator/_base_simulator.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class BaseSimulator(eqx.Module):
    """
    Base class for defining differentiable [`pastax.trajectory.Trajectory`][] or
    [`pastax.trajectory.TrajectoryEnsemble`][] simulators.

    Attributes
    ----------
    id : str | None
        The identifier for the simulator.

    Methods
    -------
    __call__(dynamics, args, x0, ts, solver, dt0)
        Simulates a [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][]
        following the prescribe drift `dynamics` and physical field(s) `args`,
        from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
        using a given `solver`.
    """

    id: str | None = None

    def __call__(
        self,
        dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree],
        args: PyTree,
        x0: Location,
        ts: Real[Array, "time"],
        n_samples: Int[Any, ""] | None,
        key: Key[Array, ""] | None,
        solver: Callable,
        dt0: Real[Any, ""],
    ) -> Trajectory | TrajectoryEnsemble:
        r"""
        Simulates a [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][]
        following the prescribe drift `dynamics` and physical field(s) `args`,
        from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
        using a given `solver`.

        This method must be implemented by its subclasses.

        Parameters
        ----------
        dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree]
            A Callable (including an [`equinox.Module`][] with a __call__ method) describing the dynamics of the
            right-hand-side of the solved Differential Equation.

            !!! example

                Formulating the displacement at time $t$ from the position $\mathbf{X}(t)$ as:

                $$
                d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt
                $$

                `dynamics` is here the function $f$ returning the displacement speed.
                In the simpliest case, $f$ is the function interpolating a velocity field $\mathbf{u}$
                in space and time.

        args : PyTree
            The PyTree of argument(s) required to compute the `dynamics`.
            Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
            (SSC, SSH, SST, etc...).
        x0 : Location
            The initial [`pastax.trajectory.Location`][].
        ts : Real[Any, "time"]
            The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
        solver : Callable
            The solver to use for the simulation.
        dt0 : Real[Any, ""]
            The initial time step of the solver, unit should be the same as for `ts`.

        Returns
        -------
        Trajectory | TrajectoryEnsemble
            The simulated [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][],
            including the initial conditions (x0, t0).

        Raises
        ------
        NotImplementedError
            If the method is not implemented by the subclass.
        """
        raise NotImplementedError()
__call__(dynamics: Callable[[Real[Any, ''], PyTree, PyTree], PyTree], args: PyTree, x0: Location, ts: Real[Array, time], n_samples: Int[Any, ''] | None, key: Key[Array, ''] | None, solver: Callable, dt0: Real[Any, '']) -> Trajectory | TrajectoryEnsemble

Simulates a pastax.trajectory.Trajectory or pastax.trajectory.TrajectoryEnsemble following the prescribe drift dynamics and physical field(s) args, from the initial pastax.trajectory.Location x0 at time steps (including t0) ts, using a given solver.

This method must be implemented by its subclasses.

Parameters:

Name Type Description Default
dynamics Callable[[Real[Any, ''], PyTree, PyTree], PyTree]

A Callable (including an equinox.Module with a call method) describing the dynamics of the right-hand-side of the solved Differential Equation.

Example

Formulating the displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:

\[ d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt \]

dynamics is here the function \(f\) returning the displacement speed. In the simpliest case, \(f\) is the function interpolating a velocity field \(\mathbf{u}\) in space and time.

required
args PyTree

The PyTree of argument(s) required to compute the dynamics. Could be for example one or several pastax.gridded.Gridded of gridded physical fields (SSC, SSH, SST, etc...).

required
x0 Location required
ts Real[Any, time]

The time steps for the simulation outputs, including \(t_0\), unit should be the same as for dt0.

required
solver Callable

The solver to use for the simulation.

required
dt0 Real[Any, '']

The initial time step of the solver, unit should be the same as for ts.

required

Returns:

Type Description
Trajectory | TrajectoryEnsemble

The simulated pastax.trajectory.Trajectory or pastax.trajectory.TrajectoryEnsemble, including the initial conditions (x0, t0).

Raises:

Type Description
NotImplementedError

If the method is not implemented by the subclass.

Source code in pastax/simulator/_base_simulator.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __call__(
    self,
    dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree],
    args: PyTree,
    x0: Location,
    ts: Real[Array, "time"],
    n_samples: Int[Any, ""] | None,
    key: Key[Array, ""] | None,
    solver: Callable,
    dt0: Real[Any, ""],
) -> Trajectory | TrajectoryEnsemble:
    r"""
    Simulates a [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][]
    following the prescribe drift `dynamics` and physical field(s) `args`,
    from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
    using a given `solver`.

    This method must be implemented by its subclasses.

    Parameters
    ----------
    dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree]
        A Callable (including an [`equinox.Module`][] with a __call__ method) describing the dynamics of the
        right-hand-side of the solved Differential Equation.

        !!! example

            Formulating the displacement at time $t$ from the position $\mathbf{X}(t)$ as:

            $$
            d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt
            $$

            `dynamics` is here the function $f$ returning the displacement speed.
            In the simpliest case, $f$ is the function interpolating a velocity field $\mathbf{u}$
            in space and time.

    args : PyTree
        The PyTree of argument(s) required to compute the `dynamics`.
        Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
        (SSC, SSH, SST, etc...).
    x0 : Location
        The initial [`pastax.trajectory.Location`][].
    ts : Real[Any, "time"]
        The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
    solver : Callable
        The solver to use for the simulation.
    dt0 : Real[Any, ""]
        The initial time step of the solver, unit should be the same as for `ts`.

    Returns
    -------
    Trajectory | TrajectoryEnsemble
        The simulated [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][],
        including the initial conditions (x0, t0).

    Raises
    ------
    NotImplementedError
        If the method is not implemented by the subclass.
    """
    raise NotImplementedError()

DiffraxSimulator

Bases: BaseSimulator

Base class for defining differentiable pastax.trajectory.Trajectory or pastax.trajectory.TrajectoryEnsemble simulators using diffrax library.

Methods:

Name Description
get_diffeqsolve_best_args

Returns optimal argument values for the diffrax.diffeqsolve function.

__call__

Simulates a pastax.trajectory.Trajectory or pastax.trajectory.TrajectoryEnsemble following the prescribe drift dynamics and physical field(s) args, from the initial pastax.trajectory.Location x0 at time steps (including t0) ts, using given diffrax.AbstractSolver, diffrax.SaveAt, diffrax.AbstractStepSizeController and diffrax.AbstractAdjoint.

Source code in pastax/simulator/_diffrax_simulator.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class DiffraxSimulator(BaseSimulator):
    """
    Base class for defining differentiable [`pastax.trajectory.Trajectory`][] or
    [`pastax.trajectory.TrajectoryEnsemble`][] simulators using `diffrax` library.

    Methods
    -------
    get_diffeqsolve_best_args(ts, dt0, n_steps, constant_step_size, save_at_steps, ad_mode)
        Returns optimal argument values for the [`diffrax.diffeqsolve`][] function.
    __call__(dynamics, args, x0, ts, solver, dt0, saveat, stepsize_controller, adjoint, max_steps)
        Simulates a [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][]
        following the prescribe drift `dynamics` and physical field(s) `args`,
        from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
        using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][], [`diffrax.AbstractStepSizeController`][] and
        [`diffrax.AbstractAdjoint`][].
    """

    @classmethod
    def get_diffeqsolve_best_args(
        cls,
        ts: Real[Any, "time"],
        dt0: Real[Any, ""],
        n_steps: Int[Any, ""] = None,
        constant_step_size: bool = True,
        save_at_steps: bool = False,
        ad_mode: Literal["forward", "reverse"] = "forward",
    ) -> tuple[
        Real[Any, ""],
        dfx.SaveAt,
        dfx.AbstractStepSizeController,
        dfx.AbstractAdjoint,
        Int[Any, ""],
        Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath],
    ]:
        """
        Returns optimal argument values for the [`diffrax.diffeqsolve`][] function.

        Significant speedups can be achieved by carefully selecting the arguments passed to [`diffrax.diffeqsolve`][],
        which is then called internally by the `__call__` method.
        This method applies general heuristics to determine optimal argument values based on a high-level description
        of the problem, derived from its own input arguments.

        Parameters
        ----------
        ts : Real[Any, "time"]
            The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
        dt0 : Real[Any, ""]
            The initial time step of the solver, unit should be the same as for `ts`.
        n_steps : Int[Any, ""], optional
            The number of steps to be taken, defaults to `None`.
        constant_step_size : bool, optional
            Whether a constant step size is used, defaults to `True`.
        save_at_steps : bool, optional
            Whether the solution is to be saved at each integration step, defaults to `False`.
        ad_mode : Literal["forward", "reverse"], optional
            The mode for automatic differentiation, defaults to "forward".
        """
        t0 = ts[0]
        t1 = ts[-1]

        if n_steps is not None:
            dt0_ = None
            stepsize_controller = dfx.StepTo(ts=jnp.linspace(t0, t1, n_steps + 1))
        else:
            dt0_ = dt0
            stepsize_controller = dfx.ConstantStepSize()

        if save_at_steps:
            saveat = dfx.SaveAt(steps=True)
        else:
            saveat = dfx.SaveAt(ts=ts)

        if ad_mode == "reverse":
            adjoint = dfx.RecursiveCheckpointAdjoint()
            if n_steps is None:
                n_steps = (t1 + dt0 - t0) // dt0
        else:
            adjoint = dfx.ForwardMode()

        if constant_step_size and n_steps is not None:
            brownian_motion: Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath] = (
                lambda shape, key: PrecomputedBrownianMotion(t0=t0, n_steps=n_steps, dt=dt0, shape=shape, key=key)
            )
        else:
            brownian_motion: Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath] = (
                lambda shape, key: dfx.VirtualBrownianTree(t0, t1, tol=dt0, shape=shape, key=key)
            )

        return dt0_, saveat, stepsize_controller, adjoint, n_steps, brownian_motion

    def __call__(
        self,
        dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree],
        args: PyTree,
        x0: Location,
        ts: Real[Any, "time"],
        solver: dfx.AbstractSolver = dfx.Heun(),
        dt0: Real[Any, ""] = None,
        saveat: dfx.SaveAt | None = None,
        stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
        adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(),
        max_steps: Int[Any, ""] = 4096,
        throw: bool = False,
    ) -> Trajectory | TrajectoryEnsemble:
        r"""
        Simulates a [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][]
        following the prescribe drift `dynamics` and physical field(s) `args`,
        from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
        using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][], [`diffrax.AbstractStepSizeController`][] and
        [`diffrax.AbstractAdjoint`][].

        Parameters
        ----------
        dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree]
            A Callable (including an [`equinox.Module`][] with a __call__ method) describing the dynamics of the
            right-hand-side of the solved Differential Equation.

            !!! example

                Formulating the displacement at time $t$ from the position $\mathbf{X}(t)$ as:

                $$
                d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt
                $$

                `dynamics` is here the function $f$ returning the displacement speed.
                In the simpliest case, $f$ is the function interpolating a velocity field $\mathbf{u}$
                in space and time.

            Parameters
            ----------
            t : Real[Array, ""]
                The current time.
            y : Float[Array, "2"]
                The current state (latitude and longitude in degrees).
            args : PyTree
                Any PyTree of argument(s) used by the simulator.
                Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
                (SSC, SSH, SST, etc...).

            Returns
            -------
            PyTree
                The drift dynamics.

        args : PyTree
            The PyTree of argument(s) required to compute the `dynamics`.
            Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
            (SSC, SSH, SST, etc...).
        x0 : Location
            The initial [`pastax.trajectory.Location`][].
        ts : Real[Any, "time"]
            The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
        solver : dfx.AbstractSolver, optional
            The [`diffrax.AbstractSolver`][] to use for the simulation, defaults to [`diffrax.Heun`][].
        dt0 : Real[Any, ""], optional
            The initial time step of the solver, unit should be the same as for `ts`, defaults to `None`.
        saveat : dfx.SaveAt, optional
            The [`diffrax.SaveAt`][] object to use for saving the solution, defaults to `SaveAt(ts=ts)`.
        stepsize_controller : dfx.AbstractStepSizeController, optional
            The [`diffrax.AbstractStepSizeController`][] to use for controlling the stepsize,
            defaults to [`diffrax.ConstantStepSize`][].
        adjoint : dfx.AbstractAdjoint, optional
            The [`diffrax.AbstractAdjoint`][] object to use for the adjoint method,
            defaults to [`diffrax.ForwardMode`][].
            [`diffrax.ForwardMode`][] should be used when computing the gradient in forward automtic differentiation
            mode with respect to few (<50) parameters, while [`diffrax.RecursiveCheckpointAdjoint`][] should be used
            when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.
        max_steps : Int[Any, ""], optional
            The maximum number of steps to take, defaults to `4096`.
        throw : bool, optional
            Whether to raise an exception if the integration fails, defaults to `True`.

        Returns
        -------
        Trajectory | TrajectoryEnsemble
            The simulated [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][],
            including the initial conditions (x0, t0).

        Raises
        ------
        NotImplementedError
            If the method is not implemented by the subclass.
        """
        raise NotImplementedError
get_diffeqsolve_best_args(ts: Real[Any, 'time'], dt0: Real[Any, ''], n_steps: Int[Any, ''] = None, constant_step_size: bool = True, save_at_steps: bool = False, ad_mode: Literal['forward', 'reverse'] = 'forward') -> tuple[Real[Any, ''], dfx.SaveAt, dfx.AbstractStepSizeController, dfx.AbstractAdjoint, Int[Any, ''], Callable[[tuple[int, ...], Key[Array, '']], dfx.AbstractBrownianPath]] classmethod

Returns optimal argument values for the diffrax.diffeqsolve function.

Significant speedups can be achieved by carefully selecting the arguments passed to diffrax.diffeqsolve, which is then called internally by the __call__ method. This method applies general heuristics to determine optimal argument values based on a high-level description of the problem, derived from its own input arguments.

Parameters:

Name Type Description Default
ts Real[Any, 'time']

The time steps for the simulation outputs, including \(t_0\), unit should be the same as for dt0.

required
dt0 Real[Any, '']

The initial time step of the solver, unit should be the same as for ts.

required
n_steps Int[Any, '']

The number of steps to be taken, defaults to None.

None
constant_step_size bool

Whether a constant step size is used, defaults to True.

True
save_at_steps bool

Whether the solution is to be saved at each integration step, defaults to False.

False
ad_mode Literal['forward', 'reverse']

The mode for automatic differentiation, defaults to "forward".

'forward'
Source code in pastax/simulator/_diffrax_simulator.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@classmethod
def get_diffeqsolve_best_args(
    cls,
    ts: Real[Any, "time"],
    dt0: Real[Any, ""],
    n_steps: Int[Any, ""] = None,
    constant_step_size: bool = True,
    save_at_steps: bool = False,
    ad_mode: Literal["forward", "reverse"] = "forward",
) -> tuple[
    Real[Any, ""],
    dfx.SaveAt,
    dfx.AbstractStepSizeController,
    dfx.AbstractAdjoint,
    Int[Any, ""],
    Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath],
]:
    """
    Returns optimal argument values for the [`diffrax.diffeqsolve`][] function.

    Significant speedups can be achieved by carefully selecting the arguments passed to [`diffrax.diffeqsolve`][],
    which is then called internally by the `__call__` method.
    This method applies general heuristics to determine optimal argument values based on a high-level description
    of the problem, derived from its own input arguments.

    Parameters
    ----------
    ts : Real[Any, "time"]
        The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
    dt0 : Real[Any, ""]
        The initial time step of the solver, unit should be the same as for `ts`.
    n_steps : Int[Any, ""], optional
        The number of steps to be taken, defaults to `None`.
    constant_step_size : bool, optional
        Whether a constant step size is used, defaults to `True`.
    save_at_steps : bool, optional
        Whether the solution is to be saved at each integration step, defaults to `False`.
    ad_mode : Literal["forward", "reverse"], optional
        The mode for automatic differentiation, defaults to "forward".
    """
    t0 = ts[0]
    t1 = ts[-1]

    if n_steps is not None:
        dt0_ = None
        stepsize_controller = dfx.StepTo(ts=jnp.linspace(t0, t1, n_steps + 1))
    else:
        dt0_ = dt0
        stepsize_controller = dfx.ConstantStepSize()

    if save_at_steps:
        saveat = dfx.SaveAt(steps=True)
    else:
        saveat = dfx.SaveAt(ts=ts)

    if ad_mode == "reverse":
        adjoint = dfx.RecursiveCheckpointAdjoint()
        if n_steps is None:
            n_steps = (t1 + dt0 - t0) // dt0
    else:
        adjoint = dfx.ForwardMode()

    if constant_step_size and n_steps is not None:
        brownian_motion: Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath] = (
            lambda shape, key: PrecomputedBrownianMotion(t0=t0, n_steps=n_steps, dt=dt0, shape=shape, key=key)
        )
    else:
        brownian_motion: Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath] = (
            lambda shape, key: dfx.VirtualBrownianTree(t0, t1, tol=dt0, shape=shape, key=key)
        )

    return dt0_, saveat, stepsize_controller, adjoint, n_steps, brownian_motion
__call__(dynamics: Callable[[Real[Any, ''], PyTree, PyTree], PyTree], args: PyTree, x0: Location, ts: Real[Any, 'time'], solver: dfx.AbstractSolver = dfx.Heun(), dt0: Real[Any, ''] = None, saveat: dfx.SaveAt | None = None, stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(), max_steps: Int[Any, ''] = 4096, throw: bool = False) -> Trajectory | TrajectoryEnsemble

Simulates a pastax.trajectory.Trajectory or pastax.trajectory.TrajectoryEnsemble following the prescribe drift dynamics and physical field(s) args, from the initial pastax.trajectory.Location x0 at time steps (including t0) ts, using given diffrax.AbstractSolver, diffrax.SaveAt, diffrax.AbstractStepSizeController and diffrax.AbstractAdjoint.

Parameters:

Name Type Description Default
dynamics Callable[[Real[Any, ''], PyTree, PyTree], PyTree]

A Callable (including an equinox.Module with a call method) describing the dynamics of the right-hand-side of the solved Differential Equation.

Example

Formulating the displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:

\[ d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt \]

dynamics is here the function \(f\) returning the displacement speed. In the simpliest case, \(f\) is the function interpolating a velocity field \(\mathbf{u}\) in space and time.

Parameters

t : Real[Array, ""] The current time. y : Float[Array, "2"] The current state (latitude and longitude in degrees). args : PyTree Any PyTree of argument(s) used by the simulator. Could be for example one or several pastax.gridded.Gridded of gridded physical fields (SSC, SSH, SST, etc...).

Returns

PyTree The drift dynamics.

required
args PyTree

The PyTree of argument(s) required to compute the dynamics. Could be for example one or several pastax.gridded.Gridded of gridded physical fields (SSC, SSH, SST, etc...).

required
x0 Location required
ts Real[Any, 'time']

The time steps for the simulation outputs, including \(t_0\), unit should be the same as for dt0.

required
solver AbstractSolver

The diffrax.AbstractSolver to use for the simulation, defaults to diffrax.Heun.

Heun()
dt0 Real[Any, '']

The initial time step of the solver, unit should be the same as for ts, defaults to None.

None
saveat SaveAt

The diffrax.SaveAt object to use for saving the solution, defaults to SaveAt(ts=ts).

None
stepsize_controller AbstractStepSizeController

The diffrax.AbstractStepSizeController to use for controlling the stepsize, defaults to diffrax.ConstantStepSize.

ConstantStepSize()
adjoint AbstractAdjoint

The diffrax.AbstractAdjoint object to use for the adjoint method, defaults to diffrax.ForwardMode. diffrax.ForwardMode should be used when computing the gradient in forward automtic differentiation mode with respect to few (<50) parameters, while diffrax.RecursiveCheckpointAdjoint should be used when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.

ForwardMode()
max_steps Int[Any, '']

The maximum number of steps to take, defaults to 4096.

4096
throw bool

Whether to raise an exception if the integration fails, defaults to True.

False

Returns:

Type Description
Trajectory | TrajectoryEnsemble

The simulated pastax.trajectory.Trajectory or pastax.trajectory.TrajectoryEnsemble, including the initial conditions (x0, t0).

Raises:

Type Description
NotImplementedError

If the method is not implemented by the subclass.

Source code in pastax/simulator/_diffrax_simulator.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def __call__(
    self,
    dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree],
    args: PyTree,
    x0: Location,
    ts: Real[Any, "time"],
    solver: dfx.AbstractSolver = dfx.Heun(),
    dt0: Real[Any, ""] = None,
    saveat: dfx.SaveAt | None = None,
    stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
    adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(),
    max_steps: Int[Any, ""] = 4096,
    throw: bool = False,
) -> Trajectory | TrajectoryEnsemble:
    r"""
    Simulates a [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][]
    following the prescribe drift `dynamics` and physical field(s) `args`,
    from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
    using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][], [`diffrax.AbstractStepSizeController`][] and
    [`diffrax.AbstractAdjoint`][].

    Parameters
    ----------
    dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree]
        A Callable (including an [`equinox.Module`][] with a __call__ method) describing the dynamics of the
        right-hand-side of the solved Differential Equation.

        !!! example

            Formulating the displacement at time $t$ from the position $\mathbf{X}(t)$ as:

            $$
            d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt
            $$

            `dynamics` is here the function $f$ returning the displacement speed.
            In the simpliest case, $f$ is the function interpolating a velocity field $\mathbf{u}$
            in space and time.

        Parameters
        ----------
        t : Real[Array, ""]
            The current time.
        y : Float[Array, "2"]
            The current state (latitude and longitude in degrees).
        args : PyTree
            Any PyTree of argument(s) used by the simulator.
            Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
            (SSC, SSH, SST, etc...).

        Returns
        -------
        PyTree
            The drift dynamics.

    args : PyTree
        The PyTree of argument(s) required to compute the `dynamics`.
        Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
        (SSC, SSH, SST, etc...).
    x0 : Location
        The initial [`pastax.trajectory.Location`][].
    ts : Real[Any, "time"]
        The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
    solver : dfx.AbstractSolver, optional
        The [`diffrax.AbstractSolver`][] to use for the simulation, defaults to [`diffrax.Heun`][].
    dt0 : Real[Any, ""], optional
        The initial time step of the solver, unit should be the same as for `ts`, defaults to `None`.
    saveat : dfx.SaveAt, optional
        The [`diffrax.SaveAt`][] object to use for saving the solution, defaults to `SaveAt(ts=ts)`.
    stepsize_controller : dfx.AbstractStepSizeController, optional
        The [`diffrax.AbstractStepSizeController`][] to use for controlling the stepsize,
        defaults to [`diffrax.ConstantStepSize`][].
    adjoint : dfx.AbstractAdjoint, optional
        The [`diffrax.AbstractAdjoint`][] object to use for the adjoint method,
        defaults to [`diffrax.ForwardMode`][].
        [`diffrax.ForwardMode`][] should be used when computing the gradient in forward automtic differentiation
        mode with respect to few (<50) parameters, while [`diffrax.RecursiveCheckpointAdjoint`][] should be used
        when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.
    max_steps : Int[Any, ""], optional
        The maximum number of steps to take, defaults to `4096`.
    throw : bool, optional
        Whether to raise an exception if the integration fails, defaults to `True`.

    Returns
    -------
    Trajectory | TrajectoryEnsemble
        The simulated [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][],
        including the initial conditions (x0, t0).

    Raises
    ------
    NotImplementedError
        If the method is not implemented by the subclass.
    """
    raise NotImplementedError

DeterministicSimulator

Bases: DiffraxSimulator

Class defining deterministic differentiable pastax.trajectory.Trajectory simulators.

Methods:

Name Description
__call__

Simulates a pastax.trajectory.Trajectory following the prescribe drift dynamics and physical field(s) args, from the initial pastax.trajectory.Location x0 at time steps (including t0) ts, using given diffrax.AbstractSolver, diffrax.SaveAt, diffrax.AbstractStepSizeController and diffrax.AbstractAdjoint.

Source code in pastax/simulator/_diffrax_simulator.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
class DeterministicSimulator(DiffraxSimulator):
    """
    Class defining deterministic differentiable [`pastax.trajectory.Trajectory`][] simulators.

    Methods
    -------
    __call__(dynamics, args, x0, ts, solver, dt0, saveat, stepsize_controller, adjoint, max_steps)
        Simulates a [`pastax.trajectory.Trajectory`][] following the prescribe drift `dynamics`
        and physical field(s) `args`, from the initial [`pastax.trajectory.Location`][] `x0`
        at time steps (including t0) `ts`, using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][],
        [`diffrax.AbstractStepSizeController`][] and [`diffrax.AbstractAdjoint`][].
    """

    def __call__(
        self,
        dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree],
        args: PyTree,
        x0: Location,
        ts: Real[Any, "time"],
        solver: dfx.AbstractSolver = dfx.Heun(),
        dt0: Real[Any, ""] = None,
        saveat: dfx.SaveAt | None = None,
        stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
        adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(),
        max_steps: Int[Any, ""] = 4096,
        throw: bool = True,
    ) -> Trajectory:
        r"""
        Simulates a [`pastax.trajectory.Trajectory`][] following the prescribe drift `dynamics`
        and physical field(s) `args`, from the initial [`pastax.trajectory.Location`][] `x0`
        at time steps (including t0) `ts`, using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][],
        [`diffrax.AbstractStepSizeController`][] and [`diffrax.AbstractAdjoint`][].

        Parameters
        ----------
        dynamics : Callable[[Real[Any, ""], PyTree, PyTree], PyTree]
            A Callable (including an [`equinox.Module`][] with a `__call__` method) describing the dynamics of the
            right-hand-side of the solved Ordinary Differential Equation.

            !!! example

                Formulating the displacement at time $t$ from the position $\mathbf{X}(t)$ as:

                $$
                d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt
                $$

                `dynamics` is here the function $f$ returning the displacement speed.
                In the simpliest case, $f$ is the function interpolating a velocity field $\mathbf{u}$
                in space and time.

            Parameters
            ----------
            t : Real[Any, ""]
                The current time.
            y : PyTree
                The current state (latitude and longitude in degrees).
            args : PyTree
                The PyTree of argument(s) required to compute the `dynamics`.
                Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
                (SSC, SSH, SST, etc...).

            Returns
            -------
            PyTree
                The drift dynamics.

        args : PyTree
            The PyTree of argument(s) required to compute the `dynamics`.
            Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
            (SSC, SSH, SST, etc...).
        x0 : Location
            The initial [`pastax.trajectory.Location`][].
        ts : Real[Any, "time"]
            The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
        solver : dfx.AbstractSolver, optional
            The [`diffrax.AbstractSolver`][] to use for the simulation, defaults to [`diffrax.Heun`][].
        dt0 : Real[Any, ""], optional
            The initial time step of the solver, unit should be the same as for `ts`, defaults to `None`.
        saveat : dfx.SaveAt, optional
            The [`diffrax.SaveAt`][] object to use for saving the solution, defaults to `SaveAt(ts=ts)`.
        stepsize_controller : dfx.AbstractStepSizeController, optional
            The [`diffrax.AbstractStepSizeController`][] to use for controlling the stepsize,
            defaults to [`diffrax.ConstantStepSize`][].
        adjoint : dfx.AbstractAdjoint, optional
            The [`diffrax.AbstractAdjoint`][] object to use for the adjoint method,
            defaults to [`diffrax.ForwardMode`][].
            [`diffrax.ForwardMode`][] should be used when computing the gradient in forward automtic differentiation
            mode with respect to few (<50) parameters, while [`diffrax.RecursiveCheckpointAdjoint`][] should be used
            when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.
        max_steps : Int[Any, ""], optional
            The maximum number of steps to take, defaults to `4096`.
        throw : bool, optional
            Whether to raise an exception if the integration fails, defaults to `True`.

        Returns
        -------
        Trajectory
            The simulated [`pastax.trajectory.Trajectory`][], including the initial conditions (x0, t0).
        """
        t0, t1 = ts[0], ts[-1]

        if saveat is None:
            saveat = dfx.SaveAt(ts=ts)

        ys = dfx.diffeqsolve(
            dfx.ODETerm(dynamics),
            solver,
            t0=t0,
            t1=t1,
            dt0=dt0,
            y0=x0.value,
            args=args,
            saveat=saveat,
            stepsize_controller=stepsize_controller,
            adjoint=adjoint,
            max_steps=max_steps,
            throw=throw,
        ).ys

        return Trajectory.from_array(ys, ts, unit=x0.unit)  # type: ignore
__call__(dynamics: Callable[[Real[Any, ''], PyTree, PyTree], PyTree], args: PyTree, x0: Location, ts: Real[Any, 'time'], solver: dfx.AbstractSolver = dfx.Heun(), dt0: Real[Any, ''] = None, saveat: dfx.SaveAt | None = None, stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(), max_steps: Int[Any, ''] = 4096, throw: bool = True) -> Trajectory

Simulates a pastax.trajectory.Trajectory following the prescribe drift dynamics and physical field(s) args, from the initial pastax.trajectory.Location x0 at time steps (including t0) ts, using given diffrax.AbstractSolver, diffrax.SaveAt, diffrax.AbstractStepSizeController and diffrax.AbstractAdjoint.

Parameters:

Name Type Description Default
dynamics Callable[[Real[Any, ''], PyTree, PyTree], PyTree]

A Callable (including an equinox.Module with a __call__ method) describing the dynamics of the right-hand-side of the solved Ordinary Differential Equation.

Example

Formulating the displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:

\[ d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt \]

dynamics is here the function \(f\) returning the displacement speed. In the simpliest case, \(f\) is the function interpolating a velocity field \(\mathbf{u}\) in space and time.

Parameters

t : Real[Any, ""] The current time. y : PyTree The current state (latitude and longitude in degrees). args : PyTree The PyTree of argument(s) required to compute the dynamics. Could be for example one or several pastax.gridded.Gridded of gridded physical fields (SSC, SSH, SST, etc...).

Returns

PyTree The drift dynamics.

required
args PyTree

The PyTree of argument(s) required to compute the dynamics. Could be for example one or several pastax.gridded.Gridded of gridded physical fields (SSC, SSH, SST, etc...).

required
x0 Location required
ts Real[Any, 'time']

The time steps for the simulation outputs, including \(t_0\), unit should be the same as for dt0.

required
solver AbstractSolver

The diffrax.AbstractSolver to use for the simulation, defaults to diffrax.Heun.

Heun()
dt0 Real[Any, '']

The initial time step of the solver, unit should be the same as for ts, defaults to None.

None
saveat SaveAt

The diffrax.SaveAt object to use for saving the solution, defaults to SaveAt(ts=ts).

None
stepsize_controller AbstractStepSizeController

The diffrax.AbstractStepSizeController to use for controlling the stepsize, defaults to diffrax.ConstantStepSize.

ConstantStepSize()
adjoint AbstractAdjoint

The diffrax.AbstractAdjoint object to use for the adjoint method, defaults to diffrax.ForwardMode. diffrax.ForwardMode should be used when computing the gradient in forward automtic differentiation mode with respect to few (<50) parameters, while diffrax.RecursiveCheckpointAdjoint should be used when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.

ForwardMode()
max_steps Int[Any, '']

The maximum number of steps to take, defaults to 4096.

4096
throw bool

Whether to raise an exception if the integration fails, defaults to True.

True

Returns:

Type Description
Trajectory

The simulated pastax.trajectory.Trajectory, including the initial conditions (x0, t0).

Source code in pastax/simulator/_diffrax_simulator.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def __call__(
    self,
    dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree],
    args: PyTree,
    x0: Location,
    ts: Real[Any, "time"],
    solver: dfx.AbstractSolver = dfx.Heun(),
    dt0: Real[Any, ""] = None,
    saveat: dfx.SaveAt | None = None,
    stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
    adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(),
    max_steps: Int[Any, ""] = 4096,
    throw: bool = True,
) -> Trajectory:
    r"""
    Simulates a [`pastax.trajectory.Trajectory`][] following the prescribe drift `dynamics`
    and physical field(s) `args`, from the initial [`pastax.trajectory.Location`][] `x0`
    at time steps (including t0) `ts`, using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][],
    [`diffrax.AbstractStepSizeController`][] and [`diffrax.AbstractAdjoint`][].

    Parameters
    ----------
    dynamics : Callable[[Real[Any, ""], PyTree, PyTree], PyTree]
        A Callable (including an [`equinox.Module`][] with a `__call__` method) describing the dynamics of the
        right-hand-side of the solved Ordinary Differential Equation.

        !!! example

            Formulating the displacement at time $t$ from the position $\mathbf{X}(t)$ as:

            $$
            d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) dt
            $$

            `dynamics` is here the function $f$ returning the displacement speed.
            In the simpliest case, $f$ is the function interpolating a velocity field $\mathbf{u}$
            in space and time.

        Parameters
        ----------
        t : Real[Any, ""]
            The current time.
        y : PyTree
            The current state (latitude and longitude in degrees).
        args : PyTree
            The PyTree of argument(s) required to compute the `dynamics`.
            Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
            (SSC, SSH, SST, etc...).

        Returns
        -------
        PyTree
            The drift dynamics.

    args : PyTree
        The PyTree of argument(s) required to compute the `dynamics`.
        Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
        (SSC, SSH, SST, etc...).
    x0 : Location
        The initial [`pastax.trajectory.Location`][].
    ts : Real[Any, "time"]
        The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
    solver : dfx.AbstractSolver, optional
        The [`diffrax.AbstractSolver`][] to use for the simulation, defaults to [`diffrax.Heun`][].
    dt0 : Real[Any, ""], optional
        The initial time step of the solver, unit should be the same as for `ts`, defaults to `None`.
    saveat : dfx.SaveAt, optional
        The [`diffrax.SaveAt`][] object to use for saving the solution, defaults to `SaveAt(ts=ts)`.
    stepsize_controller : dfx.AbstractStepSizeController, optional
        The [`diffrax.AbstractStepSizeController`][] to use for controlling the stepsize,
        defaults to [`diffrax.ConstantStepSize`][].
    adjoint : dfx.AbstractAdjoint, optional
        The [`diffrax.AbstractAdjoint`][] object to use for the adjoint method,
        defaults to [`diffrax.ForwardMode`][].
        [`diffrax.ForwardMode`][] should be used when computing the gradient in forward automtic differentiation
        mode with respect to few (<50) parameters, while [`diffrax.RecursiveCheckpointAdjoint`][] should be used
        when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.
    max_steps : Int[Any, ""], optional
        The maximum number of steps to take, defaults to `4096`.
    throw : bool, optional
        Whether to raise an exception if the integration fails, defaults to `True`.

    Returns
    -------
    Trajectory
        The simulated [`pastax.trajectory.Trajectory`][], including the initial conditions (x0, t0).
    """
    t0, t1 = ts[0], ts[-1]

    if saveat is None:
        saveat = dfx.SaveAt(ts=ts)

    ys = dfx.diffeqsolve(
        dfx.ODETerm(dynamics),
        solver,
        t0=t0,
        t1=t1,
        dt0=dt0,
        y0=x0.value,
        args=args,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        adjoint=adjoint,
        max_steps=max_steps,
        throw=throw,
    ).ys

    return Trajectory.from_array(ys, ts, unit=x0.unit)  # type: ignore

StochasticSimulator

Bases: DiffraxSimulator

Class defining stochastic differentiable pastax.trajectory.TrajectoryEnsemble simulators.

Methods:

Name Description
__call__

Simulates a pastax.trajectory.TrajectoryEnsemble of n_samples pastax.trajectory.Trajectory following the prescribe drift dynamics and physical field(s) args, from the initial pastax.trajectory.Location x0 at time steps (including t0) ts, using given diffrax.AbstractSolver, diffrax.SaveAt, diffrax.AbstractStepSizeController, diffrax.AbstractAdjoint and diffrax.AbstractBrownianPath.

Source code in pastax/simulator/_diffrax_simulator.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
class StochasticSimulator(DiffraxSimulator):
    """
    Class defining stochastic differentiable [`pastax.trajectory.TrajectoryEnsemble`][] simulators.

    Methods
    -------
    __call__(dynamics, args, x0, ts, solver, dt0, saveat, stepsize_controller, adjoint, max_steps, n_samples, key, brownian_motion)
        Simulates a [`pastax.trajectory.TrajectoryEnsemble`][] of `n_samples` [`pastax.trajectory.Trajectory`][]
        following the prescribe drift `dynamics` and physical field(s) `args`,
        from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
        using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][], [`diffrax.AbstractStepSizeController`][],
        [`diffrax.AbstractAdjoint`][] and [`diffrax.AbstractBrownianPath`][].
    """

    def __call__(
        self,
        dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree],
        args: PyTree,
        x0: Location,
        ts: Real[Any, "time"],
        solver: dfx.AbstractSolver = dfx.Heun(),
        dt0: Real[Any, ""] = None,
        saveat: dfx.SaveAt | None = None,
        stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
        adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(),
        max_steps: Int[Any, ""] = 4096,
        n_samples: Int[Any, ""] = 100,
        key: Key[Array, ""] = jrd.key(0),
        brownian_motion: Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath] | None = None,
        throw: bool = True,
    ) -> TrajectoryEnsemble:
        r"""
        Simulates a [`pastax.trajectory.TrajectoryEnsemble`][] of `n_samples` [`pastax.trajectory.Trajectory`][]
        following the prescribe drift `dynamics` and physical field(s) `args`,
        from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
        using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][], [`diffrax.AbstractStepSizeController`][],
        [`diffrax.AbstractAdjoint`][] and [`diffrax.AbstractBrownianPath`][].

        Parameters
        ----------
        dynamics : Callable[[Real[Any, ""], PyTree, PyTree], PyTree]
            A Callable (including an [`equinox.Module`][] with a `__call__` method) describing the dynamics of the
            right-hand-side of the solved Stochastic Differential Equation.

            !!! example

                Formulating a displacement at time $t$ from the position $\mathbf{X}(t)$ as:

                $$
                d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) \cdot [dt, d\mathbf{W}(t)]
                $$

                `dynamics` is here the function $f$ returning the displacement speed and diffusion as a 2*3 matrix.

            Parameters
            ----------
            t : Real[Any, ""]
                The current time.
            y : PyTree
                The current state (latitude and longitude in degrees).
            args : PyTree
                The PyTree of argument(s) required to compute the `dynamics`.
                Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
                (SSC, SSH, SST, etc...).

            Returns
            -------
            PyTree
                The drift dynamics.

        args : PyTree
            The PyTree of argument(s) required to compute the `dynamics`.
            Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
            (SSC, SSH, SST, etc...).
        x0 : Location
            The initial [`pastax.trajectory.Location`][].
        ts : Real[Any, "time"]
            The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
        solver : dfx.AbstractSolver, optional
            The [`diffrax.AbstractSolver`][] to use for the simulation, defaults to [`diffrax.Heun`][].
        dt0 : Real[Any, ""], optional
            The initial time step of the solver, unit should be the same as for `ts`, defaults to `None`.
        saveat : dfx.SaveAt, optional
            The [`diffrax.SaveAt`][] object to use for saving the solution, defaults to `SaveAt(ts=ts)`.
        stepsize_controller : dfx.AbstractStepSizeController, optional
            The [`diffrax.AbstractStepSizeController`][] to use for controlling the stepsize,
            defaults to [`diffrax.ConstantStepSize`][].
        adjoint : dfx.AbstractAdjoint, optional
            The [`diffrax.AbstractAdjoint`][] object to use for the adjoint method, defaults to [`diffrax.ForwardMode`][].
            [`diffrax.ForwardMode`][] should be used when computing the gradient in forward automtic differentiation
            mode with respect to few (<50) parameters, while [`diffrax.RecursiveCheckpointAdjoint`][] should be used
            when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.
        max_steps : Int[Any, ""], optional
            The maximum number of steps to take, defaults to `4096`.
        n_samples : Int[Any, ""], optional
            The number of samples to generate, defaults to `100`.
        key : Key[Array, ""], optional
            The random key for sampling, defaults to `jrd.key(0)`.
        brownian_motion : Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath] | None, optional
            A Callable returning the [`diffrax.AbstractBrownianPath`][] to use for the simulation of the Brownian motion, defaults to `None`.
            If `None`, a [`diffrax.VirtualBrownianTree`][] is used.

            Parameters
            ----------
            shape : tuple[int, ...]
                The shape of the Brownian motion.
            key : Key[Array, ""]
                The random key for sampling.

            Returns
            -------
            dfx.AbstractBrownianPath
                The [`diffrax.AbstractBrownianPath`][] object.
        throw : bool, optional
            Whether to raise an exception if the integration fails, defaults to `True`.

        Returns
        -------
        TrajectoryEnsemble
            The simulated [`pastax.trajectory.TrajectoryEnsemble`][].
        """
        t0, t1 = ts[0], ts[-1]

        if saveat is None:
            saveat = dfx.SaveAt(ts=ts)

        if brownian_motion is None:
            brownian_motion = lambda shape, key: dfx.VirtualBrownianTree(t0, t1, tol=dt0, shape=shape, key=key)

        keys = jrd.split(key, n_samples)

        @jax.vmap
        def solve(subkey: Array) -> Float[Array, "time 2"]:
            sde_control = SDEControl(t0=t0, t1=t1, brownian_motion=brownian_motion((2,), subkey))
            sde_term = dfx.ControlTerm(dynamics, sde_control)

            ys = dfx.diffeqsolve(
                sde_term,
                solver,
                t0=t0,
                t1=t1,
                dt0=dt0,
                y0=x0.value,
                args=args,
                saveat=saveat,
                stepsize_controller=stepsize_controller,
                adjoint=adjoint,
                max_steps=max_steps,
                throw=throw,
            ).ys

            return ys  # type: ignore

        ys = solve(keys)

        return TrajectoryEnsemble.from_array(ys, ts, unit=x0.unit)
__call__(dynamics: Callable[[Real[Any, ''], PyTree, PyTree], PyTree], args: PyTree, x0: Location, ts: Real[Any, 'time'], solver: dfx.AbstractSolver = dfx.Heun(), dt0: Real[Any, ''] = None, saveat: dfx.SaveAt | None = None, stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(), max_steps: Int[Any, ''] = 4096, n_samples: Int[Any, ''] = 100, key: Key[Array, ''] = jrd.key(0), brownian_motion: Callable[[tuple[int, ...], Key[Array, '']], dfx.AbstractBrownianPath] | None = None, throw: bool = True) -> TrajectoryEnsemble

Simulates a pastax.trajectory.TrajectoryEnsemble of n_samples pastax.trajectory.Trajectory following the prescribe drift dynamics and physical field(s) args, from the initial pastax.trajectory.Location x0 at time steps (including t0) ts, using given diffrax.AbstractSolver, diffrax.SaveAt, diffrax.AbstractStepSizeController, diffrax.AbstractAdjoint and diffrax.AbstractBrownianPath.

Parameters:

Name Type Description Default
dynamics Callable[[Real[Any, ''], PyTree, PyTree], PyTree]

A Callable (including an equinox.Module with a __call__ method) describing the dynamics of the right-hand-side of the solved Stochastic Differential Equation.

Example

Formulating a displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:

\[ d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) \cdot [dt, d\mathbf{W}(t)] \]

dynamics is here the function \(f\) returning the displacement speed and diffusion as a 2*3 matrix.

Parameters

t : Real[Any, ""] The current time. y : PyTree The current state (latitude and longitude in degrees). args : PyTree The PyTree of argument(s) required to compute the dynamics. Could be for example one or several pastax.gridded.Gridded of gridded physical fields (SSC, SSH, SST, etc...).

Returns

PyTree The drift dynamics.

required
args PyTree

The PyTree of argument(s) required to compute the dynamics. Could be for example one or several pastax.gridded.Gridded of gridded physical fields (SSC, SSH, SST, etc...).

required
x0 Location required
ts Real[Any, 'time']

The time steps for the simulation outputs, including \(t_0\), unit should be the same as for dt0.

required
solver AbstractSolver

The diffrax.AbstractSolver to use for the simulation, defaults to diffrax.Heun.

Heun()
dt0 Real[Any, '']

The initial time step of the solver, unit should be the same as for ts, defaults to None.

None
saveat SaveAt

The diffrax.SaveAt object to use for saving the solution, defaults to SaveAt(ts=ts).

None
stepsize_controller AbstractStepSizeController

The diffrax.AbstractStepSizeController to use for controlling the stepsize, defaults to diffrax.ConstantStepSize.

ConstantStepSize()
adjoint AbstractAdjoint

The diffrax.AbstractAdjoint object to use for the adjoint method, defaults to diffrax.ForwardMode. diffrax.ForwardMode should be used when computing the gradient in forward automtic differentiation mode with respect to few (<50) parameters, while diffrax.RecursiveCheckpointAdjoint should be used when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.

ForwardMode()
max_steps Int[Any, '']

The maximum number of steps to take, defaults to 4096.

4096
n_samples Int[Any, '']

The number of samples to generate, defaults to 100.

100
key Key[Array, '']

The random key for sampling, defaults to jrd.key(0).

key(0)
brownian_motion Callable[[tuple[int, ...], Key[Array, '']], AbstractBrownianPath] | None

A Callable returning the diffrax.AbstractBrownianPath to use for the simulation of the Brownian motion, defaults to None. If None, a diffrax.VirtualBrownianTree is used.

Parameters

shape : tuple[int, ...] The shape of the Brownian motion. key : Key[Array, ""] The random key for sampling.

Returns

dfx.AbstractBrownianPath The diffrax.AbstractBrownianPath object.

None
throw bool

Whether to raise an exception if the integration fails, defaults to True.

True

Returns:

Type Description
TrajectoryEnsemble
Source code in pastax/simulator/_diffrax_simulator.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
def __call__(
    self,
    dynamics: Callable[[Real[Any, ""], PyTree, PyTree], PyTree],
    args: PyTree,
    x0: Location,
    ts: Real[Any, "time"],
    solver: dfx.AbstractSolver = dfx.Heun(),
    dt0: Real[Any, ""] = None,
    saveat: dfx.SaveAt | None = None,
    stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
    adjoint: dfx.AbstractAdjoint = dfx.ForwardMode(),
    max_steps: Int[Any, ""] = 4096,
    n_samples: Int[Any, ""] = 100,
    key: Key[Array, ""] = jrd.key(0),
    brownian_motion: Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath] | None = None,
    throw: bool = True,
) -> TrajectoryEnsemble:
    r"""
    Simulates a [`pastax.trajectory.TrajectoryEnsemble`][] of `n_samples` [`pastax.trajectory.Trajectory`][]
    following the prescribe drift `dynamics` and physical field(s) `args`,
    from the initial [`pastax.trajectory.Location`][] `x0` at time steps (including t0) `ts`,
    using given [`diffrax.AbstractSolver`][], [`diffrax.SaveAt`][], [`diffrax.AbstractStepSizeController`][],
    [`diffrax.AbstractAdjoint`][] and [`diffrax.AbstractBrownianPath`][].

    Parameters
    ----------
    dynamics : Callable[[Real[Any, ""], PyTree, PyTree], PyTree]
        A Callable (including an [`equinox.Module`][] with a `__call__` method) describing the dynamics of the
        right-hand-side of the solved Stochastic Differential Equation.

        !!! example

            Formulating a displacement at time $t$ from the position $\mathbf{X}(t)$ as:

            $$
            d\mathbf{X}(t) = f(t, \mathbf{X}(t), \text{args}) \cdot [dt, d\mathbf{W}(t)]
            $$

            `dynamics` is here the function $f$ returning the displacement speed and diffusion as a 2*3 matrix.

        Parameters
        ----------
        t : Real[Any, ""]
            The current time.
        y : PyTree
            The current state (latitude and longitude in degrees).
        args : PyTree
            The PyTree of argument(s) required to compute the `dynamics`.
            Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
            (SSC, SSH, SST, etc...).

        Returns
        -------
        PyTree
            The drift dynamics.

    args : PyTree
        The PyTree of argument(s) required to compute the `dynamics`.
        Could be for example one or several [`pastax.gridded.Gridded`][] of gridded physical fields
        (SSC, SSH, SST, etc...).
    x0 : Location
        The initial [`pastax.trajectory.Location`][].
    ts : Real[Any, "time"]
        The time steps for the simulation outputs, including $t_0$, unit should be the same as for `dt0`.
    solver : dfx.AbstractSolver, optional
        The [`diffrax.AbstractSolver`][] to use for the simulation, defaults to [`diffrax.Heun`][].
    dt0 : Real[Any, ""], optional
        The initial time step of the solver, unit should be the same as for `ts`, defaults to `None`.
    saveat : dfx.SaveAt, optional
        The [`diffrax.SaveAt`][] object to use for saving the solution, defaults to `SaveAt(ts=ts)`.
    stepsize_controller : dfx.AbstractStepSizeController, optional
        The [`diffrax.AbstractStepSizeController`][] to use for controlling the stepsize,
        defaults to [`diffrax.ConstantStepSize`][].
    adjoint : dfx.AbstractAdjoint, optional
        The [`diffrax.AbstractAdjoint`][] object to use for the adjoint method, defaults to [`diffrax.ForwardMode`][].
        [`diffrax.ForwardMode`][] should be used when computing the gradient in forward automtic differentiation
        mode with respect to few (<50) parameters, while [`diffrax.RecursiveCheckpointAdjoint`][] should be used
        when computing the gradient in reverse automatic differentiation mode with respect to many (>50) parameters.
    max_steps : Int[Any, ""], optional
        The maximum number of steps to take, defaults to `4096`.
    n_samples : Int[Any, ""], optional
        The number of samples to generate, defaults to `100`.
    key : Key[Array, ""], optional
        The random key for sampling, defaults to `jrd.key(0)`.
    brownian_motion : Callable[[tuple[int, ...], Key[Array, ""]], dfx.AbstractBrownianPath] | None, optional
        A Callable returning the [`diffrax.AbstractBrownianPath`][] to use for the simulation of the Brownian motion, defaults to `None`.
        If `None`, a [`diffrax.VirtualBrownianTree`][] is used.

        Parameters
        ----------
        shape : tuple[int, ...]
            The shape of the Brownian motion.
        key : Key[Array, ""]
            The random key for sampling.

        Returns
        -------
        dfx.AbstractBrownianPath
            The [`diffrax.AbstractBrownianPath`][] object.
    throw : bool, optional
        Whether to raise an exception if the integration fails, defaults to `True`.

    Returns
    -------
    TrajectoryEnsemble
        The simulated [`pastax.trajectory.TrajectoryEnsemble`][].
    """
    t0, t1 = ts[0], ts[-1]

    if saveat is None:
        saveat = dfx.SaveAt(ts=ts)

    if brownian_motion is None:
        brownian_motion = lambda shape, key: dfx.VirtualBrownianTree(t0, t1, tol=dt0, shape=shape, key=key)

    keys = jrd.split(key, n_samples)

    @jax.vmap
    def solve(subkey: Array) -> Float[Array, "time 2"]:
        sde_control = SDEControl(t0=t0, t1=t1, brownian_motion=brownian_motion((2,), subkey))
        sde_term = dfx.ControlTerm(dynamics, sde_control)

        ys = dfx.diffeqsolve(
            sde_term,
            solver,
            t0=t0,
            t1=t1,
            dt0=dt0,
            y0=x0.value,
            args=args,
            saveat=saveat,
            stepsize_controller=stepsize_controller,
            adjoint=adjoint,
            max_steps=max_steps,
            throw=throw,
        ).ys

        return ys  # type: ignore

    ys = solve(keys)

    return TrajectoryEnsemble.from_array(ys, ts, unit=x0.unit)

pastax.dynamics

This module provides dynamics examples to be used with pastax.simulator.BaseSimulator.

LinearUV

Bases: Module

Trainable linear transformation of the Lagrangian drift velocity computed by interpolating in space and time the velocity fields.

Attributes:

Name Type Description
intercept (Float[Array, ''] | Float[Array, '2'], optional)

The intercept of the linear relation, defaults to jnp.asarray([0., 0.]).

slope (Float[Array, ''] | Float[Array, '2'], optional)

The slope of the linear relation, defaults to jnp.asarray([1., 1.]).

Methods:

Name Description
__call__

Computes the Lagrangian drift velocity.

Notes

As the class inherits from equinox.Module, its intercept and slope attributes can be treated as trainable parameters.

Source code in pastax/dynamics/_linear_uv.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class LinearUV(eqx.Module):
    """
    Trainable linear transformation of the Lagrangian drift velocity
    computed by interpolating in space and time the velocity fields.

    Attributes
    ----------
    intercept : Float[Array, ""] | Float[Array, "2"], optional
        The intercept of the linear relation, defaults to `jnp.asarray([0., 0.])`.
    slope : Float[Array, ""] | Float[Array, "2"], optional
        The slope of the linear relation, defaults to `jnp.asarray([1., 1.])`.

    Methods
    -------
    __call__(t, y, args)
        Computes the Lagrangian drift velocity.

    Notes
    -----
    As the class inherits from [`equinox.Module`][], its `intercept` and `slope` attributes can be treated as
    trainable parameters.
    """

    intercept: Float[Array, ""] | Float[Array, "2"] = eqx.field(
        converter=lambda x: jnp.asarray(x, dtype=float), default_factory=lambda: [0, 0]
    )
    slope: Float[Array, ""] | Float[Array, "2"] = eqx.field(
        converter=lambda x: jnp.asarray(x, dtype=float), default_factory=lambda: [1, 1]
    )

    def __call__(self, t: Real[Array, ""], y: Float[Array, "2"], args: Gridded) -> Float[Array, "2"]:
        """
        Computes the Lagrangian drift velocity as the linear relation `intercept + slope * [v, u]`.

        Parameters
        ----------
        t : Real[Array, ""]
            The current time.
        y : Float[Array, "2"]
            The current state (latitude and longitude in degrees).
        args : Dataset
            The [`pastax.gridded.Gridded`][] containing the physical fields (only u and v here).

        Returns
        -------
        Float[Array, "2"]
            The Lagrangian drift velocity.
        """
        vu = _linear_uv(t, y, args)

        dlatlon = self.intercept + self.slope * vu

        dataset = args
        if dataset.is_spherical_mesh and not dataset.use_degrees:
            dlatlon = meters_to_degrees(dlatlon, latitude=y[0])

        return dlatlon
__call__(t: Real[Array, ''], y: Float[Array, '2'], args: Gridded) -> Float[Array, '2']

Computes the Lagrangian drift velocity as the linear relation intercept + slope * [v, u].

Parameters:

Name Type Description Default
t Real[Array, '']

The current time.

required
y Float[Array, '2']

The current state (latitude and longitude in degrees).

required
args Dataset

The pastax.gridded.Gridded containing the physical fields (only u and v here).

required

Returns:

Type Description
Float[Array, '2']

The Lagrangian drift velocity.

Source code in pastax/dynamics/_linear_uv.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def __call__(self, t: Real[Array, ""], y: Float[Array, "2"], args: Gridded) -> Float[Array, "2"]:
    """
    Computes the Lagrangian drift velocity as the linear relation `intercept + slope * [v, u]`.

    Parameters
    ----------
    t : Real[Array, ""]
        The current time.
    y : Float[Array, "2"]
        The current state (latitude and longitude in degrees).
    args : Dataset
        The [`pastax.gridded.Gridded`][] containing the physical fields (only u and v here).

    Returns
    -------
    Float[Array, "2"]
        The Lagrangian drift velocity.
    """
    vu = _linear_uv(t, y, args)

    dlatlon = self.intercept + self.slope * vu

    dataset = args
    if dataset.is_spherical_mesh and not dataset.use_degrees:
        dlatlon = meters_to_degrees(dlatlon, latitude=y[0])

    return dlatlon

SmagorinskyDiffusion

Bases: Module

Trainable Smagorinsky diffusion dynamics.

Formulation

This dynamics allows to formulate a displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:

\[ d\mathbf{X}(t) = (\mathbf{u} + \nabla K)(t, \mathbf{X}(t)) dt + V(t, \mathbf{X}(t)) d\mathbf{W}(t) \]

where \(V = \sqrt{2 K}\) and \(K\) is the Smagorinsky diffusion:

\[ K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 + \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} + \frac{\partial v}{\partial x} \right)^2} \]

where \(C_s\) is the trainable Smagorinsky constant, \(\Delta x \Delta y\) a spatial scaling factor, and the rest of the expression represents the horizontal diffusion.

Methods:

Name Description
cs

Returns the Smagorinsky constant.

_neighborhood

Restricts the pastax.gridded.Gridded to a neighborhood around the given location and time.

_smagorinsky_coefficients

Computes the Smagorinsky diffusion:

\[ K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 + \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} + \frac{\partial v}{\partial x} \right)^2} \]

where \(C_s\) is the trainable Smagorinsky constant, \(\Delta x \Delta y\) a spatial scaling factor, and the rest of the expression represents the horizontal diffusion.

_deterministic_dynamics

Computes the deterministic part of the dynamics: \((\mathbf{u} + \nabla K)(t, \mathbf{X}(t))\).

_stochastic_dynamics

Computes the stochastic part of the dynamics: \(V(t, \mathbf{X}(t)) = \sqrt{2 K(t, \mathbf{X}(t))}\).

Notes

As the class inherits from equinox.Module, its cs attribute can be treated as a trainable parameter.

Source code in pastax/dynamics/_smagorinsky_diffusion.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class SmagorinskyDiffusion(eqx.Module):
    r"""
    Trainable Smagorinsky diffusion dynamics.

    !!! example "Formulation"

        This dynamics allows to formulate a displacement at time $t$ from the position $\mathbf{X}(t)$ as:

        $$
        d\mathbf{X}(t) = (\mathbf{u} + \nabla K)(t, \mathbf{X}(t)) dt + V(t, \mathbf{X}(t)) d\mathbf{W}(t)
        $$

        where $V = \sqrt{2 K}$ and $K$ is the Smagorinsky diffusion:

        $$
        K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 +
        \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} +
        \frac{\partial v}{\partial x} \right)^2}
        $$

        where $C_s$ is the ***trainable*** Smagorinsky constant, $\Delta x \Delta y$ a spatial scaling factor, and the
        rest of the expression represents the horizontal diffusion.

    Methods
    -------
    cs
        Returns the Smagorinsky constant.
    _neighborhood(*fields, t, y, gridded)
        Restricts the [`pastax.gridded.Gridded`][] to a neighborhood around the given location and time.
    _smagorinsky_coefficients(t, y, gridded)
        Computes the Smagorinsky diffusion:

        $$
        K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 +
        \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} +
        \frac{\partial v}{\partial x} \right)^2}
        $$

        where $C_s$ is the ***trainable*** Smagorinsky constant, $\Delta x \Delta y$ a spatial scaling factor,
        and the rest of the expression represents the horizontal diffusion.
    _deterministic_dynamics(t, y, gridded, smag_ds)
        Computes the deterministic part of the dynamics: $(\mathbf{u} + \nabla K)(t, \mathbf{X}(t))$.
    _stochastic_dynamics(y, smag_ds)
        Computes the stochastic part of the dynamics: $V(t, \mathbf{X}(t)) = \sqrt{2 K(t, \mathbf{X}(t))}$.

    Notes
    -----
    As the class inherits from [`equinox.Module`][], its `cs` attribute can be treated as a trainable parameter.
    """

    _cs: Real[Array, ""] = eqx.field(
        converter=lambda x: jnp.asarray(x, dtype=float),
        default_factory=lambda: _from_cs(0.1),
    )

    @property
    def cs(self) -> Float[Array, ""]:
        """
        Returns the Smagorinsky constant.
        """
        return _to_cs(self._cs)

    @staticmethod
    def _neighborhood(*fields: str, t: Real[Array, ""], y: Float[Array, "2"], gridded: Gridded) -> Gridded:
        """
        Restricts the [`pastax.gridded.Gridded`][] to a neighborhood around the given location and time.

        Parameters
        ----------
        *fields : list[str]
            The fields to retain in the neighborhood.
        t : Real[Array, ""]
            The current time.
        y : Float[Array, "2"]
            The current state (latitude and longitude).
        gridded : Gridded
            The [`pastax.gridded.Gridded`][] containing the physical fields.

        Returns
        -------
        Gridded
            The neighborhood [`pastax.gridded.Gridded`][].
        """
        # restrict gridded to the neighborhood around X(t)
        neighborhood = gridded.neighborhood(
            *fields, time=t, latitude=y[0], longitude=y[1], t_width=3, x_width=7
        )  # "x_width x_width"

        return neighborhood

    def _smagorinsky_diffusion(self, t: Real[Array, ""], y: Float[Array, "2"], gridded: Gridded) -> Gridded:
        r"""
        Computes the Smagorinsky diffusion:

        $$
        K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 +
        \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} +
        \frac{\partial v}{\partial x} \right)^2}
        $$

        where $C_s$ is the ***trainable*** Smagorinsky constant, $\Delta x \Delta y$ a spatial scaling factor,
        and the rest of the expression represents the horizontal diffusion.

        Parameters
        ----------
        t : Real[Array, ""]
            The simulation time.
        y : Float[Array, "2"]
            The current state (latitude and longitude).
        gridded : Gridded
            The [`pastax.gridded.Gridded`][] containing the physical fields.

        Returns
        -------
        Gridded
            The [`pastax.gridded.Gridded`][] containing the Smagorinsky coefficients.

        Notes
        -----
        The physical fields are first restricted to a small neighborhood, then interpolated in time and
        finally spatial derivatives are computed using finite central difference.
        """
        neighborhood = self._neighborhood("u", "v", t=t, y=y, gridded=gridded)

        fields = neighborhood.interp("u", "v", time=t)  # "x_width x_width"
        (dudx, dudy), (dvdx, dvdy) = spatial_derivative(
            fields["u"], fields["v"], dx=neighborhood.dx, dy=neighborhood.dy, is_masked=neighborhood.is_masked.values
        )  # "x_width-2 x_width-2"

        # computes Smagorinsky coefficients
        cell_area = neighborhood.cell_area[1:-1, 1:-1]  # "x_width-2 x_width-2"
        smag_k = self.cs * cell_area * ((dudx**2 + dvdy**2 + 0.5 * (dudy + dvdx) ** 2) ** (1 / 2))

        smag_ds = Gridded.from_array(
            {"smag_k": smag_k[None, ...]},
            time=t[None],
            latitude=neighborhood.coordinates["latitude"][1:-1],
            longitude=neighborhood.coordinates["longitude"][1:-1],
            interpolation_method="linear",
            is_spherical_mesh=neighborhood.is_spherical_mesh,
            use_degrees=neighborhood.use_degrees,
            is_uv_mps=False,  # no uv anyway...
        )

        return smag_ds

    @staticmethod
    def _deterministic_dynamics(
        t: Real[Array, ""], y: Float[Array, "2"], gridded: Gridded, smag_ds: Gridded
    ) -> Float[Array, "2"]:
        r"""
        Computes the deterministic part of the dynamics: $(\mathbf{u} + \nabla K)(t, \mathbf{X}(t))$.

        Parameters
        ----------
        t : Real[Array, ""]
            The current time.
        y : Float[Array, "2"]
            The current state (latitude and longitude).
        gridded : Gridded
            The [`pastax.gridded.Gridded`][] containing the physical fields.
        smag_ds : Gridded
            The [`pastax.gridded.Gridded`][] containing the Smagorinsky coefficients for the given fields.

        Returns
        -------
        Float[Array, "2"]
            The deterministic part of the dynamics.
        """
        latitude, longitude = y[0], y[1]

        smag_k = jnp.squeeze(smag_ds.fields["smag_k"].values)  # "x_width-2 x_width-2"

        # $\mathbf{u}(t, \mathbf{X}(t))$ term
        scalar_values = gridded.interp("u", "v", time=t, latitude=latitude, longitude=longitude)
        vu = jnp.asarray([scalar_values["v"], scalar_values["u"]])  # "2"

        # $(\nabla \cdot \mathbf{K})(t, \mathbf{X}(t))$ term
        ((dkdx, dkdy),) = spatial_derivative(
            smag_k, dx=smag_ds.dx, dy=smag_ds.dy, is_masked=smag_ds.is_masked.values
        )  # "x_width-4 x_width-4"
        dkdx = ipx.interp2d(
            latitude,
            longitude,
            smag_ds.coordinates["latitude"][1:-1],
            smag_ds.coordinates["longitude"][1:-1],
            dkdx,
            method="linear",
            extrap=True,
        )
        dkdy = ipx.interp2d(
            latitude,
            longitude,
            smag_ds.coordinates["latitude"][1:-1],
            smag_ds.coordinates["longitude"][1:-1],
            dkdy,
            method="linear",
            extrap=True,
        )
        gradk = jnp.asarray([dkdy, dkdx])  # "2"

        return vu + gradk

    @staticmethod
    def _stochastic_dynamics(y: Float[Array, "2"], smag_ds: Gridded) -> Float[Array, "2 2"]:
        r"""
        Computes the stochastic part of the dynamics: $V(t, \mathbf{X}(t)) = \sqrt{2 K(t, \mathbf{X}(t))}$.

        Parameters
        ----------
        y : Float[Array, "2"]
            The current state (latitude and longitude).
        smag_ds : Gridded
            The [`pastax.gridded.Gridded`][] containing the Smagorinsky coefficients.

        Returns
        -------
        Float[Array, "2 2"]
            The stochastic part of the dynamics.
        """
        latitude, longitude = y[0], y[1]

        scalar_value = smag_ds.interp("smag_k", latitude=latitude, longitude=longitude)
        smag_k = jnp.squeeze(scalar_value["smag_k"])  # scalar
        smag_k = (2 * smag_k) ** (1 / 2)

        return jnp.eye(2) * smag_k
cs: Float[Array, ''] property

Returns the Smagorinsky constant.

StochasticSmagorinskyDiffusion

Bases: SmagorinskyDiffusion

Trainable stochastic Smagorinsky diffusion dynamics.

Formulation

This dynamics allows to formulate a displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:

\[ d\mathbf{X}(t) = (\mathbf{u} + \nabla K)(t, \mathbf{X}(t)) dt + V(t, \mathbf{X}(t)) d\mathbf{W}(t) \]

where \(V = \sqrt{2 K}\) and \(K\) is the Smagorinsky diffusion:

\[ K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 + \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} + \frac{\partial v}{\partial x} \right)^2} \]

where \(C_s\) is the trainable Smagorinsky constant, \(\Delta x \Delta y\) a spatial scaling factor, and the rest of the expression represents the horizontal diffusion.

Methods:

Name Description
__call__

Computes the deterministic and stochastic terms of the dynamics and returns them as lineax.PyTreeLinearOperator.

Source code in pastax/dynamics/_smagorinsky_diffusion.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
class StochasticSmagorinskyDiffusion(SmagorinskyDiffusion):
    r"""
    Trainable stochastic Smagorinsky diffusion dynamics.

    !!! example "Formulation"

        This dynamics allows to formulate a displacement at time $t$ from the position $\mathbf{X}(t)$ as:

        $$
        d\mathbf{X}(t) = (\mathbf{u} + \nabla K)(t, \mathbf{X}(t)) dt + V(t, \mathbf{X}(t)) d\mathbf{W}(t)
        $$

        where $V = \sqrt{2 K}$ and $K$ is the Smagorinsky diffusion:

        $$
        K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 +
        \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} +
        \frac{\partial v}{\partial x} \right)^2}
        $$

        where $C_s$ is the ***trainable*** Smagorinsky constant, $\Delta x \Delta y$ a spatial scaling factor, and the
        rest of the expression represents the horizontal diffusion.

    Methods
    -------
    __call__(t, y, args)
        Computes the deterministic and stochastic terms of the dynamics and returns them as lineax.PyTreeLinearOperator.
    """

    def __call__(self, t: Real[Array, ""], y: Float[Array, "2"], args: Gridded) -> lx.PyTreeLinearOperator:
        r"""
        Computes the deterministic and stochastic terms of the dynamics and returns them as [`lineax.PyTreeLinearOperator`][].

        Parameters
        ----------
        t : Real[Array, ""]
            The current time.
        y : Float[Array, "2"]
            The current state (latitude and longitude).
        args : Gridded
            The [`pastax.gridded.Gridded`][] containing the velocity fields.

        Returns
        -------
        lx.PyTreeLinearOperator
            The stacked deterministic and stochastic parts of the dynamics.
        """
        gridded = args

        smag_ds = self._smagorinsky_diffusion(t, y, gridded)  # "1 x_width-2 x_width-2"

        dlatlon_deter = self._deterministic_dynamics(t, y, gridded, smag_ds)
        dlatlon_stoch = self._stochastic_dynamics(y, smag_ds)

        if gridded.is_spherical_mesh and not gridded.use_degrees:
            dlatlon_deter = meters_to_degrees(dlatlon_deter, latitude=y[0])
            dlatlon_stoch = meters_to_degrees(dlatlon_stoch, latitude=y[0])

        return lx.PyTreeLinearOperator((dlatlon_deter, dlatlon_stoch), jax.ShapeDtypeStruct((2,), float))

    @classmethod
    def from_cs(cls, cs: Real[Any, ""] = 0.1):
        """
        Initializes the stochastic Smagorinsky diffusion with the given Smagorinsky constant.

        Parameters
        ----------
        cs : Real[Any, ""], optional
            The Smagorinsky constant, defaults to `jnp.asarray(0.1, dtype=float)`.

        Returns
        -------
        StochasticSmagorinskyDiffusion
            The [`pastax.dynamics.StochasticSmagorinskyDiffusion`][] initialized with the given Smagorinsky constant.
        """
        return cls(_cs=_from_cs(cs))
__call__(t: Real[Array, ''], y: Float[Array, '2'], args: Gridded) -> lx.PyTreeLinearOperator

Computes the deterministic and stochastic terms of the dynamics and returns them as lineax.PyTreeLinearOperator.

Parameters:

Name Type Description Default
t Real[Array, '']

The current time.

required
y Float[Array, '2']

The current state (latitude and longitude).

required
args Gridded

The pastax.gridded.Gridded containing the velocity fields.

required

Returns:

Type Description
PyTreeLinearOperator

The stacked deterministic and stochastic parts of the dynamics.

Source code in pastax/dynamics/_smagorinsky_diffusion.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def __call__(self, t: Real[Array, ""], y: Float[Array, "2"], args: Gridded) -> lx.PyTreeLinearOperator:
    r"""
    Computes the deterministic and stochastic terms of the dynamics and returns them as [`lineax.PyTreeLinearOperator`][].

    Parameters
    ----------
    t : Real[Array, ""]
        The current time.
    y : Float[Array, "2"]
        The current state (latitude and longitude).
    args : Gridded
        The [`pastax.gridded.Gridded`][] containing the velocity fields.

    Returns
    -------
    lx.PyTreeLinearOperator
        The stacked deterministic and stochastic parts of the dynamics.
    """
    gridded = args

    smag_ds = self._smagorinsky_diffusion(t, y, gridded)  # "1 x_width-2 x_width-2"

    dlatlon_deter = self._deterministic_dynamics(t, y, gridded, smag_ds)
    dlatlon_stoch = self._stochastic_dynamics(y, smag_ds)

    if gridded.is_spherical_mesh and not gridded.use_degrees:
        dlatlon_deter = meters_to_degrees(dlatlon_deter, latitude=y[0])
        dlatlon_stoch = meters_to_degrees(dlatlon_stoch, latitude=y[0])

    return lx.PyTreeLinearOperator((dlatlon_deter, dlatlon_stoch), jax.ShapeDtypeStruct((2,), float))
from_cs(cs: Real[Any, ''] = 0.1) classmethod

Initializes the stochastic Smagorinsky diffusion with the given Smagorinsky constant.

Parameters:

Name Type Description Default
cs Real[Any, '']

The Smagorinsky constant, defaults to jnp.asarray(0.1, dtype=float).

0.1

Returns:

Type Description
StochasticSmagorinskyDiffusion

The pastax.dynamics.StochasticSmagorinskyDiffusion initialized with the given Smagorinsky constant.

Source code in pastax/dynamics/_smagorinsky_diffusion.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
@classmethod
def from_cs(cls, cs: Real[Any, ""] = 0.1):
    """
    Initializes the stochastic Smagorinsky diffusion with the given Smagorinsky constant.

    Parameters
    ----------
    cs : Real[Any, ""], optional
        The Smagorinsky constant, defaults to `jnp.asarray(0.1, dtype=float)`.

    Returns
    -------
    StochasticSmagorinskyDiffusion
        The [`pastax.dynamics.StochasticSmagorinskyDiffusion`][] initialized with the given Smagorinsky constant.
    """
    return cls(_cs=_from_cs(cs))

DeterministicSmagorinskyDiffusion

Bases: SmagorinskyDiffusion

Trainable deterministic Smagorinsky diffusion dynamics.

Formulation

This dynamics allows to formulate a displacement at time \(t\) from the position \(\mathbf{X}(t)\) as:

\[ d\mathbf{X}(t) = (\mathbf{u} + \nabla K)(t, \mathbf{X}(t)) dt + V(t, \mathbf{X}(t)) d\mathbf{W}(t) \]

where \(V = \sqrt{2 K}\) and \(K\) is the Smagorinsky diffusion:

\[ K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 + \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} + \frac{\partial v}{\partial x} \right)^2} \]

where \(C_s\) is the trainable Smagorinsky constant, \(\Delta x \Delta y\) a spatial scaling factor, and the rest of the expression represents the horizontal diffusion.

Methods:

Name Description
__call__

Computes the deterministic term of the dynamics.

Source code in pastax/dynamics/_smagorinsky_diffusion.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
class DeterministicSmagorinskyDiffusion(SmagorinskyDiffusion):
    r"""
    Trainable deterministic Smagorinsky diffusion dynamics.

    !!! example "Formulation"

        This dynamics allows to formulate a displacement at time $t$ from the position $\mathbf{X}(t)$ as:

        $$
        d\mathbf{X}(t) = (\mathbf{u} + \nabla K)(t, \mathbf{X}(t)) dt + V(t, \mathbf{X}(t)) d\mathbf{W}(t)
        $$

        where $V = \sqrt{2 K}$ and $K$ is the Smagorinsky diffusion:

        $$
        K = C_s \Delta x \Delta y \sqrt{\left(\frac{\partial u}{\partial x} \right)^2 +
        \left(\frac{\partial v}{\partial y} \right)^2 + \frac{1}{2} \left(\frac{\partial u}{\partial y} +
        \frac{\partial v}{\partial x} \right)^2}
        $$

        where $C_s$ is the ***trainable*** Smagorinsky constant, $\Delta x \Delta y$ a spatial scaling factor, and the
        rest of the expression represents the horizontal diffusion.

    Methods
    -------
    __call__(t, y, args)
        Computes the deterministic term of the dynamics.
    """

    def __call__(self, t: Real[Array, ""], y: Float[Array, "2"], args: Gridded) -> Float[Array, "2"]:
        r"""
        Computes the deterministic term of the dynamics.

        Parameters
        ----------
        t : Real[Array, ""]
            The current time.
        y : Float[Array, "2"]
            The current state (latitude and longitude).
        args : Gridded
            The [`pastax.gridded.Gridded`][] containing the velocity fields.

        Returns
        -------
        Float[Array, "2 3"]
            The deterministic part of the dynamics.
        """
        gridded = args

        smag_ds = self._smagorinsky_diffusion(t, y, gridded)  # "1 x_width-2 x_width-2"
        dlatlon = self._deterministic_dynamics(t, y, gridded, smag_ds)

        if gridded.is_spherical_mesh and not gridded.use_degrees:
            dlatlon = meters_to_degrees(dlatlon, latitude=y[0])

        return dlatlon

    @classmethod
    def from_cs(cls, cs: Real[Any, ""] = 0.1):
        """
        Initializes the deterministic Smagorinsky diffusion with the given Smagorinsky constant.

        Parameters
        ----------
        cs : Real[Any, ""], optional
            The Smagorinsky constant, defaults to `jnp.asarray(0.1, dtype=float)`.

        Returns
        -------
        DeterministicSmagorinskyDiffusion
            The [`pastax.dynamics.DeterministicSmagorinskyDiffusion`][] initialized with the given Smagorinsky constant.
        """
        return cls(_cs=_from_cs(cs))
__call__(t: Real[Array, ''], y: Float[Array, '2'], args: Gridded) -> Float[Array, '2']

Computes the deterministic term of the dynamics.

Parameters:

Name Type Description Default
t Real[Array, '']

The current time.

required
y Float[Array, '2']

The current state (latitude and longitude).

required
args Gridded

The pastax.gridded.Gridded containing the velocity fields.

required

Returns:

Type Description
Float[Array, '2 3']

The deterministic part of the dynamics.

Source code in pastax/dynamics/_smagorinsky_diffusion.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def __call__(self, t: Real[Array, ""], y: Float[Array, "2"], args: Gridded) -> Float[Array, "2"]:
    r"""
    Computes the deterministic term of the dynamics.

    Parameters
    ----------
    t : Real[Array, ""]
        The current time.
    y : Float[Array, "2"]
        The current state (latitude and longitude).
    args : Gridded
        The [`pastax.gridded.Gridded`][] containing the velocity fields.

    Returns
    -------
    Float[Array, "2 3"]
        The deterministic part of the dynamics.
    """
    gridded = args

    smag_ds = self._smagorinsky_diffusion(t, y, gridded)  # "1 x_width-2 x_width-2"
    dlatlon = self._deterministic_dynamics(t, y, gridded, smag_ds)

    if gridded.is_spherical_mesh and not gridded.use_degrees:
        dlatlon = meters_to_degrees(dlatlon, latitude=y[0])

    return dlatlon
from_cs(cs: Real[Any, ''] = 0.1) classmethod

Initializes the deterministic Smagorinsky diffusion with the given Smagorinsky constant.

Parameters:

Name Type Description Default
cs Real[Any, '']

The Smagorinsky constant, defaults to jnp.asarray(0.1, dtype=float).

0.1

Returns:

Type Description
DeterministicSmagorinskyDiffusion

The pastax.dynamics.DeterministicSmagorinskyDiffusion initialized with the given Smagorinsky constant.

Source code in pastax/dynamics/_smagorinsky_diffusion.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
@classmethod
def from_cs(cls, cs: Real[Any, ""] = 0.1):
    """
    Initializes the deterministic Smagorinsky diffusion with the given Smagorinsky constant.

    Parameters
    ----------
    cs : Real[Any, ""], optional
        The Smagorinsky constant, defaults to `jnp.asarray(0.1, dtype=float)`.

    Returns
    -------
    DeterministicSmagorinskyDiffusion
        The [`pastax.dynamics.DeterministicSmagorinskyDiffusion`][] initialized with the given Smagorinsky constant.
    """
    return cls(_cs=_from_cs(cs))

linear_uv(t: Real[Array, ''], y: Float[Array, '2'], args: Gridded) -> Float[Array, '2']

Computes the Lagrangian drift velocity by interpolating in space and time the velocity fields.

Parameters:

Name Type Description Default
t Real[Array, '']

The current time.

required
y Float[Array, '2']

The current state (latitude and longitude in degrees).

required
args Dataset

The pastax.gridded.Gridded containing the physical fields (only u and v here).

required

Returns:

Type Description
Float[Array, '2']

The Lagrangian drift velocity.

Source code in pastax/dynamics/_linear_uv.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def linear_uv(t: Real[Array, ""], y: Float[Array, "2"], args: Gridded) -> Float[Array, "2"]:
    """
    Computes the Lagrangian drift velocity by interpolating in space and time the velocity fields.

    Parameters
    ----------
    t : Real[Array, ""]
        The current time.
    y : Float[Array, "2"]
        The current state (latitude and longitude in degrees).
    args : Dataset
        The [`pastax.gridded.Gridded`][] containing the physical fields (only u and v here).

    Returns
    -------
    Float[Array, "2"]
        The Lagrangian drift velocity.
    """
    dlatlon = _linear_uv(t, y, args)

    dataset = args
    if dataset.is_spherical_mesh and not dataset.use_degrees:
        dlatlon = meters_to_degrees(dlatlon, latitude=y[0])

    return dlatlon

pastax.trajectory

This module provides classes for handling pastax.trajectory.State, pastax.trajectory.Timeseries, pastax.trajectory.Trajectory, and pastax.trajectory.TrajectoryEnsemble in JAX.

State

Bases: Unitful

Class representing a pastax.trajectory.State with a value, unit, and name.

Attributes:

Name Type Description
name (str | None, optional)

The name of the state, defaults to None.

Methods:

Name Description
__init__

Initializes the pastax.trajectory.State with given value, unit and name.

attach_name

Attaches a name to the [pastax.trajectory.State].

Source code in pastax/trajectory/_state.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class State(Unitful):
    """
    Class representing a [`pastax.trajectory.State`][] with a value, unit, and name.

    Attributes
    ----------
    name : str | None, optional
        The name of the state, defaults to `None`.

    Methods
    -------
    __init__(value, unit={}, name=None)
        Initializes the [`pastax.trajectory.State`][] with given value, unit and name.
    attach_name(name)
        Attaches a name to the [`pastax.trajectory.State`].
    """

    name: str | None = eqx.field(static=True)

    def __init__(
        self,
        value: Real[Array, "..."],
        unit: dict[Unit, int | float] = {},
        name: str | None = None,
    ):
        """
        Initializes the [`pastax.trajectory.State`][] with given value, unit and name.

        Parameters
        ----------
        value : Real[Array, "..."]
            The value of the state.
        unit : dict[Unit, int | float], optional
            The unit of the state, defaults to an empty `dict`.
        name : str | None, optional
            The type of the state, defaults to `None`.
        """
        super().__init__(value, unit)
        self.name = name

    def attach_name(self, name: str) -> State:
        """
        Attaches a name to the [`pastax.trajectory.State`].

        Parameters
        ----------
        name : str
            The name to attach to the state.

        Returns
        -------
        State
            A new [`pastax.trajectory.State`][] with the attached name.
        """
        return self.__class__(self.value, self.unit, name)
__init__(value: Real[Array, '...'], unit: dict[Unit, int | float] = {}, name: str | None = None)

Initializes the pastax.trajectory.State with given value, unit and name.

Parameters:

Name Type Description Default
value Real[Array, '...']

The value of the state.

required
unit dict[Unit, int | float]

The unit of the state, defaults to an empty dict.

{}
name str | None

The type of the state, defaults to None.

None
Source code in pastax/trajectory/_state.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    value: Real[Array, "..."],
    unit: dict[Unit, int | float] = {},
    name: str | None = None,
):
    """
    Initializes the [`pastax.trajectory.State`][] with given value, unit and name.

    Parameters
    ----------
    value : Real[Array, "..."]
        The value of the state.
    unit : dict[Unit, int | float], optional
        The unit of the state, defaults to an empty `dict`.
    name : str | None, optional
        The type of the state, defaults to `None`.
    """
    super().__init__(value, unit)
    self.name = name
attach_name(name: str) -> State

Attaches a name to the [pastax.trajectory.State].

Parameters:

Name Type Description Default
name str

The name to attach to the state.

required

Returns:

Type Description
State

A new pastax.trajectory.State with the attached name.

Source code in pastax/trajectory/_state.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def attach_name(self, name: str) -> State:
    """
    Attaches a name to the [`pastax.trajectory.State`].

    Parameters
    ----------
    name : str
        The name to attach to the state.

    Returns
    -------
    State
        A new [`pastax.trajectory.State`][] with the attached name.
    """
    return self.__class__(self.value, self.unit, name)

Location

Bases: State

Class representing a geographical location with latitude and longitude.

Attributes:

Name Type Description
name str

The name of the pastax.trajectory.Location, set to "Location in [latitude, longitude]".

Methods:

Name Description
__init__

Initializes the pastax.trajectory.Location with given latitude and longitude.

latitude

Returns the latitude of the pastax.trajectory.Location.

longitude

Returns the longitude of the pastax.trajectory.Location.

distance_on_earth

Computes the Earth distance between this pastax.trajectory.Location and another pastax.trajectory.Location.

Source code in pastax/trajectory/_states.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class Location(State):
    """
    Class representing a geographical location with latitude and longitude.

    Attributes
    ----------
    name : str
        The name of the [`pastax.trajectory.Location`][], set to `"Location in [latitude, longitude]"`.

    Methods
    -------
    __init__(value, **_)
        Initializes the [`pastax.trajectory.Location`][] with given latitude and longitude.
    latitude
        Returns the latitude of the [`pastax.trajectory.Location`][].
    longitude
        Returns the longitude of the [`pastax.trajectory.Location`][].
    distance_on_earth(other)
        Computes the Earth distance between this [`pastax.trajectory.Location`][] and another
        [`pastax.trajectory.Location`][].
    """

    _value: Float[Array, "... 2"] = eqx.field(converter=lambda x: jnp.asarray(x, dtype=float))

    def __init__(
        self,
        value: Float[Any, "... 2"],
        unit: dict[Unit, int | float] = unit_converter(UNIT["°"]),
        **_,
    ):
        """
        Initializes the [`pastax.trajectory.Location`][] with given latitude and longitude.

        Parameters
        ----------
        value : Float[Array, "... 2"] | Sequence[float]
            The latitude and longitude of the location.
        unit : dict[str, Unit], optional
            The [`pastax.utils.Unit`][] of the location, defaults to [`pastax.utils.LatLonDegrees`][].
        """
        if unit == unit_converter(UNIT["°"]):
            value = jnp.asarray(value)
            value = value.at[..., 1].set(longitude_in_180_180_degrees(value[..., 1]))

        super().__init__(value, unit=unit, name="Location in [latitude, longitude]")

    @property
    def latitude(self) -> State:
        """
        Returns the latitude of the [`pastax.trajectory.Location`][].

        Returns
        -------
        State
            The latitude of the [`pastax.trajectory.Location`][].
        """
        return State(self.value[..., 0], unit=UNIT["°"], name="Latitude")

    @property
    def longitude(self) -> State:
        """
        Returns the longitude of the [`pastax.trajectory.Location`][].

        Returns
        -------
        State
            The longitude of the [`pastax.trajectory.Location`][].
        """
        return State(self.value[..., 1], unit=UNIT["°"], name="Longitude")

    def distance_on_earth(self, other: Location) -> State:
        """
        Computes the distance in meters between this [`pastax.trajectory.Location`][]
        and another [`pastax.trajectory.Location`][].

        Parameters
        ----------
        other : Location
            The other [`pastax.trajectory.Location`][] to compute the distance to.

        Returns
        -------
        State
            The Earth distance in meters between the two [`pastax.trajectory.Location`][].

        Notes
        -----
        This function uses the Haversine formula to compute the distance between two points on the Earth surface.
        """
        if not self.unit == unit_converter(UNIT["°"]) or not other.unit == unit_converter(UNIT["°"]):
            raise ValueError("Both locations must be in degrees.")

        return State(
            distance_on_earth(self.value, other.value),
            unit=UNIT["m"],
            name="Distance on Earth",
        )
latitude: State property

Returns the latitude of the pastax.trajectory.Location.

Returns:

Type Description
State

The latitude of the pastax.trajectory.Location.

longitude: State property

Returns the longitude of the pastax.trajectory.Location.

Returns:

Type Description
State

The longitude of the pastax.trajectory.Location.

__init__(value: Float[Any, '... 2'], unit: dict[Unit, int | float] = unit_converter(UNIT['°']), **_)

Initializes the pastax.trajectory.Location with given latitude and longitude.

Parameters:

Name Type Description Default
value Float[Array, '... 2'] | Sequence[float]

The latitude and longitude of the location.

required
unit dict[str, Unit]

The pastax.utils.Unit of the location, defaults to pastax.utils.LatLonDegrees.

unit_converter(UNIT['°'])
Source code in pastax/trajectory/_states.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    value: Float[Any, "... 2"],
    unit: dict[Unit, int | float] = unit_converter(UNIT["°"]),
    **_,
):
    """
    Initializes the [`pastax.trajectory.Location`][] with given latitude and longitude.

    Parameters
    ----------
    value : Float[Array, "... 2"] | Sequence[float]
        The latitude and longitude of the location.
    unit : dict[str, Unit], optional
        The [`pastax.utils.Unit`][] of the location, defaults to [`pastax.utils.LatLonDegrees`][].
    """
    if unit == unit_converter(UNIT["°"]):
        value = jnp.asarray(value)
        value = value.at[..., 1].set(longitude_in_180_180_degrees(value[..., 1]))

    super().__init__(value, unit=unit, name="Location in [latitude, longitude]")
distance_on_earth(other: Location) -> State

Computes the distance in meters between this pastax.trajectory.Location and another pastax.trajectory.Location.

Parameters:

Name Type Description Default
other Location

The other pastax.trajectory.Location to compute the distance to.

required

Returns:

Type Description
State

The Earth distance in meters between the two pastax.trajectory.Location.

Notes

This function uses the Haversine formula to compute the distance between two points on the Earth surface.

Source code in pastax/trajectory/_states.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def distance_on_earth(self, other: Location) -> State:
    """
    Computes the distance in meters between this [`pastax.trajectory.Location`][]
    and another [`pastax.trajectory.Location`][].

    Parameters
    ----------
    other : Location
        The other [`pastax.trajectory.Location`][] to compute the distance to.

    Returns
    -------
    State
        The Earth distance in meters between the two [`pastax.trajectory.Location`][].

    Notes
    -----
    This function uses the Haversine formula to compute the distance between two points on the Earth surface.
    """
    if not self.unit == unit_converter(UNIT["°"]) or not other.unit == unit_converter(UNIT["°"]):
        raise ValueError("Both locations must be in degrees.")

    return State(
        distance_on_earth(self.value, other.value),
        unit=UNIT["m"],
        name="Distance on Earth",
    )

Displacement

Bases: State

Class representing a pastax.trajectory.Displacement with latitude and longitude components.

Attributes:

Name Type Description
name str

The name of the [pastax.trajectory.Displacement], set to "Displacement in [latitude, longitude]".

Methods:

Name Description
__init__

Initializes the pastax.trajectory.Displacement with given latitude and longitude components and unit.

latitude

Returns the latitude component of the [pastax.trajectory.Displacement].

longitude

Returns the longitude component of the [pastax.trajectory.Displacement].

Source code in pastax/trajectory/_states.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class Displacement(State):
    """
    Class representing a [`pastax.trajectory.Displacement`][] with latitude and longitude components.

    Attributes
    ----------
    name : str
        The name of the [`pastax.trajectory.Displacement`], set to `"Displacement in [latitude, longitude]"`.

    Methods
    -------
    __init__(value, unit)
        Initializes the [`pastax.trajectory.Displacement`][] with given latitude and longitude components and unit.
    latitude
        Returns the latitude component of the [`pastax.trajectory.Displacement`].
    longitude
        Returns the longitude component of the [`pastax.trajectory.Displacement`].
    """

    def __init__(
        self,
        value: Real[Any, "... 2"],
        unit: dict[Unit, int | float] = UNIT["°"],
        **_,
    ):
        """
        Initializes the [`pastax.trajectory.Displacement`][] with given latitude and longitude components and
        [`pastax.utils.Unit`][].

        Parameters
        ----------
        value : Real[Any, "2"]
            The latitude and longitude components of the [`pastax.trajectory.Displacement`].
        unit : dict[str, Unit], optional
            The [`pastax.utils.Unit`][] of the [`pastax.trajectory.Displacement`], defaults to
            [`pastax.utils.LatLonDegrees`].
        """
        super().__init__(value, unit=unit, name="Displacement in [latitude, longitude]")

    @property
    def latitude(self) -> State:
        """
        Returns the latitude component of the [`pastax.trajectory.Displacement`].

        Returns
        -------
        State
            The latitude component of the [`pastax.trajectory.Displacement`].
        """
        return State(self.value[..., 0], unit=self.unit, name="Displacement in latitude")

    @property
    def longitude(self) -> State:
        """
        Returns the longitude component of the [`pastax.trajectory.Displacement`].

        Returns
        -------
        State
            The longitude component of the [`pastax.trajectory.Displacement`].
        """
        return State(self.value[..., 1], unit=self.unit, name="Displacement in longitude")
latitude: State property

Returns the latitude component of the [pastax.trajectory.Displacement].

Returns:

Type Description
State

The latitude component of the [pastax.trajectory.Displacement].

longitude: State property

Returns the longitude component of the [pastax.trajectory.Displacement].

Returns:

Type Description
State

The longitude component of the [pastax.trajectory.Displacement].

__init__(value: Real[Any, '... 2'], unit: dict[Unit, int | float] = UNIT['°'], **_)

Initializes the pastax.trajectory.Displacement with given latitude and longitude components and pastax.utils.Unit.

Parameters:

Name Type Description Default
value Real[Any, '2']

The latitude and longitude components of the [pastax.trajectory.Displacement].

required
unit dict[str, Unit]

The pastax.utils.Unit of the [pastax.trajectory.Displacement], defaults to [pastax.utils.LatLonDegrees].

UNIT['°']
Source code in pastax/trajectory/_states.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(
    self,
    value: Real[Any, "... 2"],
    unit: dict[Unit, int | float] = UNIT["°"],
    **_,
):
    """
    Initializes the [`pastax.trajectory.Displacement`][] with given latitude and longitude components and
    [`pastax.utils.Unit`][].

    Parameters
    ----------
    value : Real[Any, "2"]
        The latitude and longitude components of the [`pastax.trajectory.Displacement`].
    unit : dict[str, Unit], optional
        The [`pastax.utils.Unit`][] of the [`pastax.trajectory.Displacement`], defaults to
        [`pastax.utils.LatLonDegrees`].
    """
    super().__init__(value, unit=unit, name="Displacement in [latitude, longitude]")

Time

Bases: State

Class representing a pastax.trajectory.Time value.

Attributes:

Name Type Description
name str

The name of the [pastax.trajectory.Time], set to "Time since epoch".

Methods:

Name Description
__init__

Initializes the pastax.trajectory.Time with given time value.

to_datetime

Converts the pastax.trajectory.Time to a numpy.ndarray of datetime64[s].

Source code in pastax/trajectory/_states.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class Time(State):
    """
    Class representing a [`pastax.trajectory.Time`][] value.

    Attributes
    ----------
    name : str
        The name of the [`pastax.trajectory.Time`], set to `"Time since epoch"`.

    Methods
    -------
    __init__(value, unit)
        Initializes the [`pastax.trajectory.Time`][] with given time value.
    to_datetime()
        Converts the [`pastax.trajectory.Time`][] to a `numpy.ndarray` of `datetime64[s]`.
    """

    def __init__(self, value: Real[Any, "..."], unit: dict[Unit, int | float] = UNIT["s"], **_):
        """
        Initializes the [`pastax.trajectory.Time`][] with given time value.

        Parameters
        ----------
        value : Real[Any, "..."]
            The time value.
        unit : dict[Unit, int | float], optional
            The [`pastax.utils.Unit`][] of the [`pastax.trajectory.Time`], defaults to [`pastax.utils.Seconds`][].
        """
        super().__init__(value, unit=unit, name="Time since epoch")

    def to_datetime(self) -> np.ndarray:
        """
        Converts the [`pastax.trajectory.Time`][] to a `numpy.ndarray` of `datetime64[s]`.

        Returns
        -------
        np.ndarray
            The [`pastax.trajectory.Time`][] as a `numpy.ndarray` of `datetime64[s]`.
        """
        return np.asarray(self.value).astype("datetime64[s]")
__init__(value: Real[Any, '...'], unit: dict[Unit, int | float] = UNIT['s'], **_)

Initializes the pastax.trajectory.Time with given time value.

Parameters:

Name Type Description Default
value Real[Any, '...']

The time value.

required
unit dict[Unit, int | float]

The pastax.utils.Unit of the [pastax.trajectory.Time], defaults to pastax.utils.Seconds.

UNIT['s']
Source code in pastax/trajectory/_states.py
196
197
198
199
200
201
202
203
204
205
206
207
def __init__(self, value: Real[Any, "..."], unit: dict[Unit, int | float] = UNIT["s"], **_):
    """
    Initializes the [`pastax.trajectory.Time`][] with given time value.

    Parameters
    ----------
    value : Real[Any, "..."]
        The time value.
    unit : dict[Unit, int | float], optional
        The [`pastax.utils.Unit`][] of the [`pastax.trajectory.Time`], defaults to [`pastax.utils.Seconds`][].
    """
    super().__init__(value, unit=unit, name="Time since epoch")
to_datetime() -> np.ndarray

Converts the pastax.trajectory.Time to a numpy.ndarray of datetime64[s].

Returns:

Type Description
ndarray

The pastax.trajectory.Time as a numpy.ndarray of datetime64[s].

Source code in pastax/trajectory/_states.py
209
210
211
212
213
214
215
216
217
218
def to_datetime(self) -> np.ndarray:
    """
    Converts the [`pastax.trajectory.Time`][] to a `numpy.ndarray` of `datetime64[s]`.

    Returns
    -------
    np.ndarray
        The [`pastax.trajectory.Time`][] as a `numpy.ndarray` of `datetime64[s]`.
    """
    return np.asarray(self.value).astype("datetime64[s]")

Timeseries

Bases: Unitful

Class representing a [pastax.trajectory.Timeseries].

Attributes:

Name Type Description
states State

The pastax.trajectory.State of the [pastax.trajectory.Timeseries].

times Time

The pastax.trajectory.Time points of the [pastax.trajectory.Timeseries].

length int

The length of the [pastax.trajectory.Timeseries].

Methods:

Name Description
__init__

Initializes the pastax.trajectory.Timeseries with given pastax.trajectory.State, pastax.trajectory.Time, and optional parameters.

value

Returns the value of the [pastax.trajectory.Timeseries].

unit

Returns the unit of the [pastax.trajectory.Timeseries].

name

Returns the name of the [pastax.trajectory.Timeseries].

attach_name

Attaches a name to the [pastax.trajectory.Timeseries].

euclidean_distance

Computes the Euclidean distance between this pastax.trajectory.Timeseries and another [pastax.trajectory.Timeseries].

map

Applies a function to each pastax.trajectory.State in the [pastax.trajectory.Timeseries].

to_xarray

Converts the pastax.trajectory.Timeseries to a xarray.Dataset.

from_array

Creates a pastax.trajectory.Timeseries from arrays of values and time points.

Source code in pastax/trajectory/_timeseries.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class Timeseries(Unitful):
    """
    Class representing a [`pastax.trajectory.Timeseries`].

    Attributes
    ----------
    states : State
        The [`pastax.trajectory.State`][] of the [`pastax.trajectory.Timeseries`].
    times : Time
        The [`pastax.trajectory.Time`][] points of the [`pastax.trajectory.Timeseries`].
    length : int
        The length of the [`pastax.trajectory.Timeseries`].

    Methods
    -------
    __init__(states, times, **__)
        Initializes the [`pastax.trajectory.Timeseries`][] with given [`pastax.trajectory.State`][],
        [`pastax.trajectory.Time`][], and optional parameters.
    value
        Returns the value of the [`pastax.trajectory.Timeseries`].
    unit
        Returns the unit of the [`pastax.trajectory.Timeseries`].
    name
        Returns the name of the [`pastax.trajectory.Timeseries`].
    attach_name(name)
        Attaches a name to the [`pastax.trajectory.Timeseries`].
    euclidean_distance(other)
        Computes the Euclidean distance between this [`pastax.trajectory.Timeseries`][] and another
        [`pastax.trajectory.Timeseries`].
    map(func)
        Applies a function to each [`pastax.trajectory.State`][] in the [`pastax.trajectory.Timeseries`].
    to_xarray()
        Converts the [`pastax.trajectory.Timeseries`][] to a `xarray.Dataset`.
    from_array(values, times, unit={}, name=None, **kwargs)
        Creates a [`pastax.trajectory.Timeseries`][] from arrays of values and time points.
    """

    states: State
    _states_type: ClassVar = State
    times: Time
    length: int = eqx.field(static=True)

    _value: None = eqx.field(repr=False)
    _unit: None = eqx.field(repr=False)

    def __init__(self, states: State, times: Time, **_: Any):
        """
        Initializes the [`pastax.trajectory.Timeseries`][] with given [`pastax.trajectory.State`][],
        [`pastax.trajectory.Time`][], and optional parameters.

        Parameters
        ----------
        states : Float[Array, "... time state"]
            The states of the [`pastax.trajectory.Timeseries`].
        times : Float[Array, "time"]
            The time points for the [`pastax.trajectory.Timeseries`].
        """
        super().__init__()
        self.states = states
        self.times = times
        self.length = times.value.shape[-1]

    @property
    def value(self) -> Float[Array, "... time state"]:
        """
        Returns the value of the [`pastax.trajectory.Timeseries`].

        Returns
        -------
        Float[Array, "... time state"]
            The value of the [`pastax.trajectory.Timeseries`].
        """
        return self.states.value

    @property
    def unit(self) -> dict[Unit, int | float]:
        """
        Returns the unit of the [`pastax.trajectory.Timeseries`].

        Returns
        -------
        dict[Unit, int | float]
            The unit of the [`pastax.trajectory.Timeseries`].
        """
        return self.states.unit

    @property
    def name(self) -> str | None:
        """
        Returns the name of the [`pastax.trajectory.Timeseries`].

        Returns
        -------
        str | None
            The name of the [`pastax.trajectory.Timeseries`].
        """
        return self.states.name

    def attach_name(self, name: str) -> Timeseries:
        """
        Attaches a name to the [`pastax.trajectory.Timeseries`].

        Parameters
        ----------
        name : str
            The name to attach to the [`pastax.trajectory.Timeseries`].

        Returns
        -------
        Timeseries
            A new [`pastax.trajectory.Timeseries`][] with the attached name.
        """
        return Timeseries.from_array(self.states.value, self.times.value, unit=self.unit, name=name)

    def euclidean_distance(self, other: Timeseries | Array) -> Timeseries:
        """
        Computes the Euclidean distance between this timeseries and another timeseries.

        Parameters
        ----------
        other : Timeseries | Array
            The other [`pastax.trajectory.Timeseries`][] to compute the distance to.

        Returns
        -------
        Timeseries
            The Euclidean distance between the two [`pastax.trajectory.Timeseries`].
        """
        if isinstance(other, Timeseries):
            other = other.states

        res = eqx.filter_vmap(lambda p1, p2: p1.euclidean_distance(p2))(self.states, other)

        return Timeseries.from_array(res.value, self.times.value, self.unit, name="Euclidean distance")

    def map(self, func: Callable[[State], Unitful | Array]) -> Timeseries:
        """
        Applies a function to each [`pastax.trajectory.State`][] in the [`pastax.trajectory.Timeseries`].

        Parameters
        ----------
        func : Callable[[State], Unitful | Array]
            The function to apply to each [`pastax.trajectory.State`][].

        Returns
        -------
        Timeseries
            The result of applying the function to each [`pastax.trajectory.State`][].
        """
        in_axes = eqx.filter(self.states, False)
        in_axes = eqx.tree_at(lambda x: x._value, in_axes, 0, is_leaf=lambda x: x is None)
        res = eqx.filter_vmap(func, in_axes=(in_axes,))(self.states)

        unit = {}
        if isinstance(res, Unitful):
            unit = res.unit
            res = res.value

        return Timeseries.from_array(res, self.times.value, unit)

    def to_xarray(self) -> xr.Dataset:
        """
        Converts the [`pastax.trajectory.Timeseries`][] to a `xarray.Dataset`.

        Returns
        -------
        xr.Dataset
            The corresponding `xarray.Dataset`.
        """
        da = self.to_dataarray()
        ds = da.to_dataset()

        return ds

    @classmethod
    def from_array(
        cls,
        values: Float[Array, "time state"],
        times: Float[Array, "time"],
        unit: dict[Unit, int | float] = {},
        name: str | None = None,
        **kwargs: Any,
    ) -> Timeseries:
        """
        Creates a [`pastax.trajectory.Timeseries`][] from arrays of values and time points.

        Parameters
        ----------
        values : Float[Array, "time state"]
            The array of values for the timeseries.
        times : Float[Array, "time"]
            The time points for the timeseries.
        unit : dict[Unit, int | float], optional
            The unit of the timeseries, defaults to an empty dict.
        name : str | None, optional
            The name of the timeseries, defaults to None.
        **kwargs : Any
            Additional keyword arguments.

        Returns
        -------
        Timeseries
            The corresponding [`pastax.trajectory.Timeseries`][].
        """
        values = jnp.asarray(values, dtype=float)
        times = jnp.asarray(times, dtype=float)

        return cls(cls._states_type(values, unit=unit, name=name), Time(times), **kwargs)

    def to_dataarray(self) -> xr.DataArray:
        da = xr.DataArray(
            data=self.states.value,
            dims=["time"],
            coords={"time": self.times.to_datetime()},
            name=self.name,
            attrs={"units": units_to_str(self.unit)},
        )

        return da
value: Float[Array, '... time state'] property

Returns the value of the [pastax.trajectory.Timeseries].

Returns:

Type Description
Float[Array, '... time state']

The value of the [pastax.trajectory.Timeseries].

unit: dict[Unit, int | float] property

Returns the unit of the [pastax.trajectory.Timeseries].

Returns:

Type Description
dict[Unit, int | float]

The unit of the [pastax.trajectory.Timeseries].

name: str | None property

Returns the name of the [pastax.trajectory.Timeseries].

Returns:

Type Description
str | None

The name of the [pastax.trajectory.Timeseries].

__init__(states: State, times: Time, **_: Any)

Initializes the pastax.trajectory.Timeseries with given pastax.trajectory.State, pastax.trajectory.Time, and optional parameters.

Parameters:

Name Type Description Default
states Float[Array, '... time state']

The states of the [pastax.trajectory.Timeseries].

required
times Float[Array, 'time']

The time points for the [pastax.trajectory.Timeseries].

required
Source code in pastax/trajectory/_timeseries.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(self, states: State, times: Time, **_: Any):
    """
    Initializes the [`pastax.trajectory.Timeseries`][] with given [`pastax.trajectory.State`][],
    [`pastax.trajectory.Time`][], and optional parameters.

    Parameters
    ----------
    states : Float[Array, "... time state"]
        The states of the [`pastax.trajectory.Timeseries`].
    times : Float[Array, "time"]
        The time points for the [`pastax.trajectory.Timeseries`].
    """
    super().__init__()
    self.states = states
    self.times = times
    self.length = times.value.shape[-1]
attach_name(name: str) -> Timeseries

Attaches a name to the [pastax.trajectory.Timeseries].

Parameters:

Name Type Description Default
name str

The name to attach to the [pastax.trajectory.Timeseries].

required

Returns:

Type Description
Timeseries

A new pastax.trajectory.Timeseries with the attached name.

Source code in pastax/trajectory/_timeseries.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def attach_name(self, name: str) -> Timeseries:
    """
    Attaches a name to the [`pastax.trajectory.Timeseries`].

    Parameters
    ----------
    name : str
        The name to attach to the [`pastax.trajectory.Timeseries`].

    Returns
    -------
    Timeseries
        A new [`pastax.trajectory.Timeseries`][] with the attached name.
    """
    return Timeseries.from_array(self.states.value, self.times.value, unit=self.unit, name=name)
euclidean_distance(other: Timeseries | Array) -> Timeseries

Computes the Euclidean distance between this timeseries and another timeseries.

Parameters:

Name Type Description Default
other Timeseries | Array

The other pastax.trajectory.Timeseries to compute the distance to.

required

Returns:

Type Description
Timeseries

The Euclidean distance between the two [pastax.trajectory.Timeseries].

Source code in pastax/trajectory/_timeseries.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def euclidean_distance(self, other: Timeseries | Array) -> Timeseries:
    """
    Computes the Euclidean distance between this timeseries and another timeseries.

    Parameters
    ----------
    other : Timeseries | Array
        The other [`pastax.trajectory.Timeseries`][] to compute the distance to.

    Returns
    -------
    Timeseries
        The Euclidean distance between the two [`pastax.trajectory.Timeseries`].
    """
    if isinstance(other, Timeseries):
        other = other.states

    res = eqx.filter_vmap(lambda p1, p2: p1.euclidean_distance(p2))(self.states, other)

    return Timeseries.from_array(res.value, self.times.value, self.unit, name="Euclidean distance")
map(func: Callable[[State], Unitful | Array]) -> Timeseries

Applies a function to each pastax.trajectory.State in the [pastax.trajectory.Timeseries].

Parameters:

Name Type Description Default
func Callable[[State], Unitful | Array]

The function to apply to each pastax.trajectory.State.

required

Returns:

Type Description
Timeseries

The result of applying the function to each pastax.trajectory.State.

Source code in pastax/trajectory/_timeseries.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def map(self, func: Callable[[State], Unitful | Array]) -> Timeseries:
    """
    Applies a function to each [`pastax.trajectory.State`][] in the [`pastax.trajectory.Timeseries`].

    Parameters
    ----------
    func : Callable[[State], Unitful | Array]
        The function to apply to each [`pastax.trajectory.State`][].

    Returns
    -------
    Timeseries
        The result of applying the function to each [`pastax.trajectory.State`][].
    """
    in_axes = eqx.filter(self.states, False)
    in_axes = eqx.tree_at(lambda x: x._value, in_axes, 0, is_leaf=lambda x: x is None)
    res = eqx.filter_vmap(func, in_axes=(in_axes,))(self.states)

    unit = {}
    if isinstance(res, Unitful):
        unit = res.unit
        res = res.value

    return Timeseries.from_array(res, self.times.value, unit)
to_xarray() -> xr.Dataset

Converts the pastax.trajectory.Timeseries to a xarray.Dataset.

Returns:

Type Description
Dataset

The corresponding xarray.Dataset.

Source code in pastax/trajectory/_timeseries.py
176
177
178
179
180
181
182
183
184
185
186
187
188
def to_xarray(self) -> xr.Dataset:
    """
    Converts the [`pastax.trajectory.Timeseries`][] to a `xarray.Dataset`.

    Returns
    -------
    xr.Dataset
        The corresponding `xarray.Dataset`.
    """
    da = self.to_dataarray()
    ds = da.to_dataset()

    return ds
from_array(values: Float[Array, 'time state'], times: Float[Array, 'time'], unit: dict[Unit, int | float] = {}, name: str | None = None, **kwargs: Any) -> Timeseries classmethod

Creates a pastax.trajectory.Timeseries from arrays of values and time points.

Parameters:

Name Type Description Default
values Float[Array, 'time state']

The array of values for the timeseries.

required
times Float[Array, 'time']

The time points for the timeseries.

required
unit dict[Unit, int | float]

The unit of the timeseries, defaults to an empty dict.

{}
name str | None

The name of the timeseries, defaults to None.

None
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Timeseries

The corresponding pastax.trajectory.Timeseries.

Source code in pastax/trajectory/_timeseries.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@classmethod
def from_array(
    cls,
    values: Float[Array, "time state"],
    times: Float[Array, "time"],
    unit: dict[Unit, int | float] = {},
    name: str | None = None,
    **kwargs: Any,
) -> Timeseries:
    """
    Creates a [`pastax.trajectory.Timeseries`][] from arrays of values and time points.

    Parameters
    ----------
    values : Float[Array, "time state"]
        The array of values for the timeseries.
    times : Float[Array, "time"]
        The time points for the timeseries.
    unit : dict[Unit, int | float], optional
        The unit of the timeseries, defaults to an empty dict.
    name : str | None, optional
        The name of the timeseries, defaults to None.
    **kwargs : Any
        Additional keyword arguments.

    Returns
    -------
    Timeseries
        The corresponding [`pastax.trajectory.Timeseries`][].
    """
    values = jnp.asarray(values, dtype=float)
    times = jnp.asarray(times, dtype=float)

    return cls(cls._states_type(values, unit=unit, name=name), Time(times), **kwargs)

Trajectory

Bases: Timeseries

Class representing a trajectory with 2D geographical locations over time.

Attributes:

Name Type Description
states Location

The locations of the trajectory.

id (Int[Array, ''] | None, optional)

The ID of the trajectory, defaults to None.

Methods:

Name Description
__init__

Initializes the pastax.trajectory.Trajectory with given locations, times, and optional trajectory ID.

latitudes

Returns the latitudes of the trajectory.

locations

Returns the locations of the trajectory.

longitudes

Returns the longitudes of the trajectory.

origin

Returns the origin of the trajectory.

lengths

Returns the cumulative lengths of the trajectory.

liu_index

Computes the Liu Index between this trajectory and another trajectory.

mae

Computes the Mean Absolute Error (MAE) between this trajectory and another trajectory.

plot

Plots the trajectory.

rmse

Computes the Root Mean Square Error (RMSE) between this trajectory and another trajectory.

separation_distance

Computes the separation distance between this trajectory and another trajectory.

steps

Returns the steps of the trajectory.

to_xarray

Converts the pastax.trajectory.Trajectory to a xarray.Dataset.

from_array

Creates a pastax.trajectory.Trajectory from arrays of (latitudes, longitudes) values and time points.

from_xarray

Creates a pastax.trajectory.Trajectory from a xarray.Dataset.

Source code in pastax/trajectory/_trajectory.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
class Trajectory(Timeseries):
    """
    Class representing a trajectory with 2D geographical locations over time.

    Attributes
    ----------
    states : Location
        The locations of the trajectory.
    id : Int[Array, ""] | None, optional
        The ID of the trajectory, defaults to `None`.

    Methods
    -------
    __init__(locations, times, id=None, **_)
        Initializes the [`pastax.trajectory.Trajectory`][] with given locations, times, and optional trajectory ID.
    latitudes
        Returns the latitudes of the trajectory.
    locations
        Returns the locations of the trajectory.
    longitudes
        Returns the longitudes of the trajectory.
    origin
        Returns the origin of the trajectory.
    lengths()
        Returns the cumulative lengths of the trajectory.
    liu_index(other)
        Computes the Liu Index between this trajectory and another trajectory.
    mae(other)
        Computes the Mean Absolute Error (MAE) between this trajectory and another trajectory.
    plot(ax=None, label=None, color=None, alpha_factor=1, ti=None)
        Plots the trajectory.
    rmse(other)
        Computes the Root Mean Square Error (RMSE) between this trajectory and another trajectory.
    separation_distance(other)
        Computes the separation distance between this trajectory and another trajectory.
    steps()
        Returns the steps of the trajectory.
    to_xarray()
        Converts the [`pastax.trajectory.Trajectory`][] to a `xarray.Dataset`.
    from_array(values, times, unit=UNIT["°"], id=None)
        Creates a [`pastax.trajectory.Trajectory`][] from arrays of (latitudes, longitudes) values and time points.
    from_xarray(dataset, time_varname="time", lat_varname="lat", lon_varname="lon", unit=UNIT["°"], id=None)
        Creates a [`pastax.trajectory.Trajectory`][] from a `xarray.Dataset`.
    """

    states: Location
    _states_type: ClassVar = Location
    id: Int[Array, ""] | None = None

    def __init__(
        self,
        locations: Location,
        times: Time,
        id: Int[Array, ""] | None = None,
        *_,
        **__,
    ):
        """
        Initializes the Trajectory with given locations, times, and optional trajectory ID.

        Parameters
        ----------
        locations : Float[Array, "... time 2"]
            The locations for the trajectory.
        times : Int[Array, "... time"]
            The time points for the trajectory.
        id : Int[Array, ""] | None, optional
            The ID of the trajectory, defaults to None.
        """
        super().__init__(locations, times)
        self.id = id

    @property
    def latitudes(self) -> State:
        """
        Returns the latitudes of the trajectory.

        Returns
        -------
        State
            The latitudes of the trajectory.
        """
        return self.locations.latitude

    @property
    def locations(self) -> Location:
        """
        Returns the locations of the trajectory.

        Returns
        -------
        Location
            The locations of the trajectory.
        """
        return self.states

    @property
    def longitudes(self) -> State:
        """
        Returns the longitudes of the trajectory.

        Returns
        -------
        State
            The longitudes of the trajectory.
        """
        return self.locations.longitude

    @property
    def origin(self) -> Location:
        """
        Returns the origin of the trajectory.

        Returns
        -------
        State
            The origin of the trajectory.
        """
        return Location(
            self.locations.value[..., 0, :],
            unit=self.unit,
            name="Origin in [latitude, longitude]",
        )

    def lengths(self) -> Timeseries:
        """
        Returns the cumulative lengths of the trajectory.

        Returns
        -------
        Timeseries
            The cumulative lengths of the trajectory.
        """
        lengths = self.steps().cumsum()
        return Timeseries.from_array(
            lengths.value,
            self.times.value,
            unit=lengths.unit,
            name="Cumulative lengths",
        )

    def liu_index(self, other: Trajectory) -> Timeseries:
        """
        Computes the Liu Index (over time) between this trajectory and another trajectory.

        Parameters
        ----------
        other : Trajectory
            The other trajectory to compare against.

        Returns
        -------
        Timeseries
            The Liu Index between the two trajectories.
        """
        error = self.separation_distance(other).value.cumsum()
        cum_lengths = self.lengths().value.cumsum()
        cum_lengths = jnp.where(cum_lengths == 0, jnp.inf, cum_lengths)
        liu_index = error / cum_lengths

        return Timeseries.from_array(liu_index, self.times.value, name="Liu index")

    def mae(self, other: Trajectory) -> Timeseries:
        """
        Computes the Mean Absolute Error (MAE) (over time) between this trajectory and another trajectory.

        Parameters
        ----------
        other : Trajectory
            The other trajectory to compare against.

        Returns
        -------
        Timeseries
            The MAE between the two trajectories.
        """
        error = self.separation_distance(other).cumsum()
        length = jnp.arange(self.length) + 1
        mae = error / length

        return Timeseries.from_array(mae.value, self.times.value, mae.unit, name="MAE")

    def plot(
        self,
        ax: Axes | None = None,
        label: str | None = None,
        color: str | None = None,
        alpha_factor: float = 1,
        ti: int | None = None,
        **kwargs,
    ) -> Axes:
        """
        Plots the trajectory.

        Parameters
        ----------
        ax : Axes | None, optional
            The matplotlib axis to plot on, defaults to `None`.
            If `None`, a new figure and axis are created.
        label : str | None, optional
            The label for the plot, defaults to `None`.
        color : str | None, optional
            The color for the plot, defaults to `None`.
        alpha_factor : float, optional
            A factor controlling the overall transparency of the plotted trajectory, defaults to `1`.
        ti : int | None, optional
            The time index to plot up to, defaults to `None`.
        kwargs: dict, optional
            Additional arguments passed to `LineCollection`.

        Returns
        -------
        Axes
            The matplotlib axis with the plot.
        """
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(projection=ccrs.PlateCarree())

        if ti is None:
            ti = self.length

        alpha = jnp.geomspace(0.25, 1, ti - 1) * alpha_factor

        locations = self.locations.value[:ti, None, ::-1]
        segments = jnp.concat([locations[:-1], locations[1:]], axis=1)

        lc = LineCollection(segments, color=color, alpha=alpha, **kwargs)  # type: ignore
        ax.add_collection(lc)

        # trick to display label with alpha=1
        ax.plot(
            self.longitudes.value[-1],
            self.latitudes.value[-1],
            label=label,
            color=color,
        )

        return ax

    def rmse(self, other: Trajectory) -> Timeseries:
        """
        Computes the Root Mean Square Error (RMSE) (over time) between this trajectory and another trajectory.

        Parameters
        ----------
        other : Trajectory
            The other trajectory to compare against.

        Returns
        -------
        Timeseries
            The RMSE between the two trajectories.
        """
        error = (self.separation_distance(other) ** 2).cumsum()
        length = jnp.arange(self.length) + 1
        rmse = (error / length) ** (1 / 2)

        return Timeseries.from_array(rmse.value, self.times.value, rmse.unit, name="RMSE")

    def separation_distance(self, other: Trajectory) -> Timeseries:
        """
        Computes the separation distance (over time) between this trajectory and another trajectory.

        Parameters
        ----------
        other : Trajectory
            The other trajectory to compare against.

        Returns
        -------
        Timeseries
            The separation distance between the two trajectories.
        """
        separation_distance = eqx.filter_vmap(lambda p1, p2: p1.distance_on_earth(p2))(self.locations, other.locations)

        return Timeseries.from_array(
            separation_distance.value,
            self.times.value,
            separation_distance.unit,
            name="Separation distance",
        )

    def steps(self) -> Timeseries:
        """
        Returns the steps of the trajectory.

        Returns
        -------
        Timeseries
            The steps of the trajectory.
        """
        steps = eqx.filter_vmap(lambda p1, p2: distance_on_earth(p1, p2))(
            self.locations.value[1:], self.locations.value[:-1]
        )

        steps = jnp.pad(steps, (1, 0), constant_values=0.0)  # adds a 1st 0 step

        return Timeseries.from_array(steps, self.times.value, UNIT["m"], name="Trajectory steps")

    def velocities(self) -> Timeseries:
        """
        Returns the velocities of the trajectory.

        Returns
        -------
        Timeseries
            The velocities of the trajectory.
        """
        steps = self.steps()

        times = self.times.value[1:] - self.times.value[:-1]
        times = jnp.pad(times, (1, 0), constant_values=1e-4)  # adds a 1st small time step to avoid division by zero

        velocities = steps.value / times

        return Timeseries.from_array(velocities, self.times.value, UNIT["m/s"], name="Trajectory velocities")

    def to_xarray(self) -> xr.Dataset:
        """
        Converts the [`pastax.trajectory.Trajectory`][] to a `xarray.Dataset`.

        Returns
        -------
        xr.Dataset
            The corresponding `xarray.Dataset`.
        """
        return xr.Dataset(self.to_dataarray())

    @classmethod
    def from_array(
        cls,
        values: Float[Array, "... time 2"],
        times: Float[Array, "time"],
        unit: dict[Unit, int | float] = UNIT["°"],
        id: Int[Array, ""] | None = None,
        **_: Any,
    ) -> Trajectory:
        """
        Creates a [`pastax.trajectory.Trajectory`][] from arrays of (latitudes, longitudes) values and time points.

        Parameters
        ----------
        values : Float[Array, "... time 2"]
            The array of (latitudes, longitudes) values for the trajectory.
        times : Float[Array, "time"]
            The time points for the trajectory.
        unit : dict[Unit, int | float], optional
            Unit of the trajectory locations, defaults to UNIT["°"].
        id : Int[Array, ""] | None, optional
            The ID of the trajectory, defaults to None.

        Returns
        -------
        Trajectory
            The corresponding [`pastax.trajectory.Trajectory`][].
        """
        return super().from_array(values, times, unit=unit, name="Location in [latitude, longitude]", id=id)  # type: ignore

    @classmethod
    def from_xarray(
        cls,
        dataset: xr.Dataset,
        time_varname: str = "time",
        lat_varname: str = "lat",  # follows clouddrift "convention"
        lon_varname: str = "lon",  # follows clouddrift "convention"
        unit: dict[Unit, int | float] = UNIT["°"],
        id: Int[Array, ""] | None = None,
        **_: Any,
    ) -> Trajectory:
        """
        Creates a [`pastax.trajectory.Trajectory`][] from a `xarray.Dataset`.

        Parameters
        ----------
        dataset : xr.Dataset
            The `xarray.Dataset` containing the trajectory data.
        time_varname : str, optional
            A string indicating the name of the time variable in the dataset, defaults to `time`.
        lat_varname : str, optional
            A string indicating the name of the latitude variable in the dataset, defaults to `lat`.
        lon_varname : str, optional
            A string indicating the name of the longitude variable in the dataset, defaults to `lon`.
        unit : dict[Unit, int | float], optional
            Unit of the trajectory locations, defaults to UNIT["°"].
        id : Int[Array, ""] | None, optional
            The ID of the trajectory, defaults to None.

        Returns
        -------
        Trajectory
            The corresponding [`pastax.trajectory.Trajectory`][].
        """
        values = jnp.stack([dataset[lat_varname].values, dataset[lon_varname].values], axis=-1)
        times: Array = time_in_seconds(dataset[time_varname].values)
        return cls.from_array(values, times, unit=unit, id=id)

    def to_dataarray(self) -> dict[str, xr.DataArray]:
        times = self.times.to_datetime()
        unit = units_to_str(self.unit)

        latitude_da = xr.DataArray(
            data=self.latitudes.value,
            dims=["obs"],
            coords={"time": ("obs", times)},
            name="lat",
            attrs={"units": unit},
        )
        longitude_da = xr.DataArray(
            data=self.longitudes.value,
            dims=["obs"],
            coords={"time": ("obs", times)},
            name="lon",
            attrs={"units": unit},
        )

        return {"lat": latitude_da, "lon": longitude_da}
latitudes: State property

Returns the latitudes of the trajectory.

Returns:

Type Description
State

The latitudes of the trajectory.

locations: Location property

Returns the locations of the trajectory.

Returns:

Type Description
Location

The locations of the trajectory.

longitudes: State property

Returns the longitudes of the trajectory.

Returns:

Type Description
State

The longitudes of the trajectory.

origin: Location property

Returns the origin of the trajectory.

Returns:

Type Description
State

The origin of the trajectory.

__init__(locations: Location, times: Time, id: Int[Array, ''] | None = None, *_, **__)

Initializes the Trajectory with given locations, times, and optional trajectory ID.

Parameters:

Name Type Description Default
locations Float[Array, '... time 2']

The locations for the trajectory.

required
times Int[Array, '... time']

The time points for the trajectory.

required
id Int[Array, ''] | None

The ID of the trajectory, defaults to None.

None
Source code in pastax/trajectory/_trajectory.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(
    self,
    locations: Location,
    times: Time,
    id: Int[Array, ""] | None = None,
    *_,
    **__,
):
    """
    Initializes the Trajectory with given locations, times, and optional trajectory ID.

    Parameters
    ----------
    locations : Float[Array, "... time 2"]
        The locations for the trajectory.
    times : Int[Array, "... time"]
        The time points for the trajectory.
    id : Int[Array, ""] | None, optional
        The ID of the trajectory, defaults to None.
    """
    super().__init__(locations, times)
    self.id = id
lengths() -> Timeseries

Returns the cumulative lengths of the trajectory.

Returns:

Type Description
Timeseries

The cumulative lengths of the trajectory.

Source code in pastax/trajectory/_trajectory.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def lengths(self) -> Timeseries:
    """
    Returns the cumulative lengths of the trajectory.

    Returns
    -------
    Timeseries
        The cumulative lengths of the trajectory.
    """
    lengths = self.steps().cumsum()
    return Timeseries.from_array(
        lengths.value,
        self.times.value,
        unit=lengths.unit,
        name="Cumulative lengths",
    )
liu_index(other: Trajectory) -> Timeseries

Computes the Liu Index (over time) between this trajectory and another trajectory.

Parameters:

Name Type Description Default
other Trajectory

The other trajectory to compare against.

required

Returns:

Type Description
Timeseries

The Liu Index between the two trajectories.

Source code in pastax/trajectory/_trajectory.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def liu_index(self, other: Trajectory) -> Timeseries:
    """
    Computes the Liu Index (over time) between this trajectory and another trajectory.

    Parameters
    ----------
    other : Trajectory
        The other trajectory to compare against.

    Returns
    -------
    Timeseries
        The Liu Index between the two trajectories.
    """
    error = self.separation_distance(other).value.cumsum()
    cum_lengths = self.lengths().value.cumsum()
    cum_lengths = jnp.where(cum_lengths == 0, jnp.inf, cum_lengths)
    liu_index = error / cum_lengths

    return Timeseries.from_array(liu_index, self.times.value, name="Liu index")
mae(other: Trajectory) -> Timeseries

Computes the Mean Absolute Error (MAE) (over time) between this trajectory and another trajectory.

Parameters:

Name Type Description Default
other Trajectory

The other trajectory to compare against.

required

Returns:

Type Description
Timeseries

The MAE between the two trajectories.

Source code in pastax/trajectory/_trajectory.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def mae(self, other: Trajectory) -> Timeseries:
    """
    Computes the Mean Absolute Error (MAE) (over time) between this trajectory and another trajectory.

    Parameters
    ----------
    other : Trajectory
        The other trajectory to compare against.

    Returns
    -------
    Timeseries
        The MAE between the two trajectories.
    """
    error = self.separation_distance(other).cumsum()
    length = jnp.arange(self.length) + 1
    mae = error / length

    return Timeseries.from_array(mae.value, self.times.value, mae.unit, name="MAE")
plot(ax: Axes | None = None, label: str | None = None, color: str | None = None, alpha_factor: float = 1, ti: int | None = None, **kwargs) -> Axes

Plots the trajectory.

Parameters:

Name Type Description Default
ax Axes | None

The matplotlib axis to plot on, defaults to None. If None, a new figure and axis are created.

None
label str | None

The label for the plot, defaults to None.

None
color str | None

The color for the plot, defaults to None.

None
alpha_factor float

A factor controlling the overall transparency of the plotted trajectory, defaults to 1.

1
ti int | None

The time index to plot up to, defaults to None.

None
kwargs

Additional arguments passed to LineCollection.

{}

Returns:

Type Description
Axes

The matplotlib axis with the plot.

Source code in pastax/trajectory/_trajectory.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def plot(
    self,
    ax: Axes | None = None,
    label: str | None = None,
    color: str | None = None,
    alpha_factor: float = 1,
    ti: int | None = None,
    **kwargs,
) -> Axes:
    """
    Plots the trajectory.

    Parameters
    ----------
    ax : Axes | None, optional
        The matplotlib axis to plot on, defaults to `None`.
        If `None`, a new figure and axis are created.
    label : str | None, optional
        The label for the plot, defaults to `None`.
    color : str | None, optional
        The color for the plot, defaults to `None`.
    alpha_factor : float, optional
        A factor controlling the overall transparency of the plotted trajectory, defaults to `1`.
    ti : int | None, optional
        The time index to plot up to, defaults to `None`.
    kwargs: dict, optional
        Additional arguments passed to `LineCollection`.

    Returns
    -------
    Axes
        The matplotlib axis with the plot.
    """
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(projection=ccrs.PlateCarree())

    if ti is None:
        ti = self.length

    alpha = jnp.geomspace(0.25, 1, ti - 1) * alpha_factor

    locations = self.locations.value[:ti, None, ::-1]
    segments = jnp.concat([locations[:-1], locations[1:]], axis=1)

    lc = LineCollection(segments, color=color, alpha=alpha, **kwargs)  # type: ignore
    ax.add_collection(lc)

    # trick to display label with alpha=1
    ax.plot(
        self.longitudes.value[-1],
        self.latitudes.value[-1],
        label=label,
        color=color,
    )

    return ax
rmse(other: Trajectory) -> Timeseries

Computes the Root Mean Square Error (RMSE) (over time) between this trajectory and another trajectory.

Parameters:

Name Type Description Default
other Trajectory

The other trajectory to compare against.

required

Returns:

Type Description
Timeseries

The RMSE between the two trajectories.

Source code in pastax/trajectory/_trajectory.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def rmse(self, other: Trajectory) -> Timeseries:
    """
    Computes the Root Mean Square Error (RMSE) (over time) between this trajectory and another trajectory.

    Parameters
    ----------
    other : Trajectory
        The other trajectory to compare against.

    Returns
    -------
    Timeseries
        The RMSE between the two trajectories.
    """
    error = (self.separation_distance(other) ** 2).cumsum()
    length = jnp.arange(self.length) + 1
    rmse = (error / length) ** (1 / 2)

    return Timeseries.from_array(rmse.value, self.times.value, rmse.unit, name="RMSE")
separation_distance(other: Trajectory) -> Timeseries

Computes the separation distance (over time) between this trajectory and another trajectory.

Parameters:

Name Type Description Default
other Trajectory

The other trajectory to compare against.

required

Returns:

Type Description
Timeseries

The separation distance between the two trajectories.

Source code in pastax/trajectory/_trajectory.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def separation_distance(self, other: Trajectory) -> Timeseries:
    """
    Computes the separation distance (over time) between this trajectory and another trajectory.

    Parameters
    ----------
    other : Trajectory
        The other trajectory to compare against.

    Returns
    -------
    Timeseries
        The separation distance between the two trajectories.
    """
    separation_distance = eqx.filter_vmap(lambda p1, p2: p1.distance_on_earth(p2))(self.locations, other.locations)

    return Timeseries.from_array(
        separation_distance.value,
        self.times.value,
        separation_distance.unit,
        name="Separation distance",
    )
steps() -> Timeseries

Returns the steps of the trajectory.

Returns:

Type Description
Timeseries

The steps of the trajectory.

Source code in pastax/trajectory/_trajectory.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def steps(self) -> Timeseries:
    """
    Returns the steps of the trajectory.

    Returns
    -------
    Timeseries
        The steps of the trajectory.
    """
    steps = eqx.filter_vmap(lambda p1, p2: distance_on_earth(p1, p2))(
        self.locations.value[1:], self.locations.value[:-1]
    )

    steps = jnp.pad(steps, (1, 0), constant_values=0.0)  # adds a 1st 0 step

    return Timeseries.from_array(steps, self.times.value, UNIT["m"], name="Trajectory steps")
velocities() -> Timeseries

Returns the velocities of the trajectory.

Returns:

Type Description
Timeseries

The velocities of the trajectory.

Source code in pastax/trajectory/_trajectory.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def velocities(self) -> Timeseries:
    """
    Returns the velocities of the trajectory.

    Returns
    -------
    Timeseries
        The velocities of the trajectory.
    """
    steps = self.steps()

    times = self.times.value[1:] - self.times.value[:-1]
    times = jnp.pad(times, (1, 0), constant_values=1e-4)  # adds a 1st small time step to avoid division by zero

    velocities = steps.value / times

    return Timeseries.from_array(velocities, self.times.value, UNIT["m/s"], name="Trajectory velocities")
to_xarray() -> xr.Dataset

Converts the pastax.trajectory.Trajectory to a xarray.Dataset.

Returns:

Type Description
Dataset

The corresponding xarray.Dataset.

Source code in pastax/trajectory/_trajectory.py
339
340
341
342
343
344
345
346
347
348
def to_xarray(self) -> xr.Dataset:
    """
    Converts the [`pastax.trajectory.Trajectory`][] to a `xarray.Dataset`.

    Returns
    -------
    xr.Dataset
        The corresponding `xarray.Dataset`.
    """
    return xr.Dataset(self.to_dataarray())
from_array(values: Float[Array, '... time 2'], times: Float[Array, 'time'], unit: dict[Unit, int | float] = UNIT['°'], id: Int[Array, ''] | None = None, **_: Any) -> Trajectory classmethod

Creates a pastax.trajectory.Trajectory from arrays of (latitudes, longitudes) values and time points.

Parameters:

Name Type Description Default
values Float[Array, '... time 2']

The array of (latitudes, longitudes) values for the trajectory.

required
times Float[Array, 'time']

The time points for the trajectory.

required
unit dict[Unit, int | float]

Unit of the trajectory locations, defaults to UNIT["°"].

UNIT['°']
id Int[Array, ''] | None

The ID of the trajectory, defaults to None.

None

Returns:

Type Description
Trajectory

The corresponding pastax.trajectory.Trajectory.

Source code in pastax/trajectory/_trajectory.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
@classmethod
def from_array(
    cls,
    values: Float[Array, "... time 2"],
    times: Float[Array, "time"],
    unit: dict[Unit, int | float] = UNIT["°"],
    id: Int[Array, ""] | None = None,
    **_: Any,
) -> Trajectory:
    """
    Creates a [`pastax.trajectory.Trajectory`][] from arrays of (latitudes, longitudes) values and time points.

    Parameters
    ----------
    values : Float[Array, "... time 2"]
        The array of (latitudes, longitudes) values for the trajectory.
    times : Float[Array, "time"]
        The time points for the trajectory.
    unit : dict[Unit, int | float], optional
        Unit of the trajectory locations, defaults to UNIT["°"].
    id : Int[Array, ""] | None, optional
        The ID of the trajectory, defaults to None.

    Returns
    -------
    Trajectory
        The corresponding [`pastax.trajectory.Trajectory`][].
    """
    return super().from_array(values, times, unit=unit, name="Location in [latitude, longitude]", id=id)  # type: ignore
from_xarray(dataset: xr.Dataset, time_varname: str = 'time', lat_varname: str = 'lat', lon_varname: str = 'lon', unit: dict[Unit, int | float] = UNIT['°'], id: Int[Array, ''] | None = None, **_: Any) -> Trajectory classmethod

Creates a pastax.trajectory.Trajectory from a xarray.Dataset.

Parameters:

Name Type Description Default
dataset Dataset

The xarray.Dataset containing the trajectory data.

required
time_varname str

A string indicating the name of the time variable in the dataset, defaults to time.

'time'
lat_varname str

A string indicating the name of the latitude variable in the dataset, defaults to lat.

'lat'
lon_varname str

A string indicating the name of the longitude variable in the dataset, defaults to lon.

'lon'
unit dict[Unit, int | float]

Unit of the trajectory locations, defaults to UNIT["°"].

UNIT['°']
id Int[Array, ''] | None

The ID of the trajectory, defaults to None.

None

Returns:

Type Description
Trajectory

The corresponding pastax.trajectory.Trajectory.

Source code in pastax/trajectory/_trajectory.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
@classmethod
def from_xarray(
    cls,
    dataset: xr.Dataset,
    time_varname: str = "time",
    lat_varname: str = "lat",  # follows clouddrift "convention"
    lon_varname: str = "lon",  # follows clouddrift "convention"
    unit: dict[Unit, int | float] = UNIT["°"],
    id: Int[Array, ""] | None = None,
    **_: Any,
) -> Trajectory:
    """
    Creates a [`pastax.trajectory.Trajectory`][] from a `xarray.Dataset`.

    Parameters
    ----------
    dataset : xr.Dataset
        The `xarray.Dataset` containing the trajectory data.
    time_varname : str, optional
        A string indicating the name of the time variable in the dataset, defaults to `time`.
    lat_varname : str, optional
        A string indicating the name of the latitude variable in the dataset, defaults to `lat`.
    lon_varname : str, optional
        A string indicating the name of the longitude variable in the dataset, defaults to `lon`.
    unit : dict[Unit, int | float], optional
        Unit of the trajectory locations, defaults to UNIT["°"].
    id : Int[Array, ""] | None, optional
        The ID of the trajectory, defaults to None.

    Returns
    -------
    Trajectory
        The corresponding [`pastax.trajectory.Trajectory`][].
    """
    values = jnp.stack([dataset[lat_varname].values, dataset[lon_varname].values], axis=-1)
    times: Array = time_in_seconds(dataset[time_varname].values)
    return cls.from_array(values, times, unit=unit, id=id)

TimeseriesEnsemble

Bases: Unitful

Class representing [pastax.trajectory.TimeseriesEnsemble].

Attributes:

Name Type Description
members Timeseries

The members of the [pastax.trajectory.TimeseriesEnsemble].

size int

The number of members in the [pastax.trajectory.TimeseriesEnsemble].

Methods:

Name Description
__init__
value

Returns the value of the [pastax.trajectory.TimeseriesEnsemble].

states

Returns the pastax.trajectory.State of the [pastax.trajectory.TimeseriesEnsemble].

times

Returns the pastax.trajectory.Time points of the [pastax.trajectory.TimeseriesEnsemble].

unit

Returns the unit of the [pastax.trajectory.TimeseriesEnsemble].

name

Returns the name of the [pastax.trajectory.TimeseriesEnsemble].

length

Returns the length of the [pastax.trajectory.TimeseriesEnsemble].

attach_name

Attaches a name to the [pastax.trajectory.TimeseriesEnsemble].

crps

Computes the Continuous Ranked Probability Score (CRPS) for the [pastax.trajectory.TimeseriesEnsemble].

ensemble_dispersion

Computes the pastax.trajectory.TimeseriesEnsemble dispersion.

map

Applies a function to each pastax.trajectory.Timeseries of the [pastax.trajectory.TimeseriesEnsemble].

to_xarray

Converts the pastax.trajectory.TimeseriesEnsemble to a xarray.Dataset.

from_array

Creates a pastax.trajectory.TimeseriesEnsemble from arrays of values and time points.

Source code in pastax/trajectory/_timeseries_ensemble.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class TimeseriesEnsemble(Unitful):
    """
    Class representing [`pastax.trajectory.TimeseriesEnsemble`].

    Attributes
    ----------
    members : Timeseries
        The members of the [`pastax.trajectory.TimeseriesEnsemble`].
    size : int
        The number of members in the [`pastax.trajectory.TimeseriesEnsemble`].

    Methods
    -------
    __init__(members)
        Initializes the [`pastax.trajectory.TimeseriesEnsemble`][] with [`pastax.trajectory.Timeseries`][] members.
    value
        Returns the value of the [`pastax.trajectory.TimeseriesEnsemble`].
    states
        Returns the [`pastax.trajectory.State`][] of the [`pastax.trajectory.TimeseriesEnsemble`].
    times
        Returns the [`pastax.trajectory.Time`][] points of the [`pastax.trajectory.TimeseriesEnsemble`].
    unit
        Returns the unit of the [`pastax.trajectory.TimeseriesEnsemble`].
    name
        Returns the name of the [`pastax.trajectory.TimeseriesEnsemble`].
    length
        Returns the length of the [`pastax.trajectory.TimeseriesEnsemble`].
    attach_name(name)
        Attaches a name to the [`pastax.trajectory.TimeseriesEnsemble`].
    crps(other, metric_func)
        Computes the Continuous Ranked Probability Score (CRPS) for the [`pastax.trajectory.TimeseriesEnsemble`].
    ensemble_dispersion(metric_func)
        Computes the [`pastax.trajectory.TimeseriesEnsemble`][] dispersion.
    map(func)
        Applies a function to each [`pastax.trajectory.Timeseries`][] of the [`pastax.trajectory.TimeseriesEnsemble`].
    to_xarray()
        Converts the [`pastax.trajectory.TimeseriesEnsemble`][] to a `xarray.Dataset`.
    from_array(values, times, unit={}, name=None, **kwargs)
        Creates a [`pastax.trajectory.TimeseriesEnsemble`][] from arrays of values and time points.
    """

    members: Timeseries
    _members_type: ClassVar = Timeseries
    size: int = eqx.field(static=True)

    _value: None = eqx.field(repr=False)
    _unit: None = eqx.field(repr=False)

    def __init__(self, members: Timeseries):
        """
        Initializes the [`pastax.trajectory.TimeseriesEnsemble`][] with [`pastax.trajectory.Timeseries`][] members.

        Parameters
        ----------
        members : Timeseries
            The members of the [`pastax.trajectory.TimeseriesEnsemble`].
        """
        super().__init__()
        self.members = members
        self.size = members.states.value.shape[0]

    @property
    def value(self) -> Float[Array, "member time state"]:
        """
        Returns the value of the [`pastax.trajectory.TimeseriesEnsemble`].

        Returns
        -------
        Float[Array, "... member time state"]
            The value of the [`pastax.trajectory.TimeseriesEnsemble`].
        """
        return self.members.value

    @property
    def states(self) -> State:
        """
        Returns the states of the [`pastax.trajectory.TimeseriesEnsemble`].

        Returns
        -------
        State
            The states of the [`pastax.trajectory.TimeseriesEnsemble`].
        """
        return self.members.states

    @property
    def times(self) -> Time:
        """
        Returns the time points of the [`pastax.trajectory.TimeseriesEnsemble`].

        Returns
        -------
        Time
            The time points of the [`pastax.trajectory.TimeseriesEnsemble`].
        """
        return self.members.times

    @property
    def unit(self) -> dict[Unit, int | float]:
        """
        Returns the unit of the [`pastax.trajectory.TimeseriesEnsemble`].

        Returns
        -------
        dict[Unit, int | float]
            The unit of the [`pastax.trajectory.TimeseriesEnsemble`].
        """
        return self.members.unit

    @property
    def name(self) -> str | None:
        """
        Returns the name of the [`pastax.trajectory.TimeseriesEnsemble`].

        Returns
        -------
        str | None
            The name of the [`pastax.trajectory.TimeseriesEnsemble`].
        """
        return self.members.name

    @property
    def length(self) -> int:
        """
        Returns the length of the [`pastax.trajectory.TimeseriesEnsemble`].

        Returns
        -------
        int
            The length of the [`pastax.trajectory.TimeseriesEnsemble`].
        """
        return self.members.length

    def attach_name(self, name: str) -> TimeseriesEnsemble:
        """
        Attaches a name to the [`pastax.trajectory.TimeseriesEnsemble`].

        Parameters
        ----------
        name : str
            The name to attach to the [`pastax.trajectory.TimeseriesEnsemble`].

        Returns
        -------
        TimeseriesEnsemble
            A new [`pastax.trajectory.TimeseriesEnsemble`][] with the attached name.
        """
        return TimeseriesEnsemble.from_array(self.states.value, self.times.value, unit=self.unit, name=name)

    def crps(
        self,
        other: Timeseries,
        metric_func: Callable[[Timeseries, Timeseries], Unitful | Array],
        is_metric_symmetric: bool = True,
    ) -> Unitful:
        """
        Computes the Continuous Ranked Probability Score (CRPS) for the [`pastax.trajectory.TimeseriesEnsemble`].

        Parameters
        ----------
        other : Timeseries
            The other timeseries to compare against.
        metric_func : Callable[[Timeseries, Timeseries], Unitful | Array]
            The metric function to use.
        is_metric_symmetric : bool, optional
            Whether the metric function is symmetric,
            in which case half of the intra ensemble "distances" are evaluated when computing the ensemble dispersion,
            defaults to `True`.

        Returns
        -------
        Unitful
            The CRPS for the [`pastax.trajectory.TimeseriesEnsemble`].
        """
        biases = self.map(lambda member: metric_func(other, member))
        bias = biases.mean(axis=0)

        n_members = self.size
        dispersion = self.ensemble_dispersion(metric_func, is_metric_symmetric=is_metric_symmetric)
        dispersion /= 2 * n_members * (n_members - 1)

        return bias - dispersion

    def ensemble_dispersion(
        self,
        metric_func: Callable[[Timeseries, Timeseries], Unitful | Array],
        is_metric_symmetric: bool = True,
    ) -> Unitful:
        """
        Computes the [`pastax.trajectory.TimeseriesEnsemble`][] dispersion.

        Parameters
        ----------
        metric_func : Callable[[Timeseries, Timeseries], Unitful | Array]
            The metric function to use.
        is_metric_symmetric : bool, optional
            Whether the metric function is symmetric,
            in which case half of the intra ensemble "distances" are evaluated, defaults to `True`.

        Returns
        -------
        Unitful
            The [`pastax.trajectory.TimeseriesEnsemble`][] dispersion.
        """
        ij = jnp.column_stack(jnp.triu_indices(self.size, k=1))

        if not is_metric_symmetric:
            ij = jnp.vstack([ij, ij[:, ::-1]])

        vmap_metric_fn = eqx.filter_vmap(
            lambda _ij: metric_func(
                self._members_type.from_array(self.value[_ij[0], ...], self.times.value),
                self._members_type.from_array(self.value[_ij[1], ...], self.times.value),
            )
        )

        intra_distances = vmap_metric_fn(ij)
        dispersion = intra_distances.sum(axis=0)

        if is_metric_symmetric:
            dispersion *= 2

        return dispersion

    def map(self, func: Callable[[Timeseries], Unitful | Array]) -> Unitful:
        """
        Applies a function to each [`pastax.trajectory.Timeseries`][] of the [`pastax.trajectory.TimeseriesEnsemble`].

        Parameters
        ----------
        func : Callable[[Timeseries], Unitful | Array]
            The function to apply to each [`pastax.trajectory.Timeseries`][].

        Returns
        -------
        Unitful
            The result of applying the function to each [`pastax.trajectory.Timeseries`][].
        """
        in_axes = eqx.filter(self.members, False)
        in_axes = eqx.tree_at(lambda x: (x.states._value, x.times._value), in_axes, (0, 0), is_leaf=lambda x: x is None)
        res = eqx.filter_vmap(func, in_axes=(in_axes,))(self.members)

        unit = {}
        if isinstance(res, Unitful):
            unit = res.unit
            res = res.value

        return Unitful(res, unit)

    def to_xarray(self) -> xr.Dataset:
        """
        Converts the [`pastax.trajectory.TimeseriesEnsemble`][] to a `xarray.Dataset`.

        Returns
        -------
        xr.Dataset
            The corresponding `xarray.Dataset`.
        """
        da = self.to_dataarray()
        ds = da.to_dataset()

        return ds

    @classmethod
    def from_array(
        cls,
        values: Float[Array, "member time state"],
        times: Float[Array, "time"],
        unit: dict[Unit, int | float] = {},
        name: str | None = None,
        **kwargs: Any,
    ) -> TimeseriesEnsemble:
        """
        Creates a [`pastax.trajectory.TimeseriesEnsemble`][] from arrays of values and time points.

        Parameters
        ----------
        values : Float[Array, "member time state"]
            The values for the members of the ensemble.
        times : Float[Array, "time"]
            The time points for the timeseries.
        unit : dict[Unit, int | float], optional
            The unit of the timeseries, defaults to {}.
        name : str | None, optional
            The name of the timeseries, defaults to None.
        **kwargs : Any
            Additional keyword arguments.

        Returns
        -------
        TimeseriesEnsemble
            The corresponding [`pastax.trajectory.TimeseriesEnsemble`][].
        """
        members = eqx.filter_vmap(
            lambda x: cls._members_type.from_array(x, times, unit=unit, name=name, **kwargs), out_axes=_if_mapped(0)
        )(values)

        return cls(members)

    def to_dataarray(self) -> xr.DataArray:
        da = xr.DataArray(
            data=self.states.value,
            dims=["member", "time"],
            coords={
                "member": np.arange(self.size),
                "time": self.members.times.to_datetime(),
            },
            name=self.name,
            attrs={"units": units_to_str(self.unit)},
        )

        return da
value: Float[Array, 'member time state'] property

Returns the value of the [pastax.trajectory.TimeseriesEnsemble].

Returns:

Type Description
Float[Array, '... member time state']

The value of the [pastax.trajectory.TimeseriesEnsemble].

states: State property

Returns the states of the [pastax.trajectory.TimeseriesEnsemble].

Returns:

Type Description
State

The states of the [pastax.trajectory.TimeseriesEnsemble].

times: Time property

Returns the time points of the [pastax.trajectory.TimeseriesEnsemble].

Returns:

Type Description
Time

The time points of the [pastax.trajectory.TimeseriesEnsemble].

unit: dict[Unit, int | float] property

Returns the unit of the [pastax.trajectory.TimeseriesEnsemble].

Returns:

Type Description
dict[Unit, int | float]

The unit of the [pastax.trajectory.TimeseriesEnsemble].

name: str | None property

Returns the name of the [pastax.trajectory.TimeseriesEnsemble].

Returns:

Type Description
str | None

The name of the [pastax.trajectory.TimeseriesEnsemble].

length: int property

Returns the length of the [pastax.trajectory.TimeseriesEnsemble].

Returns:

Type Description
int

The length of the [pastax.trajectory.TimeseriesEnsemble].

__init__(members: Timeseries)

Initializes the pastax.trajectory.TimeseriesEnsemble with pastax.trajectory.Timeseries members.

Parameters:

Name Type Description Default
members Timeseries

The members of the [pastax.trajectory.TimeseriesEnsemble].

required
Source code in pastax/trajectory/_timeseries_ensemble.py
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(self, members: Timeseries):
    """
    Initializes the [`pastax.trajectory.TimeseriesEnsemble`][] with [`pastax.trajectory.Timeseries`][] members.

    Parameters
    ----------
    members : Timeseries
        The members of the [`pastax.trajectory.TimeseriesEnsemble`].
    """
    super().__init__()
    self.members = members
    self.size = members.states.value.shape[0]
attach_name(name: str) -> TimeseriesEnsemble

Attaches a name to the [pastax.trajectory.TimeseriesEnsemble].

Parameters:

Name Type Description Default
name str

The name to attach to the [pastax.trajectory.TimeseriesEnsemble].

required

Returns:

Type Description
TimeseriesEnsemble

A new pastax.trajectory.TimeseriesEnsemble with the attached name.

Source code in pastax/trajectory/_timeseries_ensemble.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def attach_name(self, name: str) -> TimeseriesEnsemble:
    """
    Attaches a name to the [`pastax.trajectory.TimeseriesEnsemble`].

    Parameters
    ----------
    name : str
        The name to attach to the [`pastax.trajectory.TimeseriesEnsemble`].

    Returns
    -------
    TimeseriesEnsemble
        A new [`pastax.trajectory.TimeseriesEnsemble`][] with the attached name.
    """
    return TimeseriesEnsemble.from_array(self.states.value, self.times.value, unit=self.unit, name=name)
crps(other: Timeseries, metric_func: Callable[[Timeseries, Timeseries], Unitful | Array], is_metric_symmetric: bool = True) -> Unitful

Computes the Continuous Ranked Probability Score (CRPS) for the [pastax.trajectory.TimeseriesEnsemble].

Parameters:

Name Type Description Default
other Timeseries

The other timeseries to compare against.

required
metric_func Callable[[Timeseries, Timeseries], Unitful | Array]

The metric function to use.

required
is_metric_symmetric bool

Whether the metric function is symmetric, in which case half of the intra ensemble "distances" are evaluated when computing the ensemble dispersion, defaults to True.

True

Returns:

Type Description
Unitful

The CRPS for the [pastax.trajectory.TimeseriesEnsemble].

Source code in pastax/trajectory/_timeseries_ensemble.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def crps(
    self,
    other: Timeseries,
    metric_func: Callable[[Timeseries, Timeseries], Unitful | Array],
    is_metric_symmetric: bool = True,
) -> Unitful:
    """
    Computes the Continuous Ranked Probability Score (CRPS) for the [`pastax.trajectory.TimeseriesEnsemble`].

    Parameters
    ----------
    other : Timeseries
        The other timeseries to compare against.
    metric_func : Callable[[Timeseries, Timeseries], Unitful | Array]
        The metric function to use.
    is_metric_symmetric : bool, optional
        Whether the metric function is symmetric,
        in which case half of the intra ensemble "distances" are evaluated when computing the ensemble dispersion,
        defaults to `True`.

    Returns
    -------
    Unitful
        The CRPS for the [`pastax.trajectory.TimeseriesEnsemble`].
    """
    biases = self.map(lambda member: metric_func(other, member))
    bias = biases.mean(axis=0)

    n_members = self.size
    dispersion = self.ensemble_dispersion(metric_func, is_metric_symmetric=is_metric_symmetric)
    dispersion /= 2 * n_members * (n_members - 1)

    return bias - dispersion
ensemble_dispersion(metric_func: Callable[[Timeseries, Timeseries], Unitful | Array], is_metric_symmetric: bool = True) -> Unitful

Computes the pastax.trajectory.TimeseriesEnsemble dispersion.

Parameters:

Name Type Description Default
metric_func Callable[[Timeseries, Timeseries], Unitful | Array]

The metric function to use.

required
is_metric_symmetric bool

Whether the metric function is symmetric, in which case half of the intra ensemble "distances" are evaluated, defaults to True.

True

Returns:

Type Description
Unitful
Source code in pastax/trajectory/_timeseries_ensemble.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def ensemble_dispersion(
    self,
    metric_func: Callable[[Timeseries, Timeseries], Unitful | Array],
    is_metric_symmetric: bool = True,
) -> Unitful:
    """
    Computes the [`pastax.trajectory.TimeseriesEnsemble`][] dispersion.

    Parameters
    ----------
    metric_func : Callable[[Timeseries, Timeseries], Unitful | Array]
        The metric function to use.
    is_metric_symmetric : bool, optional
        Whether the metric function is symmetric,
        in which case half of the intra ensemble "distances" are evaluated, defaults to `True`.

    Returns
    -------
    Unitful
        The [`pastax.trajectory.TimeseriesEnsemble`][] dispersion.
    """
    ij = jnp.column_stack(jnp.triu_indices(self.size, k=1))

    if not is_metric_symmetric:
        ij = jnp.vstack([ij, ij[:, ::-1]])

    vmap_metric_fn = eqx.filter_vmap(
        lambda _ij: metric_func(
            self._members_type.from_array(self.value[_ij[0], ...], self.times.value),
            self._members_type.from_array(self.value[_ij[1], ...], self.times.value),
        )
    )

    intra_distances = vmap_metric_fn(ij)
    dispersion = intra_distances.sum(axis=0)

    if is_metric_symmetric:
        dispersion *= 2

    return dispersion
map(func: Callable[[Timeseries], Unitful | Array]) -> Unitful

Applies a function to each pastax.trajectory.Timeseries of the [pastax.trajectory.TimeseriesEnsemble].

Parameters:

Name Type Description Default
func Callable[[Timeseries], Unitful | Array]

The function to apply to each pastax.trajectory.Timeseries.

required

Returns:

Type Description
Unitful

The result of applying the function to each pastax.trajectory.Timeseries.

Source code in pastax/trajectory/_timeseries_ensemble.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def map(self, func: Callable[[Timeseries], Unitful | Array]) -> Unitful:
    """
    Applies a function to each [`pastax.trajectory.Timeseries`][] of the [`pastax.trajectory.TimeseriesEnsemble`].

    Parameters
    ----------
    func : Callable[[Timeseries], Unitful | Array]
        The function to apply to each [`pastax.trajectory.Timeseries`][].

    Returns
    -------
    Unitful
        The result of applying the function to each [`pastax.trajectory.Timeseries`][].
    """
    in_axes = eqx.filter(self.members, False)
    in_axes = eqx.tree_at(lambda x: (x.states._value, x.times._value), in_axes, (0, 0), is_leaf=lambda x: x is None)
    res = eqx.filter_vmap(func, in_axes=(in_axes,))(self.members)

    unit = {}
    if isinstance(res, Unitful):
        unit = res.unit
        res = res.value

    return Unitful(res, unit)
to_xarray() -> xr.Dataset

Converts the pastax.trajectory.TimeseriesEnsemble to a xarray.Dataset.

Returns:

Type Description
Dataset

The corresponding xarray.Dataset.

Source code in pastax/trajectory/_timeseries_ensemble.py
287
288
289
290
291
292
293
294
295
296
297
298
299
def to_xarray(self) -> xr.Dataset:
    """
    Converts the [`pastax.trajectory.TimeseriesEnsemble`][] to a `xarray.Dataset`.

    Returns
    -------
    xr.Dataset
        The corresponding `xarray.Dataset`.
    """
    da = self.to_dataarray()
    ds = da.to_dataset()

    return ds
from_array(values: Float[Array, 'member time state'], times: Float[Array, 'time'], unit: dict[Unit, int | float] = {}, name: str | None = None, **kwargs: Any) -> TimeseriesEnsemble classmethod

Creates a pastax.trajectory.TimeseriesEnsemble from arrays of values and time points.

Parameters:

Name Type Description Default
values Float[Array, 'member time state']

The values for the members of the ensemble.

required
times Float[Array, 'time']

The time points for the timeseries.

required
unit dict[Unit, int | float]

The unit of the timeseries, defaults to {}.

{}
name str | None

The name of the timeseries, defaults to None.

None
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
TimeseriesEnsemble
Source code in pastax/trajectory/_timeseries_ensemble.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
@classmethod
def from_array(
    cls,
    values: Float[Array, "member time state"],
    times: Float[Array, "time"],
    unit: dict[Unit, int | float] = {},
    name: str | None = None,
    **kwargs: Any,
) -> TimeseriesEnsemble:
    """
    Creates a [`pastax.trajectory.TimeseriesEnsemble`][] from arrays of values and time points.

    Parameters
    ----------
    values : Float[Array, "member time state"]
        The values for the members of the ensemble.
    times : Float[Array, "time"]
        The time points for the timeseries.
    unit : dict[Unit, int | float], optional
        The unit of the timeseries, defaults to {}.
    name : str | None, optional
        The name of the timeseries, defaults to None.
    **kwargs : Any
        Additional keyword arguments.

    Returns
    -------
    TimeseriesEnsemble
        The corresponding [`pastax.trajectory.TimeseriesEnsemble`][].
    """
    members = eqx.filter_vmap(
        lambda x: cls._members_type.from_array(x, times, unit=unit, name=name, **kwargs), out_axes=_if_mapped(0)
    )(values)

    return cls(members)

TrajectoryEnsemble

Bases: TimeseriesEnsemble

Class representing an ensemble of trajectories.

Attributes:

Name Type Description
members Trajectory

The members of the trajectory ensemble.

Methods:

Name Description
id

Returns the IDs of the trajectories.

latitudes

Returns the latitudes of the trajectories.

locations

Returns the locations of the trajectories.

longitudes

Returns the longitudes of the trajectories.

origin

Returns the origin of the trajectories.

crps

Computes the Continuous Ranked Probability Score (CRPS) for the ensemble.

liu_index

Computes the Liu Index for each ensemble trajectory.

lengths

Returns the lengths of the trajectories.

mae

Computes the Mean Absolute Error (MAE) for each ensemble trajectory.

plot

Plots the trajectories.

rmse

Computes the Root Mean Square Error (RMSE) for each ensemble trajectory.

separation_distance

Computes the separation distance for each ensemble trajectory.

steps

Returns the steps of the trajectories.

to_xarray

Converts the pastax.trajectory.TrajectoryEnsemble to a xarray.Dataset.

from_array

Creates a pastax.trajectory.TrajectoryEnsemble from arrays of values and time points.

Source code in pastax/trajectory/_trajectory_ensemble.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
class TrajectoryEnsemble(TimeseriesEnsemble):
    """
    Class representing an ensemble of trajectories.

    Attributes
    ----------
    members : Trajectory
        The members of the trajectory ensemble.

    Methods
    -------
    id
        Returns the IDs of the trajectories.
    latitudes
        Returns the latitudes of the trajectories.
    locations
        Returns the locations of the trajectories.
    longitudes
        Returns the longitudes of the trajectories.
    origin
        Returns the origin of the trajectories.
    crps(other, distance_func=Trajectory.separation_distance)
        Computes the Continuous Ranked Probability Score (CRPS) for the ensemble.
    liu_index(other)
        Computes the Liu Index for each ensemble trajectory.
    lengths()
        Returns the lengths of the trajectories.
    mae(other)
        Computes the Mean Absolute Error (MAE) for each ensemble trajectory.
    plot(ax=None, label=None, color=None, alpha_factor=1, ti=None)
        Plots the trajectories.
    rmse(other)
        Computes the Root Mean Square Error (RMSE) for each ensemble trajectory.
    separation_distance(other)
        Computes the separation distance for each ensemble trajectory.
    steps()
        Returns the steps of the trajectories.
    to_xarray()
        Converts the [`pastax.trajectory.TrajectoryEnsemble`][] to a `xarray.Dataset`.
    from_array(values, times, unit=UNIT["°"]
        Creates a [`pastax.trajectory.TrajectoryEnsemble`][] from arrays of values and time points.
    """

    members: Trajectory
    _members_type: ClassVar = Trajectory

    @property
    def id(self) -> Int[Array, "member"] | None:
        """
        Returns the IDs of the trajectories.

        Returns
        -------
        Int[Array, "member"] | None
            The IDs of the trajectories.
        """
        return self.members.id

    @property
    def latitudes(self) -> State:
        """
        Returns the latitudes of the trajectories.

        Returns
        -------
        State
            The latitudes of the trajectories.
        """
        return self.members.latitudes

    @property
    def locations(self) -> Location:
        """
        Returns the locations of the trajectories.

        Returns
        -------
        Location
            The locations of the trajectories.
        """
        return self.members.locations

    @property
    def longitudes(self) -> State:
        """
        Returns the longitudes of the trajectories.

        Returns
        -------
        State
            The longitudes of the trajectories.
        """
        return self.members.longitudes

    @property
    def origin(self) -> State:
        """
        Returns the origin of the trajectories.

        Returns
        -------
        State
            The origin of the trajectories.
        """
        return self.members.origin

    def liu_index(self, other: Trajectory) -> TimeseriesEnsemble:
        """
        Computes the Liu Index for each ensemble trajectory.

        Parameters
        ----------
        other : Trajectory
            The other trajectory to compare against.

        Returns
        -------
        TimeseriesEnsemble
            The Liu Index for each ensemble trajectory.
        """
        liu_index = self.map(lambda trajectory: other.liu_index(trajectory))  # type: ignore
        return TimeseriesEnsemble.from_array(liu_index.value, self.times.value, name="Liu index")

    def lengths(self) -> TimeseriesEnsemble:
        """
        Returns the lengths of the trajectories.

        Returns
        -------
        TimeseriesEnsemble
            The lengths of the trajectories.
        """
        lengths = self.map(lambda trajectory: trajectory.lengths())  # type: ignore
        return TimeseriesEnsemble.from_array(lengths.value, self.times.value, unit=lengths.unit, name="lengths")

    def mae(self, other: Trajectory) -> TimeseriesEnsemble:
        """
        Computes the Mean Absolute Error (MAE) for each ensemble trajectory.

        Parameters
        ----------
        other : Trajectory
            The other trajectory to compare against.

        Returns
        -------
        TimeseriesEnsemble
            The MAE for each ensemble trajectory.
        """
        mae = self.map(lambda trajectory: other.mae(trajectory))  # type: ignore
        return TimeseriesEnsemble.from_array(mae.value, self.times.value, unit=mae.unit, name="MAE")

    def plot(
        self,
        ax: Axes | None = None,
        label: str | list[str] | None = None,
        color: str | list[str | float | int] | None = None,
        alpha_factor: float = 1,
        ti: int | None = None,
        **kwargs,
    ) -> Axes:
        """
        Plots the trajectories.

        Parameters
        ----------
        ax : Axes | None, optional
            The matplotlib axis to plot on, defaults to `None`.
        label : str | list[str] | None, optional
            The label(s) for the plot, defaults to `None`.
        color : str | list[str | float | int] | None, optional
            The color(s) for the plot, defaults to `None`.
        alpha_factor : float, optional
            A factor controlling the overall transparency of the plotted ensemble, defaults to `1`.
        ti : int | None, optional
            The time index to plot up to, defaults to None.
        kwargs: dict, optional
            Additional arguments passed to `LineCollection`.

        Returns
        -------
        plt.Axes
            The matplotlib axis with the plot.
        """
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(projection=ccrs.PlateCarree())

        if ti is None:
            ti = self.length

        alpha_factor *= np.clip(1 / ((self.size / 10) ** 0.5), 0.05, 1).item()
        alpha = np.geomspace(0.25, 1, ti - 1) * alpha_factor

        locations = self.locations.value.swapaxes(0, 1)[:ti, :, None, ::-1]
        segments = np.concat([locations[:-1], locations[1:]], axis=2).reshape(-1, 2, 2)
        alpha = np.repeat(alpha, self.size)

        if not (isinstance(label, str) or label is None) and color is not None:
            colors = np.tile(color, ti - 1)
        else:
            colors = color
        lc = LineCollection(segments, color=colors, alpha=alpha, **kwargs)  # type: ignore
        ax.add_collection(lc)

        # trick to display label with alpha=1
        if not (isinstance(label, str) or label is None):
            for i in range(len(label)):
                if color is not None:
                    color_ = color[i]
                else:
                    color_ = color
                ax.plot(
                    self.longitudes.value[i, -1],
                    self.latitudes.value[i, -1],
                    label=label[i],
                    color=color_,
                )
        else:
            ax.plot(
                self.longitudes.value[0, -1],
                self.latitudes.value[0, -1],
                label=label,
                color=color,
            )

        return ax

    def rmse(self, other: Trajectory) -> TimeseriesEnsemble:
        """
        Computes the Root Mean Square Error (RMSE) for each ensemble trajectory.

        Parameters
        ----------
        other : Trajectory
            The other trajectory to compare against.

        Returns
        -------
        TimeseriesEnsemble
            The RMSE for each ensemble trajectory.
        """
        rmse = self.map(lambda trajectory: other.rmse(trajectory))  # type: ignore
        return TimeseriesEnsemble.from_array(rmse.value, self.times.value, unit=rmse.unit, name="RMSE")

    def separation_distance(self, other: Trajectory) -> TimeseriesEnsemble:
        """
        Computes the separation distance for each ensemble trajectory.

        Parameters
        ----------
        other : Trajectory
            The other trajectory to compare against.

        Returns
        -------
        TimeseriesEnsemble
            The separation distance for each ensemble trajectory.
        """
        separation_distance = self.map(
            lambda trajectory: other.separation_distance(trajectory)  # type: ignore
        )
        return TimeseriesEnsemble.from_array(
            separation_distance.value,
            self.times.value,
            unit=separation_distance.unit,
            name="Separation distance",
        )

    def steps(self) -> TimeseriesEnsemble:
        """
        Returns the steps of the trajectories.

        Returns
        -------
        TimeseriesEnsemble
            The steps of the trajectories.
        """
        steps = self.map(lambda trajectory: trajectory.steps())  # type: ignore
        return TimeseriesEnsemble.from_array(steps.value, self.times.value, unit=steps.unit, name="steps")

    def velocities(self) -> TimeseriesEnsemble:
        """
        Returns the velocities of the trajectories.

        Returns
        -------
        TimeseriesEnsemble
            The velocities of the trajectories.
        """
        velocities = self.map(lambda trajectory: trajectory.velocities())  # type: ignore
        return TimeseriesEnsemble.from_array(
            velocities.value, self.times.value, unit=velocities.unit, name="velocities"
        )

    def to_xarray(self) -> xr.Dataset:
        """
        Converts the [`pastax.trajectory.TrajectoryEnsemble`][] to a `xarray.Dataset`.

        Returns
        -------
        xr.Dataset
            The corresponding `xarray.Dataset`.
        """
        return xr.Dataset(self.to_dataarray())

    @classmethod
    def from_array(
        cls,
        values: Float[Array, "member time 2"],
        times: Float[Array, "time"],
        unit: Unit | dict[Unit, int | float] = UNIT["°"],
        id: Int[Array, ""] | None = None,
        **_: dict,
    ) -> TrajectoryEnsemble:
        """
        Creates a [`pastax.trajectory.TrajectoryEnsemble`][] from arrays of values and time points.

        Parameters
        ----------
        values : Float[Array, "member time 2"]
            The array of (latitudes, longitudes) values for the members of the trajectory ensemble.
        times : Float[Array, "time"]
            The time points for the trajectories.
        unit : Unit | dict[Unit, int | float], optional
            Unit of the trajectories locations, defaults to UNIT["°"].
        id : Int[Array, ""] | None, optional
            The ID of the trajectories, defaults to None.

        Returns
        -------
        TrajectoryEnsemble
            The corresponding [`pastax.trajectory.TrajectoryEnsemble`][].
        """
        return super().from_array(values, times, unit=unit, id=id)  # type: ignore

    def to_dataarray(self) -> dict[str, xr.DataArray]:
        """
        Converts the [`pastax.trajectory.TrajectoryEnsemble`][] to a dictionary of `xarray.DataArray`.

        Returns
        -------
        dict[str, xr.DataArray]
            A dictionary where keys are the variable names and values are the corresponding `xarray.DataArray`.
        """
        member = np.arange(self.size)
        times = self.members.times.to_datetime()
        unit = units_to_str(self.unit)

        wmo_da = xr.DataArray(
            data=self.id,
            dims=["traj", "obs"],
            coords={"id": ("traj", member)},
            name="WMO",
        )
        latitude_da = xr.DataArray(
            data=self.latitudes,
            dims=["traj", "obs"],
            coords={"id": ("traj", member), "time": ("obs", times)},
            name="lat",
            attrs={"units": unit},
        )
        longitude_da = xr.DataArray(
            data=self.longitudes,
            dims=["traj", "obs"],
            coords={"id": ("traj", member), "time": ("obs", times)},
            name="lon",
            attrs={"units": unit},
        )

        return {"WMO": wmo_da, "latitude": latitude_da, "longitude": longitude_da}
id: Int[Array, 'member'] | None property

Returns the IDs of the trajectories.

Returns:

Type Description
Int[Array, 'member'] | None

The IDs of the trajectories.

latitudes: State property

Returns the latitudes of the trajectories.

Returns:

Type Description
State

The latitudes of the trajectories.

locations: Location property

Returns the locations of the trajectories.

Returns:

Type Description
Location

The locations of the trajectories.

longitudes: State property

Returns the longitudes of the trajectories.

Returns:

Type Description
State

The longitudes of the trajectories.

origin: State property

Returns the origin of the trajectories.

Returns:

Type Description
State

The origin of the trajectories.

liu_index(other: Trajectory) -> TimeseriesEnsemble

Computes the Liu Index for each ensemble trajectory.

Parameters:

Name Type Description Default
other Trajectory

The other trajectory to compare against.

required

Returns:

Type Description
TimeseriesEnsemble

The Liu Index for each ensemble trajectory.

Source code in pastax/trajectory/_trajectory_ensemble.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def liu_index(self, other: Trajectory) -> TimeseriesEnsemble:
    """
    Computes the Liu Index for each ensemble trajectory.

    Parameters
    ----------
    other : Trajectory
        The other trajectory to compare against.

    Returns
    -------
    TimeseriesEnsemble
        The Liu Index for each ensemble trajectory.
    """
    liu_index = self.map(lambda trajectory: other.liu_index(trajectory))  # type: ignore
    return TimeseriesEnsemble.from_array(liu_index.value, self.times.value, name="Liu index")
lengths() -> TimeseriesEnsemble

Returns the lengths of the trajectories.

Returns:

Type Description
TimeseriesEnsemble

The lengths of the trajectories.

Source code in pastax/trajectory/_trajectory_ensemble.py
143
144
145
146
147
148
149
150
151
152
153
def lengths(self) -> TimeseriesEnsemble:
    """
    Returns the lengths of the trajectories.

    Returns
    -------
    TimeseriesEnsemble
        The lengths of the trajectories.
    """
    lengths = self.map(lambda trajectory: trajectory.lengths())  # type: ignore
    return TimeseriesEnsemble.from_array(lengths.value, self.times.value, unit=lengths.unit, name="lengths")
mae(other: Trajectory) -> TimeseriesEnsemble

Computes the Mean Absolute Error (MAE) for each ensemble trajectory.

Parameters:

Name Type Description Default
other Trajectory

The other trajectory to compare against.

required

Returns:

Type Description
TimeseriesEnsemble

The MAE for each ensemble trajectory.

Source code in pastax/trajectory/_trajectory_ensemble.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def mae(self, other: Trajectory) -> TimeseriesEnsemble:
    """
    Computes the Mean Absolute Error (MAE) for each ensemble trajectory.

    Parameters
    ----------
    other : Trajectory
        The other trajectory to compare against.

    Returns
    -------
    TimeseriesEnsemble
        The MAE for each ensemble trajectory.
    """
    mae = self.map(lambda trajectory: other.mae(trajectory))  # type: ignore
    return TimeseriesEnsemble.from_array(mae.value, self.times.value, unit=mae.unit, name="MAE")
plot(ax: Axes | None = None, label: str | list[str] | None = None, color: str | list[str | float | int] | None = None, alpha_factor: float = 1, ti: int | None = None, **kwargs) -> Axes

Plots the trajectories.

Parameters:

Name Type Description Default
ax Axes | None

The matplotlib axis to plot on, defaults to None.

None
label str | list[str] | None

The label(s) for the plot, defaults to None.

None
color str | list[str | float | int] | None

The color(s) for the plot, defaults to None.

None
alpha_factor float

A factor controlling the overall transparency of the plotted ensemble, defaults to 1.

1
ti int | None

The time index to plot up to, defaults to None.

None
kwargs

Additional arguments passed to LineCollection.

{}

Returns:

Type Description
Axes

The matplotlib axis with the plot.

Source code in pastax/trajectory/_trajectory_ensemble.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def plot(
    self,
    ax: Axes | None = None,
    label: str | list[str] | None = None,
    color: str | list[str | float | int] | None = None,
    alpha_factor: float = 1,
    ti: int | None = None,
    **kwargs,
) -> Axes:
    """
    Plots the trajectories.

    Parameters
    ----------
    ax : Axes | None, optional
        The matplotlib axis to plot on, defaults to `None`.
    label : str | list[str] | None, optional
        The label(s) for the plot, defaults to `None`.
    color : str | list[str | float | int] | None, optional
        The color(s) for the plot, defaults to `None`.
    alpha_factor : float, optional
        A factor controlling the overall transparency of the plotted ensemble, defaults to `1`.
    ti : int | None, optional
        The time index to plot up to, defaults to None.
    kwargs: dict, optional
        Additional arguments passed to `LineCollection`.

    Returns
    -------
    plt.Axes
        The matplotlib axis with the plot.
    """
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(projection=ccrs.PlateCarree())

    if ti is None:
        ti = self.length

    alpha_factor *= np.clip(1 / ((self.size / 10) ** 0.5), 0.05, 1).item()
    alpha = np.geomspace(0.25, 1, ti - 1) * alpha_factor

    locations = self.locations.value.swapaxes(0, 1)[:ti, :, None, ::-1]
    segments = np.concat([locations[:-1], locations[1:]], axis=2).reshape(-1, 2, 2)
    alpha = np.repeat(alpha, self.size)

    if not (isinstance(label, str) or label is None) and color is not None:
        colors = np.tile(color, ti - 1)
    else:
        colors = color
    lc = LineCollection(segments, color=colors, alpha=alpha, **kwargs)  # type: ignore
    ax.add_collection(lc)

    # trick to display label with alpha=1
    if not (isinstance(label, str) or label is None):
        for i in range(len(label)):
            if color is not None:
                color_ = color[i]
            else:
                color_ = color
            ax.plot(
                self.longitudes.value[i, -1],
                self.latitudes.value[i, -1],
                label=label[i],
                color=color_,
            )
    else:
        ax.plot(
            self.longitudes.value[0, -1],
            self.latitudes.value[0, -1],
            label=label,
            color=color,
        )

    return ax
rmse(other: Trajectory) -> TimeseriesEnsemble

Computes the Root Mean Square Error (RMSE) for each ensemble trajectory.

Parameters:

Name Type Description Default
other Trajectory

The other trajectory to compare against.

required

Returns:

Type Description
TimeseriesEnsemble

The RMSE for each ensemble trajectory.

Source code in pastax/trajectory/_trajectory_ensemble.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def rmse(self, other: Trajectory) -> TimeseriesEnsemble:
    """
    Computes the Root Mean Square Error (RMSE) for each ensemble trajectory.

    Parameters
    ----------
    other : Trajectory
        The other trajectory to compare against.

    Returns
    -------
    TimeseriesEnsemble
        The RMSE for each ensemble trajectory.
    """
    rmse = self.map(lambda trajectory: other.rmse(trajectory))  # type: ignore
    return TimeseriesEnsemble.from_array(rmse.value, self.times.value, unit=rmse.unit, name="RMSE")
separation_distance(other: Trajectory) -> TimeseriesEnsemble

Computes the separation distance for each ensemble trajectory.

Parameters:

Name Type Description Default
other Trajectory

The other trajectory to compare against.

required

Returns:

Type Description
TimeseriesEnsemble

The separation distance for each ensemble trajectory.

Source code in pastax/trajectory/_trajectory_ensemble.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def separation_distance(self, other: Trajectory) -> TimeseriesEnsemble:
    """
    Computes the separation distance for each ensemble trajectory.

    Parameters
    ----------
    other : Trajectory
        The other trajectory to compare against.

    Returns
    -------
    TimeseriesEnsemble
        The separation distance for each ensemble trajectory.
    """
    separation_distance = self.map(
        lambda trajectory: other.separation_distance(trajectory)  # type: ignore
    )
    return TimeseriesEnsemble.from_array(
        separation_distance.value,
        self.times.value,
        unit=separation_distance.unit,
        name="Separation distance",
    )
steps() -> TimeseriesEnsemble

Returns the steps of the trajectories.

Returns:

Type Description
TimeseriesEnsemble

The steps of the trajectories.

Source code in pastax/trajectory/_trajectory_ensemble.py
289
290
291
292
293
294
295
296
297
298
299
def steps(self) -> TimeseriesEnsemble:
    """
    Returns the steps of the trajectories.

    Returns
    -------
    TimeseriesEnsemble
        The steps of the trajectories.
    """
    steps = self.map(lambda trajectory: trajectory.steps())  # type: ignore
    return TimeseriesEnsemble.from_array(steps.value, self.times.value, unit=steps.unit, name="steps")
velocities() -> TimeseriesEnsemble

Returns the velocities of the trajectories.

Returns:

Type Description
TimeseriesEnsemble

The velocities of the trajectories.

Source code in pastax/trajectory/_trajectory_ensemble.py
301
302
303
304
305
306
307
308
309
310
311
312
313
def velocities(self) -> TimeseriesEnsemble:
    """
    Returns the velocities of the trajectories.

    Returns
    -------
    TimeseriesEnsemble
        The velocities of the trajectories.
    """
    velocities = self.map(lambda trajectory: trajectory.velocities())  # type: ignore
    return TimeseriesEnsemble.from_array(
        velocities.value, self.times.value, unit=velocities.unit, name="velocities"
    )
to_xarray() -> xr.Dataset

Converts the pastax.trajectory.TrajectoryEnsemble to a xarray.Dataset.

Returns:

Type Description
Dataset

The corresponding xarray.Dataset.

Source code in pastax/trajectory/_trajectory_ensemble.py
315
316
317
318
319
320
321
322
323
324
def to_xarray(self) -> xr.Dataset:
    """
    Converts the [`pastax.trajectory.TrajectoryEnsemble`][] to a `xarray.Dataset`.

    Returns
    -------
    xr.Dataset
        The corresponding `xarray.Dataset`.
    """
    return xr.Dataset(self.to_dataarray())
from_array(values: Float[Array, 'member time 2'], times: Float[Array, 'time'], unit: Unit | dict[Unit, int | float] = UNIT['°'], id: Int[Array, ''] | None = None, **_: dict) -> TrajectoryEnsemble classmethod

Creates a pastax.trajectory.TrajectoryEnsemble from arrays of values and time points.

Parameters:

Name Type Description Default
values Float[Array, 'member time 2']

The array of (latitudes, longitudes) values for the members of the trajectory ensemble.

required
times Float[Array, 'time']

The time points for the trajectories.

required
unit Unit | dict[Unit, int | float]

Unit of the trajectories locations, defaults to UNIT["°"].

UNIT['°']
id Int[Array, ''] | None

The ID of the trajectories, defaults to None.

None

Returns:

Type Description
TrajectoryEnsemble
Source code in pastax/trajectory/_trajectory_ensemble.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
@classmethod
def from_array(
    cls,
    values: Float[Array, "member time 2"],
    times: Float[Array, "time"],
    unit: Unit | dict[Unit, int | float] = UNIT["°"],
    id: Int[Array, ""] | None = None,
    **_: dict,
) -> TrajectoryEnsemble:
    """
    Creates a [`pastax.trajectory.TrajectoryEnsemble`][] from arrays of values and time points.

    Parameters
    ----------
    values : Float[Array, "member time 2"]
        The array of (latitudes, longitudes) values for the members of the trajectory ensemble.
    times : Float[Array, "time"]
        The time points for the trajectories.
    unit : Unit | dict[Unit, int | float], optional
        Unit of the trajectories locations, defaults to UNIT["°"].
    id : Int[Array, ""] | None, optional
        The ID of the trajectories, defaults to None.

    Returns
    -------
    TrajectoryEnsemble
        The corresponding [`pastax.trajectory.TrajectoryEnsemble`][].
    """
    return super().from_array(values, times, unit=unit, id=id)  # type: ignore
to_dataarray() -> dict[str, xr.DataArray]

Converts the pastax.trajectory.TrajectoryEnsemble to a dictionary of xarray.DataArray.

Returns:

Type Description
dict[str, DataArray]

A dictionary where keys are the variable names and values are the corresponding xarray.DataArray.

Source code in pastax/trajectory/_trajectory_ensemble.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def to_dataarray(self) -> dict[str, xr.DataArray]:
    """
    Converts the [`pastax.trajectory.TrajectoryEnsemble`][] to a dictionary of `xarray.DataArray`.

    Returns
    -------
    dict[str, xr.DataArray]
        A dictionary where keys are the variable names and values are the corresponding `xarray.DataArray`.
    """
    member = np.arange(self.size)
    times = self.members.times.to_datetime()
    unit = units_to_str(self.unit)

    wmo_da = xr.DataArray(
        data=self.id,
        dims=["traj", "obs"],
        coords={"id": ("traj", member)},
        name="WMO",
    )
    latitude_da = xr.DataArray(
        data=self.latitudes,
        dims=["traj", "obs"],
        coords={"id": ("traj", member), "time": ("obs", times)},
        name="lat",
        attrs={"units": unit},
    )
    longitude_da = xr.DataArray(
        data=self.longitudes,
        dims=["traj", "obs"],
        coords={"id": ("traj", member), "time": ("obs", times)},
        name="lon",
        attrs={"units": unit},
    )

    return {"WMO": wmo_da, "latitude": latitude_da, "longitude": longitude_da}

Set

Bases: Module

Base class representing a set of PyTrees.

Attributes:

Name Type Description
size int

The number of members in the set.

Source code in pastax/trajectory/_set.py
 9
10
11
12
13
14
15
16
17
18
19
20
class Set(eqx.Module):
    """
    Base class representing a set of PyTrees.

    Attributes
    ----------
    size : int
        The number of members in the set.
    """

    _members: PyTree | Mapping[str, PyTree] | Sequence[PyTree]
    size: int

pastax.gridded

This module provides classes and functions for handling coordinates, grids, and pastax.gridded.Gridded in JAX.

Coordinate

Bases: Module

Class for handling 1D coordinates (i.e. of rectilinear grids).

Attributes:

Name Type Description
indices Interpolator1D

Interpolator for nearest index interpolation.

Methods:

Name Description
values

Returns the coordinate values.

index

Returns the nearest index for the given query coordinates.

from_array

Creates a pastax.gridded.Coordinate instance from an array of values.

__getitem__

Retrieve an item from the coordinate array.

Source code in pastax/gridded/_coordinate.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class Coordinate(eqx.Module):
    """
    Class for handling 1D coordinates (i.e. of rectilinear grids).

    Attributes
    ----------
    indices : ipx.Interpolator1D
        Interpolator for nearest index interpolation.

    Methods
    -------
    values
        Returns the coordinate values.
    index(query)
        Returns the nearest index for the given query coordinates.
    from_array(values, **interpolator_kwargs)
        Creates a [`pastax.gridded.Coordinate`][] instance from an array of values.
    __getitem__(item:)
        Retrieve an item from the coordinate array.
    """

    _values: Float[Array, "dim"]  # only handles 1D coordinates, i.e. rectilinear grids
    indices: ipx.Interpolator1D

    @property
    def values(self) -> Float[Array, "dim"]:
        """
        Returns the coordinate values.

        Returns
        -------
        Float[Array, "dim"]
            The coordinate values.
        """
        return self._values

    def index(self, query: Float[Array, "Nq"]) -> Int[Array, "Nq"]:
        """
        Returns the nearest index interpolation for the given query.

        Parameters
        ----------
        query : Float[Array, "Nq"]
            The query array for which the nearest indices are to be found.

        Returns
        -------
        Int[Array, "Nq"]
            An array of integers representing the nearest indices.
        """
        return self.indices(query).astype(int)

    @classmethod
    def from_array(cls, values: Float[Array, "dim"], **interpolator_kwargs: Any) -> Coordinate:
        """
        Create a [`pastax.gridded.Coordinate`][] object from an array of values.

        This method initializes a [`pastax.gridded.Coordinate`][] object using the provided array of values.
        It uses a 1D interpolator to generate indices from values, with the interpolation method set to `"nearest"`.

        Parameters
        ----------
        values : Float[Array, "dim"]
            An array of coordinate values.
        **interpolator_kwargs : Any
            Additional keyword arguments for the interpolator.

        Returns
        -------
        Coordinate
            A  [`pastax.gridded.Coordinate`][] object containing the provided values and corresponding indices
            interpolator.
        """
        interpolator_kwargs["method"] = "nearest"
        indices = ipx.Interpolator1D(values, jnp.arange(values.size), **interpolator_kwargs)

        return cls(_values=values, indices=indices)

    def __getitem__(self, item: Any) -> Float[Array, "..."]:
        """
        Retrieve an item from the coordinate array.

        Parameters
        ----------
        item : Any
            The index or slice used to retrieve the item from the values array.

        Returns
        -------
        Float[Array, "..."] | Int[Array, "..."]
            The item retrieved from the coordinate array.
        """
        return self.values.__getitem__(item)
values: Float[Array, 'dim'] property

Returns the coordinate values.

Returns:

Type Description
Float[Array, 'dim']

The coordinate values.

index(query: Float[Array, 'Nq']) -> Int[Array, 'Nq']

Returns the nearest index interpolation for the given query.

Parameters:

Name Type Description Default
query Float[Array, 'Nq']

The query array for which the nearest indices are to be found.

required

Returns:

Type Description
Int[Array, 'Nq']

An array of integers representing the nearest indices.

Source code in pastax/gridded/_coordinate.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def index(self, query: Float[Array, "Nq"]) -> Int[Array, "Nq"]:
    """
    Returns the nearest index interpolation for the given query.

    Parameters
    ----------
    query : Float[Array, "Nq"]
        The query array for which the nearest indices are to be found.

    Returns
    -------
    Int[Array, "Nq"]
        An array of integers representing the nearest indices.
    """
    return self.indices(query).astype(int)
from_array(values: Float[Array, 'dim'], **interpolator_kwargs: Any) -> Coordinate classmethod

Create a pastax.gridded.Coordinate object from an array of values.

This method initializes a pastax.gridded.Coordinate object using the provided array of values. It uses a 1D interpolator to generate indices from values, with the interpolation method set to "nearest".

Parameters:

Name Type Description Default
values Float[Array, 'dim']

An array of coordinate values.

required
**interpolator_kwargs Any

Additional keyword arguments for the interpolator.

{}

Returns:

Type Description
Coordinate

A pastax.gridded.Coordinate object containing the provided values and corresponding indices interpolator.

Source code in pastax/gridded/_coordinate.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@classmethod
def from_array(cls, values: Float[Array, "dim"], **interpolator_kwargs: Any) -> Coordinate:
    """
    Create a [`pastax.gridded.Coordinate`][] object from an array of values.

    This method initializes a [`pastax.gridded.Coordinate`][] object using the provided array of values.
    It uses a 1D interpolator to generate indices from values, with the interpolation method set to `"nearest"`.

    Parameters
    ----------
    values : Float[Array, "dim"]
        An array of coordinate values.
    **interpolator_kwargs : Any
        Additional keyword arguments for the interpolator.

    Returns
    -------
    Coordinate
        A  [`pastax.gridded.Coordinate`][] object containing the provided values and corresponding indices
        interpolator.
    """
    interpolator_kwargs["method"] = "nearest"
    indices = ipx.Interpolator1D(values, jnp.arange(values.size), **interpolator_kwargs)

    return cls(_values=values, indices=indices)
__getitem__(item: Any) -> Float[Array, '...']

Retrieve an item from the coordinate array.

Parameters:

Name Type Description Default
item Any

The index or slice used to retrieve the item from the values array.

required

Returns:

Type Description
Float[Array, '...'] | Int[Array, '...']

The item retrieved from the coordinate array.

Source code in pastax/gridded/_coordinate.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __getitem__(self, item: Any) -> Float[Array, "..."]:
    """
    Retrieve an item from the coordinate array.

    Parameters
    ----------
    item : Any
        The index or slice used to retrieve the item from the values array.

    Returns
    -------
    Float[Array, "..."] | Int[Array, "..."]
        The item retrieved from the coordinate array.
    """
    return self.values.__getitem__(item)

LongitudeCoordinate

Bases: Coordinate

Class for handling 1D longitude coordinates (i.e. of rectilinear grids). This class handles the circular nature of longitudes coordinates.

Attributes:

Name Type Description
indices Interpolator1D

Interpolator for nearest index interpolation.

is_periodic bool

Whether the mesh uses spherical coordinate.

Methods:

Name Description
values

Returns the coordinate values.

index

Returns the nearest index for the given query coordinates.

from_array

Creates a pastax.gridded.LongitudeCoordinate instance from an array of values.

Source code in pastax/gridded/_coordinate.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class LongitudeCoordinate(Coordinate):
    """
    Class for handling 1D longitude coordinates (i.e. of rectilinear grids).
    This class handles the circular nature of longitudes coordinates.

    Attributes
    ----------
    indices : ipx.Interpolator1D
        Interpolator for nearest index interpolation.
    is_periodic : bool
        Whether the mesh uses spherical coordinate.

    Methods
    -------
    values
        Returns the coordinate values.
    index(query)
        Returns the nearest index for the given query coordinates.
    from_array(values, **interpolator_kwargs)
        Creates a [`pastax.gridded.LongitudeCoordinate`][] instance from an array of values.
    """

    _values: Float[Array, "dim"]  # only handles 1D coordinates, i.e. rectilinear grids
    indices: ipx.Interpolator1D
    is_periodic: bool

    @property
    def values(self) -> Float[Array, "dim"]:
        """
        Returns the coordinate values.

        Returns
        -------
        Float[Array, "dim"]
            The coordinate values.
        """
        values = self._values
        if self.is_periodic:
            values -= 180

        return values

    def index(self, query: Float[Array, "Nq"]) -> Int[Array, "Nq"]:
        """
        Returns the nearest index interpolation for the given query.

        Parameters
        ----------
        query : Float[Array, "Nq"]
            The query array for which the nearest indices are to be found.

        Returns
        -------
        Int[Array, "Nq"]
            An array of integers representing the nearest indices.
        """
        if self.is_periodic:
            query = longitude_in_180_180_degrees(query)  # force to be in -180 to 180 degrees
            query += 180  # shift back to 0 to 360 degrees

        return self.indices(query).astype(int)

    @classmethod
    def from_array(
        cls, values: Float[Array, "dim"], is_periodic: bool = True, **interpolator_kwargs: Any
    ) -> LongitudeCoordinate:
        """
        Create a LongitudeCoordinate object from an array of values.

        This method initializes a LongitudeCoordinate object using the provided array of values.
        It uses a 1D interpolator to generate indices from values, with the interpolation method set to "nearest".

        Parameters
        ----------
        values : Float[Array, "dim"]
            An array of coordinate values.
        is_periodic : bool, optional
            Whether the dimension should be consired periodic, defaults to `True`.
        **interpolator_kwargs : Any
            Additional keyword arguments for the interpolator.

        Returns
        -------
        LongitudeCoordinate
            A [`pastax.gridded.LongitudeCoordinate`][] object containing the provided values and corresponding indices
            interpolator.
        """
        if is_periodic:
            values = longitude_in_180_180_degrees(values)  # force to be in -180 to 180 degrees
            values += 180  # shift back to 0 to 360 degrees
            interpolator_kwargs["period"] = 360

        interpolator_kwargs["method"] = "nearest"
        indices = ipx.Interpolator1D(values, jnp.arange(values.size), **interpolator_kwargs)

        return cls(_values=values, indices=indices, is_periodic=is_periodic)
values: Float[Array, 'dim'] property

Returns the coordinate values.

Returns:

Type Description
Float[Array, 'dim']

The coordinate values.

index(query: Float[Array, 'Nq']) -> Int[Array, 'Nq']

Returns the nearest index interpolation for the given query.

Parameters:

Name Type Description Default
query Float[Array, 'Nq']

The query array for which the nearest indices are to be found.

required

Returns:

Type Description
Int[Array, 'Nq']

An array of integers representing the nearest indices.

Source code in pastax/gridded/_coordinate.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def index(self, query: Float[Array, "Nq"]) -> Int[Array, "Nq"]:
    """
    Returns the nearest index interpolation for the given query.

    Parameters
    ----------
    query : Float[Array, "Nq"]
        The query array for which the nearest indices are to be found.

    Returns
    -------
    Int[Array, "Nq"]
        An array of integers representing the nearest indices.
    """
    if self.is_periodic:
        query = longitude_in_180_180_degrees(query)  # force to be in -180 to 180 degrees
        query += 180  # shift back to 0 to 360 degrees

    return self.indices(query).astype(int)
from_array(values: Float[Array, 'dim'], is_periodic: bool = True, **interpolator_kwargs: Any) -> LongitudeCoordinate classmethod

Create a LongitudeCoordinate object from an array of values.

This method initializes a LongitudeCoordinate object using the provided array of values. It uses a 1D interpolator to generate indices from values, with the interpolation method set to "nearest".

Parameters:

Name Type Description Default
values Float[Array, 'dim']

An array of coordinate values.

required
is_periodic bool

Whether the dimension should be consired periodic, defaults to True.

True
**interpolator_kwargs Any

Additional keyword arguments for the interpolator.

{}

Returns:

Type Description
LongitudeCoordinate

A pastax.gridded.LongitudeCoordinate object containing the provided values and corresponding indices interpolator.

Source code in pastax/gridded/_coordinate.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
@classmethod
def from_array(
    cls, values: Float[Array, "dim"], is_periodic: bool = True, **interpolator_kwargs: Any
) -> LongitudeCoordinate:
    """
    Create a LongitudeCoordinate object from an array of values.

    This method initializes a LongitudeCoordinate object using the provided array of values.
    It uses a 1D interpolator to generate indices from values, with the interpolation method set to "nearest".

    Parameters
    ----------
    values : Float[Array, "dim"]
        An array of coordinate values.
    is_periodic : bool, optional
        Whether the dimension should be consired periodic, defaults to `True`.
    **interpolator_kwargs : Any
        Additional keyword arguments for the interpolator.

    Returns
    -------
    LongitudeCoordinate
        A [`pastax.gridded.LongitudeCoordinate`][] object containing the provided values and corresponding indices
        interpolator.
    """
    if is_periodic:
        values = longitude_in_180_180_degrees(values)  # force to be in -180 to 180 degrees
        values += 180  # shift back to 0 to 360 degrees
        interpolator_kwargs["period"] = 360

    interpolator_kwargs["method"] = "nearest"
    indices = ipx.Interpolator1D(values, jnp.arange(values.size), **interpolator_kwargs)

    return cls(_values=values, indices=indices, is_periodic=is_periodic)

Field

Bases: Module

Base class for representing a field on a grid.

Methods:

Name Description
values

Returns the field values.

__getitem__

Retrieves the value(s) at the specified index or slice of the grid.

Source code in pastax/gridded/_field.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Field(eqx.Module):
    """
    Base class for representing a field on a grid.

    Methods
    -------
    values
        Returns the field values.
    __getitem__(item)
        Retrieves the value(s) at the specified index or slice of the grid.
    """

    _values: Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]

    @property
    def values(self) -> Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]:
        """
        Returns the field values.

        Returns
        -------
        Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]
            The gridded values.
        """
        return self._values

    def __getitem__(self, item: Any) -> Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]:
        """
        Retrieve an item from the values array.

        Parameters
        ----------
        item : Any
            The index or slice used to retrieve the item from the values array.

        Returns
        -------
        Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]
            The item retrieved from the values array.
        """
        return self.values.__getitem__(item)
values: Bool[Array, '...'] | Float[Array, '...'] | Int[Array, '...'] property

Returns the field values.

Returns:

Type Description
Bool[Array, '...'] | Float[Array, '...'] | Int[Array, '...']

The gridded values.

__getitem__(item: Any) -> Bool[Array, '...'] | Float[Array, '...'] | Int[Array, '...']

Retrieve an item from the values array.

Parameters:

Name Type Description Default
item Any

The index or slice used to retrieve the item from the values array.

required

Returns:

Type Description
Bool[Array, '...'] | Float[Array, '...'] | Int[Array, '...']

The item retrieved from the values array.

Source code in pastax/gridded/_field.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __getitem__(self, item: Any) -> Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]:
    """
    Retrieve an item from the values array.

    Parameters
    ----------
    item : Any
        The index or slice used to retrieve the item from the values array.

    Returns
    -------
    Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]
        The item retrieved from the values array.
    """
    return self.values.__getitem__(item)

SpatialField

Bases: Field

Class representing a spatial (2D) field with interpolation capabilities.

Methods:

Name Description
interp

Interpolates the field at the given coordinates.

from_array

Creates a pastax.gridded.SpatialField instance from the given array of values, latitude, and longitude using the specified interpolation method.

Source code in pastax/gridded/_field.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class SpatialField(Field):
    """
    Class representing a spatial (2D) field with interpolation capabilities.

    Methods
    -------
    interp(**coordinates)
        Interpolates the field at the given coordinates.
    from_array(values, latitude, longitude, interpolation_method)
        Creates a [`pastax.gridded.SpatialField`][] instance from the given array of values, latitude, and longitude
        using the specified interpolation method.
    """

    _values: Bool[Array, "lat lon"] | Float[Array, "lat lon"] | Int[Array, "lat lon"]
    _fx: ipx.Interpolator2D

    def interp(self, **coordinates: Float[Array, "Nq"]) -> Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]:
        """
        Interpolates the field at the given coordinates.

        Parameters
        ----------
        **coordinates : Float[Array, "Nq"]
            The 2-dimensional points to interpolate to.

        Returns
        -------
        Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]
            Interpolated values at the given coordinates.
        """
        if "latitude" in coordinates and "longitude" in coordinates:
            return self._interp_spatial(
                latitude=coordinates["latitude"],
                longitude=coordinates["longitude"],
            )
        else:
            return self.values

    def _interp_spatial(
        self, latitude: Float[Array, "Nq"], longitude: Float[Array, "Nq"]
    ) -> Bool[Array, "Nq ..."] | Float[Array, "Nq ..."] | Int[Array, "Nq ..."]:
        """
        Interpolates spatial data based on given latitude and longitude arrays.

        Parameters
        ----------
        latitude : Float[Array, "Nq"]
            Array of latitude values.
        longitude : Float[Array, "Nq"]
            Array of longitude values.

        Returns
        -------
        Bool[Array, "Nq ..."] | Float[Array, "Nq ..."] | Int[Array, "Nq ..."]
            Interpolated spatial data array.
        """
        longitude = longitude_in_180_180_degrees(longitude)  # force to be in -180 to 180 degrees
        longitude += 180  # shift back to 0 to 360 degrees

        return self._fx(latitude, longitude)

    @classmethod
    def from_array(
        cls,
        values: Bool[Array, "lat lon"] | Float[Array, "lat lon"] | Int[Array, "lat lon"],
        latitude: Float[Array, "lat"],
        longitude: Float[Array, "lon"],
        interpolation_method: str,
    ) -> SpatialField:
        """
        Create a [`pastax.gridded.SpatialField`][] object from given arrays of values, latitude, and longitude.

        Parameters
        ----------
        values : Bool[Array, "lat lon"] | Float[Array, "lat lon"] | Int[Array, "lat lon"]
            A 2D array of values representing the spatial data.
        latitude : Float[Array, "lat"]
            A 1D array of latitude values.
        longitude : Float[Array, "lon"]
            A 1D array of longitude values.
        interpolation_method : str
            The method to use for interpolation.

        Returns
        -------
        SpatialField
            A [`pastax.gridded.SpatialField`][] object containing the values and the interpolated spatial field.
        """
        longitude = longitude_in_180_180_degrees(longitude)  # force to be in -180 to 180 degrees
        longitude += 180  # shift back to 0 to 360 degrees

        _fx = ipx.Interpolator2D(
            latitude,
            longitude,  # periodic domain
            values,
            method=interpolation_method,
            extrap=True,
            period=(None, 360),
        )

        return cls(_values=values, _fx=_fx)
interp(**coordinates: Float[Array, 'Nq']) -> Bool[Array, '...'] | Float[Array, '...'] | Int[Array, '...']

Interpolates the field at the given coordinates.

Parameters:

Name Type Description Default
**coordinates Float[Array, 'Nq']

The 2-dimensional points to interpolate to.

{}

Returns:

Type Description
Bool[Array, '...'] | Float[Array, '...'] | Int[Array, '...']

Interpolated values at the given coordinates.

Source code in pastax/gridded/_field.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def interp(self, **coordinates: Float[Array, "Nq"]) -> Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]:
    """
    Interpolates the field at the given coordinates.

    Parameters
    ----------
    **coordinates : Float[Array, "Nq"]
        The 2-dimensional points to interpolate to.

    Returns
    -------
    Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]
        Interpolated values at the given coordinates.
    """
    if "latitude" in coordinates and "longitude" in coordinates:
        return self._interp_spatial(
            latitude=coordinates["latitude"],
            longitude=coordinates["longitude"],
        )
    else:
        return self.values
from_array(values: Bool[Array, 'lat lon'] | Float[Array, 'lat lon'] | Int[Array, 'lat lon'], latitude: Float[Array, 'lat'], longitude: Float[Array, 'lon'], interpolation_method: str) -> SpatialField classmethod

Create a pastax.gridded.SpatialField object from given arrays of values, latitude, and longitude.

Parameters:

Name Type Description Default
values Bool[Array, 'lat lon'] | Float[Array, 'lat lon'] | Int[Array, 'lat lon']

A 2D array of values representing the spatial data.

required
latitude Float[Array, 'lat']

A 1D array of latitude values.

required
longitude Float[Array, 'lon']

A 1D array of longitude values.

required
interpolation_method str

The method to use for interpolation.

required

Returns:

Type Description
SpatialField

A pastax.gridded.SpatialField object containing the values and the interpolated spatial field.

Source code in pastax/gridded/_field.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@classmethod
def from_array(
    cls,
    values: Bool[Array, "lat lon"] | Float[Array, "lat lon"] | Int[Array, "lat lon"],
    latitude: Float[Array, "lat"],
    longitude: Float[Array, "lon"],
    interpolation_method: str,
) -> SpatialField:
    """
    Create a [`pastax.gridded.SpatialField`][] object from given arrays of values, latitude, and longitude.

    Parameters
    ----------
    values : Bool[Array, "lat lon"] | Float[Array, "lat lon"] | Int[Array, "lat lon"]
        A 2D array of values representing the spatial data.
    latitude : Float[Array, "lat"]
        A 1D array of latitude values.
    longitude : Float[Array, "lon"]
        A 1D array of longitude values.
    interpolation_method : str
        The method to use for interpolation.

    Returns
    -------
    SpatialField
        A [`pastax.gridded.SpatialField`][] object containing the values and the interpolated spatial field.
    """
    longitude = longitude_in_180_180_degrees(longitude)  # force to be in -180 to 180 degrees
    longitude += 180  # shift back to 0 to 360 degrees

    _fx = ipx.Interpolator2D(
        latitude,
        longitude,  # periodic domain
        values,
        method=interpolation_method,
        extrap=True,
        period=(None, 360),
    )

    return cls(_values=values, _fx=_fx)

SpatioTemporalField

Bases: Field

Class representing a spatiotemporal (3D) field with interpolation capabilities.

Methods:

Name Description
interp

Interpolates the field at the given coordinates.

from_array

Creates a pastax.gridded.SpatioTemporalField instance from the given array of values and coordinates.

Source code in pastax/gridded/_field.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
class SpatioTemporalField(Field):
    """
    Class representing a spatiotemporal (3D) field with interpolation capabilities.

    Methods
    -------
    interp(**coordinates)
        Interpolates the field at the given coordinates.
    from_array(values, time, latitude, longitude, interpolation_method)
        Creates a [`pastax.gridded.SpatioTemporalField`][] instance from the given array of values and coordinates.
    """

    _values: Bool[Array, "time lat lon"] | Float[Array, "time lat lon"] | Int[Array, "time lat lon"]
    _ft: ipx.Interpolator1D
    _fx: ipx.Interpolator2D
    _ftx: ipx.Interpolator3D

    def interp(self, **coordinates: Float[Array, "Nq"]) -> Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]:
        """
        Interpolates the field at the given coordinates.

        Parameters
        ----------
        **coordinates : Float[Array, "Nq"]
            The N-dimensional points to interpolate to.

        Returns
        -------
        Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]
            Interpolated values at the given coordinates.
        """
        if "time" in coordinates and "latitude" in coordinates and "longitude" in coordinates:
            return self._interp_spatiotemporal(
                time=coordinates["time"],
                latitude=coordinates["latitude"],
                longitude=coordinates["longitude"],
            )
        elif "latitude" in coordinates and "longitude" in coordinates:
            return self._interp_spatial(
                latitude=coordinates["latitude"],
                longitude=coordinates["longitude"],
            )
        elif "time" in coordinates:
            return self._interp_temporal(time=coordinates["time"])
        else:
            return self.values

    def _interp_temporal(
        self, time: Float[Array, "Nq"]
    ) -> Bool[Array, "Nq lat lon"] | Float[Array, "Nq lat lon"] | Int[Array, "Nq lat lon"]:
        """
        Interpolates the spatiotemporal field at the given time points.

        Parameters
        ----------
        time : Float[Array, Nq"]
            An array of time points at which to interpolate the spatiotemporal field.

        Returns
        -------
        Bool[Array, "Nq lat lon"] | Float[Array, "Nq lat lon"] | Int[Array, "Nq lat lon"]
            Interpolated values at the given time points.
        """
        return self._ft(time)

    def _interp_spatial(
        self, latitude: Float[Array, "Nq"], longitude: Float[Array, "Nq"]
    ) -> Bool[Array, "Nq time"] | Float[Array, "Nq time"] | Int[Array, "Nq time"]:
        """
        Interpolates the spatiotemporal field at the given latitude/longitude points.

        Parameters
        ----------
        latitude : Float[Array, "Nq"]
            Array of latitude values.
        longitude : Float[Array, "Nq"]
            Array of longitude values.

        Returns
        -------
        Bool[Array, "Nq time"] | Float[Array, "Nq time"] | Int[Array, "Nq time"]
            Interpolated values at the given latitude/longitude points.
        """
        longitude = longitude_in_180_180_degrees(longitude)  # force to be in -180 to 180 degrees
        longitude += 180  # shift back to 0 to 360 degrees

        return self._fx(latitude, longitude)

    def _interp_spatiotemporal(
        self,
        time: Float[Array, "Nq"],
        latitude: Float[Array, "Nq"],
        longitude: Float[Array, "Nq"],
    ) -> Bool[Array, "Nq"] | Float[Array, "Nq"] | Int[Array, "Nq"]:
        """
        Interpolates the spatiotemporal field at the given time/latitude/longitude points.

        Parameters
        ----------
        time : Float[Array, "Nq"]
            Array of time values.
        latitude : Float[Array, "Nq"]
            Array of latitude values.
        longitude : Float[Array, "Nq"]
            Array of longitude values.

        Returns
        -------
        Bool[Array, "Nq"] | Float[Array, "Nq"] | Int[Array, "Nq"]
            Interpolated values at the given time/latitude/longitude points.
        """
        longitude = longitude_in_180_180_degrees(longitude)  # force to be in -180 to 180 degrees
        longitude += 180  # shift back to 0 to 360 degrees

        return self._ftx(time, latitude, longitude)

    @classmethod
    def from_array(
        cls,
        values: Float[Array, "time lat lon"],
        time: Float[Array, "time"],
        latitude: Float[Array, "lat"],
        longitude: Float[Array, "lon"],
        interpolation_method: str,
    ) -> SpatioTemporalField:
        """
        Create a [`pastax.gridded.SpatioTemporalField`][] object from given arrays of values, time, latitude, and
        longitude.

        Parameters
        ----------
        values : Float[Array, "time lat lon"]
            The array of values representing the data over time, latitude, and longitude.
        time : Float[Array, "time"]
            The array of time points.
        latitude : Float[Array, "lat"]
            The array of latitude points.
        longitude : Float[Array, "lon"]
            The array of longitude points.
        interpolation_method : str
            The method to be used for interpolation (e.g., 'linear', 'nearest', ...).

        Returns
        -------
        SpatioTemporalField
            A [`pastax.gridded.SpatioTemporalField`][] object containing the original values and temporal, spatial,
            and spatiotemporal interpolators.
        """
        longitude = longitude_in_180_180_degrees(longitude)  # force to be in -180 to 180 degrees
        longitude += 180  # shift back to 0 to 360 degrees

        _ft = ipx.Interpolator1D(time, values, method=interpolation_method, extrap=True)
        _fx = ipx.Interpolator2D(
            latitude,
            longitude,
            jnp.moveaxis(values, 0, -1),  # time dim is moved to the last axis as it is not interpolated
            method=interpolation_method,
            extrap=True,
            period=(None, 360),
        )
        _ftx = ipx.Interpolator3D(
            time,
            latitude,
            longitude,
            values,
            method=interpolation_method,
            extrap=True,
            period=(None, None, 360),
        )

        return cls(
            _values=values,
            _ft=_ft,
            _fx=_fx,
            _ftx=_ftx,
        )
interp(**coordinates: Float[Array, 'Nq']) -> Bool[Array, '...'] | Float[Array, '...'] | Int[Array, '...']

Interpolates the field at the given coordinates.

Parameters:

Name Type Description Default
**coordinates Float[Array, 'Nq']

The N-dimensional points to interpolate to.

{}

Returns:

Type Description
Bool[Array, '...'] | Float[Array, '...'] | Int[Array, '...']

Interpolated values at the given coordinates.

Source code in pastax/gridded/_field.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def interp(self, **coordinates: Float[Array, "Nq"]) -> Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]:
    """
    Interpolates the field at the given coordinates.

    Parameters
    ----------
    **coordinates : Float[Array, "Nq"]
        The N-dimensional points to interpolate to.

    Returns
    -------
    Bool[Array, "..."] | Float[Array, "..."] | Int[Array, "..."]
        Interpolated values at the given coordinates.
    """
    if "time" in coordinates and "latitude" in coordinates and "longitude" in coordinates:
        return self._interp_spatiotemporal(
            time=coordinates["time"],
            latitude=coordinates["latitude"],
            longitude=coordinates["longitude"],
        )
    elif "latitude" in coordinates and "longitude" in coordinates:
        return self._interp_spatial(
            latitude=coordinates["latitude"],
            longitude=coordinates["longitude"],
        )
    elif "time" in coordinates:
        return self._interp_temporal(time=coordinates["time"])
    else:
        return self.values
from_array(values: Float[Array, 'time lat lon'], time: Float[Array, 'time'], latitude: Float[Array, 'lat'], longitude: Float[Array, 'lon'], interpolation_method: str) -> SpatioTemporalField classmethod

Create a pastax.gridded.SpatioTemporalField object from given arrays of values, time, latitude, and longitude.

Parameters:

Name Type Description Default
values Float[Array, 'time lat lon']

The array of values representing the data over time, latitude, and longitude.

required
time Float[Array, 'time']

The array of time points.

required
latitude Float[Array, 'lat']

The array of latitude points.

required
longitude Float[Array, 'lon']

The array of longitude points.

required
interpolation_method str

The method to be used for interpolation (e.g., 'linear', 'nearest', ...).

required

Returns:

Type Description
SpatioTemporalField

A pastax.gridded.SpatioTemporalField object containing the original values and temporal, spatial, and spatiotemporal interpolators.

Source code in pastax/gridded/_field.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
@classmethod
def from_array(
    cls,
    values: Float[Array, "time lat lon"],
    time: Float[Array, "time"],
    latitude: Float[Array, "lat"],
    longitude: Float[Array, "lon"],
    interpolation_method: str,
) -> SpatioTemporalField:
    """
    Create a [`pastax.gridded.SpatioTemporalField`][] object from given arrays of values, time, latitude, and
    longitude.

    Parameters
    ----------
    values : Float[Array, "time lat lon"]
        The array of values representing the data over time, latitude, and longitude.
    time : Float[Array, "time"]
        The array of time points.
    latitude : Float[Array, "lat"]
        The array of latitude points.
    longitude : Float[Array, "lon"]
        The array of longitude points.
    interpolation_method : str
        The method to be used for interpolation (e.g., 'linear', 'nearest', ...).

    Returns
    -------
    SpatioTemporalField
        A [`pastax.gridded.SpatioTemporalField`][] object containing the original values and temporal, spatial,
        and spatiotemporal interpolators.
    """
    longitude = longitude_in_180_180_degrees(longitude)  # force to be in -180 to 180 degrees
    longitude += 180  # shift back to 0 to 360 degrees

    _ft = ipx.Interpolator1D(time, values, method=interpolation_method, extrap=True)
    _fx = ipx.Interpolator2D(
        latitude,
        longitude,
        jnp.moveaxis(values, 0, -1),  # time dim is moved to the last axis as it is not interpolated
        method=interpolation_method,
        extrap=True,
        period=(None, 360),
    )
    _ftx = ipx.Interpolator3D(
        time,
        latitude,
        longitude,
        values,
        method=interpolation_method,
        extrap=True,
        period=(None, None, 360),
    )

    return cls(
        _values=values,
        _ft=_ft,
        _fx=_fx,
        _ftx=_ftx,
    )

Gridded

Bases: Module

Class providing some routines for handling gridded spatiotemporal data in JAX.

Attributes:

Name Type Description
cell_area Float[Array, 'lat lon']

Array of cell areas in square meters.

coordinates Coordinates

Coordinates object containing time, latitude, and longitude.

dx Float[Array, 'lat lon-1']

Array of longitudinal distances in meters.

dy Float[Array, 'lat-1 lon']

Array of latitudinal distances in meters.

fields dict[str, SpatioTemporalField]

Dictionary of spatiotemporal fields.

is_spherical_mesh bool

Boolean indicating whether the mesh uses spherical coordinates.

interpolation_method Literal['nearest', 'linear', 'cubic', 'cubic2', 'catmull-rom', 'cardinal', 'monotonic', 'monotonic-0', 'akima']

String indicating the interpolation method used when interpolating the fields. For details, see interpax documentation.

use_degrees bool

Boolean indicating whether distance units are degrees.

Methods:

Name Description
indices

Gets nearest indices of the spatio-temporal point time, latitude, longitude.

interp

Interpolates the given fields at the given coordinates.

neighborhood

Extracts a neighborhood of data around a specified point in time and space.

to_xarray

Returns the pastax.gridded.Gridded object as a xarray.Dataset.

from_array

use_degrees=False, is_uv_mps=True) Constructs a pastax.gridded.Gridded object from arrays of fields and coordinates time, latitude, longitude.

from_xarray

is_uv_mps=True) Constructs a pastax.gridded.Gridded object from a xarray.Dataset.

xarray_to_array

Converts an xarray.Dataset to arrays of fields and coordinates.

Source code in pastax/gridded/_gridded.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
class Gridded(eqx.Module):
    """
    Class providing some routines for handling gridded spatiotemporal data in JAX.

    Attributes
    ----------
    cell_area : Float[Array, "lat lon"]
        Array of cell areas in square meters.
    coordinates : Coordinates
        Coordinates object containing time, latitude, and longitude.
    dx : Float[Array, "lat lon-1"]
        Array of longitudinal distances in meters.
    dy : Float[Array, "lat-1 lon"]
        Array of latitudinal distances in meters.
    fields : dict[str, SpatioTemporalField]
        Dictionary of spatiotemporal fields.
    is_spherical_mesh: bool
        Boolean indicating whether the mesh uses spherical coordinates.
    interpolation_method: Literal["nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"]
        String indicating the interpolation method used when interpolating the fields.
        For details, see [`interpax` documentation](https://interpax.readthedocs.io/en/latest/index.html).
    use_degrees : bool
        Boolean indicating whether distance units are degrees.

    Methods
    -------
    indices(time, latitude, longitude)
        Gets nearest indices of the spatio-temporal point `time`, `latitude`, `longitude`.
    interp(*fields, **coordinates)
        Interpolates the given fields at the given coordinates.
    neighborhood(*fields, time, latitude, longitude, t_width, x_width)
        Extracts a neighborhood of data around a specified point in time and space.
    to_xarray()
        Returns the [`pastax.gridded.Gridded`][] object as a `xarray.Dataset`.
    from_array(fields, time, latitude, longitude, interpolation_method="linear", is_spherical_mesh=True,
            use_degrees=False, is_uv_mps=True)
        Constructs a [`pastax.gridded.Gridded`][] object from arrays of fields and coordinates `time`, `latitude`,
        `longitude`.
    from_xarray(dataset, fields, coordinates, interpolation_method="linear", is_spherical_mesh=True, use_degrees=False,
            is_uv_mps=True)
        Constructs a [`pastax.gridded.Gridded`][] object from a `xarray.Dataset`.
    xarray_to_array(dataset, fields, coordinates, transform_fn=lambda x: jnp.asarray(x, dtype=float))
        Converts an `xarray.Dataset` to arrays of fields and coordinates.
    """

    coordinates: dict[str, Coordinate | LongitudeCoordinate]
    dx: Float[Array, "lat lon-1"]
    dy: Float[Array, "lat-1 lon"]
    cell_area: Float[Array, "lat lon"]
    fields: dict[str, SpatioTemporalField]
    is_spherical_mesh: bool
    interpolation_method: Literal[
        "nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"
    ]
    use_degrees: bool

    def indices(self, **coordinates: Int[Array, "Nq"] | Float[Array, "Nq"]) -> tuple[Int[Array, "Nq"], ...]:
        """
        Gets the nearest indices of the N-dimensional point specified by the given coordinates.

        Parameters
        ----------
        **coordinates : Int[Array, "Nq"] | Float[Array, "Nq"]
            The N-dimensional point to get the nearest indices.

        Returns
        -------
        tuple[Int[Array, "Nq"], ...]
            A tuple of arrays containing the nearest indices of the N-dimensional point.
        """
        return tuple(self.coordinates[k].index(v) for k, v in coordinates.items())

    def interp(
        self, *fields: str, **coordinates: Int[Array, "Nq"] | Float[Array, "Nq"]
    ) -> dict[str, Bool[Array, "Nq ..."] | Float[Array, "Nq ..."] | Int[Array, "Nq ..."]]:
        """
        Interpolates the given fields at the given coordinates.

        Parameters
        ----------
        *fields: str
            Fields names to be interpolated.
        **coordinates : Int[Array, "Nq"] | Float[Array, "Nq"]
            The N-dimensional points to interpolate to.

        Returns
        -------
        dict[str, Bool[Array, "Nq ..."] | Float[Array, "Nq ..."] | Int[Array, "Nq ..."]]
            A dict of arrays containing the interpolated values for each field.
        """
        interpolated_fields = {}
        for field_name in fields:
            field = self.fields[field_name]
            interpolated_field = field.interp(**coordinates)
            interpolated_fields[field_name] = interpolated_field

        return interpolated_fields

    def neighborhood(
        self,
        *fields: str,
        time: Int[Scalar, ""],
        latitude: Float[Scalar, ""],
        longitude: Float[Scalar, ""],
        t_width: int,
        x_width: int,
    ) -> Gridded:
        """
        Extracts a neighborhood of data around a specified point in time and space.

        Parameters
        ----------
        *fields : tuple[str, ...]
            Fields names to extract from the dataset.
        time : Int[Scalar, ""]
            The time coordinate for the center of the neighborhood.
        latitude : Float[Scalar, ""]
            The latitude coordinate for the center of the neighborhood.
        longitude : Float[Scalar, ""]
            The longitude coordinate for the center of the neighborhood.
        t_width : int
            The width of the neighborhood in the time dimension.
        x_width : int
            The width of the neighborhood in the spatial dimensions (latitude and longitude).

        Returns
        -------
        Dataset
            A [`pastax.gridded.Gridded`][] object restricted to the neighborhing data.
        """
        t_i, lat_i, lon_i = self.indices(time=time, latitude=latitude, longitude=longitude)

        from_t_i = t_i - t_width // 2
        from_lat_i = lat_i - x_width // 2
        from_lon_i = lon_i - x_width // 2

        t_neighborhood = jax.lax.dynamic_slice_in_dim(self.coordinates["time"].values, from_t_i, t_width)
        lat_neighborhood = jax.lax.dynamic_slice_in_dim(self.coordinates["latitude"].values, from_lat_i, x_width)

        def no_edge_cases():
            lon_neighborhood = jax.lax.dynamic_slice_in_dim(self.coordinates["longitude"].values, from_lon_i, x_width)

            fields_neighborhood = dict(
                (
                    field_name,
                    jax.lax.dynamic_slice(
                        self.fields[field_name].values, (from_t_i, from_lat_i, from_lon_i), (t_width, x_width, x_width)
                    ),
                )
                for field_name in fields
            )

            return lon_neighborhood, fields_neighborhood

        def edge_cases():
            dx = jnp.linspace(-(x_width // 2), x_width // 2, x_width) * self.dx[lat_i, lon_i]
            lon = jnp.full(x_width, longitude) + dx
            lon_indices = self.indices(longitude=lon)[0]

            lon_neighborhood = self.coordinates["longitude"][lon_indices]

            fields_neighborhood = dict(
                (
                    field_name,
                    jax.lax.dynamic_slice(
                        self.fields[field_name].values,
                        (from_t_i, from_lat_i, 0),
                        (t_width, x_width, self.coordinates["longitude"].values.size),
                    )[..., lon_indices],
                )
                for field_name in fields
            )

            return lon_neighborhood, fields_neighborhood

        lon_neighborhood, fields_neighborhood = jax.lax.cond(
            (self.is_spherical_mesh and (self.indices(longitude=self.coordinates["longitude"][-1] + self.dx[-1]) == 0))
            and ((from_lon_i < 0) or (from_lon_i + x_width > self.coordinates["longitude"].values.size)),
            edge_cases,
            no_edge_cases,
        )

        return Gridded.from_array(
            fields_neighborhood,
            t_neighborhood,
            lat_neighborhood,
            lon_neighborhood,
            interpolation_method=self.interpolation_method,
            is_spherical_mesh=self.is_spherical_mesh,
            use_degrees=self.use_degrees,
        )

    def to_xarray(self) -> xr.Dataset:
        """
        Converts the [`pastax.gridded.Gridded`][] to a `xarray.Dataset`.

        This method constructs an xarray Dataset from the object's fields and coordinates.
        The fields are added as data variables with coordinates ["time", "latitude", "longitude"].
        The coordinates are added as coordinate variables.

        Returns
        -------
        xr.Dataset
            The corresponding `xarray.Dataset`.
        """
        dataset = xr.Dataset(
            data_vars=dict(
                (var_name, (["time", "latitude", "longitude"], var.values)) for var_name, var in self.fields.items()
            ),
            coords=dict(
                time=np.asarray(self.coordinates["time"].values, dtype="datetime64[s]"),
                latitude=self.coordinates["latitude"].values,
                longitude=self.coordinates["longitude"].values,
            ),
        )

        return dataset

    @classmethod
    def from_array(
        cls,
        fields: dict[str, Float[Array, "time lat lon"]],
        time: Int[Array, "time"],
        latitude: Float[Array, "lat"],
        longitude: Float[Array, "lon"],
        interpolation_method: Literal[
            "nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"
        ] = "linear",
        is_spherical_mesh: bool = True,
        use_degrees: bool = False,
        is_uv_mps: bool = True,
    ) -> Gridded:
        """
        Create a [`pastax.gridded.Gridded`][] object from arrays of fields, time, latitude, and longitude.

        Parameters
        ----------
        fields : dict[str, Float[Array, "time lat lon"]]
            A dictionary where keys are fields names and values are 3D arrays representing
            the field data over time, latitude, and longitude.
        time : Int[Array, "time"]
            A 1D array representing the time dimension.
        latitude : Float[Array, "lat"]
            A 1D array representing the latitude dimension.
        longitude : Float[Array, "lon"]
            A 1D array representing the longitude dimension.
        interpolation_method : Literal["nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"], optional
            String indicating the interpolation method used when interpolating the fields, defaults to `"linear"`.
            For details, see [`interpax` documentation](https://interpax.readthedocs.io/en/latest/index.html).
        is_spherical_mesh : bool, optional
            Whether the mesh uses spherical coordinate, defaults to `True`.
        use_degrees : bool, optional
            Whether distance units should be degrees rather than meters, defaults to `False`.
        is_uv_mps : bool, optional
            Whether the velocity data is in m/s, defaults to `True`.

        Returns
        -------
        Dataset
            The corresponding [`pastax.gridded.Gridded`][].
        """

        def compute_cell_dlatlon(dright: Float[Array, "latlon-1"], axis: int) -> Float[Array, "latlon"]:
            if axis == 0:
                dcentered = (dright[1:, :] + dright[:-1, :]) / 2
                dstart = ((dright[0, :] - dcentered[0, :] / 2) * 2)[None, :]
                dend = ((dright[-1, :] - dcentered[-1, :] / 2) * 2)[None, :]
            else:
                dcentered = (dright[:, 1:] + dright[:, :-1]) / 2
                dstart = ((dright[:, 0] - dcentered[:, 0] / 2) * 2)[:, None]
                dend = ((dright[:, -1] - dcentered[:, -1] / 2) * 2)[:, None]
            return jnp.concat((dstart, dcentered, dend), axis=axis)

        use_degrees = use_degrees & is_spherical_mesh  # if not spherical mesh, no reason to use degrees

        time_coord = Coordinate.from_array(time, extrap=True)
        latitude_coord = Coordinate.from_array(latitude, extrap=True)
        longitude_coord = LongitudeCoordinate.from_array(longitude, is_periodic=is_spherical_mesh, extrap=True)

        # compute grid spacings and cells area
        dlat = jnp.diff(latitude)
        dlon = jnp.diff(longitude)

        if is_spherical_mesh and not use_degrees:
            dlatlon = degrees_to_meters(
                jnp.stack([dlat, jnp.zeros_like(dlat)], axis=-1),
                (latitude[:-1] + latitude[1:]) / 2,
            )
            dlat = dlatlon[:, 0]
            _, dlat = jnp.meshgrid(longitude, dlat)

            dlatlon = jax.vmap(
                lambda lat: jax.vmap(
                    lambda _dlon: degrees_to_meters(jnp.stack([jnp.zeros_like(_dlon), _dlon], axis=-1), lat)
                )(dlon)
            )(latitude)
            dlon = dlatlon[:, :, 1]
        else:
            _, dlat = jnp.meshgrid(longitude, dlat)
            dlon, _ = jnp.meshgrid(dlon, latitude)

        cell_dlat = compute_cell_dlatlon(dlat, axis=0)
        cell_dlon = compute_cell_dlatlon(dlon, axis=1)
        cell_area = cell_dlat * cell_dlon

        # if required, convert uv from m/s to °/s
        if use_degrees and is_uv_mps:
            vu = jnp.stack((fields["v"], fields["u"]), axis=-1)
            original_shape = vu.shape
            vu = vu.reshape(vu.shape[0], -1, 2)

            _, lat_grid = jnp.meshgrid(longitude, latitude)
            lat_grid = lat_grid.ravel()

            vu = eqx.filter_vmap(lambda x: meters_to_degrees(x, lat_grid))(vu)
            vu = vu.reshape(original_shape)

            fields["v"] = vu[..., 0]
            fields["u"] = vu[..., 1]

            is_uv_mps = False

        fields_ = dict(
            (
                field_name,
                SpatioTemporalField.from_array(
                    values,
                    time_coord.values,
                    latitude_coord.values,
                    longitude_coord.values,
                    interpolation_method=interpolation_method,
                ),
            )
            for field_name, values in fields.items()
        )

        return cls(
            cell_area=cell_area,
            coordinates={"time": time_coord, "latitude": latitude_coord, "longitude": longitude_coord},
            dx=dlon,
            dy=dlat,
            fields=fields_,
            is_spherical_mesh=is_spherical_mesh,
            interpolation_method=interpolation_method,
            use_degrees=use_degrees,
        )

    @classmethod
    def from_xarray(
        cls,
        dataset: xr.Dataset,
        fields: dict[str, str],
        coordinates: dict[str, str],
        interpolation_method: Literal[
            "nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"
        ] = "linear",
        is_spherical_mesh: bool = True,
        is_uv_mps: bool = True,
        use_degrees: bool = False,
    ) -> Gridded:
        """
        Create a [`pastax.gridded.Gridded`][] object from a `xarray.Dataset`.

        Parameters
        ----------
        dataset : xr.Dataset
            The `xarray.Dataset` containing the data.
        fields : dict[str, str]
            A dictionary mapping the target field names (keys) to the source variable names in the dataset (values).
        coordinates : dict[str, str]
            A dictionary mapping the coordinate names ('time', 'latitude', 'longitude') to their corresponding names in
            the dataset.
        interpolation_method: Literal["nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"], optional
            String indicating the interpolation method used when interpolating the fields, defaults to `"linear"`.
            For details, see [`interpax` documentation](https://interpax.readthedocs.io/en/latest/index.html).
        is_spherical_mesh : bool, optional
            Whether the mesh uses spherical coordinate, defaults to `True`.
        is_uv_mps : bool, optional
            Whether the velocity data is in m/s, defaults to `True`.
        use_degrees : bool, optional
            Whether distance unit should be degrees rather than meters, defaults to `False`.

        Returns
        -------
        Dataset
            The corresponding [`pastax.gridded.Gridded`][].
        """
        fields_, t, lat, lon = cls.xarray_to_array(dataset, fields, coordinates)

        return cls.from_array(
            fields_,
            t,
            lat,
            lon,
            interpolation_method=interpolation_method,
            is_spherical_mesh=is_spherical_mesh,
            use_degrees=use_degrees,
            is_uv_mps=is_uv_mps,
        )

    @staticmethod
    def xarray_to_array(
        dataset: xr.Dataset,
        fields: dict[str, str],  # to -> from
        coordinates: dict[str, str],  # to -> from
        transform_fn: Callable[[Array], Array] = lambda x: jnp.asarray(x, dtype=float),
    ) -> tuple[
        dict[str, Float[Array, "..."]],
        Float[Array, "..."],
        Float[Array, "..."],
        Float[Array, "..."],
    ]:
        """
        Converts an `xarray.Dataset` to arrays of fields and coordinates.

        Parameters
        ----------
        dataset : xr.Dataset
            The `xarray.Dataset` to convert.
        fields : dict[str, str]
            A dictionary mapping the target field names to the source variable names in the dataset.
        coordinates : dict[str, str]
            A dictionary mapping the target coordinate names to the source coordinate names in the dataset.
        transform_fn : Callable[[Array], Array], optional
            Function converting dataarrays to JAX (or numpy) arrays,
            defaults to `lambda x: jnp.asarray(x, dtype=float)`.

        Returns
        -------
        tuple[dict[str, Float[Array, "..."]], Float[Array, "..."], Float[Array, "..."], Float[Array, "..."]]
            A tuple containing:
            - A dictionary of converted fields.
            - The time coordinate array.
            - The latitude coordinate array.
            - The longitude coordinate array.
        """
        fields_ = dict((to_name, transform_fn(dataset[from_name].data)) for to_name, from_name in fields.items())

        t = transform_fn(dataset[coordinates["time"]].data.astype("datetime64[s]").astype(int))
        lat = transform_fn(dataset[coordinates["latitude"]].data)
        lon = transform_fn(dataset[coordinates["longitude"]].data)

        return fields_, t, lat, lon
indices(**coordinates: Int[Array, 'Nq'] | Float[Array, 'Nq']) -> tuple[Int[Array, 'Nq'], ...]

Gets the nearest indices of the N-dimensional point specified by the given coordinates.

Parameters:

Name Type Description Default
**coordinates Int[Array, 'Nq'] | Float[Array, 'Nq']

The N-dimensional point to get the nearest indices.

{}

Returns:

Type Description
tuple[Int[Array, 'Nq'], ...]

A tuple of arrays containing the nearest indices of the N-dimensional point.

Source code in pastax/gridded/_gridded.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def indices(self, **coordinates: Int[Array, "Nq"] | Float[Array, "Nq"]) -> tuple[Int[Array, "Nq"], ...]:
    """
    Gets the nearest indices of the N-dimensional point specified by the given coordinates.

    Parameters
    ----------
    **coordinates : Int[Array, "Nq"] | Float[Array, "Nq"]
        The N-dimensional point to get the nearest indices.

    Returns
    -------
    tuple[Int[Array, "Nq"], ...]
        A tuple of arrays containing the nearest indices of the N-dimensional point.
    """
    return tuple(self.coordinates[k].index(v) for k, v in coordinates.items())
interp(*fields: str, **coordinates: Int[Array, 'Nq'] | Float[Array, 'Nq']) -> dict[str, Bool[Array, 'Nq ...'] | Float[Array, 'Nq ...'] | Int[Array, 'Nq ...']]

Interpolates the given fields at the given coordinates.

Parameters:

Name Type Description Default
*fields str

Fields names to be interpolated.

()
**coordinates Int[Array, 'Nq'] | Float[Array, 'Nq']

The N-dimensional points to interpolate to.

{}

Returns:

Type Description
dict[str, Bool[Array, 'Nq ...'] | Float[Array, 'Nq ...'] | Int[Array, 'Nq ...']]

A dict of arrays containing the interpolated values for each field.

Source code in pastax/gridded/_gridded.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def interp(
    self, *fields: str, **coordinates: Int[Array, "Nq"] | Float[Array, "Nq"]
) -> dict[str, Bool[Array, "Nq ..."] | Float[Array, "Nq ..."] | Int[Array, "Nq ..."]]:
    """
    Interpolates the given fields at the given coordinates.

    Parameters
    ----------
    *fields: str
        Fields names to be interpolated.
    **coordinates : Int[Array, "Nq"] | Float[Array, "Nq"]
        The N-dimensional points to interpolate to.

    Returns
    -------
    dict[str, Bool[Array, "Nq ..."] | Float[Array, "Nq ..."] | Int[Array, "Nq ..."]]
        A dict of arrays containing the interpolated values for each field.
    """
    interpolated_fields = {}
    for field_name in fields:
        field = self.fields[field_name]
        interpolated_field = field.interp(**coordinates)
        interpolated_fields[field_name] = interpolated_field

    return interpolated_fields
neighborhood(*fields: str, time: Int[Scalar, ''], latitude: Float[Scalar, ''], longitude: Float[Scalar, ''], t_width: int, x_width: int) -> Gridded

Extracts a neighborhood of data around a specified point in time and space.

Parameters:

Name Type Description Default
*fields tuple[str, ...]

Fields names to extract from the dataset.

()
time Int[Scalar, '']

The time coordinate for the center of the neighborhood.

required
latitude Float[Scalar, '']

The latitude coordinate for the center of the neighborhood.

required
longitude Float[Scalar, '']

The longitude coordinate for the center of the neighborhood.

required
t_width int

The width of the neighborhood in the time dimension.

required
x_width int

The width of the neighborhood in the spatial dimensions (latitude and longitude).

required

Returns:

Type Description
Dataset

A pastax.gridded.Gridded object restricted to the neighborhing data.

Source code in pastax/gridded/_gridded.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def neighborhood(
    self,
    *fields: str,
    time: Int[Scalar, ""],
    latitude: Float[Scalar, ""],
    longitude: Float[Scalar, ""],
    t_width: int,
    x_width: int,
) -> Gridded:
    """
    Extracts a neighborhood of data around a specified point in time and space.

    Parameters
    ----------
    *fields : tuple[str, ...]
        Fields names to extract from the dataset.
    time : Int[Scalar, ""]
        The time coordinate for the center of the neighborhood.
    latitude : Float[Scalar, ""]
        The latitude coordinate for the center of the neighborhood.
    longitude : Float[Scalar, ""]
        The longitude coordinate for the center of the neighborhood.
    t_width : int
        The width of the neighborhood in the time dimension.
    x_width : int
        The width of the neighborhood in the spatial dimensions (latitude and longitude).

    Returns
    -------
    Dataset
        A [`pastax.gridded.Gridded`][] object restricted to the neighborhing data.
    """
    t_i, lat_i, lon_i = self.indices(time=time, latitude=latitude, longitude=longitude)

    from_t_i = t_i - t_width // 2
    from_lat_i = lat_i - x_width // 2
    from_lon_i = lon_i - x_width // 2

    t_neighborhood = jax.lax.dynamic_slice_in_dim(self.coordinates["time"].values, from_t_i, t_width)
    lat_neighborhood = jax.lax.dynamic_slice_in_dim(self.coordinates["latitude"].values, from_lat_i, x_width)

    def no_edge_cases():
        lon_neighborhood = jax.lax.dynamic_slice_in_dim(self.coordinates["longitude"].values, from_lon_i, x_width)

        fields_neighborhood = dict(
            (
                field_name,
                jax.lax.dynamic_slice(
                    self.fields[field_name].values, (from_t_i, from_lat_i, from_lon_i), (t_width, x_width, x_width)
                ),
            )
            for field_name in fields
        )

        return lon_neighborhood, fields_neighborhood

    def edge_cases():
        dx = jnp.linspace(-(x_width // 2), x_width // 2, x_width) * self.dx[lat_i, lon_i]
        lon = jnp.full(x_width, longitude) + dx
        lon_indices = self.indices(longitude=lon)[0]

        lon_neighborhood = self.coordinates["longitude"][lon_indices]

        fields_neighborhood = dict(
            (
                field_name,
                jax.lax.dynamic_slice(
                    self.fields[field_name].values,
                    (from_t_i, from_lat_i, 0),
                    (t_width, x_width, self.coordinates["longitude"].values.size),
                )[..., lon_indices],
            )
            for field_name in fields
        )

        return lon_neighborhood, fields_neighborhood

    lon_neighborhood, fields_neighborhood = jax.lax.cond(
        (self.is_spherical_mesh and (self.indices(longitude=self.coordinates["longitude"][-1] + self.dx[-1]) == 0))
        and ((from_lon_i < 0) or (from_lon_i + x_width > self.coordinates["longitude"].values.size)),
        edge_cases,
        no_edge_cases,
    )

    return Gridded.from_array(
        fields_neighborhood,
        t_neighborhood,
        lat_neighborhood,
        lon_neighborhood,
        interpolation_method=self.interpolation_method,
        is_spherical_mesh=self.is_spherical_mesh,
        use_degrees=self.use_degrees,
    )
to_xarray() -> xr.Dataset

Converts the pastax.gridded.Gridded to a xarray.Dataset.

This method constructs an xarray Dataset from the object's fields and coordinates. The fields are added as data variables with coordinates ["time", "latitude", "longitude"]. The coordinates are added as coordinate variables.

Returns:

Type Description
Dataset

The corresponding xarray.Dataset.

Source code in pastax/gridded/_gridded.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def to_xarray(self) -> xr.Dataset:
    """
    Converts the [`pastax.gridded.Gridded`][] to a `xarray.Dataset`.

    This method constructs an xarray Dataset from the object's fields and coordinates.
    The fields are added as data variables with coordinates ["time", "latitude", "longitude"].
    The coordinates are added as coordinate variables.

    Returns
    -------
    xr.Dataset
        The corresponding `xarray.Dataset`.
    """
    dataset = xr.Dataset(
        data_vars=dict(
            (var_name, (["time", "latitude", "longitude"], var.values)) for var_name, var in self.fields.items()
        ),
        coords=dict(
            time=np.asarray(self.coordinates["time"].values, dtype="datetime64[s]"),
            latitude=self.coordinates["latitude"].values,
            longitude=self.coordinates["longitude"].values,
        ),
    )

    return dataset
from_array(fields: dict[str, Float[Array, 'time lat lon']], time: Int[Array, 'time'], latitude: Float[Array, 'lat'], longitude: Float[Array, 'lon'], interpolation_method: Literal['nearest', 'linear', 'cubic', 'cubic2', 'catmull-rom', 'cardinal', 'monotonic', 'monotonic-0', 'akima'] = 'linear', is_spherical_mesh: bool = True, use_degrees: bool = False, is_uv_mps: bool = True) -> Gridded classmethod

Create a pastax.gridded.Gridded object from arrays of fields, time, latitude, and longitude.

Parameters:

Name Type Description Default
fields dict[str, Float[Array, 'time lat lon']]

A dictionary where keys are fields names and values are 3D arrays representing the field data over time, latitude, and longitude.

required
time Int[Array, 'time']

A 1D array representing the time dimension.

required
latitude Float[Array, 'lat']

A 1D array representing the latitude dimension.

required
longitude Float[Array, 'lon']

A 1D array representing the longitude dimension.

required
interpolation_method Literal['nearest', 'linear', 'cubic', 'cubic2', 'catmull-rom', 'cardinal', 'monotonic', 'monotonic-0', 'akima']

String indicating the interpolation method used when interpolating the fields, defaults to "linear". For details, see interpax documentation.

'linear'
is_spherical_mesh bool

Whether the mesh uses spherical coordinate, defaults to True.

True
use_degrees bool

Whether distance units should be degrees rather than meters, defaults to False.

False
is_uv_mps bool

Whether the velocity data is in m/s, defaults to True.

True

Returns:

Type Description
Dataset

The corresponding pastax.gridded.Gridded.

Source code in pastax/gridded/_gridded.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
@classmethod
def from_array(
    cls,
    fields: dict[str, Float[Array, "time lat lon"]],
    time: Int[Array, "time"],
    latitude: Float[Array, "lat"],
    longitude: Float[Array, "lon"],
    interpolation_method: Literal[
        "nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"
    ] = "linear",
    is_spherical_mesh: bool = True,
    use_degrees: bool = False,
    is_uv_mps: bool = True,
) -> Gridded:
    """
    Create a [`pastax.gridded.Gridded`][] object from arrays of fields, time, latitude, and longitude.

    Parameters
    ----------
    fields : dict[str, Float[Array, "time lat lon"]]
        A dictionary where keys are fields names and values are 3D arrays representing
        the field data over time, latitude, and longitude.
    time : Int[Array, "time"]
        A 1D array representing the time dimension.
    latitude : Float[Array, "lat"]
        A 1D array representing the latitude dimension.
    longitude : Float[Array, "lon"]
        A 1D array representing the longitude dimension.
    interpolation_method : Literal["nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"], optional
        String indicating the interpolation method used when interpolating the fields, defaults to `"linear"`.
        For details, see [`interpax` documentation](https://interpax.readthedocs.io/en/latest/index.html).
    is_spherical_mesh : bool, optional
        Whether the mesh uses spherical coordinate, defaults to `True`.
    use_degrees : bool, optional
        Whether distance units should be degrees rather than meters, defaults to `False`.
    is_uv_mps : bool, optional
        Whether the velocity data is in m/s, defaults to `True`.

    Returns
    -------
    Dataset
        The corresponding [`pastax.gridded.Gridded`][].
    """

    def compute_cell_dlatlon(dright: Float[Array, "latlon-1"], axis: int) -> Float[Array, "latlon"]:
        if axis == 0:
            dcentered = (dright[1:, :] + dright[:-1, :]) / 2
            dstart = ((dright[0, :] - dcentered[0, :] / 2) * 2)[None, :]
            dend = ((dright[-1, :] - dcentered[-1, :] / 2) * 2)[None, :]
        else:
            dcentered = (dright[:, 1:] + dright[:, :-1]) / 2
            dstart = ((dright[:, 0] - dcentered[:, 0] / 2) * 2)[:, None]
            dend = ((dright[:, -1] - dcentered[:, -1] / 2) * 2)[:, None]
        return jnp.concat((dstart, dcentered, dend), axis=axis)

    use_degrees = use_degrees & is_spherical_mesh  # if not spherical mesh, no reason to use degrees

    time_coord = Coordinate.from_array(time, extrap=True)
    latitude_coord = Coordinate.from_array(latitude, extrap=True)
    longitude_coord = LongitudeCoordinate.from_array(longitude, is_periodic=is_spherical_mesh, extrap=True)

    # compute grid spacings and cells area
    dlat = jnp.diff(latitude)
    dlon = jnp.diff(longitude)

    if is_spherical_mesh and not use_degrees:
        dlatlon = degrees_to_meters(
            jnp.stack([dlat, jnp.zeros_like(dlat)], axis=-1),
            (latitude[:-1] + latitude[1:]) / 2,
        )
        dlat = dlatlon[:, 0]
        _, dlat = jnp.meshgrid(longitude, dlat)

        dlatlon = jax.vmap(
            lambda lat: jax.vmap(
                lambda _dlon: degrees_to_meters(jnp.stack([jnp.zeros_like(_dlon), _dlon], axis=-1), lat)
            )(dlon)
        )(latitude)
        dlon = dlatlon[:, :, 1]
    else:
        _, dlat = jnp.meshgrid(longitude, dlat)
        dlon, _ = jnp.meshgrid(dlon, latitude)

    cell_dlat = compute_cell_dlatlon(dlat, axis=0)
    cell_dlon = compute_cell_dlatlon(dlon, axis=1)
    cell_area = cell_dlat * cell_dlon

    # if required, convert uv from m/s to °/s
    if use_degrees and is_uv_mps:
        vu = jnp.stack((fields["v"], fields["u"]), axis=-1)
        original_shape = vu.shape
        vu = vu.reshape(vu.shape[0], -1, 2)

        _, lat_grid = jnp.meshgrid(longitude, latitude)
        lat_grid = lat_grid.ravel()

        vu = eqx.filter_vmap(lambda x: meters_to_degrees(x, lat_grid))(vu)
        vu = vu.reshape(original_shape)

        fields["v"] = vu[..., 0]
        fields["u"] = vu[..., 1]

        is_uv_mps = False

    fields_ = dict(
        (
            field_name,
            SpatioTemporalField.from_array(
                values,
                time_coord.values,
                latitude_coord.values,
                longitude_coord.values,
                interpolation_method=interpolation_method,
            ),
        )
        for field_name, values in fields.items()
    )

    return cls(
        cell_area=cell_area,
        coordinates={"time": time_coord, "latitude": latitude_coord, "longitude": longitude_coord},
        dx=dlon,
        dy=dlat,
        fields=fields_,
        is_spherical_mesh=is_spherical_mesh,
        interpolation_method=interpolation_method,
        use_degrees=use_degrees,
    )
from_xarray(dataset: xr.Dataset, fields: dict[str, str], coordinates: dict[str, str], interpolation_method: Literal['nearest', 'linear', 'cubic', 'cubic2', 'catmull-rom', 'cardinal', 'monotonic', 'monotonic-0', 'akima'] = 'linear', is_spherical_mesh: bool = True, is_uv_mps: bool = True, use_degrees: bool = False) -> Gridded classmethod

Create a pastax.gridded.Gridded object from a xarray.Dataset.

Parameters:

Name Type Description Default
dataset Dataset

The xarray.Dataset containing the data.

required
fields dict[str, str]

A dictionary mapping the target field names (keys) to the source variable names in the dataset (values).

required
coordinates dict[str, str]

A dictionary mapping the coordinate names ('time', 'latitude', 'longitude') to their corresponding names in the dataset.

required
interpolation_method Literal['nearest', 'linear', 'cubic', 'cubic2', 'catmull-rom', 'cardinal', 'monotonic', 'monotonic-0', 'akima']

String indicating the interpolation method used when interpolating the fields, defaults to "linear". For details, see interpax documentation.

'linear'
is_spherical_mesh bool

Whether the mesh uses spherical coordinate, defaults to True.

True
is_uv_mps bool

Whether the velocity data is in m/s, defaults to True.

True
use_degrees bool

Whether distance unit should be degrees rather than meters, defaults to False.

False

Returns:

Type Description
Dataset

The corresponding pastax.gridded.Gridded.

Source code in pastax/gridded/_gridded.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
@classmethod
def from_xarray(
    cls,
    dataset: xr.Dataset,
    fields: dict[str, str],
    coordinates: dict[str, str],
    interpolation_method: Literal[
        "nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"
    ] = "linear",
    is_spherical_mesh: bool = True,
    is_uv_mps: bool = True,
    use_degrees: bool = False,
) -> Gridded:
    """
    Create a [`pastax.gridded.Gridded`][] object from a `xarray.Dataset`.

    Parameters
    ----------
    dataset : xr.Dataset
        The `xarray.Dataset` containing the data.
    fields : dict[str, str]
        A dictionary mapping the target field names (keys) to the source variable names in the dataset (values).
    coordinates : dict[str, str]
        A dictionary mapping the coordinate names ('time', 'latitude', 'longitude') to their corresponding names in
        the dataset.
    interpolation_method: Literal["nearest", "linear", "cubic", "cubic2", "catmull-rom", "cardinal", "monotonic", "monotonic-0", "akima"], optional
        String indicating the interpolation method used when interpolating the fields, defaults to `"linear"`.
        For details, see [`interpax` documentation](https://interpax.readthedocs.io/en/latest/index.html).
    is_spherical_mesh : bool, optional
        Whether the mesh uses spherical coordinate, defaults to `True`.
    is_uv_mps : bool, optional
        Whether the velocity data is in m/s, defaults to `True`.
    use_degrees : bool, optional
        Whether distance unit should be degrees rather than meters, defaults to `False`.

    Returns
    -------
    Dataset
        The corresponding [`pastax.gridded.Gridded`][].
    """
    fields_, t, lat, lon = cls.xarray_to_array(dataset, fields, coordinates)

    return cls.from_array(
        fields_,
        t,
        lat,
        lon,
        interpolation_method=interpolation_method,
        is_spherical_mesh=is_spherical_mesh,
        use_degrees=use_degrees,
        is_uv_mps=is_uv_mps,
    )
xarray_to_array(dataset: xr.Dataset, fields: dict[str, str], coordinates: dict[str, str], transform_fn: Callable[[Array], Array] = lambda x: jnp.asarray(x, dtype=float)) -> tuple[dict[str, Float[Array, '...']], Float[Array, '...'], Float[Array, '...'], Float[Array, '...']] staticmethod

Converts an xarray.Dataset to arrays of fields and coordinates.

Parameters:

Name Type Description Default
dataset Dataset

The xarray.Dataset to convert.

required
fields dict[str, str]

A dictionary mapping the target field names to the source variable names in the dataset.

required
coordinates dict[str, str]

A dictionary mapping the target coordinate names to the source coordinate names in the dataset.

required
transform_fn Callable[[Array], Array]

Function converting dataarrays to JAX (or numpy) arrays, defaults to lambda x: jnp.asarray(x, dtype=float).

lambda x: asarray(x, dtype=float)

Returns:

Type Description
tuple[dict[str, Float[Array, '...']], Float[Array, '...'], Float[Array, '...'], Float[Array, '...']]

A tuple containing: - A dictionary of converted fields. - The time coordinate array. - The latitude coordinate array. - The longitude coordinate array.

Source code in pastax/gridded/_gridded.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
@staticmethod
def xarray_to_array(
    dataset: xr.Dataset,
    fields: dict[str, str],  # to -> from
    coordinates: dict[str, str],  # to -> from
    transform_fn: Callable[[Array], Array] = lambda x: jnp.asarray(x, dtype=float),
) -> tuple[
    dict[str, Float[Array, "..."]],
    Float[Array, "..."],
    Float[Array, "..."],
    Float[Array, "..."],
]:
    """
    Converts an `xarray.Dataset` to arrays of fields and coordinates.

    Parameters
    ----------
    dataset : xr.Dataset
        The `xarray.Dataset` to convert.
    fields : dict[str, str]
        A dictionary mapping the target field names to the source variable names in the dataset.
    coordinates : dict[str, str]
        A dictionary mapping the target coordinate names to the source coordinate names in the dataset.
    transform_fn : Callable[[Array], Array], optional
        Function converting dataarrays to JAX (or numpy) arrays,
        defaults to `lambda x: jnp.asarray(x, dtype=float)`.

    Returns
    -------
    tuple[dict[str, Float[Array, "..."]], Float[Array, "..."], Float[Array, "..."], Float[Array, "..."]]
        A tuple containing:
        - A dictionary of converted fields.
        - The time coordinate array.
        - The latitude coordinate array.
        - The longitude coordinate array.
    """
    fields_ = dict((to_name, transform_fn(dataset[from_name].data)) for to_name, from_name in fields.items())

    t = transform_fn(dataset[coordinates["time"]].data.astype("datetime64[s]").astype(int))
    lat = transform_fn(dataset[coordinates["latitude"]].data)
    lon = transform_fn(dataset[coordinates["longitude"]].data)

    return fields_, t, lat, lon

spatial_derivative(*fields: Float[Array, '(time) lat lon'], dx: Float[Array, 'lat lon-1'], dy: Float[Array, 'lat-1 lon'], is_masked: Bool[Array, 'lat lon']) -> tuple[tuple[Float[Array, '(time) lat lon'], Float[Array, '(time) lat lon']], ...]

Computes spatial derivatives for given fields using central finite differences.

This function calculates the spatial derivatives of the provided fields, taking into account the presence of mask and the grid spacing in both latitude and longitude directions. It uses central finite differences for the computation and leverages JAX for efficient computation and automatic differentiation.

Parameters:

Name Type Description Default
*fields Float[Array, '(time) lat lon']

Variable number of fields for which the spatial derivatives are to be computed. Each field is a 2D or 3D array with dimensions (latitude, longitude) or (time, latitude, longitude).

()
dx Float[Array, 'lat lon-1']

Gridded spacing in the longitude direction.

required
dy Float[Array, 'lat-1 lon']

Gridded spacing in the latitude direction.

required
is_masked Bool[Array, 'lat lon']

Boolean array indicating whether a grid point should be masked (True means masked, False not masked).

required

Returns:

Type Description
tuple[tuple[Float[Array, '(time) lat lon'], Float[Array, '(time) lat lon']], ...]

A tuple containing the spatial derivatives of the input fields. Each derivative is a 2D or 3D array with dimensions (latitude, longitude) or (time, latitude, longitude). For field f1 and f2, returns ((df1_x, df1_y), (df2_x, df2_y)).

Source code in pastax/gridded/_operators.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def spatial_derivative(
    *fields: Float[Array, "(time) lat lon"],
    dx: Float[Array, "lat lon-1"],
    dy: Float[Array, "lat-1 lon"],
    is_masked: Bool[Array, "lat lon"],
) -> tuple[tuple[Float[Array, "(time) lat lon"], Float[Array, "(time) lat lon"]], ...]:
    """
    Computes spatial derivatives for given fields using central finite differences.

    This function calculates the spatial derivatives of the provided fields, taking into account the presence of mask
    and the grid spacing in both latitude and longitude directions.
    It uses central finite differences for the computation and leverages JAX for efficient computation and automatic
    differentiation.

    Parameters
    ----------
    *fields : Float[Array, "(time) lat lon"]
        Variable number of fields for which the spatial derivatives are to be computed.
        Each field is a 2D or 3D array with dimensions (latitude, longitude) or (time, latitude, longitude).
    dx : Float[Array, "lat lon-1"]
        Gridded spacing in the longitude direction.
    dy : Float[Array, "lat-1 lon"]
        Gridded spacing in the latitude direction.
    is_masked : Bool[Array, "lat lon"]
        Boolean array indicating whether a grid point should be masked (`True` means masked, `False` not masked).

    Returns
    -------
    tuple[tuple[Float[Array, "(time) lat lon"], Float[Array, "(time) lat lon"]], ...]
        A tuple containing the spatial derivatives of the input fields.
        Each derivative is a 2D or 3D array with dimensions (latitude, longitude) or (time, latitude, longitude).
        For field f1 and f2, returns ((df1_x, df1_y), (df2_x, df2_y)).
    """

    def central_finite_difference(
        field: Float[Array, "(time) lat lon"], axis: int
    ) -> Float[Array, "(time) lat-2 lon-2"]:
        def _axis1(dxy: Float[Array, "lat-1 lon"]) -> tuple[Float[Array, "(time) lat-2 lon-2"], ...]:
            field_start = field[..., :-2, 1:-1]
            field_end = field[..., 2:, 1:-1]

            is_masked_start = is_masked[:-2, 1:-1]
            is_masked_end = is_masked[2:, 1:-1]

            dx_start = dxy[:-1, 1:-1]
            dx_end = dxy[1:, 1:-1]

            return field_start, field_end, is_masked_start, is_masked_end, dx_start, dx_end

        def _axis2(dxy: Float[Array, "lat lon-1"]) -> tuple[Float[Array, "(time) lat-2 lon-2"], ...]:
            field_start = field[..., 1:-1, :-2]
            field_end = field[..., 1:-1, 2:]

            is_masked_start = is_masked[1:-1, :-2]
            is_masked_end = is_masked[1:-1, 2:]

            dx_start = dxy[1:-1, :-1]
            dx_end = dxy[1:-1, 1:]

            return field_start, field_end, is_masked_start, is_masked_end, dx_start, dx_end

        field_start, field_end, is_masked_start, is_masked_end, dx_start, dx_end = jax.lax.cond(
            axis == -1, lambda: _axis2(dx), lambda: _axis1(dy)
        )

        field_center = field[..., 1:-1, 1:-1]
        field_start = jnp.where(is_masked_start, field_center, field_start)
        field_end = jnp.where(is_masked_end, field_center, field_end)

        return (field_end - field_start) / (dx_end + dx_start)  # type: ignore

    derivatives = tuple(tuple(central_finite_difference(field, axis) for axis in (-1, -2)) for field in fields)

    return derivatives  # type: ignore

pastax.evaluation

This module provides classes for evaluating simulated pastax.trajectory.Trajectory and pastax.trajectory.TrajectoryEnsemble.

Evaluation

Bases: Set

Class for accessing and visualizing a dictionary of metric timeseries or timeseries ensemble.

Methods:

Name Description
__init__

Initializes the Evaluation object with a dictionary of metric timeseries or timeseries ensemble.

get

Retrieves a metric by key.

items

Returns the items of the metrics dictionary.

keys

Returns the keys of the metrics dictionary.

values

Returns the values of the metrics dictionary.

plot

Plots the metrics timeseries or timeseries ensemble up to the time index ti on the figure fig.

__getitem__

Retrieves a metric by key.

Source code in pastax/evaluation/_evaluation.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
class Evaluation(Set):
    """
    Class for accessing and visualizing a dictionary of metric timeseries or timeseries ensemble.

    Methods
    -------
    __init__(states):
        Initializes the `Evaluation` object with a dictionary of metric timeseries or timeseries ensemble.
    get(key):
        Retrieves a metric by key.
    items():
        Returns the items of the metrics dictionary.
    keys():
        Returns the keys of the metrics dictionary.
    values():
        Returns the values of the metrics dictionary.
    plot(fig, ti):
        Plots the metrics timeseries or timeseries ensemble up to the time index `ti` on the figure `fig`.
    __getitem__(key):
        Retrieves a metric by key.
    """

    _members: dict[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]

    def __init__(
        self,
        states: dict[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]],
    ):
        """
        Initializes the [`pastax.evaluation.Evaluation`][] object with a dictionary of metric timeseries or timeseries
        ensemble.

        Parameters
        ----------
        states : dict[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]
            The initial metrics dictionary.
        """
        self._members = states
        self.size = len(states)

    def get(
        self, key: str
    ) -> Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries] | None:
        """
        Retrieves a metric by key.

        Parameters
        ----------
        key : str
            The key of the metric to retrieve.

        Returns
        -------
        Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries] | None
            The metrics corresponding to the key.
        """
        return self._members.get(key)

    def items(
        self,
    ) -> Iterable[
        tuple[
            str,
            Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries],
        ]
    ]:
        """
        Returns the items of the metrics dictionary.

        Returns
        -------
        tuple[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]
            The items of the metrics dictionary.
        """
        return self._members.items()

    def keys(self) -> Iterable[str]:
        """
        Returns the keys of the metric timeseries dictionary.

        Returns
        -------
        tuple[str]
            The keys of the metrics dictionary.
        """
        return self._members.keys()

    def values(
        self,
    ) -> Iterable[Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]:
        """
        Returns the values of the metrics dictionary.

        Returns
        -------
        tuple[Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]
            The values of the metrics dictionary.
        """
        return self._members.values()

    def to_dataarray(self) -> dict[str, xr.DataArray]:
        """
        Converts the evaluation results to a dictionary of `xarray.DataArray`s.

        Returns
        -------
        dict[str, xr.DataArray]
            A dictionary where keys are the evaluation metric names and values are the corresponding
            `xarray.DataArray`s.
        """
        da = {}
        for key, value in self.items():
            if isinstance(value, Timeseries) or isinstance(value, TimeseriesEnsemble):
                da[key] = value.to_dataarray()
            else:
                ensemble, mean, crps = value
                da[key] = ensemble.to_dataarray().mean("member")
                da[f"{key} - mean"] = mean.to_dataarray().mean("member")
                da[f"{key} - CRPS"] = crps.to_dataarray().mean("member")

        return da

    def to_dataset(self) -> xr.Dataset:
        """
        Converts the evaluation results to a `xarray.Dataset`.

        Returns
        -------
        xr.Dataset
            A `xarray.Dataset` containing the evaluation results.
        """
        return xr.Dataset(self.to_dataarray())

    def plot(self, fig: Figure, ti: int | None = None):
        """
        Plots the metric timeseries or timeseries ensemble up to the time index `ti` on the figure `fig`.

        Parameters
        ----------
        fig : Figure
            The figure to plot on.
        ti : int | None, optional
            The time index up to which to plot. If `ti=None`, plots the full metric timeseries.
        """
        if ti is None:
            metric = next(iter(self.values()))
            if isinstance(metric, Timeseries) or isinstance(metric, TimeseriesEnsemble):
                ti = metric.length
            else:
                ti = metric[0].length

        n_metrics = self.size

        n_rows, n_cols = self.__guess_metrics_fig_layout(n_metrics)
        axs = fig.subplots(n_rows, n_cols)

        for i, metric in zip(range(n_metrics), self.values()):
            rowi = i // n_cols
            coli = i % n_cols
            add_xlabel = rowi == n_rows - 1
            add_legend = rowi == 0 and coli == n_cols - 1
            self._plot_metric(ti, metric, add_xlabel, add_legend, axs[rowi, coli])

    def _plot_metric(
        self,
        ti: int,
        metric: Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries],
        add_xlabel: bool,
        add_legend: bool,
        ax: Axes,
    ):
        min_values, max_values, timedelta, unit, name = self.__do_plot_metric(ti, metric, add_legend, ax)
        self.__set_axis_limits_labels(min_values, max_values, timedelta, unit, name, add_xlabel, ax)

    def __do_plot_metric(
        self,
        ti: int,
        metric: Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries],
        add_legend: bool,
        ax: Axes,
        has_label: bool = False,
    ) -> tuple[
        Float[Scalar, ""],
        Float[Scalar, ""],
        Float[Array, "time"],
        str | None,
        str | None,
    ]:
        if isinstance(metric, Timeseries) or isinstance(metric, TimeseriesEnsemble):
            values, timedelta, unit = self.__parse_metric(metric)

            values_ti = values[..., :ti]
            timedelta_ti = timedelta[:ti]

            min_values = jnp.nanmin(values[values != -jnp.inf])
            max_values = jnp.nanmax(values[values != jnp.inf])

            if isinstance(metric, Timeseries):
                name = None
                label = metric.name if has_label else None
                self.__plot_pair_metric(values_ti, timedelta_ti, ax, label=label)
            else:
                name = metric.name
                self.__plot_ensemble_metric(values_ti, timedelta_ti, ax)
        else:
            min_values, max_values, timedelta = (
                jnp.asarray(jnp.inf),
                -jnp.asarray(jnp.inf),
                jnp.asarray(jnp.nan),
            )
            unit, name = None, None
            for _metric in metric:
                _min_values, _max_values, timedelta, unit, _name = self.__do_plot_metric(
                    ti, _metric, add_legend, ax, has_label=True
                )

                if _min_values < min_values:
                    min_values = _min_values
                if _max_values > max_values:
                    max_values = _max_values

                if _name is not None:
                    name = _name

            if add_legend:
                ax.legend()

        return min_values, max_values, timedelta, unit, name

    @staticmethod
    def __guess_metrics_fig_layout(n_metrics: int) -> tuple[int, int]:
        min_n_cells = float("inf")
        best_layout = (0, 0)
        for n_rows in range(1, int(math.ceil(n_metrics**0.5)) + 1):
            n_cols = math.ceil(n_metrics / n_rows)
            if n_rows * n_cols >= n_metrics:
                n_cells = n_rows + n_cols
                if n_cells < min_n_cells:
                    min_n_cells = n_cells
                    best_layout = (n_rows, n_cols)
        return best_layout

    @staticmethod
    def __parse_metric(
        metric: Timeseries | TimeseriesEnsemble,
    ) -> tuple[Float[Array, "time"] | Float[Array, "member time"], Float[Array, "time"], str]:
        values: Float[Array, "time"] | Float[Array, "member time"] = metric.states.value[..., 1:, :]

        unit = {}
        for k, v in metric.unit.items():
            unit[k] = v
            if k == UNIT["m"]:
                values = UNIT["m"].convert_to(UNIT["km"], values, v)  # to km for visualization
                unit[UNIT["m"]] = v
            else:
                unit[k] = v

        times = metric.times.value
        timedelta = seconds_to_days(times - times[..., 0])
        timedelta: Float[Array, "time"] = timedelta[..., 1:]

        return values, timedelta, units_to_str(unit)

    def __plot_ensemble_metric(
        self,
        values: Float[Array, "member time"],
        timedelta: Float[Array, "time"],
        ax: Axes,
    ):
        timedelta_extended = jnp.tile(timedelta, (values.shape[0], 1))
        segments = jnp.concat([timedelta_extended[..., None], values[..., None]], axis=2)
        alpha = jnp.clip(1 / ((self.size / 10) ** (1 / 2)), 0.1, 1).item() / 2
        lc = LineCollection(segments, alpha=alpha, color="black")  # type: ignore
        ax.add_collection(lc)

    @staticmethod
    def __plot_pair_metric(
        values: Float[Array, "time"],
        timedelta: Float[Array, "time"],
        ax: Axes,
        label: str | None = None,
    ):
        kwargs = {}
        if label is not None:
            kwargs["label"] = label

        ax.plot(timedelta, values, **kwargs)

    @staticmethod
    def __set_axis_limits_labels(
        min_states: Float[Array, ""],
        max_states: Float[Array, ""],
        timedelta: Float[Array, "time"],
        unit: str | None,
        name: str | None,
        add_xlabel: bool,
        ax: Axes,
    ):
        ax.set_xlim(0, (timedelta[-1] + timedelta[0]).item())
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        if add_xlabel:
            ax.set_xlabel("Elapsed days")
        else:
            ax.xaxis.set_major_formatter(NullFormatter())

        ax.set_ylim(
            (min_states - abs(min_states * 0.1)).item(),
            (max_states + abs(max_states * 0.1)).item(),
        )
        ylabel = f"{name}"
        if unit != "":
            ylabel += f" (${unit}$)"
        ax.set_ylabel(ylabel)

    def __getitem__(
        self, key: str
    ) -> Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries]:
        """
        Retrieves a metric by key.

        Parameters
        ----------
        key : str
            The key of the metric to retrieve.

        Returns
        -------
        Timeseries | TimeseriesEnsemble
            The metric [`pastax.trajectory.Timeseries`][] or [`pastax.trajectory.TimeseriesEnsemble`][] corresponding
            to the key.
        """
        return self._members[key]
__init__(states: dict[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]])

Initializes the pastax.evaluation.Evaluation object with a dictionary of metric timeseries or timeseries ensemble.

Parameters:

Name Type Description Default
states dict[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]

The initial metrics dictionary.

required
Source code in pastax/evaluation/_evaluation.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(
    self,
    states: dict[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]],
):
    """
    Initializes the [`pastax.evaluation.Evaluation`][] object with a dictionary of metric timeseries or timeseries
    ensemble.

    Parameters
    ----------
    states : dict[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]
        The initial metrics dictionary.
    """
    self._members = states
    self.size = len(states)
get(key: str) -> Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries] | None

Retrieves a metric by key.

Parameters:

Name Type Description Default
key str

The key of the metric to retrieve.

required

Returns:

Type Description
Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries] | None

The metrics corresponding to the key.

Source code in pastax/evaluation/_evaluation.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def get(
    self, key: str
) -> Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries] | None:
    """
    Retrieves a metric by key.

    Parameters
    ----------
    key : str
        The key of the metric to retrieve.

    Returns
    -------
    Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries] | None
        The metrics corresponding to the key.
    """
    return self._members.get(key)
items() -> Iterable[tuple[str, Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]]

Returns the items of the metrics dictionary.

Returns:

Type Description
tuple[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]

The items of the metrics dictionary.

Source code in pastax/evaluation/_evaluation.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def items(
    self,
) -> Iterable[
    tuple[
        str,
        Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries],
    ]
]:
    """
    Returns the items of the metrics dictionary.

    Returns
    -------
    tuple[str, Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]
        The items of the metrics dictionary.
    """
    return self._members.items()
keys() -> Iterable[str]

Returns the keys of the metric timeseries dictionary.

Returns:

Type Description
tuple[str]

The keys of the metrics dictionary.

Source code in pastax/evaluation/_evaluation.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def keys(self) -> Iterable[str]:
    """
    Returns the keys of the metric timeseries dictionary.

    Returns
    -------
    tuple[str]
        The keys of the metrics dictionary.
    """
    return self._members.keys()
values() -> Iterable[Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]

Returns the values of the metrics dictionary.

Returns:

Type Description
tuple[Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]

The values of the metrics dictionary.

Source code in pastax/evaluation/_evaluation.py
103
104
105
106
107
108
109
110
111
112
113
114
def values(
    self,
) -> Iterable[Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]:
    """
    Returns the values of the metrics dictionary.

    Returns
    -------
    tuple[Timeseries | tuple[TimeseriesEnsemble, Timeseries, Timeseries]]
        The values of the metrics dictionary.
    """
    return self._members.values()
to_dataarray() -> dict[str, xr.DataArray]

Converts the evaluation results to a dictionary of xarray.DataArrays.

Returns:

Type Description
dict[str, DataArray]

A dictionary where keys are the evaluation metric names and values are the corresponding xarray.DataArrays.

Source code in pastax/evaluation/_evaluation.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def to_dataarray(self) -> dict[str, xr.DataArray]:
    """
    Converts the evaluation results to a dictionary of `xarray.DataArray`s.

    Returns
    -------
    dict[str, xr.DataArray]
        A dictionary where keys are the evaluation metric names and values are the corresponding
        `xarray.DataArray`s.
    """
    da = {}
    for key, value in self.items():
        if isinstance(value, Timeseries) or isinstance(value, TimeseriesEnsemble):
            da[key] = value.to_dataarray()
        else:
            ensemble, mean, crps = value
            da[key] = ensemble.to_dataarray().mean("member")
            da[f"{key} - mean"] = mean.to_dataarray().mean("member")
            da[f"{key} - CRPS"] = crps.to_dataarray().mean("member")

    return da
to_dataset() -> xr.Dataset

Converts the evaluation results to a xarray.Dataset.

Returns:

Type Description
Dataset

A xarray.Dataset containing the evaluation results.

Source code in pastax/evaluation/_evaluation.py
138
139
140
141
142
143
144
145
146
147
def to_dataset(self) -> xr.Dataset:
    """
    Converts the evaluation results to a `xarray.Dataset`.

    Returns
    -------
    xr.Dataset
        A `xarray.Dataset` containing the evaluation results.
    """
    return xr.Dataset(self.to_dataarray())
plot(fig: Figure, ti: int | None = None)

Plots the metric timeseries or timeseries ensemble up to the time index ti on the figure fig.

Parameters:

Name Type Description Default
fig Figure

The figure to plot on.

required
ti int | None

The time index up to which to plot. If ti=None, plots the full metric timeseries.

None
Source code in pastax/evaluation/_evaluation.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def plot(self, fig: Figure, ti: int | None = None):
    """
    Plots the metric timeseries or timeseries ensemble up to the time index `ti` on the figure `fig`.

    Parameters
    ----------
    fig : Figure
        The figure to plot on.
    ti : int | None, optional
        The time index up to which to plot. If `ti=None`, plots the full metric timeseries.
    """
    if ti is None:
        metric = next(iter(self.values()))
        if isinstance(metric, Timeseries) or isinstance(metric, TimeseriesEnsemble):
            ti = metric.length
        else:
            ti = metric[0].length

    n_metrics = self.size

    n_rows, n_cols = self.__guess_metrics_fig_layout(n_metrics)
    axs = fig.subplots(n_rows, n_cols)

    for i, metric in zip(range(n_metrics), self.values()):
        rowi = i // n_cols
        coli = i % n_cols
        add_xlabel = rowi == n_rows - 1
        add_legend = rowi == 0 and coli == n_cols - 1
        self._plot_metric(ti, metric, add_xlabel, add_legend, axs[rowi, coli])
__getitem__(key: str) -> Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries]

Retrieves a metric by key.

Parameters:

Name Type Description Default
key str

The key of the metric to retrieve.

required

Returns:

Type Description
Timeseries | TimeseriesEnsemble

The metric pastax.trajectory.Timeseries or pastax.trajectory.TimeseriesEnsemble corresponding to the key.

Source code in pastax/evaluation/_evaluation.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def __getitem__(
    self, key: str
) -> Timeseries | TimeseriesEnsemble | tuple[TimeseriesEnsemble, Timeseries, Timeseries]:
    """
    Retrieves a metric by key.

    Parameters
    ----------
    key : str
        The key of the metric to retrieve.

    Returns
    -------
    Timeseries | TimeseriesEnsemble
        The metric [`pastax.trajectory.Timeseries`][] or [`pastax.trajectory.TimeseriesEnsemble`][] corresponding
        to the key.
    """
    return self._members[key]

BaseEvaluator

Bases: Module

Base class for evaluating trajectories using a set of predefined metrics.

Attributes:

Name Type Description
metrics list[Metric]

Methods:

Name Description
__call__

Evaluates the simulated_trajectory (which might be an ensemble of trajectories) against the reference_trajectory using self.metrics.

Source code in pastax/evaluation/_evaluator.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class BaseEvaluator(eqx.Module):  # TODO: should it be an eqx Module?
    """
    Base class for evaluating trajectories using a set of predefined metrics.

    Attributes
    -----------
    metrics : list[Metric]
        A list of [`pastax.evaluation.Metric`][]s used for evaluation.
        The default [`pastax.evaluation.Metric`][]s are [`pastax.evaluation.SeparationDistance`][],
        [`pastax.evaluation.LiuIndex`][], [`pastax.evaluation.Mae`][], and [`pastax.evaluation.Rmse`][].

    Methods
    -------
    __call__(self, reference_trajectory, simulated_trajectory)
        Evaluates the `simulated_trajectory` (which might be an ensemble of trajectories)
        against the `reference_trajectory` using `self.metrics`.
    """

    metrics: list[Metric] = eqx.field(default_factory=lambda: [SeparationDistance(), LiuIndex(), Mae(), Rmse()])

    @eqx.filter_jit
    def __call__(
        self,
        reference_trajectory: Trajectory,
        simulated_trajectory: Trajectory | TrajectoryEnsemble,
    ) -> Evaluation:
        """
        Evaluates the `simulated_trajectory` (which might be an ensemble of trajectories)
        against the `reference_trajectory` using `self.metrics`.

        Parameters
        ----------
        reference_trajectory : Trajectory
            The reference [`pastax.trajectory.Trajectory`][] to compare against.
        simulated_trajectory : Trajectory | TrajectoryEnsemble
            The simulated [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][] to be
            evaluated.

        Returns
        -------
        Evaluation
            The result of the [`pastax.evaluation.Evaluation`][].

        Raises
        ------
        NotImplementedError
            This method should be implemented by child classes.
        """
        raise NotImplementedError
__call__(reference_trajectory: Trajectory, simulated_trajectory: Trajectory | TrajectoryEnsemble) -> Evaluation

Evaluates the simulated_trajectory (which might be an ensemble of trajectories) against the reference_trajectory using self.metrics.

Parameters:

Name Type Description Default
reference_trajectory Trajectory

The reference pastax.trajectory.Trajectory to compare against.

required
simulated_trajectory Trajectory | TrajectoryEnsemble required

Returns:

Type Description
Evaluation

The result of the pastax.evaluation.Evaluation.

Raises:

Type Description
NotImplementedError

This method should be implemented by child classes.

Source code in pastax/evaluation/_evaluator.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@eqx.filter_jit
def __call__(
    self,
    reference_trajectory: Trajectory,
    simulated_trajectory: Trajectory | TrajectoryEnsemble,
) -> Evaluation:
    """
    Evaluates the `simulated_trajectory` (which might be an ensemble of trajectories)
    against the `reference_trajectory` using `self.metrics`.

    Parameters
    ----------
    reference_trajectory : Trajectory
        The reference [`pastax.trajectory.Trajectory`][] to compare against.
    simulated_trajectory : Trajectory | TrajectoryEnsemble
        The simulated [`pastax.trajectory.Trajectory`][] or [`pastax.trajectory.TrajectoryEnsemble`][] to be
        evaluated.

    Returns
    -------
    Evaluation
        The result of the [`pastax.evaluation.Evaluation`][].

    Raises
    ------
    NotImplementedError
        This method should be implemented by child classes.
    """
    raise NotImplementedError

EnsembleEvaluator

Bases: BaseEvaluator

Class for evaluating an ensemble of simulated trajectories using a set of predefined metrics.

Methods:

Name Description
__call__

Evaluates the simulated_trajectories ensemble against the reference_trajectory using self.metrics.

Source code in pastax/evaluation/_evaluator.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class EnsembleEvaluator(BaseEvaluator):
    """
    Class for evaluating an ensemble of simulated trajectories using a set of predefined metrics.

    Methods
    -------
    __call__(self, reference_trajectory, simulated_trajectory)
        Evaluates the `simulated_trajectories` ensemble against the `reference_trajectory` using `self.metrics`.
    """

    def __call__(
        self,
        reference_trajectory: Trajectory,
        simulated_trajectories: TrajectoryEnsemble,
    ) -> Evaluation:
        """
        Evaluates the `simulated_trajectories` ensemble against the `reference_trajectory` using `self.metrics`.

        Parameters
        ----------
        reference_trajectory : Trajectory
            The reference [`pastax.trajectory.Trajectory`][] to compare against.
        simulated_trajectories : TrajectoryEnsemble
            The simulated [`pastax.trajectory.TrajectoryEnsemble`][] to be evaluated.

        Returns
        -------
        Evaluation
            The result of the [`pastax.evaluation.Evaluation`][].
        """
        metrics = {}
        for metric in self.metrics:
            metric_fun = metric.metric_fun

            ensemble = getattr(simulated_trajectories, metric_fun)(reference_trajectory)
            ensemble = getattr(simulated_trajectories, metric_fun)(reference_trajectory)

            crps = simulated_trajectories.crps(reference_trajectory, metric_func=getattr(Trajectory, metric_fun))

            mean = ensemble.mean(axis=0)

            metrics[metric_fun] = (ensemble, crps, mean)

        return Evaluation(metrics)
__call__(reference_trajectory: Trajectory, simulated_trajectories: TrajectoryEnsemble) -> Evaluation

Evaluates the simulated_trajectories ensemble against the reference_trajectory using self.metrics.

Parameters:

Name Type Description Default
reference_trajectory Trajectory

The reference pastax.trajectory.Trajectory to compare against.

required
simulated_trajectories TrajectoryEnsemble

The simulated pastax.trajectory.TrajectoryEnsemble to be evaluated.

required

Returns:

Type Description
Evaluation

The result of the pastax.evaluation.Evaluation.

Source code in pastax/evaluation/_evaluator.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __call__(
    self,
    reference_trajectory: Trajectory,
    simulated_trajectories: TrajectoryEnsemble,
) -> Evaluation:
    """
    Evaluates the `simulated_trajectories` ensemble against the `reference_trajectory` using `self.metrics`.

    Parameters
    ----------
    reference_trajectory : Trajectory
        The reference [`pastax.trajectory.Trajectory`][] to compare against.
    simulated_trajectories : TrajectoryEnsemble
        The simulated [`pastax.trajectory.TrajectoryEnsemble`][] to be evaluated.

    Returns
    -------
    Evaluation
        The result of the [`pastax.evaluation.Evaluation`][].
    """
    metrics = {}
    for metric in self.metrics:
        metric_fun = metric.metric_fun

        ensemble = getattr(simulated_trajectories, metric_fun)(reference_trajectory)
        ensemble = getattr(simulated_trajectories, metric_fun)(reference_trajectory)

        crps = simulated_trajectories.crps(reference_trajectory, metric_func=getattr(Trajectory, metric_fun))

        mean = ensemble.mean(axis=0)

        metrics[metric_fun] = (ensemble, crps, mean)

    return Evaluation(metrics)

PairEvaluator

Bases: BaseEvaluator

Class for evaluating a simulated trajectory using a set of predefined metrics.

Methods:

Name Description
__call__

Evaluates the simulated_trajectory against the reference_trajectory using the self.metrics.

Source code in pastax/evaluation/_evaluator.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class PairEvaluator(BaseEvaluator):
    """
    Class for evaluating a simulated trajectory using a set of predefined metrics.

    Methods
    -------
    __call__(self, reference_trajectory, simulated_trajectory)
        Evaluates the `simulated_trajectory` against the `reference_trajectory` using the `self.metrics`.
    """

    def __call__(self, reference_trajectory: Trajectory, simulated_trajectory: Trajectory) -> Evaluation:
        """
        Evaluates the `simulated_trajectory` against the `reference_trajectory` using `self.metrics`.

        Parameters
        ----------
        reference_trajectory : Trajectory
            The reference [`pastax.trajectory.Trajectory`][] to compare against.
        simulated_trajectory : Trajectory
            The simulated [`pastax.trajectory.Trajectory`][] to be evaluated.

        Returns
        -------
        Evaluation
            The result of the [`pastax.evaluation.Evaluation`][].
        """
        metrics = {}
        for metric in self.metrics:
            metric_fun = metric.metric_fun
            metrics[metric_fun] = getattr(reference_trajectory, metric_fun)(simulated_trajectory)

        return Evaluation(metrics)
__call__(reference_trajectory: Trajectory, simulated_trajectory: Trajectory) -> Evaluation

Evaluates the simulated_trajectory against the reference_trajectory using self.metrics.

Parameters:

Name Type Description Default
reference_trajectory Trajectory

The reference pastax.trajectory.Trajectory to compare against.

required
simulated_trajectory Trajectory

The simulated pastax.trajectory.Trajectory to be evaluated.

required

Returns:

Type Description
Evaluation

The result of the pastax.evaluation.Evaluation.

Source code in pastax/evaluation/_evaluator.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __call__(self, reference_trajectory: Trajectory, simulated_trajectory: Trajectory) -> Evaluation:
    """
    Evaluates the `simulated_trajectory` against the `reference_trajectory` using `self.metrics`.

    Parameters
    ----------
    reference_trajectory : Trajectory
        The reference [`pastax.trajectory.Trajectory`][] to compare against.
    simulated_trajectory : Trajectory
        The simulated [`pastax.trajectory.Trajectory`][] to be evaluated.

    Returns
    -------
    Evaluation
        The result of the [`pastax.evaluation.Evaluation`][].
    """
    metrics = {}
    for metric in self.metrics:
        metric_fun = metric.metric_fun
        metrics[metric_fun] = getattr(reference_trajectory, metric_fun)(simulated_trajectory)

    return Evaluation(metrics)

LiuIndex

Bases: Metric

The Liu index metric.

Attributes:

Name Type Description
metric_fun str

The name of the metric function or method: metric_fun="liu_index".

Source code in pastax/evaluation/_metric.py
17
18
19
20
21
22
23
24
25
26
27
class LiuIndex(Metric):
    """
    The Liu index metric.

    Attributes
    ----------
    metric_fun : str
        The name of the metric function or method: `metric_fun="liu_index"`.
    """

    metric_fun: str = eqx.field(static=True, default_factory=lambda: "liu_index")

Mae

Bases: Metric

The Mean Absolute Error metric.

Attributes:

Name Type Description
metric_fun str

The name of the metric function or method: metric_fun="mae".

Source code in pastax/evaluation/_metric.py
30
31
32
33
34
35
36
37
38
39
40
class Mae(Metric):
    """
    The Mean Absolute Error metric.

    Attributes
    ----------
    metric_fun : str
        The name of the metric function or method: `metric_fun="mae"`.
    """

    metric_fun: str = eqx.field(static=True, default_factory=lambda: "mae")

Metric

Bases: Module

Base class for metric objects.

Attributes:

Name Type Description
metric_fun str

The name of the metric function or method.

Source code in pastax/evaluation/_metric.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class Metric(eqx.Module):
    """
    Base class for metric objects.

    Attributes
    ----------
    metric_fun : str
        The name of the metric function or method.
    """

    metric_fun: str = eqx.field(static=True)

Rmse

Bases: Metric

The Root Mean Square Error metric.

Attributes:

Name Type Description
metric_fun str

The name of the metric function or method: metric_fun="rmse".

Source code in pastax/evaluation/_metric.py
43
44
45
46
47
48
49
50
51
52
53
class Rmse(Metric):
    """
    The Root Mean Square Error metric.

    Attributes
    ----------
    metric_fun : str
        The name of the metric function or method: `metric_fun="rmse"`.
    """

    metric_fun: str = eqx.field(static=True, default_factory=lambda: "rmse")

SeparationDistance

Bases: Metric

The Separation distance metric.

Attributes:

Name Type Description
metric_fun str

The name of the metric function or method: metric_fun="separation_distance".

Source code in pastax/evaluation/_metric.py
56
57
58
59
60
61
62
63
64
65
66
class SeparationDistance(Metric):
    """
    The Separation distance metric.

    Attributes
    ----------
    metric_fun : str
        The name of the metric function or method: `metric_fun="separation_distance"`.
    """

    metric_fun: str = eqx.field(static=True, default_factory=lambda: "separation_distance")

pastax.utils

This module provides various geographical and pastax.utils.Unit conversion and manipulation utilities in JAX.

EARTH_RADIUS = 6371008.8 module-attribute

float: The radius of the Earth in meters.

UNIT = {'m': Meters(), 'km': Kilometers(), '°': LatLonDegrees(), 's': Seconds(), 'min': Minutes(), 'h': Hours(), 'd': Days()} module-attribute

A dictionary mapping unit symbols to their corresponding pastax.utils.Unit objects.

Keys

"m" : Meters Represents meters as a pastax.utils.Unit of measurement. "km" : Kilometers Represents kilometers as a pastax.utils.Unit of measurement. "°" : LatLonDegrees Represents latitude and longitude degrees as a pastax.utils.Unit of measurement. "s" : Seconds Represents seconds as a pastax.utils.Unit of measurement. "min" : Minutes Represents minutes as a pastax.utils.Unit of measurement. "h" : Hours Represents hours as a pastax.utils.Unit of measurement. "d" : Days Represents days as a pastax.utils.Unit of measurement.

Values

Unit The corresponding pastax.utils.Unit object for each unit symbol.

Unit

Bases: Module

Base class representing pastax.utils.Unit of measurement.

Attributes:

Name Type Description
name str

The name of the [pastax.utils.Unit].

Methods:

Name Description
__eq__

Checks if two pastax.utils.Unit are equal.

__lt__

Checks if one pastax.utils.Unit is less than another (using their name).

__hash__

Returns the hash of the [pastax.utils.Unit].

__repr__

Returns the string representation of the [pastax.utils.Unit].

convert_to

Converts the value to the specified [pastax.utils.Unit].

Source code in pastax/utils/_unit.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class Unit(eqx.Module):
    """
    Base class representing [`pastax.utils.Unit`][] of measurement.

    Attributes
    ----------
    name : str
        The name of the [`pastax.utils.Unit`].

    Methods
    -------
    __eq__(other)
        Checks if two [`pastax.utils.Unit`][] are equal.
    __lt__(other)
        Checks if one [`pastax.utils.Unit`][] is less than another (using their name).
    __hash__()
        Returns the hash of the [`pastax.utils.Unit`].
    __repr__()
        Returns the string representation of the [`pastax.utils.Unit`].
    convert_to(unit, value, exp, *args)
        Converts the `value` to the specified [`pastax.utils.Unit`].
    """

    name: str = eqx.field(static=True, default_factory=lambda: "")

    def __eq__(self, other):
        if isinstance(other, Unit):
            return self.name == other.name
        return False

    def __lt__(self, other):
        if isinstance(other, Unit):
            return self.name < other.name
        return NotImplemented

    def __hash__(self):
        return hash(self.name)

    def __repr__(self):
        return self.name

    @staticmethod
    def _pre_convert(value: Array, exp: int | float) -> Array:
        """
        Prepares the value for conversion between "base" [`pastax.utils.Unit`][] by raising it to the power of the
        reciprocal of the exponent.

        Parameters
        ----------
        value : Array
            The value to be converted.
        exp : int or float
            The exponent to use for conversion.

        Returns
        -------
        Array
            The prepared value.
        """
        if exp == 1:
            return value
        return value ** (1 / exp)

    @staticmethod
    def _post_convert(value: Array, exp: int | float) -> Array:
        """
        Finalizes the conversion between "base" [`pastax.utils.Unit`][] by raising the value to the power of the
        exponent.

        Parameters
        ----------
        value : Array
            The value to be converted.
        exp : int or float
            The exponent to use for conversion.

        Returns
        -------
        Array
            The converted value.
        """
        if exp == 1:
            return value
        return value**exp

    def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
        """
        Converts the `value` to the specified [`pastax.utils.Unit`].

        Parameters
        ----------
        unit : Unit
            The [`pastax.utils.Unit`][] to convert to.
        value : Array
            The value to convert.
        exp : int or float, optional
            The exponent to use for conversion, defaults to 1.
        *args
            Additional arguments for conversion.

        Returns
        -------
        Array
            The converted value.

        Raises
        ------
        NotImplementedError
            If not implemented by subclasses.
        """
        raise NotImplementedError
convert_to(unit: Unit, value: Array, exp: int | float = 1, *args) -> Array

Converts the value to the specified [pastax.utils.Unit].

Parameters:

Name Type Description Default
unit Unit

The pastax.utils.Unit to convert to.

required
value Array

The value to convert.

required
exp int or float

The exponent to use for conversion, defaults to 1.

1
*args

Additional arguments for conversion.

()

Returns:

Type Description
Array

The converted value.

Raises:

Type Description
NotImplementedError

If not implemented by subclasses.

Source code in pastax/utils/_unit.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
    """
    Converts the `value` to the specified [`pastax.utils.Unit`].

    Parameters
    ----------
    unit : Unit
        The [`pastax.utils.Unit`][] to convert to.
    value : Array
        The value to convert.
    exp : int or float, optional
        The exponent to use for conversion, defaults to 1.
    *args
        Additional arguments for conversion.

    Returns
    -------
    Array
        The converted value.

    Raises
    ------
    NotImplementedError
        If not implemented by subclasses.
    """
    raise NotImplementedError

Meters

Bases: Unit

Class representing meters as a pastax.utils.Unit of measurement.

Attributes:

Name Type Description
name str

The name of the [pastax.utils.Unit], set to "m".

Source code in pastax/utils/_unit.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class Meters(Unit):
    """
    Class representing meters as a [`pastax.utils.Unit`][] of measurement.

    Attributes
    ----------
    name : str
        The name of the [`pastax.utils.Unit`], set to `"m"`.
    """

    name: str = eqx.field(static=True, default_factory=lambda: "m")

    def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
        """
        Converts the value to the specified [`pastax.utils.Unit`].

        Parameters
        ----------
        unit : Unit
            The [`pastax.utils.Unit`][] to convert to.
        value : Array
            The value to convert.
        exp : int or float, optional
            The exponent to use for conversion, defaults to 1.
        *args
            Additional arguments for conversion.

        Returns
        -------
        Array
            The converted value.

        Raises
        ------
        ValueError
            If the conversion is not possible.
        """
        if isinstance(unit, Meters):
            return value

        value = self._pre_convert(value, exp)

        if isinstance(unit, Kilometers):
            value = meters_to_kilometers(value)
        elif isinstance(unit, LatLonDegrees):
            value = meters_to_degrees(value, *args)
        else:
            raise ValueError(f"Cannot convert {self} to {unit}")

        return self._post_convert(value, exp)
convert_to(unit: Unit, value: Array, exp: int | float = 1, *args) -> Array

Converts the value to the specified [pastax.utils.Unit].

Parameters:

Name Type Description Default
unit Unit

The pastax.utils.Unit to convert to.

required
value Array

The value to convert.

required
exp int or float

The exponent to use for conversion, defaults to 1.

1
*args

Additional arguments for conversion.

()

Returns:

Type Description
Array

The converted value.

Raises:

Type Description
ValueError

If the conversion is not possible.

Source code in pastax/utils/_unit.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
    """
    Converts the value to the specified [`pastax.utils.Unit`].

    Parameters
    ----------
    unit : Unit
        The [`pastax.utils.Unit`][] to convert to.
    value : Array
        The value to convert.
    exp : int or float, optional
        The exponent to use for conversion, defaults to 1.
    *args
        Additional arguments for conversion.

    Returns
    -------
    Array
        The converted value.

    Raises
    ------
    ValueError
        If the conversion is not possible.
    """
    if isinstance(unit, Meters):
        return value

    value = self._pre_convert(value, exp)

    if isinstance(unit, Kilometers):
        value = meters_to_kilometers(value)
    elif isinstance(unit, LatLonDegrees):
        value = meters_to_degrees(value, *args)
    else:
        raise ValueError(f"Cannot convert {self} to {unit}")

    return self._post_convert(value, exp)

Kilometers

Bases: Unit

Class representing kilometers as a pastax.utils.Unit of measurement.

Attributes:

Name Type Description
name str

The name of the [pastax.utils.Unit], set to "km".

Source code in pastax/utils/_unit.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
class Kilometers(Unit):
    """
    Class representing kilometers as a [`pastax.utils.Unit`][] of measurement.

    Attributes
    ----------
    name : str
        The name of the [`pastax.utils.Unit`], set to `"km"`.
    """

    name: str = eqx.field(static=True, default_factory=lambda: "km")

    def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
        """
        Converts the value to the specified [`pastax.utils.Unit`].

        Parameters
        ----------
        unit : Unit
            The [`pastax.utils.Unit`][] to convert to.
        value : Array
            The value to convert.
        exp : int or float, optional
            The exponent to use for conversion, defaults to 1.
        *args
            Additional arguments for conversion.

        Returns
        -------
        Array
            The converted value.

        Raises
        ------
        ValueError
            If the conversion is not possible.
        """
        if isinstance(unit, Kilometers):
            return value

        value = self._pre_convert(value, exp)

        if isinstance(unit, Meters):
            value = kilometers_to_meters(value)
        elif isinstance(unit, LatLonDegrees):
            value = kilometers_to_degrees(value, *args)
        else:
            raise ValueError(f"Cannot convert {self} to {unit}")

        return self._post_convert(value, exp)
convert_to(unit: Unit, value: Array, exp: int | float = 1, *args) -> Array

Converts the value to the specified [pastax.utils.Unit].

Parameters:

Name Type Description Default
unit Unit

The pastax.utils.Unit to convert to.

required
value Array

The value to convert.

required
exp int or float

The exponent to use for conversion, defaults to 1.

1
*args

Additional arguments for conversion.

()

Returns:

Type Description
Array

The converted value.

Raises:

Type Description
ValueError

If the conversion is not possible.

Source code in pastax/utils/_unit.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
    """
    Converts the value to the specified [`pastax.utils.Unit`].

    Parameters
    ----------
    unit : Unit
        The [`pastax.utils.Unit`][] to convert to.
    value : Array
        The value to convert.
    exp : int or float, optional
        The exponent to use for conversion, defaults to 1.
    *args
        Additional arguments for conversion.

    Returns
    -------
    Array
        The converted value.

    Raises
    ------
    ValueError
        If the conversion is not possible.
    """
    if isinstance(unit, Kilometers):
        return value

    value = self._pre_convert(value, exp)

    if isinstance(unit, Meters):
        value = kilometers_to_meters(value)
    elif isinstance(unit, LatLonDegrees):
        value = kilometers_to_degrees(value, *args)
    else:
        raise ValueError(f"Cannot convert {self} to {unit}")

    return self._post_convert(value, exp)

LatLonDegrees

Bases: Unit

Class representing latitude and longitude degrees as a pastax.utils.Unit of measurement.

Attributes:

Name Type Description
name str

The name of the [pastax.utils.Unit], set to "°".

Source code in pastax/utils/_unit.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
class LatLonDegrees(Unit):
    """
    Class representing latitude and longitude degrees as a [`pastax.utils.Unit`][] of measurement.

    Attributes
    ----------
    name : str
        The name of the [`pastax.utils.Unit`], set to `"°"`.
    """

    name: str = eqx.field(static=True, default_factory=lambda: "°")

    def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
        """
        Converts the value to the specified [`pastax.utils.Unit`].

        Parameters
        ----------
        unit : Unit
            The [`pastax.utils.Unit`][] to convert to.
        value : Array
            The value to convert.
        exp : int or float, optional
            The exponent to use for conversion, defaults to 1.
        *args
            Additional arguments for conversion.

        Returns
        -------
        Array
            The converted value.

        Raises
        ------
        ValueError
            If the conversion is not possible.
        """
        if isinstance(unit, LatLonDegrees):
            return value

        value = self._pre_convert(value, exp)

        if isinstance(unit, Meters):
            value = degrees_to_meters(value, *args)
        elif isinstance(unit, Kilometers):
            value = degrees_to_kilometers(value, *args)
        else:
            raise ValueError(f"Cannot convert {self} to {unit}")

        return self._post_convert(value, exp)
convert_to(unit: Unit, value: Array, exp: int | float = 1, *args) -> Array

Converts the value to the specified [pastax.utils.Unit].

Parameters:

Name Type Description Default
unit Unit

The pastax.utils.Unit to convert to.

required
value Array

The value to convert.

required
exp int or float

The exponent to use for conversion, defaults to 1.

1
*args

Additional arguments for conversion.

()

Returns:

Type Description
Array

The converted value.

Raises:

Type Description
ValueError

If the conversion is not possible.

Source code in pastax/utils/_unit.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
    """
    Converts the value to the specified [`pastax.utils.Unit`].

    Parameters
    ----------
    unit : Unit
        The [`pastax.utils.Unit`][] to convert to.
    value : Array
        The value to convert.
    exp : int or float, optional
        The exponent to use for conversion, defaults to 1.
    *args
        Additional arguments for conversion.

    Returns
    -------
    Array
        The converted value.

    Raises
    ------
    ValueError
        If the conversion is not possible.
    """
    if isinstance(unit, LatLonDegrees):
        return value

    value = self._pre_convert(value, exp)

    if isinstance(unit, Meters):
        value = degrees_to_meters(value, *args)
    elif isinstance(unit, Kilometers):
        value = degrees_to_kilometers(value, *args)
    else:
        raise ValueError(f"Cannot convert {self} to {unit}")

    return self._post_convert(value, exp)

Seconds

Bases: Unit

Class representing seconds as a pastax.utils.Unit of measurement.

Attributes:

Name Type Description
name str

The name of the [pastax.utils.Unit], set to "s".

Source code in pastax/utils/_unit.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
class Seconds(Unit):
    """
    Class representing seconds as a [`pastax.utils.Unit`][] of measurement.

    Attributes
    ----------
    name : str
        The name of the [`pastax.utils.Unit`], set to `"s"`.
    """

    name: str = eqx.field(static=True, default_factory=lambda: "s")

    def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
        """
        Converts the value to the specified unit.

        Parameters
        ----------
        unit : Unit
            The [`pastax.utils.Unit`][] to convert to.
        value : Array
            The value to convert.
        exp : int or float, optional
            The exponent to use for conversion, defaults to 1.
        *args
            Additional arguments for conversion.

        Returns
        -------
        Array
            The converted value.

        Raises
        ------
        ValueError
            If the conversion is not possible.
        """
        if isinstance(unit, Seconds):
            return value

        value = self._pre_convert(value, exp)

        if isinstance(unit, Minutes):
            value = seconds_to_minutes(value)
        elif isinstance(unit, Hours):
            value = seconds_to_hours(value)
        elif isinstance(unit, Days):
            value = seconds_to_days(value)
        else:
            raise ValueError(f"Cannot convert {self} to {unit}")

        return self._post_convert(value, exp)
convert_to(unit: Unit, value: Array, exp: int | float = 1, *args) -> Array

Converts the value to the specified unit.

Parameters:

Name Type Description Default
unit Unit

The pastax.utils.Unit to convert to.

required
value Array

The value to convert.

required
exp int or float

The exponent to use for conversion, defaults to 1.

1
*args

Additional arguments for conversion.

()

Returns:

Type Description
Array

The converted value.

Raises:

Type Description
ValueError

If the conversion is not possible.

Source code in pastax/utils/_unit.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
    """
    Converts the value to the specified unit.

    Parameters
    ----------
    unit : Unit
        The [`pastax.utils.Unit`][] to convert to.
    value : Array
        The value to convert.
    exp : int or float, optional
        The exponent to use for conversion, defaults to 1.
    *args
        Additional arguments for conversion.

    Returns
    -------
    Array
        The converted value.

    Raises
    ------
    ValueError
        If the conversion is not possible.
    """
    if isinstance(unit, Seconds):
        return value

    value = self._pre_convert(value, exp)

    if isinstance(unit, Minutes):
        value = seconds_to_minutes(value)
    elif isinstance(unit, Hours):
        value = seconds_to_hours(value)
    elif isinstance(unit, Days):
        value = seconds_to_days(value)
    else:
        raise ValueError(f"Cannot convert {self} to {unit}")

    return self._post_convert(value, exp)

Minutes

Bases: Unit

Class representing minutes as a pastax.utils.Unit of measurement.

Attributes:

Name Type Description
name str

The name of the [pastax.utils.Unit], set to "min".

Source code in pastax/utils/_unit.py
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class Minutes(Unit):
    """
    Class representing minutes as a [`pastax.utils.Unit`][] of measurement.

    Attributes
    ----------
    name : str
        The name of the [`pastax.utils.Unit`], set to `"min"`.
    """

    name: str = eqx.field(static=True, default_factory=lambda: "min")

    def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
        """
        Converts the value to the specified unit.

        Parameters
        ----------
        unit : Unit
            The [`pastax.utils.Unit`][] to convert to.
        value : Array
            The value to convert.
        exp : int or float, optional
            The exponent to use for conversion, defaults to 1.
        *args
            Additional arguments for conversion.

        Returns
        -------
        Array
            The converted value.

        Raises
        ------
        ValueError
            If the conversion is not possible.
        """
        if isinstance(unit, Minutes):
            return value

        value = self._pre_convert(value, exp)

        if isinstance(unit, Seconds):
            value = minutes_to_seconds(value)
        elif isinstance(unit, Hours):
            value = minutes_to_hours(value)
        elif isinstance(unit, Days):
            value = minutes_to_days(value)
        else:
            raise ValueError(f"Cannot convert {self} to {unit}")

        return self._post_convert(value, exp)
convert_to(unit: Unit, value: Array, exp: int | float = 1, *args) -> Array

Converts the value to the specified unit.

Parameters:

Name Type Description Default
unit Unit

The pastax.utils.Unit to convert to.

required
value Array

The value to convert.

required
exp int or float

The exponent to use for conversion, defaults to 1.

1
*args

Additional arguments for conversion.

()

Returns:

Type Description
Array

The converted value.

Raises:

Type Description
ValueError

If the conversion is not possible.

Source code in pastax/utils/_unit.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
    """
    Converts the value to the specified unit.

    Parameters
    ----------
    unit : Unit
        The [`pastax.utils.Unit`][] to convert to.
    value : Array
        The value to convert.
    exp : int or float, optional
        The exponent to use for conversion, defaults to 1.
    *args
        Additional arguments for conversion.

    Returns
    -------
    Array
        The converted value.

    Raises
    ------
    ValueError
        If the conversion is not possible.
    """
    if isinstance(unit, Minutes):
        return value

    value = self._pre_convert(value, exp)

    if isinstance(unit, Seconds):
        value = minutes_to_seconds(value)
    elif isinstance(unit, Hours):
        value = minutes_to_hours(value)
    elif isinstance(unit, Days):
        value = minutes_to_days(value)
    else:
        raise ValueError(f"Cannot convert {self} to {unit}")

    return self._post_convert(value, exp)

Hours

Bases: Unit

Class representing hours as a pastax.utils.Unit of measurement.

Attributes:

Name Type Description
name str

The name of the [pastax.utils.Unit], set to "h".

Source code in pastax/utils/_unit.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
class Hours(Unit):
    """
    Class representing hours as a [`pastax.utils.Unit`][] of measurement.

    Attributes
    ----------
    name : str
        The name of the [`pastax.utils.Unit`], set to `"h"`.
    """

    name: str = eqx.field(static=True, default_factory=lambda: "h")

    def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
        """
        Converts the value to the specified unit.

        Parameters
        ----------
        unit : Unit
            The [`pastax.utils.Unit`][] to convert to.
        value : Array
            The value to convert.
        exp : int or float, optional
            The exponent to use for conversion, defaults to 1.
        *args
            Additional arguments for conversion.

        Returns
        -------
        Array
            The converted value.

        Raises
        ------
        ValueError
            If the conversion is not possible.
        """
        if isinstance(unit, Hours):
            return value

        value = self._pre_convert(value, exp)

        if isinstance(unit, Seconds):
            value = hours_to_seconds(value)
        elif isinstance(unit, Minutes):
            value = hours_to_minutes(value)
        elif isinstance(unit, Days):
            value = hours_to_days(value)
        else:
            raise ValueError(f"Cannot convert {self} to {unit}")

        return self._post_convert(value, exp)
convert_to(unit: Unit, value: Array, exp: int | float = 1, *args) -> Array

Converts the value to the specified unit.

Parameters:

Name Type Description Default
unit Unit

The pastax.utils.Unit to convert to.

required
value Array

The value to convert.

required
exp int or float

The exponent to use for conversion, defaults to 1.

1
*args

Additional arguments for conversion.

()

Returns:

Type Description
Array

The converted value.

Raises:

Type Description
ValueError

If the conversion is not possible.

Source code in pastax/utils/_unit.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
    """
    Converts the value to the specified unit.

    Parameters
    ----------
    unit : Unit
        The [`pastax.utils.Unit`][] to convert to.
    value : Array
        The value to convert.
    exp : int or float, optional
        The exponent to use for conversion, defaults to 1.
    *args
        Additional arguments for conversion.

    Returns
    -------
    Array
        The converted value.

    Raises
    ------
    ValueError
        If the conversion is not possible.
    """
    if isinstance(unit, Hours):
        return value

    value = self._pre_convert(value, exp)

    if isinstance(unit, Seconds):
        value = hours_to_seconds(value)
    elif isinstance(unit, Minutes):
        value = hours_to_minutes(value)
    elif isinstance(unit, Days):
        value = hours_to_days(value)
    else:
        raise ValueError(f"Cannot convert {self} to {unit}")

    return self._post_convert(value, exp)

Days

Bases: Unit

Class representing days as a pastax.utils.Unit of measurement.

Attributes:

Name Type Description
name str

The name of the [pastax.utils.Unit], set to "d".

Source code in pastax/utils/_unit.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
class Days(Unit):
    """
    Class representing days as a [`pastax.utils.Unit`][] of measurement.

    Attributes
    ----------
    name : str
        The name of the [`pastax.utils.Unit`], set to `"d"`.
    """

    name: str = eqx.field(static=True, default_factory=lambda: "d")

    def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
        """
        Converts the value to the specified unit.

        Parameters
        ----------
        unit : Unit
            The [`pastax.utils.Unit`][] to convert to.
        value : Array
            The value to convert.
        exp : int or float, optional
            The exponent to use for conversion, defaults to 1.
        *args
            Additional arguments for conversion.

        Returns
        -------
        Array
            The converted value.

        Raises
        ------
        ValueError
            If the conversion is not possible.
        """
        if isinstance(unit, Days):
            return value

        value = self._pre_convert(value, exp)

        if isinstance(unit, Seconds):
            value = days_to_seconds(value)
        elif isinstance(unit, Minutes):
            value = days_to_minutes(value)
        elif isinstance(unit, Hours):
            value = days_to_hours(value)
        else:
            raise ValueError(f"Cannot convert {self} to {unit}")

        return self._post_convert(value, exp)
convert_to(unit: Unit, value: Array, exp: int | float = 1, *args) -> Array

Converts the value to the specified unit.

Parameters:

Name Type Description Default
unit Unit

The pastax.utils.Unit to convert to.

required
value Array

The value to convert.

required
exp int or float

The exponent to use for conversion, defaults to 1.

1
*args

Additional arguments for conversion.

()

Returns:

Type Description
Array

The converted value.

Raises:

Type Description
ValueError

If the conversion is not possible.

Source code in pastax/utils/_unit.py
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
def convert_to(self, unit: Unit, value: Array, exp: int | float = 1, *args) -> Array:
    """
    Converts the value to the specified unit.

    Parameters
    ----------
    unit : Unit
        The [`pastax.utils.Unit`][] to convert to.
    value : Array
        The value to convert.
    exp : int or float, optional
        The exponent to use for conversion, defaults to 1.
    *args
        Additional arguments for conversion.

    Returns
    -------
    Array
        The converted value.

    Raises
    ------
    ValueError
        If the conversion is not possible.
    """
    if isinstance(unit, Days):
        return value

    value = self._pre_convert(value, exp)

    if isinstance(unit, Seconds):
        value = days_to_seconds(value)
    elif isinstance(unit, Minutes):
        value = days_to_minutes(value)
    elif isinstance(unit, Hours):
        value = days_to_hours(value)
    else:
        raise ValueError(f"Cannot convert {self} to {unit}")

    return self._post_convert(value, exp)

distance_on_earth(latlon1: Float[Array, '... 2'], latlon2: Float[Array, '... 2']) -> Array

Calculates the distance in meters between two points on the Earth's surface.

This function uses the Haversine formula to compute the distance between two (array of) points specified by their latitude and longitude coordinates.

Parameters:

Name Type Description Default
latlon1 Float[Array, '... 2']

A 2-element(s) array containing the latitude and longitude in degrees of the first point(s).

required
latlon2 Float[Array, '... 2']

A 2-element(s) array containing the latitude and longitude in degrees of the second point(s).

required

Returns:

Type Description
Array

The distance between the two (array of) points in meters.

Source code in pastax/utils/_geo.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def distance_on_earth(latlon1: Float[Array, "... 2"], latlon2: Float[Array, "... 2"]) -> Array:
    """
    Calculates the distance in meters between two points on the Earth's surface.

    This function uses the Haversine formula to compute the distance between two (array of) points
    specified by their latitude and longitude coordinates.

    Parameters
    ----------
    latlon1 : Float[Array, "... 2"]
        A 2-element(s) array containing the latitude and longitude in degrees of the first point(s).
    latlon2 : Float[Array, "... 2"]
        A 2-element(s) array containing the latitude and longitude in degrees of the second point(s).

    Returns
    -------
    Array
        The distance between the two (array of) points in meters.
    """

    def safe_for_grad_sqrt(x):
        # grad(sqrt(x)) is not defined for x=0, here we are happy if it evaluates to 0 in that case
        mask = x != 0.0
        y = jnp.sqrt(jnp.where(mask, x, 1.0))  # type: ignore
        return jnp.where(mask, y, 0.0)

    lat1_rad = jnp.radians(latlon1[..., 0])
    lat2_rad = jnp.radians(latlon2[..., 0])
    d_rad = jnp.radians(latlon1 - latlon2)

    a = jnp.sin(d_rad[..., 0] / 2) ** 2 + jnp.cos(lat1_rad) * jnp.cos(lat2_rad) * jnp.sin(d_rad[..., 1] / 2) ** 2
    c = 2 * jnp.atan2(safe_for_grad_sqrt(a), jnp.sqrt(1 - a))
    d = EARTH_RADIUS * c

    return d

longitude_in_180_180_degrees(longitude: Array) -> Array

Adjusts an array of longitudes to be within the range of -180 to 180 degrees.

Parameters:

Name Type Description Default
longitude Array

An array of longitudes in degrees.

required

Returns:

Type Description
Array

The input longitudes adjusted to be within the range of -180 to 180 degrees.

Notes

This function acts as the identity for longitudes that are already within the range of -180 to 180 degrees.

Source code in pastax/utils/_geo.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def longitude_in_180_180_degrees(longitude: Array) -> Array:
    """
    Adjusts an array of longitudes to be within the range of -180 to 180 degrees.

    Parameters
    ----------
    longitude : Array
        An array of longitudes in degrees.

    Returns
    -------
    Array
        The input longitudes adjusted to be within the range of -180 to 180 degrees.

    Notes
    -----
    This function acts as the identity for longitudes that are already within the range of -180 to 180 degrees.
    """
    return (longitude + 180) % 360 - 180

degrees_to_meters(arr: Float[Array, '... 2'], latitude: Float[Array, '...']) -> Float[Array, '... 2']

Converts an array of latitude/longitude distances from degrees to meters.

Parameters:

Name Type Description Default
arr Float[Array, '... 2']

An array of latitude/longitude distances in degrees.

required
latitude Float[Array, '...']

The latitude in degrees at which the conversion is to be performed.

required

Returns:

Type Description
Float[Array, '... 2']

An array of latitude/longitude distances in meters.

Source code in pastax/utils/_unit.py
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
def degrees_to_meters(arr: Float[Array, "... 2"], latitude: Float[Array, "..."]) -> Float[Array, "... 2"]:
    """
    Converts an array of latitude/longitude distances from degrees to meters.

    Parameters
    ----------
    arr : Float[Array, "... 2"]
        An array of latitude/longitude distances in degrees.
    latitude : Float[Array, "..."]
        The latitude in degrees at which the conversion is to be performed.

    Returns
    -------
    Float[Array, "... 2"]
        An array of latitude/longitude distances in meters.
    """
    arr = jnp.radians(arr) * EARTH_RADIUS
    arr = arr.at[..., 1].multiply(jnp.cos(jnp.radians(latitude)))
    return arr

degrees_to_kilometers(arr: Float[Array, '... 2'], latitude: Float[Array, '...']) -> Float[Array, '... 2']

Converts an array of latitude/longitude distances from degrees to kilometers.

Parameters:

Name Type Description Default
arr Float[Array, '... 2']

An array of latitude/longitude distances in degrees.

required
latitude Float[Array, '...']

The latitude in degrees at which the conversion is to be performed.

required

Returns:

Type Description
Float[Array, '... 2']

An array of latitude/longitude distances in kilometers.

Source code in pastax/utils/_unit.py
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
def degrees_to_kilometers(arr: Float[Array, "... 2"], latitude: Float[Array, "..."]) -> Float[Array, "... 2"]:
    """
    Converts an array of latitude/longitude distances from degrees to kilometers.

    Parameters
    ----------
    arr : Float[Array, "... 2"]
        An array of latitude/longitude distances in degrees.
    latitude : Float[Array, "..."]
        The latitude in degrees at which the conversion is to be performed.

    Returns
    -------
    Float[Array, "... 2"]
        An array of latitude/longitude distances in kilometers.
    """
    return meters_to_kilometers(degrees_to_meters(arr, latitude))

meters_to_degrees(arr: Float[Array, '... 2'], latitude: Float[Array, '...']) -> Float[Array, '... 2']

Converts an array of latitude/longitude distances from meters to degrees.

Parameters:

Name Type Description Default
arr Float[Array, '... 2']

An array of latitude/longitude distances in meters.

required
latitude Float[Array, '...']

The latitude in degrees at which the conversion is to be performed.

required

Returns:

Type Description
Float[Array, '... 2']

An array of latitude/longitude distances in degrees.

Source code in pastax/utils/_unit.py
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
def meters_to_degrees(arr: Float[Array, "... 2"], latitude: Float[Array, "..."]) -> Float[Array, "... 2"]:
    """
    Converts an array of latitude/longitude distances from meters to degrees.

    Parameters
    ----------
    arr : Float[Array, "... 2"]
        An array of latitude/longitude distances in meters.
    latitude : Float[Array, "..."]
        The latitude in degrees at which the conversion is to be performed.

    Returns
    -------
    Float[Array, "... 2"]
        An array of latitude/longitude distances in degrees.
    """
    arr = jnp.degrees(arr / EARTH_RADIUS)
    arr = arr.at[..., 1].divide(jnp.cos(jnp.radians(latitude)))
    return arr

meters_to_kilometers(arr: Array) -> Array

Converts an array of distances from meters to kilometers.

Parameters:

Name Type Description Default
arr Array

An array of distances in meters.

required

Returns:

Type Description
Array

An array of distances in kilometers.

Source code in pastax/utils/_unit.py
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
def meters_to_kilometers(arr: Array) -> Array:
    """
    Converts an array of distances from meters to kilometers.

    Parameters
    ----------
    arr : Array
        An array of distances in meters.

    Returns
    -------
    Array
        An array of distances in kilometers.
    """
    return arr / 1000

kilometers_to_degrees(arr: Float[Array, '... 2'], latitude: Float[Array, '...']) -> Float[Array, '... 2']

Converts an array of latitude/longitude distances from kilometers to degrees.

Parameters:

Name Type Description Default
arr Float[Array, '... 2']

An array of latitude/longitude distances in kilometers.

required
latitude Float[Array, '...']

The latitude in degrees at which the conversion is to be performed.

required

Returns:

Type Description
Float[Array, '... 2']

An array of latitude/longitude distances in degrees.

Source code in pastax/utils/_unit.py
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def kilometers_to_degrees(arr: Float[Array, "... 2"], latitude: Float[Array, "..."]) -> Float[Array, "... 2"]:
    """
    Converts an array of latitude/longitude distances from kilometers to degrees.

    Parameters
    ----------
    arr : Float[Array, "... 2"]
        An array of latitude/longitude distances in kilometers.
    latitude : Float[Array, "..."]
        The latitude in degrees at which the conversion is to be performed.

    Returns
    -------
    Float[Array, "... 2"]
        An array of latitude/longitude distances in degrees.
    """
    return meters_to_degrees(kilometers_to_meters(arr), latitude)

kilometers_to_meters(arr: Array) -> Array

Converts an array of distances from kilometers to meters.

Parameters:

Name Type Description Default
arr Array

An array of distances in kilometers.

required

Returns:

Type Description
Array

An array of distances in meters.

Source code in pastax/utils/_unit.py
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
def kilometers_to_meters(arr: Array) -> Array:
    """
    Converts an array of distances from kilometers to meters.

    Parameters
    ----------
    arr : Array
        An array of distances in kilometers.

    Returns
    -------
    Array
        An array of distances in meters.
    """
    return arr * 1000

time_in_seconds(arr: Array | ArrayLike) -> Array | ArrayLike

Converts an array of datetime64 values to seconds since the Unix epoch.

Parameters:

Name Type Description Default
arr Array | ArrayLike

An array of datetime64 values or a single datetime64 value.

required

Returns:

Type Description
Array | ArrayLike

An array of integers representing the number of seconds since the Unix epoch.

Source code in pastax/utils/_unit.py
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
def time_in_seconds(arr: Array | ArrayLike) -> Array | ArrayLike:
    """
    Converts an array of datetime64 values to seconds since the Unix epoch.

    Parameters
    ----------
    arr : Array | ArrayLike
        An array of datetime64 values or a single datetime64 value.

    Returns
    -------
    Array | ArrayLike
        An array of integers representing the number of seconds since the Unix epoch.
    """
    if isinstance(arr, np.datetime64) or (isinstance(arr, np.ndarray) and np.issubdtype(arr.dtype, np.datetime64)):
        arr = arr.astype("datetime64[s]").astype(int)

    return arr

seconds_to_days(arr: Array) -> Array

Converts an array of time durations from seconds to days.

Parameters:

Name Type Description Default
arr Array

An array of time durations in seconds.

required

Returns:

Type Description
Array

An array of time durations in days.

Source code in pastax/utils/_unit.py
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
def seconds_to_days(arr: Array) -> Array:
    """
    Converts an array of time durations from seconds to days.

    Parameters
    ----------
    arr : Array
        An array of time durations in seconds.

    Returns
    -------
    Array
        An array of time durations in days.
    """
    return minutes_to_days(seconds_to_minutes(arr))

units_to_str(unit: dict[Unit, int | float]) -> str

Converts a dictionary of pastax.utils.Unit with their exponents to a string representation.

Parameters:

Name Type Description Default
unit dict[Unit, int or float]

A dictionary of pastax.utils.Unit with their exponents.

required

Returns:

Type Description
str

A string representation of the pastax.utils.Unit with their exponents.

Source code in pastax/utils/_unit.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
def units_to_str(unit: dict[Unit, int | float]) -> str:
    """
    Converts a dictionary of [`pastax.utils.Unit`][] with their exponents to a string representation.

    Parameters
    ----------
    unit : dict[Unit, int or float]
        A dictionary of [`pastax.utils.Unit`][] with their exponents.

    Returns
    -------
    str
        A string representation of the [`pastax.utils.Unit`][] with their exponents.
    """

    def get_exp_str(exp: int | float) -> str:
        if exp == 1:
            return ""
        else:
            return f"^{{{Fraction(exp).limit_denominator()}}}"

    def get_dim_str(dim: Unit, exp: int | float) -> str:
        if exp == 0:
            return ""
        else:
            return f"{dim}{get_exp_str(exp)}"

    return " ".join(get_dim_str(dim, exp) for dim, exp in unit.items()).strip()

compose_units(unit1: dict[Unit, int | float], unit2: dict[Unit, int | float] | None, mul: Literal[-1, 1]) -> dict[Unit, int | float]

Compose two pastax.utils.Unit dictionaries by combining their values, optionally multiplying the second dictionary's values by a factor to account for multiplication or division.

Parameters:

Name Type Description Default
unit1 dict[Unit, int | float]

The first pastax.utils.Unit dictionary.

required
unit2 dict[Unit, int | float]

The second pastax.utils.Unit dictionary.

required
mul Literal[-1, 1]

The multiplier for the second pastax.utils.Unit dictionary's values. Should be either 1 in case of multiplication or -1 in case of division.

required

Returns:

Type Description
dict[Unit, int | float]

The composed pastax.utils.Unit dictionary, or an empty dictionary if both input dictionaries are empty.

Source code in pastax/utils/_unit.py
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def compose_units(
    unit1: dict[Unit, int | float],
    unit2: dict[Unit, int | float] | None,
    mul: Literal[-1, 1],
) -> dict[Unit, int | float]:
    """
    Compose two [`pastax.utils.Unit`][] dictionaries by combining their values,
    optionally multiplying the second dictionary's values by a factor to account for multiplication or division.

    Parameters
    ----------
    unit1 : dict[Unit, int | float]
        The first [`pastax.utils.Unit`][] dictionary.
    unit2 : dict[Unit, int | float]
        The second [`pastax.utils.Unit`][] dictionary.
    mul : Literal[-1, 1]
        The multiplier for the second [`pastax.utils.Unit`][] dictionary's values.
        Should be either 1 in case of multiplication or -1 in case of division.

    Returns
    -------
    dict[Unit, int | float]
        The composed [`pastax.utils.Unit`][] dictionary, or an empty dictionary if both input dictionaries are empty.
    """
    if (not unit1) and (not unit2):
        return {}
    if not unit1:
        return unit2  # type: ignore
    if not unit2:
        return unit1

    unit = unit1.copy()

    for k, v in unit2.items():
        v *= mul
        if k in unit:
            unit[k] += v
        else:
            unit[k] = v

    return unit