Skip to content

Commit 2e56977

Browse files
committed
added input export to light mode
1 parent cb276de commit 2e56977

File tree

1 file changed

+164
-69
lines changed

1 file changed

+164
-69
lines changed

rocketpy/simulation/monte_carlo.py

Lines changed: 164 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def __run_in_parallel(self, append, light_mode, n_workers=None):
450450
inputs_size,
451451
input_writer_stop_event,
452452
n_sim_memory,
453+
light_mode,
453454
),
454455
)
455456

@@ -463,6 +464,7 @@ def __run_in_parallel(self, append, light_mode, n_workers=None):
463464
results_size,
464465
results_writer_stop_event,
465466
n_sim_memory,
467+
light_mode,
466468
),
467469
)
468470

@@ -607,16 +609,22 @@ def __run_simulation_worker(
607609
)
608610

609611
# Export to file
612+
inputs_dict = dict(
613+
item
614+
for d in [
615+
sto_env.last_rnd_dict,
616+
sto_rocket.last_rnd_dict,
617+
sto_flight.last_rnd_dict,
618+
]
619+
for item in d.items()
620+
)
621+
622+
inputs_dict["idx"] = sim_idx
623+
inputs_dict = MonteCarlo.prepare_export_data(
624+
inputs_dict, export_sample_time, remove_functions=True
625+
)
626+
610627
if light_mode:
611-
inputs_dict = dict(
612-
item
613-
for d in [
614-
sto_env.last_rnd_dict,
615-
sto_rocket.last_rnd_dict,
616-
sto_flight.last_rnd_dict,
617-
]
618-
for item in d.items()
619-
)
620628
# Construct the dict with the results from the flight
621629
results = {
622630
export_item: getattr(monte_carlo_flight, export_item)
@@ -628,15 +636,13 @@ def __run_simulation_worker(
628636

629637
else:
630638
# serialize data
631-
flight_results = MonteCarlo.inspect_object_attributes(
639+
flight_results = MonteCarlo.prepare_export_data(
632640
monte_carlo_flight, sample_time=export_sample_time
633641
)
634642

635643
# place data in dictionary as it will be found in output file
636644
export_inputs = {
637-
str(
638-
sim_idx
639-
): "Currently no addition data is exported. Use the results file.",
645+
str(sim_idx): inputs_dict,
640646
}
641647

642648
export_outputs = {
@@ -737,6 +743,7 @@ def __run_single_simulation(
737743
]
738744
for item in d.items()
739745
)
746+
self._inputs_dict["idx"] = sim_idx
740747

741748
# Export inputs and outputs to file
742749
if light_mode:
@@ -748,17 +755,14 @@ def __run_single_simulation(
748755
)
749756
else:
750757
# serialize data
751-
flight_results = MonteCarlo.inspect_object_attributes(
758+
flight_results = MonteCarlo.prepare_export_data(
752759
monte_carlo_flight, sample_time=self.export_sample_time
753760
)
754761

755762
# place data in dictionary as it will be found in output file
756763
export_inputs = {
757-
str(
758-
sim_idx
759-
): "Currently no addition data is exported. Use the results file.",
764+
str(sim_idx): self._inputs_dict,
760765
}
761-
762766
export_outputs = {
763767
str(sim_idx): flight_results,
764768
}
@@ -780,21 +784,29 @@ def __run_single_simulation(
780784

781785
@staticmethod
782786
def __loop_though_buffer(
783-
h5_file, shared_buffer, go_read_semaphores, go_write_semaphores
787+
file,
788+
shared_buffer,
789+
go_read_semaphores,
790+
go_write_semaphores,
791+
light_mode,
784792
):
785793
"""
786794
Loop through the shared buffer, writing the data to the file.
787795
788796
Parameters
789797
----------
790-
h5_file : h5py.File
798+
file : h5py.File or TextIOWrapper
791799
File object to write the data.
792800
shared_buffer : np.ndarray
793801
Shared memory buffer with the data.
794802
go_read_semaphores : list
795803
List of semaphores to read the data.
796804
go_write_semaphores : list
797805
List of semaphores to write the data.
806+
light_mode : bool
807+
If True, only variables from the export_list will be saved to
808+
the output file as a .txt file. If False, all variables will be
809+
saved to the output file as a .h5 file.
798810
799811
Returns
800812
-------
@@ -806,10 +818,13 @@ def __loop_though_buffer(
806818
if sem.acquire(timeout=1e-3):
807819
# retrieve the data from the shared buffer
808820
data = shared_buffer[i]
809-
data_dict = pickle.loads(bytes(data))
821+
data_deserialized = pickle.loads(bytes(data))
810822

811823
# write data to the file
812-
MonteCarlo.__dict_to_h5(h5_file, "/", data_dict)
824+
if light_mode:
825+
file.write(data_deserialized + "\n")
826+
else:
827+
MonteCarlo.__dict_to_h5(file, "/", data_deserialized)
813828

814829
# release the write semaphore // tell worker it can write again
815830
go_write_semaphores[i].release()
@@ -823,6 +838,7 @@ def _write_data_worker(
823838
data_size,
824839
stop_event,
825840
n_sim_memory,
841+
light_mode,
826842
):
827843
"""
828844
Worker function to write data to the file.
@@ -843,22 +859,55 @@ def _write_data_worker(
843859
Event to stop the worker.
844860
n_sim_memory : int
845861
Number of simulations that can be stored in memory.
862+
light_mode : bool
863+
If True, only variables from the export_list will be saved to
864+
the output file as a .txt file. If False, all variables will be
865+
saved to the output file as a .h5 file.
846866
"""
847867
shm = shared_memory.SharedMemory(shared_name)
848868
shared_buffer = np.ndarray(
849869
(n_sim_memory, data_size), dtype=ctypes.c_ubyte, buffer=shm.buf
850870
)
851-
with h5py.File(file_path, 'a') as h5_file:
852-
# loop until the stop event is set
853-
while not stop_event.is_set():
871+
if light_mode:
872+
with open(file_path, mode="a", encoding="utf-8") as f:
873+
while not stop_event.is_set():
874+
MonteCarlo.__loop_though_buffer(
875+
f,
876+
shared_buffer,
877+
go_read_semaphores,
878+
go_write_semaphores,
879+
light_mode,
880+
)
881+
882+
# loop through the remaining data
854883
MonteCarlo.__loop_though_buffer(
855-
h5_file, shared_buffer, go_read_semaphores, go_write_semaphores
884+
f,
885+
shared_buffer,
886+
go_read_semaphores,
887+
go_write_semaphores,
888+
light_mode,
856889
)
857890

858-
# loop through the remaining data
859-
MonteCarlo.__loop_though_buffer(
860-
h5_file, shared_buffer, go_read_semaphores, go_write_semaphores
861-
)
891+
else:
892+
with h5py.File(file_path, 'a') as h5_file:
893+
# loop until the stop event is set
894+
while not stop_event.is_set():
895+
MonteCarlo.__loop_though_buffer(
896+
h5_file,
897+
shared_buffer,
898+
go_read_semaphores,
899+
go_write_semaphores,
900+
light_mode,
901+
)
902+
903+
# loop through the remaining data
904+
MonteCarlo.__loop_though_buffer(
905+
h5_file,
906+
shared_buffer,
907+
go_read_semaphores,
908+
go_write_semaphores,
909+
light_mode,
910+
)
862911

863912
@staticmethod
864913
def __downsample_recursive(data_dict, max_time, sample_time):
@@ -910,44 +959,66 @@ def __get_export_size(self, light_mode):
910959
dictionary. The purpose is to estimate the size of the exported data.
911960
"""
912961
# Run trajectory simulation
913-
monte_carlo_flight = self.flight.create_object()
962+
env = self.environment.create_object()
963+
rocket = self.rocket.create_object()
964+
rail_length = self.flight._randomize_rail_length()
965+
inclination = self.flight._randomize_inclination()
966+
heading = self.flight._randomize_heading()
967+
initial_solution = self.flight.initial_solution
968+
terminate_on_apogee = self.flight.terminate_on_apogee
969+
970+
monte_carlo_flight = Flight(
971+
rocket=rocket,
972+
environment=env,
973+
rail_length=rail_length,
974+
inclination=inclination,
975+
heading=heading,
976+
initial_solution=initial_solution,
977+
terminate_on_apogee=terminate_on_apogee,
978+
)
914979

915980
if monte_carlo_flight.max_time is None or monte_carlo_flight.max_time <= 0:
916981
raise ValueError(
917982
"The max_time attribute must be greater than zero. To use parallel mode."
918983
)
919984

920985
# Export inputs and outputs to file
986+
export_inputs = dict(
987+
item
988+
for d in [
989+
self.environment.last_rnd_dict,
990+
self.rocket.last_rnd_dict,
991+
self.flight.last_rnd_dict,
992+
]
993+
for item in d.items()
994+
)
995+
export_inputs["idx"] = 123456789
996+
997+
export_inputs = self.prepare_export_data(
998+
export_inputs, self.export_sample_time, remove_functions=True
999+
)
1000+
1001+
export_inputs = self.__downsample_recursive(
1002+
data_dict=export_inputs,
1003+
max_time=monte_carlo_flight.max_time,
1004+
sample_time=self.export_sample_time,
1005+
)
1006+
9211007
if light_mode:
922-
export_inputs = dict(
923-
item
924-
for d in [
925-
self.environment.last_rnd_dict,
926-
self.rocket.last_rnd_dict,
927-
self.flight.last_rnd_dict,
928-
]
929-
for item in d.items()
930-
)
9311008
results = {
9321009
export_item: getattr(monte_carlo_flight, export_item)
9331010
for export_item in self.export_list
9341011
}
1012+
1013+
export_inputs_bytes = json.dumps(export_inputs, cls=RocketPyEncoder)
1014+
results_bytes = json.dumps(results, cls=RocketPyEncoder)
9351015
else:
936-
flight_results = self.inspect_object_attributes(
1016+
flight_results = self.prepare_export_data(
9371017
monte_carlo_flight, self.export_sample_time
9381018
)
939-
940-
export_inputs = {
941-
"probe_flight": "Currently no addition data is exported. Use the results file.",
942-
}
9431019
results = {"probe_flight": flight_results}
9441020

9451021
# downsample the arrays, filling them up to the max time
946-
export_inputs = self.__downsample_recursive(
947-
data_dict=export_inputs,
948-
max_time=monte_carlo_flight.max_time,
949-
sample_time=self.export_sample_time,
950-
)
9511022
results = self.__downsample_recursive(
9521023
data_dict=results,
9531024
max_time=monte_carlo_flight.max_time,
@@ -1452,7 +1523,7 @@ def time_function_serializer(function_object, t_range=None, sample_time=None):
14521523
return source
14531524

14541525
@staticmethod
1455-
def inspect_object_attributes(obj, sample_time=0.1):
1526+
def prepare_export_data(obj, sample_time=0.1, remove_functions=False):
14561527
"""
14571528
Inspects the attributes of an object and returns a dictionary of its
14581529
attributes.
@@ -1472,26 +1543,50 @@ def inspect_object_attributes(obj, sample_time=0.1):
14721543
are integers, floats, dictionaries or Function objects.
14731544
"""
14741545
result = {}
1475-
# Iterate through all attributes of the object
1476-
for attr_name in dir(obj):
1477-
attr_value = getattr(obj, attr_name)
1478-
1479-
# Check if the attribute is of a type we are interested in and not a private attribute
1480-
if isinstance(
1481-
attr_value, (int, float, Function)
1482-
) and not attr_name.startswith('_'):
1483-
if isinstance(attr_value, Function):
1484-
# Serialize the Functions
1485-
result[attr_name] = MonteCarlo.time_function_serializer(
1486-
attr_value, None, sample_time
1487-
)
14881546

1489-
elif isinstance(attr_value, (int, float)):
1490-
result[attr_name] = attr_value
1547+
if isinstance(obj, dict):
1548+
# Iterate through all attributes of the object
1549+
for attr_name, attr_value in obj.items():
1550+
# Filter out private attributes and check if the attribute is of a type we are interested in
1551+
if not attr_name.startswith('_') and isinstance(
1552+
attr_value, (int, float, dict, Function)
1553+
):
1554+
if isinstance(attr_value, (int, float)):
1555+
result[attr_name] = attr_value
1556+
1557+
elif isinstance(attr_value, dict):
1558+
result[attr_name] = MonteCarlo.prepare_export_data(
1559+
attr_value, sample_time
1560+
)
1561+
1562+
elif not remove_functions and isinstance(attr_value, Function):
1563+
# Serialize the Functions
1564+
result[attr_name] = MonteCarlo.time_function_serializer(
1565+
attr_value, None, sample_time
1566+
)
1567+
else:
1568+
# Iterate through all attributes of the object
1569+
for attr_name in dir(obj):
1570+
attr_value = getattr(obj, attr_name)
1571+
1572+
# Filter out private attributes and check if the attribute is of a type we are interested in
1573+
if not attr_name.startswith('_') and isinstance(
1574+
attr_value, (int, float, dict, Function)
1575+
):
1576+
if isinstance(attr_value, (int, float)):
1577+
result[attr_name] = attr_value
1578+
1579+
elif isinstance(attr_value, dict):
1580+
result[attr_name] = MonteCarlo.prepare_export_data(
1581+
attr_value, sample_time
1582+
)
1583+
1584+
elif not remove_functions and isinstance(attr_value, Function):
1585+
# Serialize the Functions
1586+
result[attr_name] = MonteCarlo.time_function_serializer(
1587+
attr_value, None, sample_time
1588+
)
14911589

1492-
else:
1493-
# Should never reach this point
1494-
raise TypeError("Methods should be preprocessed before saving.")
14951590
return result
14961591

14971592
@staticmethod

0 commit comments

Comments
 (0)