-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathgenerate.py
133 lines (115 loc) · 4.99 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0
import argparse
from pprint import pprint
import torch
import os
from byprot import utils
from byprot.models.lm.dplm import DiffusionProteinLanguageModel
def format_check(args):
seq_list = args.cond_seq
cond_position = args.cond_position
assert len(seq_list) == len(
cond_position
), "The length of cond_seq and cond_position does not match."
position_list = []
for pos in cond_position:
pos = pos.split("-")
assert (
len(pos) == 2
), "The format of position is illegal, which is not correctly splited by '-'"
start_pos, end_pos = int(pos[0]), int(pos[1])
assert end_pos >= start_pos, "The end position is smaller than start position."
position_list.append((start_pos, end_pos))
# check if position segment has overlap
temp_position_list = [pos for tup in position_list for pos in tup]
for i in range(1, len(temp_position_list) - 2, 2):
assert (
temp_position_list[i + 1] > temp_position_list[i]
), "The position segment has overlap, which is not supported"
# check if the length of each position segment and seq segment matches
for i, (start_pos, end_pos) in enumerate(position_list):
assert len(seq_list[i]) == (
end_pos - start_pos + 1
), "The length of each position segment and seq segment does not match."
return seq_list, position_list
def get_initial(args, length, tokenizer, device):
seq = ["<mask>"] * length
if args.cond_seq is not None:
# Inpainting generation, conditioned on some sequence segments
seq_segment_list, position_list = format_check(args)
for i, (start_pos, end_pos) in enumerate(position_list):
seq[start_pos : end_pos + 1] = [char for char in seq_segment_list[i]]
seq = ["".join(seq)]
init_seq = seq * args.num_seqs
batch = tokenizer.batch_encode_plus(
init_seq, add_special_tokens=True, padding="longest", return_tensors="pt"
)
batch = {
"input_ids": batch["input_ids"],
"input_mask": batch["attention_mask"].bool(),
}
# if cond_seq is None:
# batch['input_ids'], _ = _full_mask(batch['input_ids'].clone(), collater.alphabet)
batch = utils.recursive_to(batch, device)
pprint(batch)
return batch
def generate(args):
model = DiffusionProteinLanguageModel.from_pretrained(args.model_name)
tokenizer = model.tokenizer
model = model.eval()
model = model.cuda()
device = next(model.parameters()).device
for seq_len in args.seq_lens:
max_iter = args.max_iter
batch = get_initial(args, seq_len, tokenizer, device)
partial_mask = batch["input_ids"].ne(model.mask_id).type_as(batch["input_mask"])
with torch.cuda.amp.autocast():
outputs = model.generate(
batch=batch,
tokenizer=tokenizer,
max_iter=max_iter,
sampling_strategy=args.sampling_strategy,
partial_masks=partial_mask,
disable_resample=args.disable_resample,
resample_ratio=args.resample_ratio,
temperature=args.temperature,
)
output_tokens = outputs[0]
print("final:")
output_results = [
"".join(seq.split(" "))
for seq in tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
]
pprint(output_results)
os.makedirs(args.saveto, exist_ok=True)
saveto_name = os.path.join(args.saveto, f"iter_{max_iter}_L_{seq_len}.fasta")
fp_save = open(saveto_name, "w")
for idx, seq in enumerate(output_results):
fp_save.write(f">SEQUENCE_{idx}_L={seq_len}\n")
fp_save.write(f"{seq}\n")
fp_save.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--model_name", type=str, default="airkingbd/dplm_150m")
parser.add_argument("--num_seqs", type=int, default=40)
parser.add_argument("--seq_lens", nargs="*", type=int)
parser.add_argument("--saveto", type=str, default="gen.fasta")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--sampling_strategy", type=str, default="gumbel_argmax")
parser.add_argument("--max_iter", type=int, default=500)
parser.add_argument("--disable_resample", action="store_true", default=False)
parser.add_argument("--resample_ratio", type=float, default=0.25)
# inpainting
# Note: the format of --cond_position and --cond_seq should split by ','
# the number and the length of segments should match.
# Like this:
# --cond_position 1-4 8-10 (position starts from 0)
# --cond_seq ALVE EME
parser.add_argument("--cond_position", nargs="*", type=str)
parser.add_argument("--cond_seq", nargs="*", type=str)
args = parser.parse_args()
generate(args)
if __name__ == "__main__":
main()