Skip to content

Fix: cherry-pick hmac encryption from main branch #3635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 102 additions & 17 deletions tensorrt_llm/executor/ipc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import hashlib
import hmac
import os
import pickle # nosec B403
import time
import traceback
from queue import Queue
Expand All @@ -23,51 +27,75 @@ class ZeroMqQueue:
}

def __init__(self,
address: Optional[str] = None,
address: Optional[tuple[str, Optional[bytes]]] = None,
*,
socket_type: int = zmq.PAIR,
is_server: bool,
is_async: bool = False,
name: Optional[str] = None):
name: Optional[str] = None,
use_hmac_encryption: bool = True):
'''
Parameters:
address (Tuple[str, str], optional): The address (tcp-ip_port, authkey) for the IPC. Defaults to None.
address (tuple[str, Optional[bytes]], optional): The address (tcp-ip_port, hmac_auth_key) for the IPC. Defaults to None. If hmac_auth_key is None and use_hmac_encryption is False, the queue will not use HMAC encryption.
is_server (bool): Whether the current process is the server or the client.
use_hmac_encryption (bool): Whether to use HMAC encryption for pickled data. Defaults to True.
'''

self.socket_type = socket_type
self.address = address or "tcp://127.0.0.1:*"
self.address_endpoint = address[
0] if address is not None else "tcp://127.0.0.1:*"
self.is_server = is_server
self.context = zmq.Context() if not is_async else zmq.asyncio.Context()
self.poller = None
self.socket = None

self._setup_done = False
self.name = name
self.socket_type = socket_type

self.socket = self.context.socket(socket_type)

self.hmac_key = address[1] if address is not None else None
self.use_hmac_encryption = use_hmac_encryption

# Check HMAC key condition
if self.use_hmac_encryption and self.is_server and self.hmac_key is not None:
raise ValueError(
"Server should not receive HMAC key when encryption is enabled")
elif self.use_hmac_encryption and not self.is_server and self.hmac_key is None:
raise ValueError(
"Client must receive HMAC key when encryption is enabled")
elif not self.use_hmac_encryption and self.hmac_key is not None:
raise ValueError(
"Server and client should not receive HMAC key when encryption is disabled"
)

if (socket_type == zmq.PAIR
and self.is_server) or socket_type == zmq.PULL:
self.socket.bind(
self.address
self.address_endpoint
) # Binds to the address and occupy a port immediately
self.address = self.socket.getsockopt(zmq.LAST_ENDPOINT).decode()
self.address_endpoint = self.socket.getsockopt(
zmq.LAST_ENDPOINT).decode()
print_colored_debug(
f"Server [{name}] bound to {self.address} in {self.socket_type_str[socket_type]}\n",
f"Server [{name}] bound to {self.address_endpoint} in {self.socket_type_str[socket_type]}\n",
"green")

if self.use_hmac_encryption:
# Initialize HMAC key for pickle encryption
logger.info(f"Generating a new HMAC key for server {self.name}")
self.hmac_key = os.urandom(32)

self.address = (self.address_endpoint, self.hmac_key)

def setup_lazily(self):
if self._setup_done:
return
self._setup_done = True

if not self.is_server:
print_colored_debug(
f"Client [{self.name}] connecting to {self.address} in {self.socket_type_str[self.socket_type]}\n",
f"Client [{self.name}] connecting to {self.address_endpoint} in {self.socket_type_str[self.socket_type]}\n",
"green")
self.socket.connect(self.address)
self.socket.connect(self.address_endpoint)

self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
Expand All @@ -88,12 +116,26 @@ def poll(self, timeout: int) -> bool:
def put(self, obj: Any):
self.setup_lazily()
with nvtx_range_debug("send", color="blue", category="IPC"):
self.socket.send_pyobj(obj)
if self.use_hmac_encryption:
# Send pickled data with HMAC appended
data = pickle.dumps(obj) # nosec B301
signed_data = self._sign_data(data)
self.socket.send(signed_data)
else:
# Send data without HMAC
self.socket.send_pyobj(obj)

async def put_async(self, obj: Any):
self.setup_lazily()
try:
await self.socket.send_pyobj(obj)
if self.use_hmac_encryption:
# Send pickled data with HMAC appended
data = pickle.dumps(obj) # nosec B301
signed_data = self._sign_data(data)
await self.socket.send(signed_data)
else:
# Send data without HMAC
await self.socket.send_pyobj(obj)
except TypeError as e:
logger.error(f"Cannot pickle {obj}")
raise e
Expand All @@ -107,12 +149,44 @@ async def put_async(self, obj: Any):
def get(self) -> Any:
self.setup_lazily()

return self.socket.recv_pyobj()
if self.use_hmac_encryption:
# Receive signed data with HMAC
signed_data = self.socket.recv()

