Skip to content

Develop queuing algorithms and some relevant functionalities for asynchronous REXEE #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: lint

on:
push:
pull_request:

jobs:

Expand Down
130 changes: 130 additions & 0 deletions ensemble_md/cli/run_AREXEE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
####################################################################
# #
# ensemble_md, #
# a python package for running GROMACS simulation ensembles #
# #
# Written by Wei-Tse Hsu <[email protected]> #
# Copyright (c) 2022 University of Colorado Boulder #
# #
####################################################################
import sys
import time
import argparse
import warnings
from mpi4py import MPI

from ensemble_md.utils import utils
from ensemble_md.cli.run_REXEE import initialize

warnings.warn('This module is only for experimental purposes and still in progress. Please do not use it for any production research.', UserWarning) # noqa: E501

"""
Currently, this CLI still uses MPI to run REXEE simulations, but it tries to mock some behaviors of asynchronous REXEE in the following way:
1. Finish an iteration of the REXEE simulation.
2. Based on the time it took for each simulation to finish, figure out the order in which the replicas should be added to the queue.
3. Apply a queueing algorithm to figure out what replicas to swap first.

Eventually, we would like to get rid of the use of MPI and really rely on asynchronous parallelization schemes. The most likely
direction is to use functionalities in airflowHPC to manage the queueing and launching of replicas. If possible, this CLI should be
integrated into the CLI run_REXEE.
"""


def main():
t1 = time.time()
args = initialize(sys.argv[1:])
sys.stdout = utils.Logger(logfile=args.output)
sys.stderr = utils.Logger(logfile=args.output)

# Step 1: Set up MPI rank and instantiate ReplicaExchangeEE to set up REXEE parameters
comm = MPI.COMM_WORLD
rank = comm.Get_rank() # Note that this is a GLOBAL variable

if rank == 0:
print(f'Current time: {datetime.now().strftime("%d/%m/%Y %H:%M:%S")}')
print(f'Command line: {" ".join(sys.argv)}\n')

REXEE = ReplicaExchangeEE(args.yaml)

if rank == 0:
# Print out simulation parameters
REXEE.print_params()

# Print out warnings and fail if needed
for i in REXEE.warnings:
print(f'\n{i}\n')

if len(REXEE.warnings) > args.maxwarn:
print(f"The execution failed due to warning(s) about parameter spcificaiton. Check the warnings, or consider setting maxwarn in the input YAML file if you find them harmless.") # noqa: E501, F541
comm.Abort(101)

# Step 2: If there is no checkpoint file found/provided, perform the 1st iteration (index 0)

# Note that here we assume no checkpoint files just to minimize this CLI.
# We also leave out Step 2-3 since we won't be using this CLI to test calculations with any restraints.
start_idx = 1

# 2-1. Set up input files for all simulations
if rank == 0:
for i in range(REXEE.n_sim):
os.mkdir(f'{REXEE.working_dir}/sim_{i}')
os.mkdir(f'{REXEE.working_dir}/sim_{i}/iteration_0')
MDP = REXEE.initialize_MDP(i)
MDP.write(f"{REXEE.working_dir}/sim_{i}/iteration_0/expanded.mdp", skipempty=True)
if REXEE.modify_coords == 'default' and (not os.path.exists('residue_connect.csv') or not os.path.exists('residue_swap_map.csv')): # noqa: E501
REXEE.process_top()

# 2-2. Run the first set of simulations
REXEE.run_REXEE(0)

for i in range(start_idx, REXEE.n_iter):
try:
if rank == 0:
# Step 3: Swap the coordinates
# Note that here we leave out Steps 3-3 and 3-4, which are for weight combination/correction and coordinate modification, respectively.

# 3-1. Extract the final dhdl and log files from the previous iteration
dhdl_files = [f'{REXEE.working_dir}/sim_{j}/iteration_{i - 1}/dhdl.xvg' for j in range(REXEE.n_sim)]
log_files = [f'{REXEE.working_dir}/sim_{j}/iteration_{i - 1}/md.log' for j in range(REXEE.n_sim)]
states_ = REXEE.extract_final_dhdl_info(dhdl_files)
wl_delta, weights_, counts_ = REXEE.extract_final_log_info(log_files)
print()

