Skip to content

Commit 3f35ead

Browse files
authored
Merge pull request #542 from RocketPy-Team/enh/function-csv-inputs
ENH: Function inputs from CSV file header.
2 parents 1132a11 + ad7c573 commit 3f35ead

File tree

3 files changed

+44
-16
lines changed

3 files changed

+44
-16
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ straightforward as possible.
3232

3333
### Added
3434

35+
- ENH: Function Support for CSV Header Inputs [#542](https://github.com/RocketPy-Team/RocketPy/pull/542)
3536
- ENH: Shepard Optimized Interpolation - Multiple Inputs Support [#515](https://github.com/RocketPy-Team/RocketPy/pull/515)
3637
- ENH: adds new Function.savetxt method [#514](https://github.com/RocketPy-Team/RocketPy/pull/514)
3738
- ENH: Argument for Optional Mutation on Function Discretize [#519](https://github.com/RocketPy-Team/RocketPy/pull/519)

rocketpy/mathutils/function.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def __init__(
5959
and 'z' is the output.
6060
6161
- string: Path to a CSV file. The file is read and converted into an
62-
ndarray. The file can optionally contain a single header line.
62+
ndarray. The file can optionally contain a single header line, see
63+
notes below for more information.
6364
6465
- Function: Copies the source of the provided Function object,
6566
creating a new Function with adjusted inputs and outputs.
@@ -94,12 +95,19 @@ def __init__(
9495
9596
Notes
9697
-----
97-
(I) CSV files can optionally contain a single header line. If present,
98-
the header is ignored during processing.
99-
(II) Fields in CSV files may be enclosed in double quotes. If fields are
100-
not quoted, double quotes should not appear inside them.
98+
(I) CSV files may include an optional single header line. If this
99+
header line is present and contains names for each data column, those
100+
names will be used to label the inputs and outputs unless specified
101+
otherwise by the `inputs` and `outputs` arguments.
102+
If the header is specified for only a few columns, it is ignored.
103+
104+
Commas in a header will be interpreted as a delimiter, which may cause
105+
undesired input or output labeling. To avoid this, specify each input
106+
and output name using the `inputs` and `outputs` arguments.
107+
108+
(II) Fields in CSV files may be enclosed in double quotes. If fields
109+
are not quoted, double quotes should not appear inside them.
101110
"""
102-
# Set input and output
103111
if inputs is None:
104112
inputs = ["Scalar"]
105113
if outputs is None:
@@ -184,10 +192,18 @@ def set_source(self, source):
184192
185193
Notes
186194
-----
187-
(I) CSV files can optionally contain a single header line. If present,
188-
the header is ignored during processing.
189-
(II) Fields in CSV files may be enclosed in double quotes. If fields are
190-
not quoted, double quotes should not appear inside them.
195+
(I) CSV files may include an optional single header line. If this
196+
header line is present and contains names for each data column, those
197+
names will be used to label the inputs and outputs unless specified
198+
otherwise. If the header is specified for only a few columns, it is
199+
ignored.
200+
201+
Commas in a header will be interpreted as a delimiter, which may cause
202+
undesired input or output labeling. To avoid this, specify each input
203+
and output name using the `inputs` and `outputs` arguments.
204+
205+
(II) Fields in CSV files may be enclosed in double quotes. If fields
206+
are not quoted, double quotes should not appear inside them.
191207
192208
Returns
193209
-------
@@ -3019,7 +3035,7 @@ def _check_user_input(
30193035
if isinstance(inputs, str):
30203036
inputs = [inputs]
30213037

3022-
elif len(outputs) > 1:
3038+
if len(outputs) > 1:
30233039
raise ValueError(
30243040
"Output must either be a string or have dimension 1, "
30253041
+ f"it currently has dimension ({len(outputs)})."
@@ -3036,8 +3052,19 @@ def _check_user_input(
30363052
try:
30373053
source = np.loadtxt(source, delimiter=",", dtype=float)
30383054
except ValueError:
3039-
# Skip header
3040-
source = np.loadtxt(source, delimiter=",", dtype=float, skiprows=1)
3055+
with open(source, "r") as file:
3056+
header, *data = file.read().splitlines()
3057+
3058+
header = [
3059+
label.strip("'").strip('"') for label in header.split(",")
3060+
]
3061+
source = np.loadtxt(data, delimiter=",", dtype=float)
3062+
3063+
if len(source[0]) == len(header):
3064+
if inputs == ["Scalar"]:
3065+
inputs = header[:-1]
3066+
if outputs == ["Scalar"]:
3067+
outputs = [header[-1]]
30413068
except Exception as e:
30423069
raise ValueError(
30433070
"The source file is not a valid csv or txt file."
@@ -3055,7 +3082,7 @@ def _check_user_input(
30553082

30563083
## single dimension
30573084
if source_dim == 2:
3058-
# possible interpolation values: llinear, polynomial, akima and spline
3085+
# possible interpolation values: linear, polynomial, akima and spline
30593086
if interpolation is None:
30603087
interpolation = "spline"
30613088
elif interpolation.lower() not in [
@@ -3106,7 +3133,7 @@ def _check_user_input(
31063133
in_out_dim = len(inputs) + len(outputs)
31073134
if source_dim != in_out_dim:
31083135
raise ValueError(
3109-
"Source dimension ({source_dim}) does not match input "
3136+
f"Source dimension ({source_dim}) does not match input "
31103137
+ f"and output dimension ({in_out_dim})."
31113138
)
31123139
return inputs, outputs, interpolation, extrapolation

tests/test_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_func_from_csv_with_header(csv_file):
4949
line. It tests cases where the fields are separated by quotes and without
5050
quotes."""
5151
f = Function(csv_file)
52-
assert f.__repr__() == "'Function from R1 to R1 : (Scalar) → (Scalar)'"
52+
assert f.__repr__() == "'Function from R1 to R1 : (time) → (value)'"
5353
assert np.isclose(f(0), 100)
5454
assert np.isclose(f(0) + f(1), 300), "Error summing the values of the function"
5555

0 commit comments

Comments
 (0)