Skip to content

Commit 3550b0f

Browse files
committed
ENH: improved new methods
1 parent 8b78bb5 commit 3550b0f

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

rocketpy/Function.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,7 +1986,11 @@ def integral(self, a, b, numerical=False):
19861986
ans : float
19871987
Evaluated integral.
19881988
"""
1989-
if self.__interpolation__ == "spline" and numerical is False:
1989+
if self.__interpolation__ == "linear" and not numerical:
1990+
Xs = np.linspace(a, b, int((b - a) * 5))
1991+
Ys = self.getValue(Xs)
1992+
ans = np.trapz(Ys, x=Xs)
1993+
elif self.__interpolation__ == "spline" and not numerical:
19901994
# Integrate using spline coefficients
19911995
xData = self.xArray
19921996
yData = self.yArray
@@ -2047,8 +2051,9 @@ def integral(self, a, b, numerical=False):
20472051
else:
20482052
# self.__extrapolation__ = 'zero'
20492053
pass
2050-
elif self.__interpolation__ == "linear" and numerical is False:
2051-
return np.trapz(self.yArray, x=self.xArray)
2054+
else:
2055+
# Integrate numerically
2056+
ans, _ = integrate.quad(self, a, b, epsabs=0.01, limit=10000)
20522057
return ans
20532058

20542059
def differentiate(self, x, dx=1e-6):
@@ -2080,18 +2085,21 @@ def derivativeFunction(self):
20802085
# Check if Function object source is array
20812086
if isinstance(self.source, np.ndarray):
20822087
# Operate on grid values
2083-
Ys = np.diff(self.source[:, 1]) / np.diff(self.source[:, 0])
2084-
Xs = self.source[:-1, 0] + np.diff(self.source[:, 0]) / 2
2085-
source = np.concatenate(([Xs], [Ys])).transpose()
2088+
Ys = np.diff(self.yArray) / np.diff(self.xArray)
2089+
Xs = self.source[:-1, 0] + np.diff(self.xArray) / 2
2090+
source = np.column_stack((Xs, Ys))
20862091
# Retrieve inputs, outputs and interpolation
20872092
inputs = self.__inputs__[:]
2088-
outputs = "d(" + self.__outputs__[0] + ")/d(" + inputs[0] + ")"
2089-
outputs = "(" + outputs + ")"
2093+
outputs = f"d({self.__outputs__[0]})/d({inputs[0]})"
20902094
interpolation = "linear"
2091-
# Create new Function object
2092-
return Function(source, inputs, outputs, interpolation)
20932095
else:
2094-
return Function(lambda x: self.differentiate(x))
2096+
source = lambda x: self.differentiate(x)
2097+
inputs = self.__inputs__[:]
2098+
outputs = f"d({self.__outputs__[0]})/d({inputs[0]})"
2099+
interpolation = "linear"
2100+
2101+
# Create new Function object
2102+
return Function(source, inputs, outputs, interpolation)
20952103

20962104
def integralFunction(self, lower=None, upper=None, datapoints=100):
20972105
"""Returns a Function object representing the integral of the Function object.
@@ -2124,7 +2132,7 @@ def integralFunction(self, lower=None, upper=None, datapoints=100):
21242132
for i in range(datapoints):
21252133
yData[i] = self.integral(lower, xData[i])
21262134
return Function(
2127-
np.concatenate(([xData], [yData])).transpose(),
2135+
np.column_stack((xData, yData)),
21282136
inputs=self.__inputs__,
21292137
outputs=[o + " Integral" for o in self.__outputs__],
21302138
)
@@ -2155,27 +2163,18 @@ def inverseFunction(self, approxFunc=None, tol=1e-4):
21552163
"""
21562164
if isinstance(self.source, np.ndarray):
21572165
# Swap the columns
2158-
source = np.concatenate(
2159-
([self.source[:, 1]], [self.source[:, 0]])
2160-
).transpose()
2161-
2162-
return Function(
2163-
source,
2164-
inputs=self.__outputs__,
2165-
outputs=self.__inputs__,
2166-
interpolation=self.__interpolation__,
2167-
)
2166+
source = np.flip(self.source, axis=1)
21682167
else:
21692168
if approxFunc:
2170-
source = lambda x: self.findInput(x, approxFunc(x), tol)
2169+
source = lambda x: self.findInput(x, approxFunc(x), tol=tol)
21712170
else:
21722171
source = lambda x: self.findInput(x, tol=tol)
2173-
return Function(
2174-
source,
2175-
inputs=self.__outputs__,
2176-
outputs=self.__inputs__,
2177-
interpolation=self.__interpolation__,
2178-
)
2172+
return Function(
2173+
source,
2174+
inputs=self.__outputs__,
2175+
outputs=self.__inputs__,
2176+
interpolation=self.__interpolation__,
2177+
)
21792178

21802179
def findInput(self, val, start=0, tol=1e-4):
21812180
"""
@@ -2195,7 +2194,7 @@ def findInput(self, val, start=0, tol=1e-4):
21952194
lambda x: self.getValue(x) - val,
21962195
start,
21972196
tol=tol,
2198-
).x
2197+
).x[0]
21992198

22002199
def average(self, lower, upper):
22012200
"""
@@ -2215,7 +2214,8 @@ def averageFunction(self, lower=None):
22152214
Parameters
22162215
----------
22172216
lower : float
2218-
Lower limit of the new domain. Only required if the Function's source is a callable instead of a list of points.
2217+
Lower limit of the new domain. Only required if the Function's source
2218+
is a callable instead of a list of points.
22192219
22202220
Returns
22212221
-------

0 commit comments

Comments
 (0)