Skip to content

Commit 69ffc01

Browse files
committed
ENH: implement setDiscreteBasedOnModel
1 parent ce5a730 commit 69ffc01

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

rocketpy/Function.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,54 @@ def setDiscrete(
473473
self.__interpolation__ = "shepard"
474474
return self
475475

476+
def setDiscreteBasedOnModel(self, modelFunction, oneByOne=True):
477+
"""This method transforms function defined Functions into list
478+
defined Functions. It evaluates the function at certain points
479+
(sampling range) and stores the results in a list, which is converted
480+
into a Function and then returned. The original Function object is
481+
replaced by the new one.
482+
483+
Parameters
484+
----------
485+
modelFunction : Function
486+
Function object that will be used to define the sampling points,
487+
interpolation method and extrapolation method.
488+
Must be a Function whose source attribute is a list (i.e. a list based
489+
Function instance).
490+
Must have the same domain dimension as the Function to be discretized.
491+
492+
oneByOne : boolean, optional
493+
If True, evaluate Function in each sample point separately. If
494+
False, evaluates Function in vectorized form. Default is True.
495+
496+
Returns
497+
-------
498+
self : Function
499+
"""
500+
if not isinstance(modelFunction.source, np.ndarray):
501+
raise TypeError("modelFunction must be a list based Function.")
502+
if modelFunction.__domDim__ != self.__domDim__:
503+
raise ValueError("modelFunction must have the same domain dimension.")
504+
505+
if self.__domDim__ == 1:
506+
Xs = modelFunction.source[:, 0]
507+
Ys = self.getValue(Xs.tolist()) if oneByOne else self.getValue(Xs)
508+
self.source = np.concatenate(([Xs], [Ys])).transpose()
509+
elif self.__domDim__ == 2:
510+
# Create nodes to evaluate function
511+
Xs = modelFunction.source[:, 0]
512+
Ys = modelFunction.source[:, 1]
513+
Xs, Ys = np.meshgrid(Xs, Ys)
514+
Xs, Ys = Xs.flatten(), Ys.flatten()
515+
mesh = [[Xs[i], Ys[i]] for i in range(len(Xs))]
516+
# Evaluate function at all mesh nodes and convert it to matrix
517+
Zs = np.array(self.getValue(mesh))
518+
self.source = np.concatenate(([Xs], [Ys], [Zs])).transpose()
519+
520+
self.setInterpolation(modelFunction.__interpolation__)
521+
self.setExtrapolation(modelFunction.__extrapolation__)
522+
return self
523+
476524
# Define all get methods
477525
def getInputs(self):
478526
"Return tuple of inputs of the function."

0 commit comments

Comments
 (0)