Skip to content

Commit 7fb2d35

Browse files
authored
Merge pull request #368 from makaveli10/upgrade_trt_v0_18
Upgrade tensorrt_llm to v0.18.2
2 parents af50fed + 47ee035 commit 7fb2d35

File tree

8 files changed

+134
-57
lines changed

8 files changed

+134
-57
lines changed

README.md

+2-7
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,11 @@ client(hls_url="http://as-hls-ww-live.akamaized.net/pool_904/live/ww/bbc_1xtra/b
141141

142142
## Browser Extensions
143143
- Run the server with your desired backend as shown [here](https://github.com/collabora/WhisperLive?tab=readme-ov-file#running-the-server).
144-
- Transcribe audio directly from your browser using our Chrome or Firefox extensions. Refer to [Audio-Transcription-Chrome](https://github.com/collabora/whisper-live/tree/main/Audio-Transcription-Chrome#readme) and [Audio-Transcription-Firefox](https://github.com/collabora/whisper-live/tree/main/Audio-Transcription-Firefox#readme) for setup instructions.
145-
146-
## Whisper Live Server in Docker
147-
- GPU
148-
- Faster-Whisper
149-
```bash
144+
- Transcribe audio directly from your browser using our Chrome or Firefox extensions. Refer to [Audio-Transcription-Chrome](https://github.com/collabora/whisper-live/tree/main/Audio-Transcription-Chrome#readme) and https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md
150145
docker run -it --gpus all -p 9090:9090 ghcr.io/collabora/whisperlive-gpu:latest
151146
```
152147
153-
- TensorRT.
148+
- TensorRT. Refer to [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup and more tensorrt backend configurations.
154149
```bash
155150
docker run -p 9090:9090 --runtime=nvidia --gpus all --entrypoint /bin/bash -it ghcr.io/collabora/whisperlive-tensorrt
156151

TensorRT_whisper.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# WhisperLive-TensorRT
22
We have only tested the TensorRT backend in docker so, we recommend docker for a smooth TensorRT backend setup.
3-
**Note**: We use `tensorrt_llm==0.15.0.dev2024111200`
3+
**Note**: We use `tensorrt_llm==0.18.2`
44

55
## Installation
66
- Install [docker](https://docs.docker.com/engine/install/)
@@ -36,3 +36,11 @@ python3 run_server.py --port 9090 \
3636
--trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_float16" \
3737
--trt_multilingual
3838
```
39+
40+
By default trt_backend uses cpp_session, to use python session pass `--trt_py_session` to run_server.py
41+
```bash
42+
python3 run_server.py --port 9090 \
43+
--backend tensorrt \
44+
--trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_float16" \
45+
--trt_py_session
46+
```

docker/Dockerfile.tensorrt

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS base
1+
FROM nvidia/cuda:12.8.1-base-ubuntu22.04 AS base
22

33
ARG DEBIAN_FRONTEND=noninteractive
44

55
RUN apt-get update && apt-get install -y \
66
python3.10 python3-pip openmpi-bin libopenmpi-dev git git-lfs wget \
7+
&& apt install python-is-python3 \
8+
&& pip install --upgrade pip setuptools \
79
&& rm -rf /var/lib/apt/lists/*
810

911
FROM base AS devel
10-
RUN pip3 install --no-cache-dir -U tensorrt_llm==0.15.0.dev2024111200 --extra-index-url https://pypi.nvidia.com
12+
RUN pip install --no-cache-dir -U tensorrt_llm==0.18.2 --extra-index-url https://pypi.nvidia.com
1113
WORKDIR /app
12-
RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git && cd TensorRT-LLM && \
13-
git checkout c629546ce429623c8a163633095230154a6f0574 && cd ../ && \
14-
mv TensorRT-LLM/examples ./TensorRT-LLM-examples && \
15-
rm -rf TensorRT-LLM
16-
14+
RUN git clone -b v0.18.2 https://github.com/NVIDIA/TensorRT-LLM.git \
15+
&& mv TensorRT-LLM/examples ./TensorRT-LLM-examples \
16+
&& rm -rf TensorRT-LLM
1717

1818
FROM devel AS release
1919
WORKDIR /app
@@ -25,7 +25,6 @@ RUN apt update && bash setup.sh && rm setup.sh
2525

2626
COPY requirements/server.txt .
2727
RUN pip install --no-cache-dir -r server.txt && rm server.txt
28-
RUN pip install pynvml==11.5.0
2928
COPY whisper_live ./whisper_live
3029
COPY scripts/build_whisper_tensorrt.sh .
3130
COPY run_server.py .

run_server.py

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
parser.add_argument('--trt_multilingual', '-m',
2222
action="store_true",
2323
help='Boolean only for TensorRT model. True if multilingual.')
24+
parser.add_argument('--trt_py_session',
25+
action="store_true",
26+
help='Boolean only for TensorRT model. Use python session or cpp session, By default uses Cpp.')
2427
parser.add_argument('--omp_num_threads', '-omp',
2528
type=int,
2629
default=1,
@@ -46,5 +49,6 @@
4649
faster_whisper_custom_model_path=args.faster_whisper_custom_model_path,
4750
whisper_tensorrt_path=args.trt_model_path,
4851
trt_multilingual=args.trt_multilingual,
52+
trt_py_session=args.trt_py_session,
4953
single_model=not args.no_single_model,
5054
)

scripts/build_whisper_tensorrt.sh

+3-5
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ download_and_build_model() {
5454
local inference_precision="float16"
5555
local weight_only_precision="${2:-float16}"
5656
local max_beam_width=4
57-
local max_batch_size=1
57+
local max_batch_size=4
5858

5959
echo "Downloading $model_name..."
6060
# wget --directory-prefix=assets "$model_url"
@@ -80,7 +80,6 @@ download_and_build_model() {
8080
--checkpoint_dir "${checkpoint_dir}/encoder" \
8181
--output_dir "${output_dir}/encoder" \
8282
--moe_plugin disable \
83-
--enable_xqa disable \
8483
--max_batch_size "$max_batch_size" \
8584
--gemm_plugin disable \
8685
--bert_attention_plugin "$inference_precision" \
@@ -92,11 +91,10 @@ download_and_build_model() {
9291
--checkpoint_dir "${checkpoint_dir}/decoder" \
9392
--output_dir "${output_dir}/decoder" \
9493
--moe_plugin disable \
95-
--enable_xqa disable \
9694
--max_beam_width "$max_beam_width" \
9795
--max_batch_size "$max_batch_size" \
98-
--max_seq_len 200 \
99-
--max_input_len 14 \
96+
--max_seq_len 225 \
97+
--max_input_len 32 \
10098
--max_encoder_input_len 3000 \
10199
--gemm_plugin "$inference_precision" \
102100
--bert_attention_plugin "$inference_precision" \

whisper_live/backend/trt_backend.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,18 @@ class ServeClientTensorRT(ServeClientBase):
1111
SINGLE_MODEL = None
1212
SINGLE_MODEL_LOCK = threading.Lock()
1313

14-
def __init__(self, websocket, task="transcribe", multilingual=False, language=None, client_uid=None, model=None, single_model=False):
14+
def __init__(
15+
self,
16+
websocket,
17+
task="transcribe",
18+
multilingual=False,
19+
language=None,
20+
client_uid=None,
21+
model=None,
22+
single_model=False,
23+
use_py_session=False,
24+
max_new_tokens=225,
25+
):
1526
"""
1627
Initialize a ServeClient instance.
1728
The Whisper model is initialized based on the client's language and device availability.
@@ -26,21 +37,24 @@ def __init__(self, websocket, task="transcribe", multilingual=False, language=No
2637
language (str, optional): The language for transcription. Defaults to None.
2738
client_uid (str, optional): A unique identifier for the client. Defaults to None.
2839
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
40+
use_py_session (bool, optional): Use python session or cpp session. Defaults to Cpp Session.
41+
max_new_tokens (int, optional): Max number of tokens to generate.
2942
3043
"""
3144
super().__init__(client_uid, websocket)
3245
self.language = language if multilingual else "en"
3346
self.task = task
3447
self.eos = False
48+
self.max_new_tokens = max_new_tokens
3549

3650
if single_model:
3751
if ServeClientTensorRT.SINGLE_MODEL is None:
38-
self.create_model(model, multilingual)
52+
self.create_model(model, multilingual, use_py_session=use_py_session)
3953
ServeClientTensorRT.SINGLE_MODEL = self.transcriber
4054
else:
4155
self.transcriber = ServeClientTensorRT.SINGLE_MODEL
4256
else:
43-
self.create_model(model, multilingual)
57+
self.create_model(model, multilingual, use_py_session=use_py_session)
4458

4559
# threading
4660
self.trans_thread = threading.Thread(target=self.speech_to_text)
@@ -52,7 +66,7 @@ def __init__(self, websocket, task="transcribe", multilingual=False, language=No
5266
"backend": "tensorrt"
5367
}))
5468

55-
def create_model(self, model, multilingual, warmup=True):
69+
def create_model(self, model, multilingual, warmup=True, use_py_session=False):
5670
"""
5771
Instantiates a new model, sets it as the transcriber and does warmup if desired.
5872
"""
@@ -62,7 +76,9 @@ def create_model(self, model, multilingual, warmup=True):
6276
device="cuda",
6377
is_multilingual=multilingual,
6478
language=self.language,
65-
task=self.task
79+
task=self.task,
80+
use_py_session=use_py_session,
81+
max_output_len=self.max_new_tokens,
6682
)
6783
if warmup:
6884
self.warmup()
@@ -117,7 +133,7 @@ def transcribe_audio(self, input_bytes):
117133
mel, duration = self.transcriber.log_mel_spectrogram(input_bytes)
118134
last_segment = self.transcriber.transcribe(
119135
mel,
120-
text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>"
136+
text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>",
121137
)
122138
if ServeClientTensorRT.SINGLE_MODEL:
123139
ServeClientTensorRT.SINGLE_MODEL_LOCK.release()

whisper_live/server.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(self):
153153

154154
def initialize_client(
155155
self, websocket, options, faster_whisper_custom_model_path,
156-
whisper_tensorrt_path, trt_multilingual
156+
whisper_tensorrt_path, trt_multilingual, trt_py_session=False,
157157
):
158158
client: Optional[ServeClientBase] = None
159159

@@ -168,6 +168,7 @@ def initialize_client(
168168
client_uid=options["uid"],
169169
model=whisper_tensorrt_path,
170170
single_model=self.single_model,
171+
use_py_session=trt_py_session,
171172
)
172173
logging.info("Running TensorRT backend.")
173174
except Exception as e:
@@ -248,7 +249,7 @@ def get_audio_from_websocket(self, websocket):
248249
return np.frombuffer(frame_data, dtype=np.float32)
249250

250251
def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
251-
whisper_tensorrt_path, trt_multilingual):
252+
whisper_tensorrt_path, trt_multilingual, trt_py_session=False):
252253
try:
253254
logging.info("New client connected")
254255
options = websocket.recv()
@@ -267,7 +268,7 @@ def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
267268
if self.backend.is_tensorrt():
268269
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
269270
self.initialize_client(websocket, options, faster_whisper_custom_model_path,
270-
whisper_tensorrt_path, trt_multilingual)
271+
whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session)
271272
return True
272273
except json.JSONDecodeError:
273274
logging.error("Failed to decode JSON from client")
@@ -299,11 +300,12 @@ def process_audio_frames(self, websocket):
299300
return True
300301

301302
def recv_audio(self,
302-
websocket,
303+
websocket,
303304
backend: BackendType = BackendType.FASTER_WHISPER,
304305
faster_whisper_custom_model_path=None,
305306
whisper_tensorrt_path=None,
306-
trt_multilingual=False):
307+
trt_multilingual=False,
308+
trt_py_session=False):
307309
"""
308310
Receive audio chunks from a client in an infinite loop.
309311
@@ -330,7 +332,7 @@ def recv_audio(self,
330332
"""
331333
self.backend = backend
332334
if not self.handle_new_connection(websocket, faster_whisper_custom_model_path,
333-
whisper_tensorrt_path, trt_multilingual):
335+
whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session):
334336
return
335337

336338
try:
@@ -354,6 +356,7 @@ def run(self,
354356
faster_whisper_custom_model_path=None,
355357
whisper_tensorrt_path=None,
356358
trt_multilingual=False,
359+
trt_py_session=False,
357360
single_model=False):
358361
"""
359362
Run the transcription server.
@@ -381,7 +384,8 @@ def run(self,
381384
backend=BackendType(backend),
382385
faster_whisper_custom_model_path=faster_whisper_custom_model_path,
383386
whisper_tensorrt_path=whisper_tensorrt_path,
384-
trt_multilingual=trt_multilingual
387+
trt_multilingual=trt_multilingual,
388+
trt_py_session=trt_py_session,
385389
),
386390
host,
387391
port

0 commit comments

Comments
 (0)