# Split data and HMAC
data = signed_data[:-32]
actual_hmac = signed_data[-32:]

# Verify HMAC
if not self._verify_hmac(data, actual_hmac):
raise RuntimeError("HMAC verification failed")

obj = pickle.loads(data) # nosec B301
else:
# Receive data without HMAC
obj = self.socket.recv_pyobj()
return obj

async def get_async(self) -> Any:
self.setup_lazily()

return await self.socket.recv_pyobj()
if self.use_hmac_encryption:
# Receive signed data with HMAC
signed_data = await self.socket.recv()

# Split data and HMAC
data = signed_data[:-32]
actual_hmac = signed_data[-32:]

# Verify HMAC
if not self._verify_hmac(data, actual_hmac):
raise RuntimeError("HMAC verification failed")

obj = pickle.loads(data) # nosec B301
else:
# Receive data without HMAC
obj = await self.socket.recv_pyobj()
return obj

def close(self):
if self.socket:
Expand All @@ -122,6 +196,17 @@ def close(self):
self.context.term()
self.context = None

def _verify_hmac(self, data: bytes, actual_hmac: bytes) -> bool:
"""Verify the HMAC of received pickle data."""
expected_hmac = hmac.new(self.hmac_key, data, hashlib.sha256).digest()
return hmac.compare_digest(expected_hmac, actual_hmac)

def _sign_data(self, data_before_encoding: bytes) -> bytes:
"""Generate HMAC for data."""
hmac_signature = hmac.new(self.hmac_key, data_before_encoding,
hashlib.sha256).digest()
return data_before_encoding + hmac_signature

def __del__(self):
self.close()

Expand All @@ -133,7 +218,7 @@ class FusedIpcQueue:
''' A Queue-like container for IPC with optional message batched. '''