# 3-2. Identify swappable pairs, propose swap(s), calculate P_acc, and accept/reject swap(s)
states = copy.deepcopy(states_)
weights = copy.deepcopy(weights_)
counts = copy.deepcopy(counts_)
swap_pattern, swap_list = REXEE.get_swapping_pattern(dhdl_files, states_) # swap_list will only be used for modify_coords # noqa: E501
else:
swap_pattern, swap_list = None, None

except Exception:
print('\n--------------------------------------------------------------------------\n')
print(f'An error occurred on rank 0:\n{traceback.format_exc()}')
MPI.COMM_WORLD.Abort(1)

# Note that we leave out the block for exiting the for loop when the weights got equilibrated, as this CLI
# won't be tested for weight-updating simulations for now.

# Step 4: Perform another iteration
# Here we leave out the block that uses swap_list, which is only for coordinate modifications.
swap_pattern = comm.bcast(swap_pattern, root=0)

# Here we run another set of simulations (i.e. Step 4-2 in CLI run_REXEE)
REXEE.run_REXEE(i, swap_pattern)

# Here we leave out the block for saving data (i.e. Step 4-3 in CLI run_REXEE) since we won't run for too many iterations when testing this CLI.

# Step 5: Write a summary for the simulation ensemble
if rank == 0:
print('\nSummary of the simulation ensemble')
print('==================================')

# We leave out the section showing the simulation status.
print(f'\n{REXEE.n_empty_swappable} out of {REXEE.n_iter}, or {REXEE.n_empty_swappable / REXEE.n_iter * 100:.1f}% iterations had an empty list of swappable pairs.') # noqa: E501
if REXEE.n_swap_attempts != 0:
print(f'{REXEE.n_rejected} out of {REXEE.n_swap_attempts}, or {REXEE.n_rejected / REXEE.n_swap_attempts * 100:.1f}% of attempted exchanges were rejected.') # noqa: E501

print(f'\nTime elapsed: {utils.format_time(time.time() - t1)}')

MPI.Finalize()
2 changes: 2 additions & 0 deletions ensemble_md/cli/run_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def main():
os.mkdir(f'{REXEE.working_dir}/sim_{i}/iteration_0')
MDP = REXEE.initialize_MDP(i)
MDP.write(f"{REXEE.working_dir}/sim_{i}/iteration_0/expanded.mdp", skipempty=True)
if REXEE.modify_coords == 'default' and (not os.path.exists('residue_connect.csv') or not os.path.exists('residue_swap_map.csv')): # noqa: E501
REXEE.process_top()

# 2-2. Run the first set of simulations
REXEE.run_REXEE(0)
Expand Down
160 changes: 150 additions & 10 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import warnings
import importlib
import subprocess
import mdtraj as md
import pandas as pd
import numpy as np
from mpi4py import MPI
from itertools import combinations
Expand All @@ -29,6 +31,7 @@
import ensemble_md
from ensemble_md.utils import utils
from ensemble_md.utils import gmx_parser
from ensemble_md.utils import coordinate_swap
from ensemble_md.utils.exceptions import ParameterError

