Skip to content

Commit 98ce6ba

Browse files
committed
Centralized simulation control in SimCounter
1 parent 38a29b1 commit 98ce6ba

File tree

1 file changed

+61
-28
lines changed

1 file changed

+61
-28
lines changed

rocketpy/simulation/monte_carlo.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,9 @@ def __run_in_parallel(self, append, light_mode, n_workers=None):
327327
)
328328
# open files in write/append mode
329329
with open(self._input_file, mode=open_mode) as f:
330-
pass # initialize file
330+
pass # initialize file
331331
with open(self._output_file, mode=open_mode) as f:
332-
pass # initialize file
332+
pass # initialize file
333333

334334
else:
335335
# Change file extensions to .h5
@@ -353,16 +353,7 @@ def __run_in_parallel(self, append, light_mode, n_workers=None):
353353
pass # initialize file
354354

355355
# Initialize simulation counter
356-
sim_counter = manager.SimCounter(idx_i)
357-
358-
queue = manager.JoinableQueue()
359-
360-
# Initialize queue
361-
for _ in range(self.number_of_simulations - sim_counter.get_count()):
362-
queue.put("RUN")
363-
364-
for _ in range(n_workers):
365-
queue.put("STOP")
356+
sim_counter = manager.SimCounter(idx_i, self.number_of_simulations)
366357

367358
print("\nStarting monte carlo analysis", end="\r")
368359
print(f"Number of simulations: {self.number_of_simulations}")
@@ -379,7 +370,6 @@ def __run_in_parallel(self, append, light_mode, n_workers=None):
379370
inputs_lock,
380371
outputs_lock,
381372
errors_lock,
382-
queue,
383373
light_mode,
384374
file_paths,
385375
),
@@ -413,12 +403,11 @@ def __run_simulation_worker(
413403
inputs_lock,
414404
outputs_lock,
415405
errors_lock,
416-
queue,
417406
light_mode,
418407
file_paths,
419408
):
420409
"""
421-
Runs a simulation from a queue.
410+
Worker code to execute a simulation in a process.
422411
423412
Parameters
424413
----------
@@ -452,11 +441,9 @@ def __run_simulation_worker(
452441
"""
453442
try:
454443
while True:
455-
if queue.get() == "STOP":
456-
break
457-
458444
sim_idx = sim_counter.increment()
459-
sim_start = time()
445+
if sim_idx == -1:
446+
break
460447

461448
env = sto_env.create_object()
462449
rocket = sto_rocket.create_object()
@@ -534,12 +521,19 @@ def __run_simulation_worker(
534521
MonteCarlo.__dict_to_h5(h5_file, '/', export_outputs)
535522
outputs_lock.release()
536523

537-
sim_end = time()
524+
average_time = (
525+
time() - sim_counter.get_intial_time()
526+
) / sim_counter.get_count()
527+
estimated_time = int(
528+
(sim_counter.get_n_simulations() - sim_counter.get_count())
529+
* average_time
530+
)
538531

539-
print(
540-
f"Current iteration: {sim_idx} | "
541-
f"Time of Iteration {sim_end - sim_start:.2f} seconds to run.",
542-
end="\n",
532+
sim_counter.reprint(
533+
f"Current iteration: {sim_idx:06d} | "
534+
f"Average Time per Iteration: {average_time:.3f} s | "
535+
f"Estimated time left: {estimated_time} s",
536+
end="\r",
543537
flush=True,
544538
)
545539

@@ -1126,7 +1120,7 @@ def __get_light_indexes(input_file, output_file, append):
11261120
idx_i = MonteCarlo.__get_initial_sim_idx(f, light_mode=True)
11271121
with open(output_file, 'r', encoding="utf-8") as f:
11281122
idx_o = MonteCarlo.__get_initial_sim_idx(f, light_mode=True)
1129-
except OSError: # File not found, return 0
1123+
except OSError: # File not found, return 0
11301124
idx_i = 0
11311125
idx_o = 0
11321126
else:
@@ -1182,12 +1176,51 @@ def __init__(self):
11821176

11831177

11841178
class SimCounter:
1185-
def __init__(self, initial_count):
1179+
def __init__(self, initial_count, n_simulations):
11861180
self.count = initial_count
1181+
self.n_simulations = n_simulations
1182+
self._last_print_len = 0 # used to print on the same line
1183+
self.initial_time = time()
1184+
1185+
def increment(self):
1186+
if self.count >= self.n_simulations:
1187+
return -1
11871188

1188-
def increment(self) -> int:
11891189
self.count += 1
11901190
return self.count - 1
11911191

1192-
def get_count(self) -> int:
1192+
def get_count(self):
11931193
return self.count
1194+
1195+
def get_n_simulations(self):
1196+
return self.n_simulations
1197+
1198+
def get_intial_time(self):
1199+
return self.initial_time
1200+
1201+
def reprint(self, msg, end="\n", flush=False):
1202+
"""Prints a message on the same line as the previous one and replaces
1203+
the previous message with the new one, deleting the extra characters
1204+
from the previous message.
1205+
1206+
Parameters
1207+
----------
1208+
msg : str
1209+
Message to be printed.
1210+
end : str, optional
1211+
String appended after the message. Default is a new line.
1212+
flush : bool, optional
1213+
If True, the output is flushed. Default is False.
1214+
1215+
Returns
1216+
-------
1217+
None
1218+
"""
1219+
1220+
len_msg = len(msg)
1221+
if len_msg < self._last_print_len:
1222+
msg += " " * (self._last_print_len - len_msg)
1223+
else:
1224+
self._last_print_len = len_msg
1225+
1226+
print(msg, end=end, flush=flush)

0 commit comments

Comments
 (0)