@@ -483,25 +483,9 @@ def get_value_opt(x):
483
483
return y
484
484
485
485
elif self .__interpolation__ == "shepard" :
486
- x_data = self .source [:, 0 :- 1 ] # Support for N-Dimensions
487
- y_data = self .source [:, - 1 ]
488
- len_y_data = len (y_data ) # A little speed up
489
-
490
486
# change the function's name to avoid mypy's error
491
487
def get_value_opt_multiple (* args ):
492
- x = np .array ([[float (val ) for val in args ]])
493
- sub_matrix = x_data - x
494
- distances_squared = np .sum (sub_matrix ** 2 , axis = 1 )
495
-
496
- zero_distance_index = np .where (distances_squared == 0 )[0 ]
497
- if len (zero_distance_index ) > 0 :
498
- return y_data [zero_distance_index [0 ]]
499
-
500
- weights = distances_squared ** (- 1.5 )
501
- numerator_sum = np .sum (y_data * weights )
502
- denominator_sum = np .sum (weights )
503
-
504
- return numerator_sum / denominator_sum
488
+ return self .__interpolate_shepard__ (args )
505
489
506
490
get_value_opt = get_value_opt_multiple
507
491
@@ -903,28 +887,8 @@ def get_value(self, *args):
903
887
904
888
# Returns value for shepard interpolation
905
889
elif self .__interpolation__ == "shepard" :
906
- if all (isinstance (arg , Iterable ) for arg in args ):
907
- x = list (np .column_stack (args ))
908
- else :
909
- x = [[float (x ) for x in list (args )]]
910
- ans = x
911
- x_data = self .source [:, 0 :- 1 ]
912
- y_data = self .source [:, - 1 ]
913
- for i , _ in enumerate (x ):
914
- numerator_sum = 0
915
- denominator_sum = 0
916
- for o , _ in enumerate (y_data ):
917
- sub = x_data [o ] - x [i ]
918
- distance = (sub .dot (sub )) ** (0.5 )
919
- if distance == 0 :
920
- numerator_sum = y_data [o ]
921
- denominator_sum = 1
922
- break
923
- weight = distance ** (- 3 )
924
- numerator_sum = numerator_sum + y_data [o ] * weight
925
- denominator_sum = denominator_sum + weight
926
- ans [i ] = numerator_sum / denominator_sum
927
- return ans if len (ans ) > 1 else ans [0 ]
890
+ return self .__interpolate_shepard__ (args )
891
+
928
892
# Returns value for polynomial interpolation function type
929
893
elif self .__interpolation__ == "polynomial" :
930
894
if isinstance (args [0 ], (int , float )):
@@ -1687,6 +1651,47 @@ def __interpolate_akima__(self):
1687
1651
coeffs [4 * i : 4 * i + 4 ] = np .linalg .solve (matrix , result )
1688
1652
self .__akima_coefficients__ = coeffs
1689
1653
1654
+ def __interpolate_shepard__ (self , args ):
1655
+ """Calculates the shepard interpolation from the given arguments.
1656
+ The shepard interpolation is computed by a inverse distance weighting
1657
+ in a vectorized manner.
1658
+
1659
+ Parameters
1660
+ ----------
1661
+ args : scalar, list
1662
+ Values where the Function is to be evaluated.
1663
+
1664
+ Returns
1665
+ -------
1666
+ result : scalar, list
1667
+ The result of the interpolation.
1668
+ """
1669
+ x_data = self .source [:, 0 :- 1 ] # Support for N-Dimensions
1670
+ y_data = self .source [:, - 1 ]
1671
+
1672
+ arg_stack = np .column_stack (args )
1673
+ arg_qty , arg_dim = arg_stack .shape
1674
+ result = np .zeros (arg_qty )
1675
+
1676
+ # Reshape to vectorize calculations
1677
+ x = arg_stack .reshape (arg_qty , 1 , arg_dim )
1678
+
1679
+ sub_matrix = x_data - x
1680
+ distances_squared = np .sum (sub_matrix ** 2 , axis = 2 )
1681
+
1682
+ # Remove zero distances from further calculations
1683
+ zero_distances = np .where (distances_squared == 0 )
1684
+ valid_indexes = np .ones (arg_qty , dtype = bool )
1685
+ valid_indexes [zero_distances [0 ]] = False
1686
+
1687
+ weights = distances_squared [valid_indexes ] ** (- 1.5 )
1688
+ numerator_sum = np .sum (y_data * weights , axis = 1 )
1689
+ denominator_sum = np .sum (weights , axis = 1 )
1690
+ result [valid_indexes ] = numerator_sum / denominator_sum
1691
+ result [~ valid_indexes ] = y_data [zero_distances [1 ]]
1692
+
1693
+ return result if len (result ) > 1 else result [0 ]
1694
+
1690
1695
def __neg__ (self ):
1691
1696
"""Negates the Function object. The result has the same effect as
1692
1697
multiplying the Function by -1.
0 commit comments