Skip to content

Commit 0fb38d2

Browse files
feat: added lag_plot (#548)
Closes #519 ### Summary of Changes I added the visualization of the lag plot to the timeseries class --------- Co-authored-by: megalinter-bot <[email protected]>
1 parent 2f1d5c5 commit 0fb38d2

File tree

4 files changed

+91
-1
lines changed

4 files changed

+91
-1
lines changed

src/safeds/data/tabular/containers/_column.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
import sys
43
import io
4+
import sys
55
from collections.abc import Sequence
66
from numbers import Number
77
from typing import TYPE_CHECKING, Any, TypeVar, overload

src/safeds/data/tabular/containers/_time_series.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
from __future__ import annotations
22

3+
import io
34
import sys
45
from typing import TYPE_CHECKING
56

7+
import matplotlib.pyplot as plt
8+
import pandas as pd
9+
10+
from safeds.data.image.containers import Image
611
from safeds.data.tabular.containers import Column, Row, Table, TaggedTable
712
from safeds.exceptions import (
813
ColumnIsTargetError,
914
ColumnIsTimeError,
1015
IllegalSchemaModificationError,
16+
NonNumericColumnError,
1117
UnknownColumnNameError,
1218
)
1319

@@ -839,6 +845,13 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Time
839845
840846
The original time series is not modified.
841847
848+
Parameters
849+
----------
850+
name:
851+
The name of the column to be transformed.
852+
transformer:
853+
The transformer to the given column
854+
842855
Returns
843856
-------
844857
result : TimeSeries
@@ -857,3 +870,39 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Time
857870
),
858871
time_name=self.time.name,
859872
)
873+
874+
def plot_lagplot(self, lag: int) -> Image:
875+
"""
876+
Plot a lagplot for the target column.
877+
878+
Parameters
879+
----------
880+
lag:
881+
The amount of lag used to plot
882+
883+
Returns
884+
-------
885+
plot:
886+
The plot as an image.
887+
888+
Raises
889+
------
890+
NonNumericColumnError
891+
If the time series targets contains non-numerical values.
892+
893+
Examples
894+
--------
895+
>>> from safeds.data.tabular.containers import TimeSeries
896+
>>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], )
897+
>>> image = table.plot_lagplot(lag = 1)
898+
899+
"""
900+
if not self.target.type.is_numeric():
901+
raise NonNumericColumnError("This time series target contains non-numerical columns.")
902+
ax = pd.plotting.lag_plot(self.target._data, lag=lag)
903+
fig = ax.figure
904+
buffer = io.BytesIO()
905+
fig.savefig(buffer, format="png")
906+
plt.close() # Prevents the figure from being displayed directly
907+
buffer.seek(0)
908+
return Image.from_bytes(buffer.read())
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
from safeds.data.tabular.containers import TimeSeries
3+
from safeds.exceptions import NonNumericColumnError
4+
from syrupy import SnapshotAssertion
5+
6+
7+
def test_should_return_table(snapshot_png: SnapshotAssertion) -> None:
8+
table = TimeSeries(
9+
{
10+
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
11+
"feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
12+
"target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
13+
},
14+
target_name="target",
15+
time_name="time",
16+
feature_names=None,
17+
)
18+
lag_plot = table.plot_lagplot(lag=1)
19+
assert lag_plot == snapshot_png
20+
21+
22+
def test_should_raise_if_column_contains_non_numerical_values() -> None:
23+
table = TimeSeries(
24+
{
25+
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
26+
"feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
27+
"target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
28+
},
29+
target_name="target",
30+
time_name="time",
31+
feature_names=None,
32+
)
33+
with pytest.raises(
34+
NonNumericColumnError,
35+
match=(
36+
r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis time series target"
37+
r" contains"
38+
r" non-numerical columns."
39+
),
40+
):
41+
table.plot_lagplot(2)

0 commit comments

Comments
 (0)