comm = MPI.COMM_WORLD
Expand Down Expand Up @@ -157,6 +160,8 @@ def set_params(self, analysis):
optional_args = {
"add_swappables": None,
"modify_coords": None,
"resname_list": None,
"swap_rep_pattern": None,
"nst_sim": None,
"proposal": 'exhaustive',
"w_combine": False,
Expand Down Expand Up @@ -254,6 +259,10 @@ def set_params(self, analysis):
raise ParameterError(f"The parameter '{i}' should be a boolean variable.")

params_list = ['add_swappables', 'df_ref']
if self.resname_list is not None:
params_list.append('resname_list')
if self.swap_rep_pattern is not None:
params_list.append('swap_rep_pattern')
for i in params_list:
if getattr(self, i) is not None and not isinstance(getattr(self, i), list):
raise ParameterError(f"The parameter '{i}' should be a list.")
Expand Down Expand Up @@ -441,17 +450,24 @@ def set_params(self, analysis):

# 7-12. External module for coordinate modification
if self.modify_coords is not None:
module_file = os.path.basename(self.modify_coords)
module_dir = os.path.dirname(self.modify_coords)
if module_dir not in sys.path:
sys.path.append(module_dir) # so that the module can be imported
module_name = os.path.splitext(module_file)[0]
module = importlib.import_module(module_name)
if not hasattr(module, module_name):
err_msg = f'The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., {module_name}.' # noqa: E501
raise ParameterError(err_msg)
if self.modify_coords == 'default':
if self.swap_rep_pattern is None and (not os.path.exists('residue_connect.csv') or not os.path.exists('residue_swap_map.csv')): # noqa: E501
raise Exception('swap_rep_pattern option must be filled in if using default swapping function and not swap guide') # noqa: E501
if self.resname_list is None and (not os.path.exists('residue_connect.csv') or not os.path.exists('residue_swap_map.csv')): # noqa: E501
raise Exception('resname_list option must be filled in if using default swapping function and not swap guide') # noqa: E501
self.modify_coords_fn = self.default_coords_fn
else:
self.modify_coords_fn = getattr(module, module_name)
module_file = os.path.basename(self.modify_coords)
module_dir = os.path.dirname(self.modify_coords)
if module_dir not in sys.path:
sys.path.append(module_dir) # so that the module can be imported
module_name = os.path.splitext(module_file)[0]
module = importlib.import_module(module_name)
if not hasattr(module, module_name):
err_msg = f'The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., {module_name}.' # noqa: E501
raise ParameterError(err_msg)
else:
self.modify_coords_fn = getattr(module, module_name)
else:
self.modify_coords_fn = None

Expand Down Expand Up @@ -1496,3 +1512,127 @@ def run_REXEE(self, n, swap_pattern=None):
# want it to start parsing the dhdl file (in the if condition of if rank == 0) of simulation 3 being run by
# rank 3 that has not been generated, which will lead to an I/O error.
comm.barrier()

def default_coords_fn(self, molA_file_name, molB_file_name):
"""
Swaps coordinates between two GRO files.

Parameters
----------
molA_file_name : str
GRO file name for the moleucle to be swapped.
molB_file_name : str
GRO file name for the other moleucle to be swapped.
"""
# Determine name for transformed residue
molA_dir = molA_file_name.rsplit('/', 1)[0] + '/'
molB_dir = molB_file_name.rsplit('/', 1)[0] + '/'

# Load trajectory trr for higher precison coordinates
molA = md.load_trr(f'{molA_dir}/traj.trr', top=molA_file_name).slice(-1) # Load last frame of trr trajectory
molB = md.load_trr(f'{molB_dir}/traj.trr', top=molB_file_name).slice(-1)

# Load the coordinate swapping map
connection_map = pd.read_csv('residue_connect.csv')
swap_map = pd.read_csv('residue_swap_map.csv')

# Step 1: Read the GRO input coordinate files and open temporary Output files
molA_file = open(molA_file_name, 'r').readlines() # open input file
molB_new_file_name = 'B_hybrid_swap.gro'
molB_new = open(molB_new_file_name, 'w')
molB_file = open(molB_file_name, 'r').readlines() # open input file
molA_new_file_name = 'A_hybrid_swap.gro'
molA_new = open(molA_new_file_name, 'w')

# Step 2: Determine atoms for alignment and swapping
nameA = coordinate_swap.identify_res(molA.topology, swap_map['Swap A'].to_list() + swap_map['Swap B'].to_list()) # noqa: E501
nameB = coordinate_swap.identify_res(molB.topology, swap_map['Swap A'].to_list() + swap_map['Swap B'].to_list()) # noqa: E501
df_atom_swap = coordinate_swap.find_common(molA_file, molB_file, nameA, nameB)

# Step 3: Fix break if present for solvated systems only
if len(molA.topology.select('water')) != 0:
A_dimensions = coordinate_swap.get_dimensions(molA_file)
B_dimensions = coordinate_swap.get_dimensions(molB_file)
molA = coordinate_swap.fix_break(molA, nameA, A_dimensions, connection_map[connection_map['Resname'] == nameA]) # noqa: E501
molB = coordinate_swap.fix_break(molB, nameB, B_dimensions, connection_map[connection_map['Resname'] == nameB]) # noqa: E501

# Step 4: Determine coordinates of atoms which need to be reconstructed as we swap coordinates between molecules # noqa: E501
miss_B = df_atom_swap[(df_atom_swap['Swap'] == 'B2A') & (df_atom_swap['Direction'] == 'miss')]['Name'].to_list() # noqa: E501
miss_A = df_atom_swap[(df_atom_swap['Swap'] == 'A2B') & (df_atom_swap['Direction'] == 'miss')]['Name'].to_list() # noqa: E501
if len(miss_B) != 0:
df_atom_swap = coordinate_swap.get_miss_coord(molB, molA, nameB, nameA, df_atom_swap, 'B2A', swap_map[(swap_map['Swap A'] == nameB) & (swap_map['Swap B'] == nameA)]) # noqa: E501
if len(miss_A) != 0:
df_atom_swap = coordinate_swap.get_miss_coord(molA, molB, nameA, nameB, df_atom_swap, 'A2B', swap_map[(swap_map['Swap A'] == nameA) & (swap_map['Swap B'] == nameB)]) # noqa: E501

# Reprint preamble text
line_start = coordinate_swap.print_preamble(molA_file, molB_new, len(miss_B), len(miss_A))

# Print new coordinates to file for molB
coordinate_swap.write_new_file(df_atom_swap, 'A2B', 'B2A', line_start, molA_file, molB_new, nameA, nameB, copy.deepcopy(molA.xyz[0]), miss_A) # noqa: E501

# Print new coordinates to file
# Reprint preamble text
line_start = coordinate_swap.print_preamble(molB_file, molA_new, len(miss_A), len(miss_B))

# Print new coordinates for molA
coordinate_swap.write_new_file(df_atom_swap, 'B2A', 'A2B', line_start, molB_file, molA_new, nameB, nameA, copy.deepcopy(molB.xyz[0]), miss_B) # noqa: E501

# Rename temp files
os.rename('A_hybrid_swap.gro', molB_dir + '/confout.gro')
os.rename('B_hybrid_swap.gro', molA_dir + '/confout.gro')

def process_top(self):
"""
Processes the input topologies in order to determine the atoms for alignment in the default GRO swapping
function. Output as csv files to prevent needing to re-run this step.
"""
if not os.path.exists('residue_connect.csv'):
df_top = pd.DataFrame()
for f, file_name in enumerate(self.top):
# Read file
input_file = coordinate_swap.read_top(file_name, self.resname_list[f])

# Determine the atom names corresponding to the atom numbers
start_line, atom_name, state = coordinate_swap.get_names(input_file)

# Determine the connectivity of all atoms
connect_1, connect_2, state_1, state_2 = [], [], [], [] # Atom 1 and atom 2 which are connected and which state they are dummy atoms # noqa: E501
for l, line in enumerate(input_file[start_line:]): # noqa: E741
line_sep = line.split(' ')
if line_sep[0] == ';':
continue
if line_sep[0] == '\n':
break
while '' in line_sep:
line_sep.remove('')
connect_1.append(atom_name[int(line_sep[0])-1])
connect_2.append(atom_name[int(line_sep[1])-1])
state_1.append(state[int(line_sep[0])-1])
state_2.append(state[int(line_sep[1])-1])
df = pd.DataFrame({'Resname': self.resname_list[f], 'Connect 1': connect_1, 'Connect 2': connect_2, 'State 1': state_1, 'State 2': state_2}) # noqa: E501
df_top = pd.concat([df_top, df])
df_top.to_csv('residue_connect.csv')
else:
df_top = pd.read_csv('residue_connect.csv')

if not os.path.exists('residue_swap_map.csv'):
df_map = pd.DataFrame()
for swap in self.swap_rep_pattern:
# Determine atoms not present in both molecules
X, Y = [int(swap[0][0]), int(swap[1][0])]
lam = {X: int(swap[0][1]), Y: int(swap[1][1])}
for A, B in zip([X, Y], [Y, X]):
input_A = coordinate_swap.read_top(self.top[A], self.resname_list[A])
start_line, A_name, state = coordinate_swap.get_names(input_A)
input_B = coordinate_swap.read_top(self.top[B], self.resname_list[B])
start_line, B_name, state = coordinate_swap.get_names(input_B)

A_only = [x for x in A_name if x not in B_name]
B_only = [x for x in B_name if x not in A_name]

# Seperate real to dummy switches
df = coordinate_swap.determine_connection(A_only, B_only, self.resname_list[A], self.resname_list[B], df_top, lam[A]) # noqa: E501

df_map = pd.concat([df_map, df])

df_map.to_csv('residue_swap_map.csv')
Loading
Loading