Skip to content

Commit ecaa167

Browse files
committed
ENH: add average operation and minor fixes.
1 parent 2676e44 commit ecaa167

File tree

1 file changed

+69
-125
lines changed

1 file changed

+69
-125
lines changed

rocketpy/Function.py

Lines changed: 69 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def getValue(self, *args):
580580
return ans if len(ans) > 1 else ans[0]
581581
# Returns value for spline, akima or linear interpolation function type
582582
elif self.__interpolation__ in ["spline", "akima", "linear"]:
583-
if isinstance(args[0], (int, float, complex)):
583+
if isinstance(args[0], (int, float, complex, np.integer)):
584584
args = [list(args)]
585585
x = [arg for arg in args[0]]
586586
xData = self.source[:, 0]
@@ -1004,6 +1004,7 @@ def plot1D(
10041004
forceData=False,
10051005
forcePoints=False,
10061006
returnObject=False,
1007+
equalAxis=False,
10071008
):
10081009
"""Plot 1-Dimensional Function, from a lower limit to an upper limit,
10091010
by sampling the Function several times in the interval. The title of
@@ -1066,6 +1067,8 @@ def plot1D(
10661067
# Plots function
10671068
if forcePoints:
10681069
plt.scatter(x, y, marker="o")
1070+
if equalAxis:
1071+
plt.axis("equal")
10691072
plt.plot(x, y)
10701073
# Turn on grid and set title and axis
10711074
plt.grid(True)
@@ -2056,165 +2059,103 @@ def derivativeFunction(self):
20562059
Ys = np.diff(self.source[:, 1]) / np.diff(self.source[:, 0])
20572060
Xs = self.source[:-1, 0] + np.diff(self.source[:, 0]) / 2
20582061
source = np.concatenate(([Xs], [Ys])).transpose()
2059-
20602062
# Retrieve inputs, outputs and interpolation
20612063
inputs = self.__inputs__[:]
20622064
outputs = "d(" + self.__outputs__[0] + ")/d(" + inputs[0] + ")"
20632065
outputs = "(" + outputs + ")"
20642066
interpolation = "linear"
2065-
20662067
# Create new Function object
20672068
return Function(source, inputs, outputs, interpolation)
20682069
else:
20692070
return Function(lambda x: self.differentiate(x))
20702071

2071-
def integralFunction(self, lower=None):
2072-
"""Returns a Function object representing the integral of the Function
2073-
object.
2072+
def integralFunction(self, lower=None, upper=None, datapoints=100):
2073+
"""Returns a Function object representing the integral of the Function object.
20742074
20752075
Parameters
20762076
----------
20772077
lower : scalar, optional
2078-
The lower integration limit. If the Function is given by a dataset
2079-
of points the default value is the start of the dataset. If the
2080-
Function is defined by a callable, then this parameter must be
2081-
given.
2078+
The lower limit of the interval in which the function is to be
2079+
plotted. If the Function is given by a dataset, the default
2080+
value is the start of the dataset.
2081+
upper : scalar, optional
2082+
The upper limit of the interval in which the function is to be
2083+
plotted. If the Function is given by a dataset, the default
2084+
value is the end of the dataset.
2085+
datapoints : int, optional
2086+
The number of points in which the integral will be evaluated for
2087+
plotting it, which draws lines between each evaluated point.
2088+
The default value is 100.
20822089
20832090
Returns
20842091
-------
20852092
result : Function
2086-
The integral function of the Function object. Note that the domain
2087-
of the integral function is the same as the domain of the original
2088-
Function object.
2089-
"""
2090-
if callable(self.source):
2091-
return Function(lambda x: self.integral(lower, x))
2092-
2093-
# Not callable, i.e., defined by a dataset of points
2094-
lower = lower if lower is not None else self.source[0, 0]
2095-
xData = self.source[:, 0]
2096-
yData = [self.integral(lower, x) for x in xData]
2097-
2098-
return Function(
2099-
np.concatenate(([xData], [yData])).transpose(),
2100-
inputs=self.__inputs__,
2101-
outputs=[o + " Integral" for o in self.__outputs__],
2102-
)
2103-
2104-
def isBijective(self):
2105-
"""Checks whether the Function is bijective. Only applicable to Functions whose source is a list of points, raises an error otherwise.
2106-
2107-
Returns
2108-
-------
2109-
result : bool
2110-
True if the Function is bijective, False otherwise.
2093+
The integral of the Function object.
21112094
"""
21122095
if isinstance(self.source, np.ndarray):
2113-
xDataDistinct = set(self.source[:, 0])
2114-
yDataDistinct = set(self.source[:, 1])
2115-
distinctMap = set(zip(xDataDistinct, yDataDistinct))
2116-
return len(distinctMap) == len(xDataDistinct) == len(yDataDistinct)
2117-
else:
2118-
raise TypeError(
2119-
"Only Functions whose source is a list of points can be checked for bijectivity."
2096+
lower = self.source[0, 0] if lower is None else lower
2097+
upper = self.source[-1, 0] if upper is None else upper
2098+
xData = np.linspace(lower, upper, datapoints)
2099+
yData = np.zeros(datapoints)
2100+
for i in range(datapoints):
2101+
yData[i] = self.integral(lower, xData[i])
2102+
return Function(
2103+
np.concatenate(([xData], [yData])).transpose(),
2104+
inputs=self.__inputs__,
2105+
outputs=[o + " Integral" for o in self.__outputs__],
21202106
)
2121-
2122-
def isStrictlyBijective(self):
2123-
"""Checks whether the Function is "strictly" bijective.
2124-
Only applicable to Functions whose source is a list of points,raises an
2125-
error otherwise.
2126-
2127-
Notes
2128-
-----
2129-
By "strictly" bijective, this implementation considers the
2130-
list-of-points-defined Function bijective between each consecutive pair
2131-
of points. Therefore, the Function may be flagged as not bijective even
2132-
if the mapping between the set of points which define the Function is
2133-
bijective.
2134-
2135-
Returns
2136-
-------
2137-
result : bool
2138-
True if the Function is "strictly" bijective, False otherwise.
2139-
2140-
Examples
2141-
--------
2142-
>>> f = Function([[0, 0], [1, 1], [2, 4]])
2143-
>>> f.isBijective()
2144-
True
2145-
>>> f.isStrictlyBijective()
2146-
True
2147-
2148-
>>> f = Function([[-1, 1], [0, 0], [1, 1], [2, 4]])
2149-
>>> f.isBijective()
2150-
False
2151-
>>> f.isStrictlyBijective()
2152-
False
2153-
2154-
A Function which is not "strictly" bijective, but is bijective, can be
2155-
constructed as x^2 defined at -1, 0 and 2.
2156-
2157-
>>> f = Function([[-1, 1], [0, 0], [2, 4]])
2158-
>>> f.isBijective()
2159-
True
2160-
>>> f.isStrictlyBijective()
2161-
False
2162-
"""
2163-
if isinstance(self.source, np.ndarray):
2164-
# Assuming domain is sorted, range must also be
2165-
yData = self.source[:, 1]
2166-
# Both ascending and descending order means Function is bijective
2167-
yDataDiff = np.diff(yData)
2168-
return np.all(yDataDiff >= 0) or np.all(yDataDiff <= 0)
21692107
else:
2170-
raise TypeError(
2171-
"Only Functions whose source is a list of points can be checked for bijectivity."
2108+
lower = 0 if lower is None else lower
2109+
return Function(
2110+
lambda x: self.integral(lower, x),
2111+
inputs=self.__inputs__,
2112+
outputs=[o + " Integral" for o in self.__outputs__],
21722113
)
21732114

2174-
def inverseFunction(self):
2115+
def inverseFunction(self, approxFunc=None, tol=1e-4):
21752116
"""
2176-
Returns the inverse of the Function. The inverse function of F is a function
2177-
that undoes the operation of F. The inverse of F exists if and only if F is
2178-
bijective. Makes the domain the range and the range the domain.
2117+
Returns the inverse of the Function. The inverse function of F is a function that undoes the operation of F. The
2118+
inverse of F exists if and only if F is bijective. Makes the domain the range and the range the domain.
21792119
2180-
If the Function is given by a list of points, its bijectivity is checked and an
2181-
error is raised if it is not bijective.
2182-
If the Function is given by a function, its bijectivity is not checked and may
2183-
lead to innacuracies outside of its bijective region.
2120+
Parameters
2121+
----------
2122+
lower : float
2123+
Lower limit of the new domain. Only required if the Function's source is a callable instead of a list of points.
2124+
upper : float
2125+
Upper limit of the new domain. Only required if the Function's source is a callable instead of a list of points.
21842126
21852127
Returns
21862128
-------
21872129
result : Function
21882130
A Function whose domain and range have been inverted.
21892131
"""
21902132
if isinstance(self.source, np.ndarray):
2191-
if self.isStrictlyBijective():
2192-
# Swap the columns
2193-
source = np.concatenate(
2194-
([self.source[:, 1]], [self.source[:, 0]])
2195-
).transpose()
2196-
2197-
return Function(
2198-
source,
2199-
inputs=self.__outputs__,
2200-
outputs=self.__inputs__,
2201-
interpolation=self.__interpolation__,
2202-
)
2203-
else:
2204-
raise ValueError(
2205-
"Function is not bijective, so it does not have an inverse."
2206-
)
2133+
# Swap the columns
2134+
source = np.concatenate(
2135+
([self.source[:, 1]], [self.source[:, 0]])
2136+
).transpose()
2137+
2138+
return Function(
2139+
source,
2140+
inputs=self.__outputs__,
2141+
outputs=self.__inputs__,
2142+
interpolation=self.__interpolation__,
2143+
)
22072144
else:
2145+
if approxFunc:
2146+
source = lambda x: self.findInput(x, approxFunc(x), tol)
2147+
else:
2148+
source = lambda x: self.findInput(x, tol=tol)
22082149
return Function(
2209-
lambda x: self.findInput(x),
2150+
source,
22102151
inputs=self.__outputs__,
22112152
outputs=self.__inputs__,
22122153
interpolation=self.__interpolation__,
22132154
)
22142155

2215-
def findInput(self, val):
2156+
def findInput(self, val, start=0, tol=1e-4):
22162157
"""
2217-
Finds the input for a given output.
2158+
Finds the optimal input for a given output.
22182159
22192160
Parameters
22202161
----------
@@ -2226,6 +2167,12 @@ def findInput(self, val):
22262167
result : ndarray
22272168
The value of the input which gives the output closest to val.
22282169
"""
2170+
return optimize.root(
2171+
lambda x: self.getValue(x) - val,
2172+
start,
2173+
tol=tol,
2174+
).x
2175+
22292176
def average(self, lower, upper):
22302177
"""
22312178
Returns the average of the function.
@@ -2314,10 +2261,9 @@ def __new__(
23142261
datapoints=50,
23152262
):
23162263
"""
2317-
Creates a piecewise function from a dictionary of functions. The keys of the
2318-
dictionary must be tuples that represent the domain of the function. The domains
2319-
must be disjoint. The piecewise function will be evaluated at datapoints points
2320-
to create Function object.
2264+
Creates a piecewise function from a dictionary of functions. The keys of the dictionary
2265+
must be tuples that represent the domain of the function. The domains must be disjoint.
2266+
The piecewise function will be evaluated at datapoints points to create Function object.
23212267
23222268
Parameters
23232269
----------
@@ -2338,12 +2284,10 @@ def __new__(
23382284
# Check if source is a dictionary
23392285
if not isinstance(source, dict):
23402286
raise TypeError("source must be a dictionary")
2341-
23422287
# Check if all keys are tuples
23432288
for key in source.keys():
23442289
if not isinstance(key, tuple):
23452290
raise TypeError("keys of source must be tuples")
2346-
23472291
# Check if all domains are disjoint
23482292
for key1 in source.keys():
23492293
for key2 in source.keys():

0 commit comments

Comments
 (0)