Source code for gradgraph.optimization.tf.interp

#!/usr/bin/env python3
# 
# interp.py
# 
# Created by Nicolas Fricker on 08/22/2025.
# Copyright © 2025 Nicolas Fricker. All rights reserved.
# 

import tensorflow as tf

[docs] def batch_linear_interp_1d( x, x_ref_min, x_ref_max, y_ref, clamp=True, fill_value_below=0.0, fill_value_above=0.0, ): """ Perform batched differentiable 1D linear interpolation with shape safety. Parameters ---------- x : array_like Array of shape `(batch, n_interp)` containing x values to interpolate. x_ref_min : array_like or scalar Array of shape `(batch,)` or a scalar representing the start of the reference grid. x_ref_max : array_like or scalar Array of shape `(batch,)` or a scalar representing the end of the reference grid. y_ref : array_like Array of shape `(batch, nx)` containing y values at regularly spaced grid points. clamp : bool, optional If True, clamps x to the range `[x_ref_min, x_ref_max]`. Default is True. fill_value_below : scalar, optional Scalar value to use for out-of-bound handling below `x_ref_min` if `clamp` is False. Default is 0.0. fill_value_above : scalar, optional Scalar value to use for out-of-bound handling above `x_ref_max` if `clamp` is False. Default is 0.0. Returns ------- y_interp : array_like Array of shape `(batch, n_interp)` containing the interpolated y values. Notes ----- This function assumes that `y_ref` is defined on a regularly spaced grid between `x_ref_min` and `x_ref_max`. The interpolation is performed in a batched manner, allowing for efficient computation over multiple sets of input data. Examples -------- >>> import numpy as np >>> x = np.array([[0.5, 1.5], [2.0, 3.0]]) >>> x_ref_min = np.array([0.0, 1.0]) >>> x_ref_max = np.array([2.0, 3.0]) >>> y_ref = np.array([[0.0, 1.0, 2.0], [1.0, 2.0, 3.0]]) >>> batch_linear_interp_1d(x, x_ref_min, x_ref_max, y_ref) array([[0.5, 1.5], [2.0, 3.0]]) """ x = tf.convert_to_tensor(x) dtype = x.dtype y_ref = tf.convert_to_tensor(y_ref, dtype=dtype) batch_size = tf.shape(x)[0] n_interp = tf.shape(x)[1] nx = tf.shape(y_ref)[1] x_ref_min = tf.reshape(x_ref_min, [-1, 1]) # (batch, 1) x_ref_max = tf.reshape(x_ref_max, [-1, 1]) # (batch, 1) if clamp: x_clamped = tf.clip_by_value(x, x_ref_min, x_ref_max) else: x_clamped = x t = tf.cast((x_clamped - x_ref_min) / (x_ref_max - x_ref_min + 1e-8), dtype) x_idx = t * tf.cast(nx - 1, dtype) i0 = tf.floor(x_idx) i1 = i0 + 1 i0 = tf.clip_by_value(i0, 0.0, tf.cast(nx - 1, dtype)) i1 = tf.clip_by_value(i1, 0.0, tf.cast(nx - 1, dtype)) i0 = tf.cast(i0, tf.int32) i1 = tf.cast(i1, tf.int32) w1 = x_idx - tf.cast(i0, dtype) w0 = 1.0 - w1 batch_idx = tf.range(batch_size)[:, tf.newaxis] # (batch, 1) batch_idx = tf.tile(batch_idx, [1, n_interp]) # (batch, n_interp) i0 = tf.reshape(i0, [batch_size, n_interp]) i1 = tf.reshape(i1, [batch_size, n_interp]) idx0 = tf.stack([batch_idx, i0], axis=-1) # (batch, n_interp, 2) idx1 = tf.stack([batch_idx, i1], axis=-1) y0 = tf.gather_nd(y_ref, idx0) y1 = tf.gather_nd(y_ref, idx1) y_interp = w0 * y0 + w1 * y1 if not clamp: below = tf.less(x, x_ref_min) # (batch, n_interp) above = tf.greater(x, x_ref_max) y_interp = tf.where(below, fill_value_below, y_interp) y_interp = tf.where(above, fill_value_above, y_interp) return y_interp