Skip to content

Patched edge-case in LinearInterpolation #305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions diffrax/global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .custom_types import Array, DenseInfos, Int, PyTree, Scalar
from .local_interpolation import AbstractLocalInterpolation
from .misc import fill_forward, left_broadcast_to
from .misc import fill_forward, left_broadcast_to, linear_rescale
from .path import AbstractPath


Expand Down Expand Up @@ -124,10 +124,10 @@ def _index(_ys):
next_ys = (self.ys**ω)[index + 1].ω
prev_t = self.ts[index]
next_t = self.ts[index + 1]
diff_t = next_t - prev_t

return (
prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
prev_ys**ω
+ (next_ys**ω - prev_ys**ω)
* (linear_rescale(prev_t, fractional_part, next_t))
).ω

@eqx.filter_jit
Expand Down Expand Up @@ -407,7 +407,6 @@ def _linear_interpolation_forward(
Tuple[Array["channels":...], Array["channels":...]], # noqa: F821
Array["channels":...], # noqa: F821
]:

prev_ti, prev_yi = carry
ti, yi, next_ti, next_yi = value
cond = jnp.isnan(yi)
Expand All @@ -426,7 +425,6 @@ def _linear_interpolation(
ys: Array["times", "channels":...], # noqa: F821
replace_nans_at_start: Optional[Array["channels":...]] = None, # noqa: F821
) -> Array["times", "channels":...]: # noqa: F821

ts = left_broadcast_to(ts, ys.shape)

if replace_nans_at_start is None:
Expand Down Expand Up @@ -599,7 +597,6 @@ def _hermite_forward(
Array["channels":...], # noqa: F821
],
]:

prev_ti, prev_yi, prev_deriv_i = carry
ti, yi, next_ti, next_yi = value
first_deriv_i = (next_yi - yi) / (next_ti - ti)
Expand Down