Skip to content

Commit 5d09d4a

Browse files
authored
Merge pull request #515 from RocketPy-Team/enh/shepard-multiple-opt
ENH: Shepard Optimized Interpolation - Multiple Inputs Support
2 parents 2ebed8d + 48e4f5d commit 5d09d4a

File tree

3 files changed

+94
-39
lines changed

3 files changed

+94
-39
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ straightforward as possible.
3232

3333
### Added
3434

35+
- ENH: Shepard Optimized Interpolation - Multiple Inputs Support [#515](https://github.com/RocketPy-Team/RocketPy/pull/515)
3536
- ENH: adds new Function.savetxt method [#514](https://github.com/RocketPy-Team/RocketPy/pull/514)
3637
- ENH: Argument for Optional Mutation on Function Discretize [#519](https://github.com/RocketPy-Team/RocketPy/pull/519)
3738

rocketpy/mathutils/function.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -483,25 +483,9 @@ def get_value_opt(x):
483483
return y
484484

485485
elif self.__interpolation__ == "shepard":
486-
x_data = self.source[:, 0:-1] # Support for N-Dimensions
487-
y_data = self.source[:, -1]
488-
len_y_data = len(y_data) # A little speed up
489-
490486
# change the function's name to avoid mypy's error
491487
def get_value_opt_multiple(*args):
492-
x = np.array([[float(val) for val in args]])
493-
sub_matrix = x_data - x
494-
distances_squared = np.sum(sub_matrix**2, axis=1)
495-
496-
zero_distance_index = np.where(distances_squared == 0)[0]
497-
if len(zero_distance_index) > 0:
498-
return y_data[zero_distance_index[0]]
499-
500-
weights = distances_squared ** (-1.5)
501-
numerator_sum = np.sum(y_data * weights)
502-
denominator_sum = np.sum(weights)
503-
504-
return numerator_sum / denominator_sum
488+
return self.__interpolate_shepard__(args)
505489

506490
get_value_opt = get_value_opt_multiple
507491

@@ -903,28 +887,8 @@ def get_value(self, *args):
903887

904888
# Returns value for shepard interpolation
905889
elif self.__interpolation__ == "shepard":
906-
if all(isinstance(arg, Iterable) for arg in args):
907-
x = list(np.column_stack(args))
908-
else:
909-
x = [[float(x) for x in list(args)]]
910-
ans = x
911-
x_data = self.source[:, 0:-1]
912-
y_data = self.source[:, -1]
913-
for i, _ in enumerate(x):
914-
numerator_sum = 0
915-
denominator_sum = 0
916-
for o, _ in enumerate(y_data):
917-
sub = x_data[o] - x[i]
918-
distance = (sub.dot(sub)) ** (0.5)
919-
if distance == 0:
920-
numerator_sum = y_data[o]
921-
denominator_sum = 1
922-
break
923-
weight = distance ** (-3)
924-
numerator_sum = numerator_sum + y_data[o] * weight
925-
denominator_sum = denominator_sum + weight
926-
ans[i] = numerator_sum / denominator_sum
927-
return ans if len(ans) > 1 else ans[0]
890+
return self.__interpolate_shepard__(args)
891+
928892
# Returns value for polynomial interpolation function type
929893
elif self.__interpolation__ == "polynomial":
930894
if isinstance(args[0], (int, float)):
@@ -1687,6 +1651,47 @@ def __interpolate_akima__(self):
16871651
coeffs[4 * i : 4 * i + 4] = np.linalg.solve(matrix, result)
16881652
self.__akima_coefficients__ = coeffs
16891653

1654+
def __interpolate_shepard__(self, args):
1655+
"""Calculates the shepard interpolation from the given arguments.
1656+
The shepard interpolation is computed by a inverse distance weighting
1657+
in a vectorized manner.
1658+
1659+
Parameters
1660+
----------
1661+
args : scalar, list
1662+
Values where the Function is to be evaluated.
1663+
1664+
Returns
1665+
-------
1666+
result : scalar, list
1667+
The result of the interpolation.
1668+
"""
1669+
x_data = self.source[:, 0:-1] # Support for N-Dimensions
1670+
y_data = self.source[:, -1]
1671+
1672+
arg_stack = np.column_stack(args)
1673+
arg_qty, arg_dim = arg_stack.shape
1674+
result = np.zeros(arg_qty)
1675+
1676+
# Reshape to vectorize calculations
1677+
x = arg_stack.reshape(arg_qty, 1, arg_dim)
1678+
1679+
sub_matrix = x_data - x
1680+
distances_squared = np.sum(sub_matrix**2, axis=2)
1681+
1682+
# Remove zero distances from further calculations
1683+
zero_distances = np.where(distances_squared == 0)
1684+
valid_indexes = np.ones(arg_qty, dtype=bool)
1685+
valid_indexes[zero_distances[0]] = False
1686+
1687+
weights = distances_squared[valid_indexes] ** (-1.5)
1688+
numerator_sum = np.sum(y_data * weights, axis=1)
1689+
denominator_sum = np.sum(weights, axis=1)
1690+
result[valid_indexes] = numerator_sum / denominator_sum
1691+
result[~valid_indexes] = y_data[zero_distances[1]]
1692+
1693+
return result if len(result) > 1 else result[0]
1694+
16901695
def __neg__(self):
16911696
"""Negates the Function object. The result has the same effect as
16921697
multiplying the Function by -1.

tests/test_function.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,55 @@ def test_multivariable_dataset(a, b):
215215
assert np.isclose(func(a, b), a + b, atol=1e-6)
216216

217217

218+
@pytest.mark.parametrize(
219+
"x,y,z_expected",
220+
[
221+
(1, 0, 0),
222+
(0, 1, 0),
223+
(0, 0, 1),
224+
(0.5, 0.5, 1 / 3),
225+
(0.25, 0.25, 25 / (25 + 2 * 5**0.5)),
226+
([0, 0.5], [0, 0.5], [1, 1 / 3]),
227+
],
228+
)
229+
def test_2d_shepard_interpolation(x, y, z_expected):
230+
"""Test the shepard interpolation method of the Function class."""
231+
# Test plane x + y + z = 1
232+
source = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
233+
func = Function(
234+
source=source, inputs=["x", "y"], outputs=["z"], interpolation="shepard"
235+
)
236+
z = func(x, y)
237+
z_opt = func.get_value_opt(x, y)
238+
assert np.isclose(z, z_opt, atol=1e-8).all()
239+
assert np.isclose(z_expected, z, atol=1e-8).all()
240+
241+
242+
@pytest.mark.parametrize(
243+
"x,y,z,w_expected",
244+
[
245+
(0, 0, 0, 1),
246+
(1, 0, 0, 0),
247+
(0, 1, 0, 0),
248+
(0, 0, 1, 0),
249+
(0.5, 0.5, 0.5, 1 / 4),
250+
(0.25, 0.25, 0.25, 0.700632626832),
251+
([0, 0.5], [0, 0.5], [0, 0.5], [1, 1 / 4]),
252+
],
253+
)
254+
def test_3d_shepard_interpolation(x, y, z, w_expected):
255+
"""Test the shepard interpolation method of the Function class."""
256+
# Test plane x + y + z + w = 1
257+
source = [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]
258+
func = Function(
259+
source=source, inputs=["x", "y", "z"], outputs=["w"], interpolation="shepard"
260+
)
261+
w = func(x, y, z)
262+
w_opt = func.get_value_opt(x, y, z)
263+
assert np.isclose(w, w_opt, atol=1e-8).all()
264+
assert np.isclose(w_expected, w, atol=1e-8).all()
265+
266+
218267
@pytest.mark.parametrize("a", [-1, -0.5, 0, 0.5, 1])
219268
@pytest.mark.parametrize("b", [-1, -0.5, 0, 0.5, 1])
220269
def test_multivariable_function(a, b):

0 commit comments

Comments
 (0)