def __init__(self,
address: Optional[str] = None,
address: Optional[tuple[str, Optional[bytes]]] = None,
*,
is_server: bool,
fuse_message=False,
Expand Down Expand Up @@ -186,7 +271,7 @@ def get(self) -> Any:
return self.queue.get()

@property
def address(self) -> str:
def address(self) -> tuple[str, Optional[bytes]]:
return self.queue.address

def __del__(self):
Expand Down
11 changes: 6 additions & 5 deletions tensorrt_llm/executor/postproc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ class Output(NamedTuple):

def __init__(
self,
pull_pipe_addr: str,
push_pipe_addr: str,
pull_pipe_addr: tuple[str, Optional[bytes]],
push_pipe_addr: tuple[str, Optional[bytes]],
tokenizer_dir: str,
record_creator: Callable[
["PostprocWorker.Input", TransformersTokenizer], Any],
):
'''
Args:
pull_pipe_addr (str): The address of the input IPC.
push_pipe_addr (str): The address of the output IPC.
pull_pipe_addr (tuple[str, Optional[bytes]]): The address and HMAC key of the input IPC.
push_pipe_addr (tuple[str, Optional[bytes]]): The address and HMAC key of the output IPC.
tokenizer_dir (str): The directory to load tokenizer.
record_creator (Callable[["ResponsePostprocessWorker.Input"], Any]): A creator for creating a record for a request.
result_handler (Optional[Callable[[GenerationResultBase], Any]]): A callback handles the final result.
Expand Down Expand Up @@ -210,7 +210,8 @@ async def main():


@print_traceback_on_error
def postproc_worker_main(feedin_ipc_addr: str, feedout_ipc_addr: str,
def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]],
feedout_ipc_addr: tuple[str, Optional[bytes]],
tokenizer_dir: str, record_creator: Callable):
worker = PostprocWorker(feedin_ipc_addr,
feedout_ipc_addr,
Expand Down
20 changes: 12 additions & 8 deletions tensorrt_llm/executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from concurrent.futures import ProcessPoolExecutor
from queue import Empty, Queue
from typing import Any, Callable, List, NamedTuple
from typing import Any, Callable, List, NamedTuple, Optional

from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.llmapi.utils import print_colored_debug
Expand Down Expand Up @@ -41,8 +41,12 @@ def create_mpi_comm_session(
print_colored_debug(
f"Using RemoteMpiPoolSessionClient to bind to external MPI processes at {get_spawn_proxy_process_ipc_addr_env()}\n",
"yellow")
hmac_key = os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY")
# Convert the hex string to bytes
if hmac_key is not None:
hmac_key = bytes.fromhex(hmac_key)
return RemoteMpiCommSessionClient(
get_spawn_proxy_process_ipc_addr_env())
addr=get_spawn_proxy_process_ipc_addr_env(), hmac_key=hmac_key)
else:
print_colored_debug(
f"Using MpiCommSession to bind to external MPI processes\n",
Expand Down Expand Up @@ -123,9 +127,9 @@ def poll(self, timeout=None) -> bool:


class WorkerCommIpcAddrs(NamedTuple):
''' IPC addresses for communication with the worker processes. '''
request_queue_addr: str
request_error_queue_addr: str
result_queue_addr: str
stats_queue_addr: str
kv_cache_events_queue_addr: str
''' IPC addresses (str) and HMAC keys (bytes) for communication with the worker processes. '''
request_queue_addr: tuple[str, Optional[bytes]]
request_error_queue_addr: tuple[str, Optional[bytes]]
result_queue_addr: tuple[str, Optional[bytes]]
stats_queue_addr: tuple[str, Optional[bytes]]
kv_cache_events_queue_addr: tuple[str, Optional[bytes]]
5 changes: 3 additions & 2 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,13 +586,14 @@ def notify_proxy_threads_to_quit():
if is_leader and postproc_worker_config.enabled:
print_colored_debug(f"initiate postprocess workers...", "yellow")

proxy_result_queue: str = worker_queues.result_queue_addr
proxy_result_queue: tuple[
str, Optional[bytes]] = worker_queues.result_queue_addr

assert result_queues is not None
assert postproc_worker_config.postprocess_tokenizer_dir is not None
postproc_worker_pool = ProcessPoolExecutor(
max_workers=postproc_worker_config.num_postprocess_workers)
assert isinstance(proxy_result_queue, str)
assert isinstance(proxy_result_queue, tuple)
for i in range(postproc_worker_config.num_postprocess_workers):
fut = postproc_worker_pool.submit(
postproc_worker_main,
Expand Down
11 changes: 8 additions & 3 deletions tensorrt_llm/llmapi/mpi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,15 @@ class RemoteMpiCommSessionClient():
RemoteMpiCommSessionClient is a variant of MpiCommSession that is used to connect to a remote MPI pool.
'''

def __init__(self, addr: str):
def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
# FIXME: this is a hack to avoid circular import, resolve later
from tensorrt_llm.executor.ipc import ZeroMqQueue
self.addr = addr
print_colored_debug(
f"RemoteMpiCommSessionClient connecting to {addr}\n", "yellow")
self.queue = ZeroMqQueue(addr, is_server=False)
self.queue = ZeroMqQueue((addr, hmac_key),
is_server=False,
use_hmac_encryption=bool(hmac_key))
self._is_shutdown = False

def submit(self, task: Callable[..., T], *args, **kwargs) -> list:
Expand Down Expand Up @@ -311,12 +313,15 @@ class RemoteMpiCommSessionServer():
def __init__(self,
n_workers: int = 0,
addr: str = f'tcp://127.0.0.1:*',
hmac_key: Optional[bytes] = None,
comm=None,
is_comm: bool = False):
# FIXME: this is a hack to avoid circular import, resolve later
from tensorrt_llm.executor.ipc import ZeroMqQueue
self.addr = addr
self.queue = ZeroMqQueue(addr, is_server=True)
self.queue = ZeroMqQueue((addr, hmac_key),
is_server=True,
use_hmac_encryption=bool(hmac_key))
self.comm = comm

if self.comm is not None:
Expand Down
19 changes: 15 additions & 4 deletions tensorrt_llm/llmapi/trtllm-llmapi-launch
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,28 @@ function mpi_world_size {
fi
}

function export_free_tcp_addr {
# find free port with python -c, the port should start from 10000
local free_port=$(python -c 'import socket; s=socket.socket(); s.bind(("", 10000)); print(s.getsockname()[1]); s.close()')
function export_free_tcp_addr_for_spawn_proxy_process {
# find free port starting from 10012
local free_port=$(python -c 'import socket; s=socket.socket();
port = 10012
while True:
try:
s.bind(("", port))
break
except OSError:
port += 1
print(port); s.close()')
export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:${free_port}"
log_stderr "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR: $TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR"

export TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY=$(openssl rand -hex 32)
}


export tllm_mpi_size=$(mpi_world_size)
log_stderr "tllm_mpi_size: $tllm_mpi_size"

export_free_tcp_addr
export_free_tcp_addr_for_spawn_proxy_process

if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then
log_stderr "rank${mpi_rank} run ${task_with_command} in background"
Expand Down
3 changes: 0 additions & 3 deletions tests/unittest/llmapi/test_llm_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def test_llm_multi_node(engine_from_checkpoint: tempfile.TemporaryDirectory):
run_command(command)


@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5223608: timeout")
@skip_single_gpu
def test_llm_multi_node_pytorch():
nworkers = 2
Expand All @@ -313,10 +312,8 @@ def test_llm_multi_node_pytorch():
run_command(command)


@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5223608: timeout")
@skip_single_gpu
def test_llm_multi_node_with_postproc():
# TODO[chunweiy]: reactivate this later
nworkers = 2
test_case_file = os.path.join(os.path.dirname(__file__),
"run_llm_with_postproc.py")
Expand Down