Skip to content

Commit 25374fa

Browse files
Merge pull request #554 from RocketPy-Team/enh/function-remove-outliers
ENH: adds `Function.remove_outliers` method
2 parents 35a9439 + fda6b33 commit 25374fa

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
3232

3333
### Added
3434

35+
- ENH: adds `Function.remove_outliers` method [#554](https://github.com/RocketPy-Team/RocketPy/pull/554)
3536

3637
### Changed
3738
- ENH: Optional argument to show the plot in Function.compare_plots [#563](https://github.com/RocketPy-Team/RocketPy/pull/563)

rocketpy/mathutils/function.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,51 @@ def low_pass_filter(self, alpha, file_path=None):
11441144
title=self.title,
11451145
)
11461146

1147+
def remove_outliers_iqr(self, threshold=1.5):
1148+
"""Remove outliers from the Function source using the interquartile
1149+
range method. The Function should have an array-like source.
1150+
1151+
Parameters
1152+
----------
1153+
threshold : float, optional
1154+
Threshold for the interquartile range method. Default is 1.5.
1155+
1156+
Returns
1157+
-------
1158+
Function
1159+
The Function with the outliers removed.
1160+
1161+
References
1162+
----------
1163+
[1] https://en.wikipedia.org/wiki/Outlier#Tukey's_fences
1164+
"""
1165+
1166+
if callable(self.source):
1167+
raise TypeError(
1168+
"Cannot remove outliers if the source is a callable object."
1169+
+ " The Function.source should be array-like."
1170+
)
1171+
1172+
x = self.x_array
1173+
y = self.y_array
1174+
y_q1 = np.percentile(y, 25)
1175+
y_q3 = np.percentile(y, 75)
1176+
y_iqr = y_q3 - y_q1
1177+
y_lower = y_q1 - threshold * y_iqr
1178+
y_upper = y_q3 + threshold * y_iqr
1179+
1180+
y_filtered = y[(y >= y_lower) & (y <= y_upper)]
1181+
x_filtered = x[(y >= y_lower) & (y <= y_upper)]
1182+
1183+
return Function(
1184+
source=np.column_stack((x_filtered, y_filtered)),
1185+
inputs=self.__inputs__,
1186+
outputs=self.__outputs__,
1187+
interpolation=self.__interpolation__,
1188+
extrapolation=self.__extrapolation__,
1189+
title=self.title,
1190+
)
1191+
11471192
# Define all presentation methods
11481193
def __call__(self, *args):
11491194
"""Plot the Function if no argument is given. If an

tests/unit/test_function.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,33 @@ def test_set_discrete_based_on_model_non_mutator(linear_func):
276276
assert isinstance(func, Function)
277277
assert discretized_func.source.shape == (4, 2)
278278
assert callable(func.source)
279+
280+
281+
@pytest.mark.parametrize(
282+
"x, y, expected_x, expected_y",
283+
[
284+
(
285+
np.array([1, 2, 3, 4, 5, 6]),
286+
np.array([10, 20, 30, 40, 50000, 60]),
287+
np.array([1, 2, 3, 4, 6]),
288+
np.array([10, 20, 30, 40, 60]),
289+
),
290+
],
291+
)
292+
def test_remove_outliers_iqr(x, y, expected_x, expected_y):
293+
"""Test the function remove_outliers_iqr which is expected to remove
294+
outliers from the data based on the Interquartile Range (IQR) method.
295+
"""
296+
func = Function(source=np.column_stack((x, y)))
297+
filtered_func = func.remove_outliers_iqr(threshold=1.5)
298+
299+
# Check if the outliers are removed
300+
assert np.array_equal(filtered_func.x_array, expected_x)
301+
assert np.array_equal(filtered_func.y_array, expected_y)
302+
303+
# Check if the other attributes are preserved
304+
assert filtered_func.__inputs__ == func.__inputs__
305+
assert filtered_func.__outputs__ == func.__outputs__
306+
assert filtered_func.__interpolation__ == func.__interpolation__
307+
assert filtered_func.__extrapolation__ == func.__extrapolation__
308+
assert filtered_func.title == func.title

0 commit comments

Comments
 (0)