Upload folder using huggingface_hub
Browse files- model_repo_whisper_qwen_trtllm/tensorrt_llm/1/.gitkeep +0 -0
- model_repo_whisper_qwen_trtllm/tensorrt_llm/1/model.py +947 -0
- model_repo_whisper_qwen_trtllm/tensorrt_llm/config.pbtxt +577 -0
- model_repo_whisper_qwen_trtllm/tensorrt_llm/config.template +577 -0
- model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/fbank.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/model.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/whisper_trtllm.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/0/fbank.py +91 -0
- model_repo_whisper_qwen_trtllm/whisper/0/mel_filters.npz +3 -0
- model_repo_whisper_qwen_trtllm/whisper/0/model.py +346 -0
- model_repo_whisper_qwen_trtllm/whisper/0/whisper_trtllm.py +278 -0
- model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/fbank.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/model.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/whisper_trtllm.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/1/fbank.py +91 -0
- model_repo_whisper_qwen_trtllm/whisper/1/mel_filters.npz +3 -0
- model_repo_whisper_qwen_trtllm/whisper/1/model.py +318 -0
- model_repo_whisper_qwen_trtllm/whisper/1/whisper_trtllm.py +212 -0
- model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/fbank.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/model.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/whisper_trtllm.cpython-310.pyc +0 -0
- model_repo_whisper_qwen_trtllm/whisper/2/fbank.py +91 -0
- model_repo_whisper_qwen_trtllm/whisper/2/mel_filters.npz +3 -0
- model_repo_whisper_qwen_trtllm/whisper/2/model.py +346 -0
- model_repo_whisper_qwen_trtllm/whisper/2/whisper_trtllm.py +278 -0
- model_repo_whisper_qwen_trtllm/whisper/config.pbtxt +61 -0
model_repo_whisper_qwen_trtllm/tensorrt_llm/1/.gitkeep
ADDED
File without changes
|
model_repo_whisper_qwen_trtllm/tensorrt_llm/1/model.py
ADDED
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
from random import randint
|
7 |
+
from threading import Lock, Thread
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import triton_python_backend_utils as pb_utils
|
12 |
+
from torch import from_numpy
|
13 |
+
from torch.utils.dlpack import from_dlpack
|
14 |
+
|
15 |
+
import tensorrt_llm.bindings.executor as trtllm
|
16 |
+
|
17 |
+
|
18 |
+
def get_input_tensor_by_name(request,
|
19 |
+
name,
|
20 |
+
expected_batch_size=None,
|
21 |
+
batch_index=None):
|
22 |
+
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
23 |
+
if tensor is None:
|
24 |
+
return None
|
25 |
+
|
26 |
+
if tensor.is_cpu():
|
27 |
+
tensor = tensor.as_numpy()
|
28 |
+
else:
|
29 |
+
tensor = from_dlpack(tensor.to_dlpack())
|
30 |
+
|
31 |
+
if expected_batch_size is not None and tensor.shape[
|
32 |
+
0] != expected_batch_size:
|
33 |
+
raise pb_utils.TritonModelException(
|
34 |
+
f"Expected batch size doesn't match batch size for tensor {name}. Expected {expected_batch_size} got {tensor.shape[0]}"
|
35 |
+
)
|
36 |
+
|
37 |
+
if batch_index is not None and expected_batch_size is not None and batch_index >= expected_batch_size:
|
38 |
+
raise pb_utils.TritonModelException(
|
39 |
+
f"Invalid batch index in get_input_tensor_by_name for {name}")
|
40 |
+
|
41 |
+
if batch_index is not None:
|
42 |
+
# Add leading 1 batch dimension
|
43 |
+
if isinstance(tensor, np.ndarray):
|
44 |
+
return np.expand_dims(tensor[batch_index], axis=0)
|
45 |
+
elif isinstance(tensor, torch.Tensor):
|
46 |
+
return torch.unsqueeze(tensor[batch_index], dim=0)
|
47 |
+
else:
|
48 |
+
return tensor
|
49 |
+
|
50 |
+
|
51 |
+
def get_input_scalar_by_name(request,
|
52 |
+
name,
|
53 |
+
expected_batch_size=1,
|
54 |
+
batch_index=0):
|
55 |
+
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
56 |
+
if tensor is None:
|
57 |
+
return None
|
58 |
+
tensor = tensor.as_numpy()
|
59 |
+
|
60 |
+
if tensor.size != expected_batch_size:
|
61 |
+
raise pb_utils.TritonModelException(
|
62 |
+
f"Expected a scalar tensor for tensor {name}")
|
63 |
+
|
64 |
+
return tensor.item(batch_index)
|
65 |
+
|
66 |
+
|
67 |
+
def read_parameter_as_type(value, name, pytype=str):
|
68 |
+
if value == "":
|
69 |
+
return None
|
70 |
+
if value.startswith("${") and value.endswith("}"):
|
71 |
+
return None
|
72 |
+
if pytype is bool:
|
73 |
+
return value.lower() in ["1", "true"]
|
74 |
+
try:
|
75 |
+
result = pytype(value)
|
76 |
+
return result
|
77 |
+
except:
|
78 |
+
pb_utils.Logger.log_warning(
|
79 |
+
f"Could not read parameter '{name}' with value '{value}', will use default."
|
80 |
+
)
|
81 |
+
return None
|
82 |
+
|
83 |
+
|
84 |
+
def get_parameter(model_config, name, pytype=str):
|
85 |
+
if name not in model_config['parameters']:
|
86 |
+
return None
|
87 |
+
return read_parameter_as_type(
|
88 |
+
model_config['parameters'][name]['string_value'], name, pytype)
|
89 |
+
|
90 |
+
|
91 |
+
def convert_word_list(word_list):
|
92 |
+
if word_list is None:
|
93 |
+
return None
|
94 |
+
word_list = word_list.tolist()
|
95 |
+
if len(word_list) == 0 or len(word_list[0]) != 2:
|
96 |
+
raise pb_utils.TritonModelException(f"Invalid format for word list.")
|
97 |
+
words, indices = word_list[0]
|
98 |
+
result = []
|
99 |
+
current_index = 0
|
100 |
+
for i in indices:
|
101 |
+
if i == -1:
|
102 |
+
continue
|
103 |
+
if i > len(words):
|
104 |
+
raise pb_utils.TritonModelException(
|
105 |
+
f"Invalid format for word list.")
|
106 |
+
current_word = []
|
107 |
+
while current_index < i:
|
108 |
+
current_word.append(words[current_index])
|
109 |
+
current_index += 1
|
110 |
+
result.append(current_word)
|
111 |
+
return result
|
112 |
+
|
113 |
+
|
114 |
+
def parse_medusa_choices(medusa_choices):
|
115 |
+
if medusa_choices is None:
|
116 |
+
return None
|
117 |
+
try:
|
118 |
+
result = json.loads(
|
119 |
+
"[" + medusa_choices.replace("{", "[").replace("}", "]") + "]")
|
120 |
+
assert isinstance(result, list) and len(result) > 0
|
121 |
+
assert all([isinstance(x, list) for x in result])
|
122 |
+
assert all([isinstance(y, int) for x in result for y in x])
|
123 |
+
except Exception:
|
124 |
+
raise pb_utils.TritonModelException(
|
125 |
+
"Invalid format for medusa_choices")
|
126 |
+
return result
|
127 |
+
|
128 |
+
|
129 |
+
def get_sampling_config_from_request(request, batch_size=1, batch_index=0):
|
130 |
+
kwargs = {}
|
131 |
+
kwargs['beam_width'] = get_input_scalar_by_name(
|
132 |
+
request, 'beam_width', batch_size, batch_index) or 1
|
133 |
+
kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k',
|
134 |
+
batch_size, batch_index)
|
135 |
+
kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p',
|
136 |
+
batch_size, batch_index)
|
137 |
+
kwargs['top_p'] = None if kwargs['top_p'] is None or kwargs[
|
138 |
+
'top_p'] <= 0 else kwargs['top_p']
|
139 |
+
kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed',
|
140 |
+
batch_size, batch_index)
|
141 |
+
kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature',
|
142 |
+
batch_size, batch_index)
|
143 |
+
kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length',
|
144 |
+
batch_size, batch_index)
|
145 |
+
kwargs['repetition_penalty'] = get_input_scalar_by_name(
|
146 |
+
request, 'repetition_penalty', batch_size, batch_index)
|
147 |
+
kwargs['presence_penalty'] = get_input_scalar_by_name(
|
148 |
+
request, 'presence_penalty', batch_size, batch_index)
|
149 |
+
kwargs['frequency_penalty'] = get_input_scalar_by_name(
|
150 |
+
request, 'frequency_penalty', batch_size, batch_index)
|
151 |
+
kwargs['length_penalty'] = get_input_scalar_by_name(
|
152 |
+
request, 'len_penalty', batch_size, batch_index)
|
153 |
+
kwargs['top_p_min'] = get_input_scalar_by_name(request,
|
154 |
+
'runtime_top_p_min',
|
155 |
+
batch_size, batch_index)
|
156 |
+
kwargs['top_p_reset_ids'] = get_input_scalar_by_name(
|
157 |
+
request, 'runtime_top_p_reset_ids', batch_size, batch_index)
|
158 |
+
kwargs['top_p_decay'] = get_input_scalar_by_name(request,
|
159 |
+
'runtime_top_p_decay',
|
160 |
+
batch_size, batch_index)
|
161 |
+
kwargs['beam_search_diversity_rate'] = get_input_scalar_by_name(
|
162 |
+
request, 'beam_search_diversity_rate', batch_size, batch_index)
|
163 |
+
kwargs['early_stopping'] = get_input_scalar_by_name(
|
164 |
+
request, 'early_stopping', batch_size, batch_index)
|
165 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
166 |
+
return trtllm.SamplingConfig(**kwargs)
|
167 |
+
|
168 |
+
|
169 |
+
def get_output_config_from_request(request,
|
170 |
+
exclude_input_from_output,
|
171 |
+
batch_size=1,
|
172 |
+
batch_index=0):
|
173 |
+
kwargs = {}
|
174 |
+
kwargs["return_log_probs"] = get_input_scalar_by_name(
|
175 |
+
request, 'return_log_probs', batch_size, batch_index)
|
176 |
+
kwargs["return_context_logits"] = get_input_scalar_by_name(
|
177 |
+
request, 'return_context_logits', batch_size, batch_index)
|
178 |
+
kwargs["return_generation_logits"] = get_input_scalar_by_name(
|
179 |
+
request, 'return_generation_logits', batch_size, batch_index)
|
180 |
+
kwargs["exclude_input_from_output"] = exclude_input_from_output
|
181 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
182 |
+
return trtllm.OutputConfig(**kwargs)
|
183 |
+
|
184 |
+
|
185 |
+
def get_external_draft_tokens_config_from_request(request,
|
186 |
+
batch_size=1,
|
187 |
+
batch_index=0):
|
188 |
+
kwargs = {}
|
189 |
+
draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids',
|
190 |
+
batch_size, batch_index)
|
191 |
+
if draft_input_ids is not None:
|
192 |
+
kwargs['tokens'] = draft_input_ids[0].tolist()
|
193 |
+
draft_logits = get_input_tensor_by_name(request, 'draft_logits',
|
194 |
+
batch_size, batch_index)
|
195 |
+
if draft_logits is not None:
|
196 |
+
kwargs['logits'] = from_numpy(draft_logits).squeeze()
|
197 |
+
kwargs['acceptance_threshold'] = get_input_scalar_by_name(
|
198 |
+
request, 'draft_acceptance_threshold', batch_size, batch_index)
|
199 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
200 |
+
if len(kwargs) > 0:
|
201 |
+
return trtllm.ExternalDraftTokensConfig(**kwargs)
|
202 |
+
return None
|
203 |
+
|
204 |
+
|
205 |
+
def get_prompt_tuning_config_from_request(request,
|
206 |
+
batch_size=1,
|
207 |
+
batch_index=0):
|
208 |
+
# prompt_vocab_size is unused by executor.
|
209 |
+
kwargs = {}
|
210 |
+
prompt_embedding_table = get_input_tensor_by_name(
|
211 |
+
request, 'prompt_embedding_table', batch_size, batch_index)
|
212 |
+
if prompt_embedding_table is not None:
|
213 |
+
if isinstance(prompt_embedding_table, np.ndarray):
|
214 |
+
kwargs["embedding_table"] = from_numpy(
|
215 |
+
prompt_embedding_table).squeeze()
|
216 |
+
elif isinstance(prompt_embedding_table, torch.Tensor):
|
217 |
+
kwargs["embedding_table"] = from_dlpack(
|
218 |
+
prompt_embedding_table.to_dlpack()).squeeze(dim=0)
|
219 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
220 |
+
if len(kwargs) > 0:
|
221 |
+
return trtllm.PromptTuningConfig(**kwargs)
|
222 |
+
return None
|
223 |
+
|
224 |
+
|
225 |
+
def get_lora_config_from_request(request, batch_size=1, batch_index=0):
|
226 |
+
kwargs = {}
|
227 |
+
kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id',
|
228 |
+
batch_size, batch_index)
|
229 |
+
lora_weights = get_input_tensor_by_name(request, 'lora_weights',
|
230 |
+
batch_size, batch_index)
|
231 |
+
if lora_weights is not None:
|
232 |
+
kwargs["weights"] = from_numpy(lora_weights).squeeze()
|
233 |
+
lora_config = get_input_tensor_by_name(request, 'lora_config', batch_size,
|
234 |
+
batch_index)
|
235 |
+
if lora_config is not None:
|
236 |
+
kwargs["config"] = from_numpy(lora_config).squeeze()
|
237 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
238 |
+
if len(kwargs) > 0:
|
239 |
+
return trtllm.LoraConfig(**kwargs)
|
240 |
+
return None
|
241 |
+
|
242 |
+
|
243 |
+
def convert_request(request, exclude_input_from_output, decoupled):
|
244 |
+
inputs = {}
|
245 |
+
input_token_ids = get_input_tensor_by_name(request, 'input_ids')
|
246 |
+
if input_token_ids is None:
|
247 |
+
raise pb_utils.TritonModelException(
|
248 |
+
"A value is required for input_ids")
|
249 |
+
if len(input_token_ids.shape) != 2:
|
250 |
+
raise pb_utils.TritonModelException(f"Invalid format for input_ids")
|
251 |
+
batch_size = input_token_ids.shape[0]
|
252 |
+
requests = []
|
253 |
+
for batch_index in range(0, batch_size):
|
254 |
+
input_token_ids = get_input_tensor_by_name(request, 'input_ids',
|
255 |
+
batch_size, batch_index)[0]
|
256 |
+
if input_token_ids is None:
|
257 |
+
raise pb_utils.TritonModelException(
|
258 |
+
"A value is required for input_ids")
|
259 |
+
input_token_ids = input_token_ids.tolist()
|
260 |
+
if len(input_token_ids) == 0:
|
261 |
+
raise pb_utils.TritonModelException(
|
262 |
+
f"Invalid format for input_ids")
|
263 |
+
|
264 |
+
input_length = get_input_scalar_by_name(request, 'input_lengths',
|
265 |
+
batch_size, batch_index)
|
266 |
+
if input_length is None:
|
267 |
+
input_length = len(input_token_ids)
|
268 |
+
# Trim input token ids with input_lengths
|
269 |
+
inputs['input_token_ids'] = input_token_ids[0:input_length]
|
270 |
+
|
271 |
+
inputs['max_new_tokens'] = get_input_scalar_by_name(
|
272 |
+
request, 'request_output_len', batch_size, batch_index)
|
273 |
+
if inputs['max_new_tokens'] is None:
|
274 |
+
raise pb_utils.TritonModelException(
|
275 |
+
"A value is required for request_output_len")
|
276 |
+
inputs['streaming'] = get_input_scalar_by_name(request, 'streaming',
|
277 |
+
batch_size, batch_index)
|
278 |
+
if inputs['streaming'] and not decoupled:
|
279 |
+
raise pb_utils.TritonModelException(
|
280 |
+
"Streaming is only supported in decoupled mode.")
|
281 |
+
inputs['end_id'] = get_input_scalar_by_name(request, 'end_id',
|
282 |
+
batch_size, batch_index)
|
283 |
+
inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id',
|
284 |
+
batch_size, batch_index)
|
285 |
+
inputs['stop_words'] = convert_word_list(
|
286 |
+
get_input_tensor_by_name(request, 'stop_words_list', batch_size,
|
287 |
+
batch_index))
|
288 |
+
inputs['bad_words'] = convert_word_list(
|
289 |
+
get_input_tensor_by_name(request, 'bad_words_list', batch_size,
|
290 |
+
batch_index))
|
291 |
+
embedding_bias = get_input_tensor_by_name(request, 'embedding_bias',
|
292 |
+
batch_size, batch_index)
|
293 |
+
if embedding_bias is not None and embedding_bias.size != 0:
|
294 |
+
inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze()
|
295 |
+
|
296 |
+
sampling_config = get_sampling_config_from_request(
|
297 |
+
request, batch_size, batch_index)
|
298 |
+
output_config = get_output_config_from_request(
|
299 |
+
request, exclude_input_from_output, batch_size, batch_index)
|
300 |
+
external_draft_tokens_config = get_external_draft_tokens_config_from_request(
|
301 |
+
request, batch_size, batch_index)
|
302 |
+
prompt_tuning_config = get_prompt_tuning_config_from_request(
|
303 |
+
request, batch_size, batch_index)
|
304 |
+
lora_config = get_lora_config_from_request(request, batch_size,
|
305 |
+
batch_index)
|
306 |
+
|
307 |
+
requests.append(
|
308 |
+
trtllm.Request(
|
309 |
+
**inputs,
|
310 |
+
sampling_config=sampling_config,
|
311 |
+
output_config=output_config,
|
312 |
+
external_draft_tokens_config=external_draft_tokens_config,
|
313 |
+
prompt_tuning_config=prompt_tuning_config,
|
314 |
+
lora_config=lora_config,
|
315 |
+
))
|
316 |
+
return requests
|
317 |
+
|
318 |
+
|
319 |
+
def convert_response(response, batch_index):
|
320 |
+
if response.has_error():
|
321 |
+
return pb_utils.InferenceResponse(output_tensors=[],
|
322 |
+
error=pb_utils.TritonError(
|
323 |
+
response.error_msg)), True
|
324 |
+
result = response.result
|
325 |
+
beam_lengths = np.expand_dims(
|
326 |
+
np.array([len(beam) for beam in result.output_token_ids], np.int32), 0)
|
327 |
+
max_beam_length = max([len(beam) for beam in result.output_token_ids])
|
328 |
+
output_ids = np.full((1, len(result.output_token_ids), max_beam_length),
|
329 |
+
-1, np.int32)
|
330 |
+
for idx, beam in enumerate(result.output_token_ids):
|
331 |
+
output_ids[0, idx, :len(beam)] = beam
|
332 |
+
output_tensors = [
|
333 |
+
pb_utils.Tensor("output_ids", output_ids),
|
334 |
+
pb_utils.Tensor("sequence_length", beam_lengths),
|
335 |
+
]
|
336 |
+
output_tensors.append(
|
337 |
+
pb_utils.Tensor(
|
338 |
+
"cum_log_probs",
|
339 |
+
np.expand_dims(np.array(result.cum_log_probs, np.float32), 0)
|
340 |
+
if result.cum_log_probs is not None else np.zeros(
|
341 |
+
(1, 1), np.float32)))
|
342 |
+
output_tensors.append(
|
343 |
+
pb_utils.Tensor(
|
344 |
+
"output_log_probs",
|
345 |
+
np.expand_dims(np.array(result.log_probs, np.float32), 0) if
|
346 |
+
result.log_probs is not None else np.zeros((1, 1, 1), np.float32)))
|
347 |
+
output_tensors.append(
|
348 |
+
pb_utils.Tensor(
|
349 |
+
"context_logits",
|
350 |
+
np.expand_dims(np.array(result.context_logits, np.float32), 0)
|
351 |
+
if result.context_logits is not None else np.zeros(
|
352 |
+
(1, 1, 1), np.float32)))
|
353 |
+
output_tensors.append(
|
354 |
+
pb_utils.Tensor(
|
355 |
+
"generation_logits",
|
356 |
+
np.expand_dims(np.array(result.generation_logits, np.float32), 0)
|
357 |
+
if result.generation_logits is not None else np.zeros(
|
358 |
+
(1, 1, 1, 1), np.float32)))
|
359 |
+
output_tensors.append(
|
360 |
+
pb_utils.Tensor("batch_index",
|
361 |
+
np.expand_dims(np.array([batch_index], np.int32), 0)))
|
362 |
+
|
363 |
+
return pb_utils.InferenceResponse(output_tensors), result.is_final
|
364 |
+
|
365 |
+
|
366 |
+
def convert_scheduler_policy(batch_scheduler_policy: str):
|
367 |
+
if batch_scheduler_policy.lower() == "max_utilization":
|
368 |
+
return trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
|
369 |
+
elif batch_scheduler_policy.lower() == "guaranteed_no_evict":
|
370 |
+
return trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
|
371 |
+
raise pb_utils.TritonModelException(
|
372 |
+
f"batch_scheduler_policy value of '{batch_scheduler_policy}' is not supported."
|
373 |
+
)
|
374 |
+
|
375 |
+
|
376 |
+
def convert_batching_type(gpt_model_type: str):
|
377 |
+
if gpt_model_type is None:
|
378 |
+
return None
|
379 |
+
if gpt_model_type.lower(
|
380 |
+
) == "inflight_fused_batching" or gpt_model_type.lower(
|
381 |
+
) == "inflight_batching":
|
382 |
+
return trtllm.BatchingType.INFLIGHT
|
383 |
+
elif gpt_model_type.lower() == "v1":
|
384 |
+
return trtllm.BatchingType.STATIC
|
385 |
+
raise pb_utils.TritonModelException(
|
386 |
+
f"gpt_model_type value of '{gpt_model_type}' is not supported.")
|
387 |
+
|
388 |
+
|
389 |
+
def convert_decoding_mode(decoding_mode: str):
|
390 |
+
if decoding_mode is None:
|
391 |
+
return None
|
392 |
+
elif decoding_mode == "auto":
|
393 |
+
return trtllm.DecodingMode.Auto()
|
394 |
+
elif decoding_mode == "top_k":
|
395 |
+
return trtllm.DecodingMode.TopK()
|
396 |
+
elif decoding_mode == "top_p":
|
397 |
+
return trtllm.DecodingMode.TopP()
|
398 |
+
elif decoding_mode == "top_k_top_p":
|
399 |
+
return trtllm.DecodingMode.TopKTopP()
|
400 |
+
elif decoding_mode == "beam_search":
|
401 |
+
return trtllm.DecodingMode.BeamSearch()
|
402 |
+
elif decoding_mode == "medusa":
|
403 |
+
return trtllm.DecodingMode.Medusa()
|
404 |
+
raise pb_utils.TritonModelException(
|
405 |
+
f"decoding_mode value of '{decoding_mode}' is not supported.")
|
406 |
+
|
407 |
+
|
408 |
+
def convert_timestamp_to_seconds(timestamp: str):
|
409 |
+
return int(
|
410 |
+
datetime.datetime.strptime(timestamp,
|
411 |
+
"%m-%d-%Y %H:%M:%S.%f").timestamp())
|
412 |
+
|
413 |
+
|
414 |
+
class TritonPythonModel:
|
415 |
+
"""Your Python model must use the same class name. Every Python model
|
416 |
+
that is created must have "TritonPythonModel" as the class name.
|
417 |
+
"""
|
418 |
+
|
419 |
+
def get_scheduler_config(self, model_config):
|
420 |
+
batch_scheduler_policy = get_parameter(model_config,
|
421 |
+
"batch_scheduler_policy")
|
422 |
+
if batch_scheduler_policy is None:
|
423 |
+
return trtllm.SchedulerConfig()
|
424 |
+
return trtllm.SchedulerConfig(
|
425 |
+
convert_scheduler_policy(batch_scheduler_policy))
|
426 |
+
|
427 |
+
def get_kv_cache_config(self, model_config):
|
428 |
+
kwargs = {
|
429 |
+
"enable_block_reuse":
|
430 |
+
get_parameter(model_config, "enable_kv_cache_reuse", bool),
|
431 |
+
"max_tokens":
|
432 |
+
get_parameter(model_config, "max_tokens_in_paged_kv_cache", int),
|
433 |
+
"sink_token_length":
|
434 |
+
get_parameter(model_config, "sink_token_length", int),
|
435 |
+
"free_gpu_memory_fraction":
|
436 |
+
get_parameter(model_config, "kv_cache_free_gpu_mem_fraction",
|
437 |
+
float),
|
438 |
+
"host_cache_size":
|
439 |
+
get_parameter(model_config, "kv_cache_host_memory_bytes", int),
|
440 |
+
"onboard_blocks":
|
441 |
+
get_parameter(model_config, "kv_cache_onboard_blocks", bool),
|
442 |
+
}
|
443 |
+
max_attention_window_size = get_parameter(model_config,
|
444 |
+
"max_attention_window_size")
|
445 |
+
if max_attention_window_size:
|
446 |
+
kwargs["max_attention_window"] = [
|
447 |
+
int(x) for x in max_attention_window_size.split(",")
|
448 |
+
]
|
449 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
450 |
+
return trtllm.KvCacheConfig(**kwargs)
|
451 |
+
|
452 |
+
def get_parallel_config(self, model_config):
|
453 |
+
kwargs = {}
|
454 |
+
gpu_device_ids = get_parameter(model_config, "gpu_device_ids")
|
455 |
+
if gpu_device_ids:
|
456 |
+
kwargs["device_ids"] = [int(x) for x in gpu_device_ids.split(",")]
|
457 |
+
self.use_orchestrator_mode = os.environ.get("TRTLLM_ORCHESTRATOR",
|
458 |
+
"0") == "1"
|
459 |
+
if self.use_orchestrator_mode:
|
460 |
+
kwargs[
|
461 |
+
"communication_mode"] = trtllm.CommunicationMode.ORCHESTRATOR
|
462 |
+
worker_path = get_parameter(model_config, "worker_path")
|
463 |
+
if worker_path is not None:
|
464 |
+
raise pb_utils.TritonModelException(
|
465 |
+
"worker_path parameter is specified, but this is no longer supported. Please specify executor_worker_path instead to specify the location of the trtllmExecutorWorker executable."
|
466 |
+
)
|
467 |
+
executor_worker_path = get_parameter(model_config,
|
468 |
+
"executor_worker_path")
|
469 |
+
kwargs["orchestrator_config"] = trtllm.OrchestratorConfig(
|
470 |
+
True, executor_worker_path)
|
471 |
+
if len(kwargs) > 0:
|
472 |
+
return trtllm.ParallelConfig(**kwargs)
|
473 |
+
return None
|
474 |
+
|
475 |
+
def get_peft_cache_config(self, model_config):
|
476 |
+
kwargs = {
|
477 |
+
"optimal_adapter_size":
|
478 |
+
get_parameter(model_config, "lora_cache_optimal_adapter_size",
|
479 |
+
int),
|
480 |
+
"max_adapter_size":
|
481 |
+
get_parameter(model_config, "lora_cache_max_adapter_size", int),
|
482 |
+
"device_cache_percent":
|
483 |
+
get_parameter(model_config, "lora_cache_gpu_memory_fraction",
|
484 |
+
float),
|
485 |
+
"host_cache_size":
|
486 |
+
get_parameter(model_config, "lora_cache_host_memory_bytes", int),
|
487 |
+
}
|
488 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
489 |
+
return trtllm.PeftCacheConfig(**kwargs)
|
490 |
+
|
491 |
+
def get_decoding_config(self, model_config):
|
492 |
+
kwargs = {
|
493 |
+
"medusa_choices":
|
494 |
+
parse_medusa_choices(get_parameter(model_config,
|
495 |
+
"medusa_choices")),
|
496 |
+
"decoding_mode":
|
497 |
+
convert_decoding_mode(get_parameter(model_config,
|
498 |
+
"decoding_mode")),
|
499 |
+
}
|
500 |
+
print(kwargs)
|
501 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
502 |
+
return trtllm.DecodingConfig(**kwargs)
|
503 |
+
|
504 |
+
def get_extended_runtime_perf_knob_config(self, model_config):
|
505 |
+
kwargs = {
|
506 |
+
"multi_block_mode":
|
507 |
+
get_parameter(model_config, "multi_block_mode", bool),
|
508 |
+
"enable_context_fmha_fp32_acc":
|
509 |
+
get_parameter(model_config, "enable_context_fmha_fp32_acc", bool)
|
510 |
+
}
|
511 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
512 |
+
return trtllm.ExtendedRuntimePerfKnobConfig(**kwargs)
|
513 |
+
|
514 |
+
def get_executor_config(self, model_config):
|
515 |
+
kwargs = {
|
516 |
+
"max_beam_width":
|
517 |
+
get_parameter(model_config, "max_beam_width", int),
|
518 |
+
"scheduler_config":
|
519 |
+
self.get_scheduler_config(model_config),
|
520 |
+
"kv_cache_config":
|
521 |
+
self.get_kv_cache_config(model_config),
|
522 |
+
"enable_chunked_context":
|
523 |
+
get_parameter(model_config, "enable_chunked_context", bool),
|
524 |
+
"normalize_log_probs":
|
525 |
+
get_parameter(model_config, "normalize_log_probs", bool),
|
526 |
+
"batching_type":
|
527 |
+
convert_batching_type(get_parameter(model_config,
|
528 |
+
"gpt_model_type")),
|
529 |
+
"parallel_config":
|
530 |
+
self.get_parallel_config(model_config),
|
531 |
+
"peft_cache_config":
|
532 |
+
self.get_peft_cache_config(model_config),
|
533 |
+
"decoding_config":
|
534 |
+
self.get_decoding_config(model_config),
|
535 |
+
"max_queue_size":
|
536 |
+
model_config.get(
|
537 |
+
"dynamic_batching",
|
538 |
+
{},
|
539 |
+
).get(
|
540 |
+
"default_queue_policy",
|
541 |
+
{},
|
542 |
+
).get("max_queue_size"),
|
543 |
+
"extended_runtime_perf_knob_config":
|
544 |
+
self.get_extended_runtime_perf_knob_config(model_config)
|
545 |
+
}
|
546 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
547 |
+
return trtllm.ExecutorConfig(**kwargs)
|
548 |
+
|
549 |
+
def create_metrics(self, model: str, version: str, is_v1_model: bool):
|
550 |
+
self.request_metric_family = pb_utils.MetricFamily(
|
551 |
+
name="nv_trt_llm_request_metrics",
|
552 |
+
description="TRT LLM request metrics",
|
553 |
+
kind=pb_utils.MetricFamily.GAUGE,
|
554 |
+
)
|
555 |
+
self.runtime_memory_metric_family = pb_utils.MetricFamily(
|
556 |
+
name="nv_trt_llm_runtime_memory_metrics",
|
557 |
+
description="TRT LLM runtime memory metrics",
|
558 |
+
kind=pb_utils.MetricFamily.GAUGE,
|
559 |
+
)
|
560 |
+
self.kv_cache_metric_family = pb_utils.MetricFamily(
|
561 |
+
name="nv_trt_llm_kv_cache_block_metrics",
|
562 |
+
description="TRT LLM KV cache block metrics",
|
563 |
+
kind=pb_utils.MetricFamily.GAUGE,
|
564 |
+
)
|
565 |
+
model_type = "v1" if is_v1_model else "inflight_batcher"
|
566 |
+
self.model_type_metric_family = pb_utils.MetricFamily(
|
567 |
+
name=f"nv_trt_llm_{model_type}_metrics",
|
568 |
+
description=f"TRT LLM {model_type}-specific metrics",
|
569 |
+
kind=pb_utils.MetricFamily.GAUGE,
|
570 |
+
)
|
571 |
+
self.general_metric_family = pb_utils.MetricFamily(
|
572 |
+
name="nv_trt_llm_general_metrics",
|
573 |
+
description="General TRT LLM metrics",
|
574 |
+
kind=pb_utils.MetricFamily.GAUGE,
|
575 |
+
)
|
576 |
+
common_labels = {"model": model, "version": version}
|
577 |
+
self.all_metrics = {
|
578 |
+
# Request metrics
|
579 |
+
"num_active_requests":
|
580 |
+
self.request_metric_family.Metric(labels={
|
581 |
+
"request_type": "active",
|
582 |
+
**common_labels
|
583 |
+
}),
|
584 |
+
"max_num_active_requests":
|
585 |
+
self.request_metric_family.Metric(labels={
|
586 |
+
"request_type": "max",
|
587 |
+
**common_labels
|
588 |
+
}),
|
589 |
+
"num_scheduled_requests":
|
590 |
+
self.request_metric_family.Metric(labels={
|
591 |
+
"request_type": "scheduled",
|
592 |
+
**common_labels
|
593 |
+
}),
|
594 |
+
"num_context_requests":
|
595 |
+
self.request_metric_family.Metric(labels={
|
596 |
+
"request_type": "context",
|
597 |
+
**common_labels
|
598 |
+
}),
|
599 |
+
# Runtime metrics
|
600 |
+
"cpu_mem_usage":
|
601 |
+
self.runtime_memory_metric_family.Metric(labels={
|
602 |
+
"memory_type": "cpu",
|
603 |
+
**common_labels
|
604 |
+
}),
|
605 |
+
"gpu_mem_usage":
|
606 |
+
self.runtime_memory_metric_family.Metric(labels={
|
607 |
+
"memory_type": "gpu",
|
608 |
+
**common_labels
|
609 |
+
}),
|
610 |
+
"pinned_mem_usage":
|
611 |
+
self.runtime_memory_metric_family.Metric(labels={
|
612 |
+
"memory_type": "pinned",
|
613 |
+
**common_labels
|
614 |
+
}),
|
615 |
+
# KV cache metrics
|
616 |
+
"max_num_blocks":
|
617 |
+
self.kv_cache_metric_family.Metric(labels={
|
618 |
+
"kv_cache_block_type": "max",
|
619 |
+
**common_labels
|
620 |
+
}),
|
621 |
+
"free_num_blocks":
|
622 |
+
self.kv_cache_metric_family.Metric(labels={
|
623 |
+
"kv_cache_block_type": "free",
|
624 |
+
**common_labels
|
625 |
+
}),
|
626 |
+
"used_num_blocks":
|
627 |
+
self.kv_cache_metric_family.Metric(labels={
|
628 |
+
"kv_cache_block_type": "used",
|
629 |
+
**common_labels
|
630 |
+
}),
|
631 |
+
"tokens_per_block":
|
632 |
+
self.kv_cache_metric_family.Metric(labels={
|
633 |
+
"kv_cache_block_type": "tokens_per",
|
634 |
+
**common_labels
|
635 |
+
}),
|
636 |
+
# General metrics
|
637 |
+
"timestamp":
|
638 |
+
self.general_metric_family.Metric(labels={
|
639 |
+
"general_type": "timestamp",
|
640 |
+
**common_labels
|
641 |
+
}),
|
642 |
+
"iter":
|
643 |
+
self.general_metric_family.Metric(labels={
|
644 |
+
"general_type": "iteration_counter",
|
645 |
+
**common_labels
|
646 |
+
}),
|
647 |
+
}
|
648 |
+
if is_v1_model:
|
649 |
+
self.all_metrics.update({
|
650 |
+
"num_ctx_tokens":
|
651 |
+
self.model_type_metric_family.Metric(labels={
|
652 |
+
"v1_specific_metric": "total_context_tokens",
|
653 |
+
**common_labels
|
654 |
+
}),
|
655 |
+
"num_gen_tokens":
|
656 |
+
self.model_type_metric_family.Metric(
|
657 |
+
labels={
|
658 |
+
"v1_specific_metric": "total_generation_tokens",
|
659 |
+
**common_labels
|
660 |
+
}),
|
661 |
+
"empty_gen_slots":
|
662 |
+
self.model_type_metric_family.Metric(
|
663 |
+
labels={
|
664 |
+
"v1_specific_metric": "empty_generation_slots",
|
665 |
+
**common_labels
|
666 |
+
}),
|
667 |
+
})
|
668 |
+
else:
|
669 |
+
self.all_metrics.update({
|
670 |
+
"num_ctx_tokens":
|
671 |
+
self.model_type_metric_family.Metric(
|
672 |
+
labels={
|
673 |
+
"inflight_batcher_specific_metric":
|
674 |
+
"total_context_tokens",
|
675 |
+
**common_labels
|
676 |
+
}),
|
677 |
+
"num_gen_requests":
|
678 |
+
self.model_type_metric_family.Metric(
|
679 |
+
labels={
|
680 |
+
"inflight_batcher_specific_metric":
|
681 |
+
"generation_requests",
|
682 |
+
**common_labels
|
683 |
+
}),
|
684 |
+
"micro_batch_id":
|
685 |
+
self.model_type_metric_family.Metric(
|
686 |
+
labels={
|
687 |
+
"inflight_batcher_specific_metric": "micro_batch_id",
|
688 |
+
**common_labels
|
689 |
+
}),
|
690 |
+
"num_paused_requests":
|
691 |
+
self.model_type_metric_family.Metric(
|
692 |
+
labels={
|
693 |
+
"inflight_batcher_specific_metric": "paused_requests",
|
694 |
+
**common_labels
|
695 |
+
}),
|
696 |
+
})
|
697 |
+
|
698 |
+
def initialize(self, args):
|
699 |
+
"""`initialize` is called only once when the model is being loaded.
|
700 |
+
Implementing `initialize` function is optional. This function allows
|
701 |
+
the model to initialize any state associated with this model.
|
702 |
+
|
703 |
+
Parameters
|
704 |
+
----------
|
705 |
+
args : dict
|
706 |
+
Both keys and values are strings. The dictionary keys and values are:
|
707 |
+
* model_config: A JSON string containing the model configuration
|
708 |
+
* model_instance_kind: A string containing model instance kind
|
709 |
+
* model_instance_device_id: A string containing model instance device ID
|
710 |
+
* model_repository: Model repository path
|
711 |
+
* model_version: Model version
|
712 |
+
* model_name: Model name
|
713 |
+
"""
|
714 |
+
model_config = json.loads(args['model_config'])
|
715 |
+
gpt_model_path = get_parameter(model_config, "gpt_model_path")
|
716 |
+
if get_parameter(model_config, "enable_trt_overlap", bool):
|
717 |
+
raise pb_utils.TritonModelException(
|
718 |
+
f"enable_trt_overlap=true is not supported.")
|
719 |
+
self.exclude_input_from_output = get_parameter(
|
720 |
+
model_config, "exclude_input_in_output", bool)
|
721 |
+
executor_config = self.get_executor_config(model_config)
|
722 |
+
self.executor = trtllm.Executor(gpt_model_path,
|
723 |
+
trtllm.ModelType.DECODER_ONLY,
|
724 |
+
executor_config)
|
725 |
+
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(
|
726 |
+
model_config)
|
727 |
+
self.cancellation_check_period_ms = get_parameter(
|
728 |
+
model_config, "cancellation_check_period_ms", int) or 100
|
729 |
+
self.stats_check_period_ms = get_parameter(
|
730 |
+
model_config, "stats_check_period_ms", int) or 100
|
731 |
+
|
732 |
+
if not self.decoupled:
|
733 |
+
raise pb_utils.TritonModelException(
|
734 |
+
"Please enable decoupled transaction policy in the model configuration to serve this model"
|
735 |
+
)
|
736 |
+
|
737 |
+
self.create_metrics(args["model_name"],
|
738 |
+
args["model_version"],
|
739 |
+
is_v1_model=executor_config.batching_type ==
|
740 |
+
trtllm.BatchingType.STATIC)
|
741 |
+
self.triton_user_id_to_req_ids = {}
|
742 |
+
self.triton_req_id_to_req_ids = {}
|
743 |
+
self.req_id_to_request_data = {}
|
744 |
+
self.lock = Lock()
|
745 |
+
self.running = False
|
746 |
+
self.awaiter_thread = Thread(target=self.awaiter_loop)
|
747 |
+
self.cancellation_thread = Thread(target=self.cancellation_loop)
|
748 |
+
self.metrics_thread = Thread(target=self.metrics_loop)
|
749 |
+
if self.executor.can_enqueue_requests():
|
750 |
+
self.running = True
|
751 |
+
self.awaiter_thread.start()
|
752 |
+
self.cancellation_thread.start()
|
753 |
+
self.metrics_thread.start()
|
754 |
+
else:
|
755 |
+
# In leader mode, worker ranks will wait here until leader is done.
|
756 |
+
self.executor.shutdown()
|
757 |
+
|
758 |
+
def handle_stop_request(self, triton_user_id, response_sender):
|
759 |
+
if triton_user_id is None or triton_user_id == "":
|
760 |
+
response_sender.send(
|
761 |
+
pb_utils.InferenceResponse(error=pb_utils.TritonError(
|
762 |
+
"A request id must be provided for request cancellation")),
|
763 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
764 |
+
return
|
765 |
+
|
766 |
+
with self.lock:
|
767 |
+
if triton_user_id in self.triton_user_id_to_req_ids:
|
768 |
+
req_ids = self.triton_user_id_to_req_ids[triton_user_id]
|
769 |
+
for req_id in req_ids:
|
770 |
+
self.executor.cancel_request(req_id)
|
771 |
+
|
772 |
+
response_sender.send(
|
773 |
+
pb_utils.InferenceResponse(),
|
774 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
775 |
+
|
776 |
+
def execute(self, requests):
|
777 |
+
"""`execute` must be implemented in every Python model. `execute`
|
778 |
+
function receives a list of pb_utils.InferenceRequest as the only
|
779 |
+
argument. This function is called when an inference is requested
|
780 |
+
for this model.
|
781 |
+
|
782 |
+
Parameters
|
783 |
+
----------
|
784 |
+
requests : list
|
785 |
+
A list of pb_utils.InferenceRequest
|
786 |
+
|
787 |
+
Returns
|
788 |
+
-------
|
789 |
+
list
|
790 |
+
A list of pb_utils.InferenceResponse. The length of this list must
|
791 |
+
be the same as `requests`
|
792 |
+
"""
|
793 |
+
if not self.executor.can_enqueue_requests():
|
794 |
+
return
|
795 |
+
|
796 |
+
# Convert to executor requests.
|
797 |
+
|
798 |
+
triton_requests = []
|
799 |
+
executor_requests = []
|
800 |
+
batch_indices = []
|
801 |
+
triton_user_ids = []
|
802 |
+
triton_req_ids = []
|
803 |
+
|
804 |
+
for request in requests:
|
805 |
+
|
806 |
+
triton_user_id = request.request_id()
|
807 |
+
|
808 |
+
response_sender = request.get_response_sender()
|
809 |
+
stop = get_input_scalar_by_name(request, 'stop')
|
810 |
+
|
811 |
+
if stop:
|
812 |
+
self.handle_stop_request(triton_user_id, response_sender)
|
813 |
+
else:
|
814 |
+
#Unique request id used to identify each triton request
|
815 |
+
triton_req_id = str(randint(0, sys.maxsize))
|
816 |
+
self.triton_req_id_to_req_ids[triton_req_id] = set()
|
817 |
+
if triton_user_id is not None and triton_user_id != "":
|
818 |
+
self.triton_user_id_to_req_ids[triton_user_id] = set()
|
819 |
+
|
820 |
+
try:
|
821 |
+
converted_reqs = convert_request(
|
822 |
+
request, self.exclude_input_from_output,
|
823 |
+
self.decoupled)
|
824 |
+
except Exception as e:
|
825 |
+
response_sender.send(
|
826 |
+
pb_utils.InferenceResponse(error=pb_utils.TritonError(
|
827 |
+
f"An error occurred when processing the input values for request id {request.request_id()}, the error was '{e}'"
|
828 |
+
)),
|
829 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
830 |
+
else:
|
831 |
+
for batch_index, converted_req in enumerate(
|
832 |
+
converted_reqs):
|
833 |
+
triton_requests.append(request)
|
834 |
+
executor_requests.append(converted_req)
|
835 |
+
triton_user_ids.append(triton_user_id)
|
836 |
+
triton_req_ids.append(triton_req_id)
|
837 |
+
batch_indices.append(batch_index)
|
838 |
+
|
839 |
+
with self.lock:
|
840 |
+
request_ids = self.executor.enqueue_requests(executor_requests)
|
841 |
+
for req_id, triton_req_id, triton_user_id, triton_request, batch_index in zip(
|
842 |
+
request_ids, triton_req_ids, triton_user_ids,
|
843 |
+
triton_requests, batch_indices):
|
844 |
+
self.req_id_to_request_data[
|
845 |
+
req_id] = triton_req_id, triton_user_id, batch_index, triton_request.get_response_sender(
|
846 |
+
)
|
847 |
+
self.triton_req_id_to_req_ids[triton_req_id].add(req_id)
|
848 |
+
if triton_user_id is not None and triton_user_id != "":
|
849 |
+
self.triton_user_id_to_req_ids[triton_user_id].add(req_id)
|
850 |
+
|
851 |
+
return None
|
852 |
+
|
853 |
+
def awaiter_loop(self):
|
854 |
+
"""Gets responses from executor and returns the results."""
|
855 |
+
while self.running:
|
856 |
+
for response in self.executor.await_responses(
|
857 |
+
timeout=datetime.timedelta(milliseconds=1)):
|
858 |
+
req_id = response.request_id
|
859 |
+
with self.lock:
|
860 |
+
if req_id not in self.req_id_to_request_data:
|
861 |
+
continue
|
862 |
+
triton_req_id, triton_user_id, batch_index, response_sender = self.req_id_to_request_data[
|
863 |
+
req_id]
|
864 |
+
|
865 |
+
triton_response, is_final = convert_response(
|
866 |
+
response, batch_index)
|
867 |
+
|
868 |
+
triton_request_final = False
|
869 |
+
if is_final:
|
870 |
+
with self.lock:
|
871 |
+
# Check if all executor requests part of that triton request are finished
|
872 |
+
self.triton_req_id_to_req_ids[triton_req_id].remove(
|
873 |
+
req_id)
|
874 |
+
if len(self.triton_req_id_to_req_ids[triton_req_id]
|
875 |
+
) == 0:
|
876 |
+
pb_utils.Logger.log_info(
|
877 |
+
f"DELETING Req id {req_id}, triton_req_id {triton_req_id} "
|
878 |
+
)
|
879 |
+
triton_request_final = True
|
880 |
+
del self.triton_req_id_to_req_ids[triton_req_id]
|
881 |
+
if triton_user_id is not None and triton_user_id != "":
|
882 |
+
del self.triton_user_id_to_req_ids[
|
883 |
+
triton_user_id]
|
884 |
+
del self.req_id_to_request_data[req_id]
|
885 |
+
|
886 |
+
response_sender.send(
|
887 |
+
triton_response,
|
888 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
|
889 |
+
if triton_request_final else 0)
|
890 |
+
|
891 |
+
# Remove local reference so response_sender can be cleaned properly.
|
892 |
+
del response_sender
|
893 |
+
|
894 |
+
def cancellation_loop(self):
|
895 |
+
"""Checks if any pending requests have been cancelled."""
|
896 |
+
while self.running:
|
897 |
+
time.sleep(self.cancellation_check_period_ms / 1000.0)
|
898 |
+
with self.lock:
|
899 |
+
for req_id, (triton_req_id, triton_user_id, batch_index,
|
900 |
+
response_sender
|
901 |
+
) in self.req_id_to_request_data.items():
|
902 |
+
if response_sender.is_cancelled():
|
903 |
+
self.executor.cancel_request(req_id)
|
904 |
+
# Remove local reference so response_sender can be cleaned properly.
|
905 |
+
del response_sender
|
906 |
+
|
907 |
+
def metrics_loop(self):
|
908 |
+
"""Updates triton metrics using stats from the executor."""
|
909 |
+
while self.running:
|
910 |
+
time.sleep(self.stats_check_period_ms / 1000.0)
|
911 |
+
for stat in self.executor.get_latest_iteration_stats():
|
912 |
+
try:
|
913 |
+
for key, metric in self.all_metrics.items():
|
914 |
+
value = None
|
915 |
+
if hasattr(stat, key):
|
916 |
+
value = getattr(stat, key)
|
917 |
+
elif stat.kv_cache_stats is not None and hasattr(
|
918 |
+
stat.kv_cache_stats, key):
|
919 |
+
value = getattr(stat.kv_cache_stats, key)
|
920 |
+
elif stat.static_batching_stats is not None and hasattr(
|
921 |
+
stat.static_batching_stats, key):
|
922 |
+
value = getattr(stat.static_batching_stats, key)
|
923 |
+
elif stat.inflight_batching_stats is not None and hasattr(
|
924 |
+
stat.inflight_batching_stats, key):
|
925 |
+
value = getattr(stat.inflight_batching_stats, key)
|
926 |
+
if value is not None:
|
927 |
+
if key == "timestamp":
|
928 |
+
value = convert_timestamp_to_seconds(value)
|
929 |
+
metric.set(value)
|
930 |
+
else:
|
931 |
+
pb_utils.Logger.log_warn(
|
932 |
+
f"Metric \"{key}\" not found.")
|
933 |
+
except Exception as e:
|
934 |
+
pb_utils.Logger.log_warn(
|
935 |
+
f"Error while processing metrics: {e}")
|
936 |
+
|
937 |
+
def finalize(self):
|
938 |
+
"""`finalize` is called only once when the model is being unloaded.
|
939 |
+
Implementing `finalize` function is optional. This function allows
|
940 |
+
the model to perform any necessary clean ups before exit.
|
941 |
+
"""
|
942 |
+
if self.executor.can_enqueue_requests():
|
943 |
+
self.running = False
|
944 |
+
self.awaiter_thread.join()
|
945 |
+
self.cancellation_thread.join()
|
946 |
+
self.metrics_thread.join()
|
947 |
+
self.executor.shutdown()
|
model_repo_whisper_qwen_trtllm/tensorrt_llm/config.pbtxt
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
name: "tensorrt_llm"
|
28 |
+
backend: "tensorrtllm"
|
29 |
+
max_batch_size: 8
|
30 |
+
|
31 |
+
model_transaction_policy {
|
32 |
+
decoupled: false
|
33 |
+
}
|
34 |
+
|
35 |
+
dynamic_batching {
|
36 |
+
preferred_batch_size: [ 8 ]
|
37 |
+
max_queue_delay_microseconds: 0
|
38 |
+
default_queue_policy: { max_queue_size: 0 }
|
39 |
+
}
|
40 |
+
|
41 |
+
input [
|
42 |
+
{
|
43 |
+
name: "input_ids"
|
44 |
+
data_type: TYPE_INT32
|
45 |
+
dims: [ -1 ]
|
46 |
+
allow_ragged_batch: true
|
47 |
+
optional: true
|
48 |
+
},
|
49 |
+
{
|
50 |
+
name: "encoder_input_features"
|
51 |
+
data_type: TYPE_FP16
|
52 |
+
dims: [ -1, -1 ]
|
53 |
+
allow_ragged_batch: true
|
54 |
+
optional: true
|
55 |
+
},
|
56 |
+
{
|
57 |
+
name: "encoder_output_lengths"
|
58 |
+
data_type: TYPE_INT32
|
59 |
+
dims: [ 1 ]
|
60 |
+
reshape: { shape: [ ] }
|
61 |
+
optional: true
|
62 |
+
},
|
63 |
+
{
|
64 |
+
name: "input_lengths"
|
65 |
+
data_type: TYPE_INT32
|
66 |
+
dims: [ 1 ]
|
67 |
+
reshape: { shape: [ ] }
|
68 |
+
},
|
69 |
+
{
|
70 |
+
name: "request_output_len"
|
71 |
+
data_type: TYPE_INT32
|
72 |
+
dims: [ 1 ]
|
73 |
+
reshape: { shape: [ ] }
|
74 |
+
},
|
75 |
+
{
|
76 |
+
name: "draft_input_ids"
|
77 |
+
data_type: TYPE_INT32
|
78 |
+
dims: [ -1 ]
|
79 |
+
optional: true
|
80 |
+
allow_ragged_batch: true
|
81 |
+
},
|
82 |
+
{
|
83 |
+
name: "decoder_input_ids"
|
84 |
+
data_type: TYPE_INT32
|
85 |
+
dims: [ -1 ]
|
86 |
+
optional: true
|
87 |
+
allow_ragged_batch: true
|
88 |
+
},
|
89 |
+
{
|
90 |
+
name: "decoder_input_lengths"
|
91 |
+
data_type: TYPE_INT32
|
92 |
+
dims: [ 1 ]
|
93 |
+
optional: true
|
94 |
+
reshape: { shape: [ ] }
|
95 |
+
},
|
96 |
+
{
|
97 |
+
name: "draft_logits"
|
98 |
+
data_type: TYPE_FP32
|
99 |
+
dims: [ -1, -1 ]
|
100 |
+
optional: true
|
101 |
+
allow_ragged_batch: true
|
102 |
+
},
|
103 |
+
{
|
104 |
+
name: "draft_acceptance_threshold"
|
105 |
+
data_type: TYPE_FP32
|
106 |
+
dims: [ 1 ]
|
107 |
+
reshape: { shape: [ ] }
|
108 |
+
optional: true
|
109 |
+
},
|
110 |
+
{
|
111 |
+
name: "end_id"
|
112 |
+
data_type: TYPE_INT32
|
113 |
+
dims: [ 1 ]
|
114 |
+
reshape: { shape: [ ] }
|
115 |
+
optional: true
|
116 |
+
},
|
117 |
+
{
|
118 |
+
name: "pad_id"
|
119 |
+
data_type: TYPE_INT32
|
120 |
+
dims: [ 1 ]
|
121 |
+
reshape: { shape: [ ] }
|
122 |
+
optional: true
|
123 |
+
},
|
124 |
+
{
|
125 |
+
name: "stop_words_list"
|
126 |
+
data_type: TYPE_INT32
|
127 |
+
dims: [ 2, -1 ]
|
128 |
+
optional: true
|
129 |
+
allow_ragged_batch: true
|
130 |
+
},
|
131 |
+
{
|
132 |
+
name: "bad_words_list"
|
133 |
+
data_type: TYPE_INT32
|
134 |
+
dims: [ 2, -1 ]
|
135 |
+
optional: true
|
136 |
+
allow_ragged_batch: true
|
137 |
+
},
|
138 |
+
{
|
139 |
+
name: "embedding_bias"
|
140 |
+
data_type: TYPE_FP32
|
141 |
+
dims: [ -1 ]
|
142 |
+
optional: true
|
143 |
+
allow_ragged_batch: true
|
144 |
+
},
|
145 |
+
{
|
146 |
+
name: "beam_width"
|
147 |
+
data_type: TYPE_INT32
|
148 |
+
dims: [ 1 ]
|
149 |
+
reshape: { shape: [ ] }
|
150 |
+
optional: true
|
151 |
+
},
|
152 |
+
{
|
153 |
+
name: "temperature"
|
154 |
+
data_type: TYPE_FP32
|
155 |
+
dims: [ 1 ]
|
156 |
+
reshape: { shape: [ ] }
|
157 |
+
optional: true
|
158 |
+
},
|
159 |
+
{
|
160 |
+
name: "runtime_top_k"
|
161 |
+
data_type: TYPE_INT32
|
162 |
+
dims: [ 1 ]
|
163 |
+
reshape: { shape: [ ] }
|
164 |
+
optional: true
|
165 |
+
},
|
166 |
+
{
|
167 |
+
name: "runtime_top_p"
|
168 |
+
data_type: TYPE_FP32
|
169 |
+
dims: [ 1 ]
|
170 |
+
reshape: { shape: [ ] }
|
171 |
+
optional: true
|
172 |
+
},
|
173 |
+
{
|
174 |
+
name: "runtime_top_p_min"
|
175 |
+
data_type: TYPE_FP32
|
176 |
+
dims: [ 1 ]
|
177 |
+
reshape: { shape: [ ] }
|
178 |
+
optional: true
|
179 |
+
},
|
180 |
+
{
|
181 |
+
name: "runtime_top_p_decay"
|
182 |
+
data_type: TYPE_FP32
|
183 |
+
dims: [ 1 ]
|
184 |
+
reshape: { shape: [ ] }
|
185 |
+
optional: true
|
186 |
+
},
|
187 |
+
{
|
188 |
+
name: "runtime_top_p_reset_ids"
|
189 |
+
data_type: TYPE_INT32
|
190 |
+
dims: [ 1 ]
|
191 |
+
reshape: { shape: [ ] }
|
192 |
+
optional: true
|
193 |
+
},
|
194 |
+
{
|
195 |
+
name: "len_penalty"
|
196 |
+
data_type: TYPE_FP32
|
197 |
+
dims: [ 1 ]
|
198 |
+
reshape: { shape: [ ] }
|
199 |
+
optional: true
|
200 |
+
},
|
201 |
+
{
|
202 |
+
name: "early_stopping"
|
203 |
+
data_type: TYPE_BOOL
|
204 |
+
dims: [ 1 ]
|
205 |
+
reshape: { shape: [ ] }
|
206 |
+
optional: true
|
207 |
+
},
|
208 |
+
{
|
209 |
+
name: "repetition_penalty"
|
210 |
+
data_type: TYPE_FP32
|
211 |
+
dims: [ 1 ]
|
212 |
+
reshape: { shape: [ ] }
|
213 |
+
optional: true
|
214 |
+
},
|
215 |
+
{
|
216 |
+
name: "min_length"
|
217 |
+
data_type: TYPE_INT32
|
218 |
+
dims: [ 1 ]
|
219 |
+
reshape: { shape: [ ] }
|
220 |
+
optional: true
|
221 |
+
},
|
222 |
+
{
|
223 |
+
name: "beam_search_diversity_rate"
|
224 |
+
data_type: TYPE_FP32
|
225 |
+
dims: [ 1 ]
|
226 |
+
reshape: { shape: [ ] }
|
227 |
+
optional: true
|
228 |
+
},
|
229 |
+
{
|
230 |
+
name: "presence_penalty"
|
231 |
+
data_type: TYPE_FP32
|
232 |
+
dims: [ 1 ]
|
233 |
+
reshape: { shape: [ ] }
|
234 |
+
optional: true
|
235 |
+
},
|
236 |
+
{
|
237 |
+
name: "frequency_penalty"
|
238 |
+
data_type: TYPE_FP32
|
239 |
+
dims: [ 1 ]
|
240 |
+
reshape: { shape: [ ] }
|
241 |
+
optional: true
|
242 |
+
},
|
243 |
+
{
|
244 |
+
name: "random_seed"
|
245 |
+
data_type: TYPE_UINT64
|
246 |
+
dims: [ 1 ]
|
247 |
+
reshape: { shape: [ ] }
|
248 |
+
optional: true
|
249 |
+
},
|
250 |
+
{
|
251 |
+
name: "return_log_probs"
|
252 |
+
data_type: TYPE_BOOL
|
253 |
+
dims: [ 1 ]
|
254 |
+
reshape: { shape: [ ] }
|
255 |
+
optional: true
|
256 |
+
},
|
257 |
+
{
|
258 |
+
name: "return_context_logits"
|
259 |
+
data_type: TYPE_BOOL
|
260 |
+
dims: [ 1 ]
|
261 |
+
reshape: { shape: [ ] }
|
262 |
+
optional: true
|
263 |
+
},
|
264 |
+
{
|
265 |
+
name: "return_generation_logits"
|
266 |
+
data_type: TYPE_BOOL
|
267 |
+
dims: [ 1 ]
|
268 |
+
reshape: { shape: [ ] }
|
269 |
+
optional: true
|
270 |
+
},
|
271 |
+
{
|
272 |
+
name: "stop"
|
273 |
+
data_type: TYPE_BOOL
|
274 |
+
dims: [ 1 ]
|
275 |
+
reshape: { shape: [ ] }
|
276 |
+
optional: true
|
277 |
+
},
|
278 |
+
{
|
279 |
+
name: "streaming"
|
280 |
+
data_type: TYPE_BOOL
|
281 |
+
dims: [ 1 ]
|
282 |
+
reshape: { shape: [ ] }
|
283 |
+
optional: true
|
284 |
+
},
|
285 |
+
{
|
286 |
+
name: "prompt_embedding_table"
|
287 |
+
data_type: TYPE_FP16
|
288 |
+
dims: [ -1, -1 ]
|
289 |
+
optional: true
|
290 |
+
allow_ragged_batch: true
|
291 |
+
},
|
292 |
+
{
|
293 |
+
name: "prompt_vocab_size"
|
294 |
+
data_type: TYPE_INT32
|
295 |
+
dims: [ 1 ]
|
296 |
+
reshape: { shape: [ ] }
|
297 |
+
optional: true
|
298 |
+
},
|
299 |
+
# the unique task ID for the given LoRA.
|
300 |
+
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
|
301 |
+
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
|
302 |
+
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
|
303 |
+
{
|
304 |
+
name: "lora_task_id"
|
305 |
+
data_type: TYPE_UINT64
|
306 |
+
dims: [ 1 ]
|
307 |
+
reshape: { shape: [ ] }
|
308 |
+
optional: true
|
309 |
+
},
|
310 |
+
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
|
311 |
+
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
|
312 |
+
# each of the in / out tensors are first flattened and then concatenated together in the format above.
|
313 |
+
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
|
314 |
+
{
|
315 |
+
name: "lora_weights"
|
316 |
+
data_type: TYPE_FP16
|
317 |
+
dims: [ -1, -1 ]
|
318 |
+
optional: true
|
319 |
+
allow_ragged_batch: true
|
320 |
+
},
|
321 |
+
# module identifier (same size a first dimension of lora_weights)
|
322 |
+
# See LoraModule::ModuleType for model id mapping
|
323 |
+
#
|
324 |
+
# "attn_qkv": 0 # compbined qkv adapter
|
325 |
+
# "attn_q": 1 # q adapter
|
326 |
+
# "attn_k": 2 # k adapter
|
327 |
+
# "attn_v": 3 # v adapter
|
328 |
+
# "attn_dense": 4 # adapter for the dense layer in attention
|
329 |
+
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
|
330 |
+
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
|
331 |
+
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
|
332 |
+
#
|
333 |
+
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
|
334 |
+
{
|
335 |
+
name: "lora_config"
|
336 |
+
data_type: TYPE_INT32
|
337 |
+
dims: [ -1, 3 ]
|
338 |
+
optional: true
|
339 |
+
allow_ragged_batch: true
|
340 |
+
}
|
341 |
+
]
|
342 |
+
output [
|
343 |
+
{
|
344 |
+
name: "output_ids"
|
345 |
+
data_type: TYPE_INT32
|
346 |
+
dims: [ -1, -1 ]
|
347 |
+
},
|
348 |
+
{
|
349 |
+
name: "sequence_length"
|
350 |
+
data_type: TYPE_INT32
|
351 |
+
dims: [ -1 ]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
name: "cum_log_probs"
|
355 |
+
data_type: TYPE_FP32
|
356 |
+
dims: [ -1 ]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
name: "output_log_probs"
|
360 |
+
data_type: TYPE_FP32
|
361 |
+
dims: [ -1, -1 ]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
name: "context_logits"
|
365 |
+
data_type: TYPE_FP32
|
366 |
+
dims: [ -1, -1 ]
|
367 |
+
},
|
368 |
+
{
|
369 |
+
name: "generation_logits"
|
370 |
+
data_type: TYPE_FP32
|
371 |
+
dims: [ -1, -1, -1 ]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
name: "batch_index"
|
375 |
+
data_type: TYPE_INT32
|
376 |
+
dims: [ 1 ]
|
377 |
+
}
|
378 |
+
]
|
379 |
+
instance_group [
|
380 |
+
{
|
381 |
+
count: 1
|
382 |
+
kind : KIND_CPU
|
383 |
+
}
|
384 |
+
]
|
385 |
+
parameters: {
|
386 |
+
key: "max_beam_width"
|
387 |
+
value: {
|
388 |
+
string_value: "1"
|
389 |
+
}
|
390 |
+
}
|
391 |
+
parameters: {
|
392 |
+
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
|
393 |
+
value: {
|
394 |
+
string_value: "no"
|
395 |
+
}
|
396 |
+
}
|
397 |
+
parameters: {
|
398 |
+
key: "gpt_model_type"
|
399 |
+
value: {
|
400 |
+
string_value: "inflight_fused_batching"
|
401 |
+
}
|
402 |
+
}
|
403 |
+
parameters: {
|
404 |
+
key: "gpt_model_path"
|
405 |
+
value: {
|
406 |
+
string_value: "/home/scratch.yuekaiz_wwfo_1/tekit/examples/qwen/qwen2_1.5B_instruct_fp16_merged_max_prompt_embedding_table_size_256"
|
407 |
+
}
|
408 |
+
}
|
409 |
+
parameters: {
|
410 |
+
key: "encoder_model_path"
|
411 |
+
value: {
|
412 |
+
string_value: "${encoder_engine_dir}"
|
413 |
+
}
|
414 |
+
}
|
415 |
+
parameters: {
|
416 |
+
key: "max_tokens_in_paged_kv_cache"
|
417 |
+
value: {
|
418 |
+
string_value: "2560"
|
419 |
+
}
|
420 |
+
}
|
421 |
+
parameters: {
|
422 |
+
key: "max_attention_window_size"
|
423 |
+
value: {
|
424 |
+
string_value: "2000"
|
425 |
+
}
|
426 |
+
}
|
427 |
+
parameters: {
|
428 |
+
key: "sink_token_length"
|
429 |
+
value: {
|
430 |
+
string_value: "${sink_token_length}"
|
431 |
+
}
|
432 |
+
}
|
433 |
+
parameters: {
|
434 |
+
key: "batch_scheduler_policy"
|
435 |
+
value: {
|
436 |
+
string_value: "${batch_scheduler_policy}"
|
437 |
+
}
|
438 |
+
}
|
439 |
+
parameters: {
|
440 |
+
key: "kv_cache_free_gpu_mem_fraction"
|
441 |
+
value: {
|
442 |
+
string_value: "0.5"
|
443 |
+
}
|
444 |
+
}
|
445 |
+
parameters: {
|
446 |
+
key: "kv_cache_host_memory_bytes"
|
447 |
+
value: {
|
448 |
+
string_value: "${kv_cache_host_memory_bytes}"
|
449 |
+
}
|
450 |
+
}
|
451 |
+
parameters: {
|
452 |
+
key: "kv_cache_onboard_blocks"
|
453 |
+
value: {
|
454 |
+
string_value: "${kv_cache_onboard_blocks}"
|
455 |
+
}
|
456 |
+
}
|
457 |
+
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
|
458 |
+
# parameters: {
|
459 |
+
# key: "enable_trt_overlap"
|
460 |
+
# value: {
|
461 |
+
# string_value: "${enable_trt_overlap}"
|
462 |
+
# }
|
463 |
+
# }
|
464 |
+
parameters: {
|
465 |
+
key: "exclude_input_in_output"
|
466 |
+
value: {
|
467 |
+
string_value: "True"
|
468 |
+
}
|
469 |
+
}
|
470 |
+
parameters: {
|
471 |
+
key: "cancellation_check_period_ms"
|
472 |
+
value: {
|
473 |
+
string_value: "${cancellation_check_period_ms}"
|
474 |
+
}
|
475 |
+
}
|
476 |
+
parameters: {
|
477 |
+
key: "stats_check_period_ms"
|
478 |
+
value: {
|
479 |
+
string_value: "${stats_check_period_ms}"
|
480 |
+
}
|
481 |
+
}
|
482 |
+
parameters: {
|
483 |
+
key: "iter_stats_max_iterations"
|
484 |
+
value: {
|
485 |
+
string_value: "${iter_stats_max_iterations}"
|
486 |
+
}
|
487 |
+
}
|
488 |
+
parameters: {
|
489 |
+
key: "request_stats_max_iterations"
|
490 |
+
value: {
|
491 |
+
string_value: "${request_stats_max_iterations}"
|
492 |
+
}
|
493 |
+
}
|
494 |
+
parameters: {
|
495 |
+
key: "enable_kv_cache_reuse"
|
496 |
+
value: {
|
497 |
+
string_value: "False"
|
498 |
+
}
|
499 |
+
}
|
500 |
+
parameters: {
|
501 |
+
key: "normalize_log_probs"
|
502 |
+
value: {
|
503 |
+
string_value: "${normalize_log_probs}"
|
504 |
+
}
|
505 |
+
}
|
506 |
+
parameters: {
|
507 |
+
key: "enable_chunked_context"
|
508 |
+
value: {
|
509 |
+
string_value: "${enable_chunked_context}"
|
510 |
+
}
|
511 |
+
}
|
512 |
+
parameters: {
|
513 |
+
key: "gpu_device_ids"
|
514 |
+
value: {
|
515 |
+
string_value: "${gpu_device_ids}"
|
516 |
+
}
|
517 |
+
}
|
518 |
+
parameters: {
|
519 |
+
key: "lora_cache_optimal_adapter_size"
|
520 |
+
value: {
|
521 |
+
string_value: "${lora_cache_optimal_adapter_size}"
|
522 |
+
}
|
523 |
+
}
|
524 |
+
parameters: {
|
525 |
+
key: "lora_cache_max_adapter_size"
|
526 |
+
value: {
|
527 |
+
string_value: "${lora_cache_max_adapter_size}"
|
528 |
+
}
|
529 |
+
}
|
530 |
+
parameters: {
|
531 |
+
key: "lora_cache_gpu_memory_fraction"
|
532 |
+
value: {
|
533 |
+
string_value: "${lora_cache_gpu_memory_fraction}"
|
534 |
+
}
|
535 |
+
}
|
536 |
+
parameters: {
|
537 |
+
key: "lora_cache_host_memory_bytes"
|
538 |
+
value: {
|
539 |
+
string_value: "${lora_cache_host_memory_bytes}"
|
540 |
+
}
|
541 |
+
}
|
542 |
+
parameters: {
|
543 |
+
key: "decoding_mode"
|
544 |
+
value: {
|
545 |
+
string_value: "${decoding_mode}"
|
546 |
+
}
|
547 |
+
}
|
548 |
+
parameters: {
|
549 |
+
key: "executor_worker_path"
|
550 |
+
value: {
|
551 |
+
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
|
552 |
+
}
|
553 |
+
}
|
554 |
+
parameters: {
|
555 |
+
key: "medusa_choices"
|
556 |
+
value: {
|
557 |
+
string_value: "${medusa_choices}"
|
558 |
+
}
|
559 |
+
}
|
560 |
+
parameters: {
|
561 |
+
key: "gpu_weights_percent"
|
562 |
+
value: {
|
563 |
+
string_value: "${gpu_weights_percent}"
|
564 |
+
}
|
565 |
+
}
|
566 |
+
parameters: {
|
567 |
+
key: "enable_context_fmha_fp32_acc"
|
568 |
+
value: {
|
569 |
+
string_value: "${enable_context_fmha_fp32_acc}"
|
570 |
+
}
|
571 |
+
}
|
572 |
+
parameters: {
|
573 |
+
key: "multi_block_mode"
|
574 |
+
value: {
|
575 |
+
string_value: "${multi_block_mode}"
|
576 |
+
}
|
577 |
+
}
|
model_repo_whisper_qwen_trtllm/tensorrt_llm/config.template
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
name: "tensorrt_llm"
|
28 |
+
backend: "${triton_backend}"
|
29 |
+
max_batch_size: ${triton_max_batch_size}
|
30 |
+
|
31 |
+
model_transaction_policy {
|
32 |
+
decoupled: ${decoupled_mode}
|
33 |
+
}
|
34 |
+
|
35 |
+
dynamic_batching {
|
36 |
+
preferred_batch_size: [ ${triton_max_batch_size} ]
|
37 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
38 |
+
default_queue_policy: { max_queue_size: ${max_queue_size} }
|
39 |
+
}
|
40 |
+
|
41 |
+
input [
|
42 |
+
{
|
43 |
+
name: "input_ids"
|
44 |
+
data_type: TYPE_INT32
|
45 |
+
dims: [ -1 ]
|
46 |
+
allow_ragged_batch: true
|
47 |
+
optional: true
|
48 |
+
},
|
49 |
+
{
|
50 |
+
name: "encoder_input_features"
|
51 |
+
data_type: TYPE_FP16
|
52 |
+
dims: [ -1, -1 ]
|
53 |
+
allow_ragged_batch: true
|
54 |
+
optional: true
|
55 |
+
},
|
56 |
+
{
|
57 |
+
name: "encoder_output_lengths"
|
58 |
+
data_type: TYPE_INT32
|
59 |
+
dims: [ 1 ]
|
60 |
+
reshape: { shape: [ ] }
|
61 |
+
optional: true
|
62 |
+
},
|
63 |
+
{
|
64 |
+
name: "input_lengths"
|
65 |
+
data_type: TYPE_INT32
|
66 |
+
dims: [ 1 ]
|
67 |
+
reshape: { shape: [ ] }
|
68 |
+
},
|
69 |
+
{
|
70 |
+
name: "request_output_len"
|
71 |
+
data_type: TYPE_INT32
|
72 |
+
dims: [ 1 ]
|
73 |
+
reshape: { shape: [ ] }
|
74 |
+
},
|
75 |
+
{
|
76 |
+
name: "draft_input_ids"
|
77 |
+
data_type: TYPE_INT32
|
78 |
+
dims: [ -1 ]
|
79 |
+
optional: true
|
80 |
+
allow_ragged_batch: true
|
81 |
+
},
|
82 |
+
{
|
83 |
+
name: "decoder_input_ids"
|
84 |
+
data_type: TYPE_INT32
|
85 |
+
dims: [ -1 ]
|
86 |
+
optional: true
|
87 |
+
allow_ragged_batch: true
|
88 |
+
},
|
89 |
+
{
|
90 |
+
name: "decoder_input_lengths"
|
91 |
+
data_type: TYPE_INT32
|
92 |
+
dims: [ 1 ]
|
93 |
+
optional: true
|
94 |
+
reshape: { shape: [ ] }
|
95 |
+
},
|
96 |
+
{
|
97 |
+
name: "draft_logits"
|
98 |
+
data_type: TYPE_FP32
|
99 |
+
dims: [ -1, -1 ]
|
100 |
+
optional: true
|
101 |
+
allow_ragged_batch: true
|
102 |
+
},
|
103 |
+
{
|
104 |
+
name: "draft_acceptance_threshold"
|
105 |
+
data_type: TYPE_FP32
|
106 |
+
dims: [ 1 ]
|
107 |
+
reshape: { shape: [ ] }
|
108 |
+
optional: true
|
109 |
+
},
|
110 |
+
{
|
111 |
+
name: "end_id"
|
112 |
+
data_type: TYPE_INT32
|
113 |
+
dims: [ 1 ]
|
114 |
+
reshape: { shape: [ ] }
|
115 |
+
optional: true
|
116 |
+
},
|
117 |
+
{
|
118 |
+
name: "pad_id"
|
119 |
+
data_type: TYPE_INT32
|
120 |
+
dims: [ 1 ]
|
121 |
+
reshape: { shape: [ ] }
|
122 |
+
optional: true
|
123 |
+
},
|
124 |
+
{
|
125 |
+
name: "stop_words_list"
|
126 |
+
data_type: TYPE_INT32
|
127 |
+
dims: [ 2, -1 ]
|
128 |
+
optional: true
|
129 |
+
allow_ragged_batch: true
|
130 |
+
},
|
131 |
+
{
|
132 |
+
name: "bad_words_list"
|
133 |
+
data_type: TYPE_INT32
|
134 |
+
dims: [ 2, -1 ]
|
135 |
+
optional: true
|
136 |
+
allow_ragged_batch: true
|
137 |
+
},
|
138 |
+
{
|
139 |
+
name: "embedding_bias"
|
140 |
+
data_type: TYPE_FP32
|
141 |
+
dims: [ -1 ]
|
142 |
+
optional: true
|
143 |
+
allow_ragged_batch: true
|
144 |
+
},
|
145 |
+
{
|
146 |
+
name: "beam_width"
|
147 |
+
data_type: TYPE_INT32
|
148 |
+
dims: [ 1 ]
|
149 |
+
reshape: { shape: [ ] }
|
150 |
+
optional: true
|
151 |
+
},
|
152 |
+
{
|
153 |
+
name: "temperature"
|
154 |
+
data_type: TYPE_FP32
|
155 |
+
dims: [ 1 ]
|
156 |
+
reshape: { shape: [ ] }
|
157 |
+
optional: true
|
158 |
+
},
|
159 |
+
{
|
160 |
+
name: "runtime_top_k"
|
161 |
+
data_type: TYPE_INT32
|
162 |
+
dims: [ 1 ]
|
163 |
+
reshape: { shape: [ ] }
|
164 |
+
optional: true
|
165 |
+
},
|
166 |
+
{
|
167 |
+
name: "runtime_top_p"
|
168 |
+
data_type: TYPE_FP32
|
169 |
+
dims: [ 1 ]
|
170 |
+
reshape: { shape: [ ] }
|
171 |
+
optional: true
|
172 |
+
},
|
173 |
+
{
|
174 |
+
name: "runtime_top_p_min"
|
175 |
+
data_type: TYPE_FP32
|
176 |
+
dims: [ 1 ]
|
177 |
+
reshape: { shape: [ ] }
|
178 |
+
optional: true
|
179 |
+
},
|
180 |
+
{
|
181 |
+
name: "runtime_top_p_decay"
|
182 |
+
data_type: TYPE_FP32
|
183 |
+
dims: [ 1 ]
|
184 |
+
reshape: { shape: [ ] }
|
185 |
+
optional: true
|
186 |
+
},
|
187 |
+
{
|
188 |
+
name: "runtime_top_p_reset_ids"
|
189 |
+
data_type: TYPE_INT32
|
190 |
+
dims: [ 1 ]
|
191 |
+
reshape: { shape: [ ] }
|
192 |
+
optional: true
|
193 |
+
},
|
194 |
+
{
|
195 |
+
name: "len_penalty"
|
196 |
+
data_type: TYPE_FP32
|
197 |
+
dims: [ 1 ]
|
198 |
+
reshape: { shape: [ ] }
|
199 |
+
optional: true
|
200 |
+
},
|
201 |
+
{
|
202 |
+
name: "early_stopping"
|
203 |
+
data_type: TYPE_BOOL
|
204 |
+
dims: [ 1 ]
|
205 |
+
reshape: { shape: [ ] }
|
206 |
+
optional: true
|
207 |
+
},
|
208 |
+
{
|
209 |
+
name: "repetition_penalty"
|
210 |
+
data_type: TYPE_FP32
|
211 |
+
dims: [ 1 ]
|
212 |
+
reshape: { shape: [ ] }
|
213 |
+
optional: true
|
214 |
+
},
|
215 |
+
{
|
216 |
+
name: "min_length"
|
217 |
+
data_type: TYPE_INT32
|
218 |
+
dims: [ 1 ]
|
219 |
+
reshape: { shape: [ ] }
|
220 |
+
optional: true
|
221 |
+
},
|
222 |
+
{
|
223 |
+
name: "beam_search_diversity_rate"
|
224 |
+
data_type: TYPE_FP32
|
225 |
+
dims: [ 1 ]
|
226 |
+
reshape: { shape: [ ] }
|
227 |
+
optional: true
|
228 |
+
},
|
229 |
+
{
|
230 |
+
name: "presence_penalty"
|
231 |
+
data_type: TYPE_FP32
|
232 |
+
dims: [ 1 ]
|
233 |
+
reshape: { shape: [ ] }
|
234 |
+
optional: true
|
235 |
+
},
|
236 |
+
{
|
237 |
+
name: "frequency_penalty"
|
238 |
+
data_type: TYPE_FP32
|
239 |
+
dims: [ 1 ]
|
240 |
+
reshape: { shape: [ ] }
|
241 |
+
optional: true
|
242 |
+
},
|
243 |
+
{
|
244 |
+
name: "random_seed"
|
245 |
+
data_type: TYPE_UINT64
|
246 |
+
dims: [ 1 ]
|
247 |
+
reshape: { shape: [ ] }
|
248 |
+
optional: true
|
249 |
+
},
|
250 |
+
{
|
251 |
+
name: "return_log_probs"
|
252 |
+
data_type: TYPE_BOOL
|
253 |
+
dims: [ 1 ]
|
254 |
+
reshape: { shape: [ ] }
|
255 |
+
optional: true
|
256 |
+
},
|
257 |
+
{
|
258 |
+
name: "return_context_logits"
|
259 |
+
data_type: TYPE_BOOL
|
260 |
+
dims: [ 1 ]
|
261 |
+
reshape: { shape: [ ] }
|
262 |
+
optional: true
|
263 |
+
},
|
264 |
+
{
|
265 |
+
name: "return_generation_logits"
|
266 |
+
data_type: TYPE_BOOL
|
267 |
+
dims: [ 1 ]
|
268 |
+
reshape: { shape: [ ] }
|
269 |
+
optional: true
|
270 |
+
},
|
271 |
+
{
|
272 |
+
name: "stop"
|
273 |
+
data_type: TYPE_BOOL
|
274 |
+
dims: [ 1 ]
|
275 |
+
reshape: { shape: [ ] }
|
276 |
+
optional: true
|
277 |
+
},
|
278 |
+
{
|
279 |
+
name: "streaming"
|
280 |
+
data_type: TYPE_BOOL
|
281 |
+
dims: [ 1 ]
|
282 |
+
reshape: { shape: [ ] }
|
283 |
+
optional: true
|
284 |
+
},
|
285 |
+
{
|
286 |
+
name: "prompt_embedding_table"
|
287 |
+
data_type: TYPE_FP16
|
288 |
+
dims: [ -1, -1 ]
|
289 |
+
optional: true
|
290 |
+
allow_ragged_batch: true
|
291 |
+
},
|
292 |
+
{
|
293 |
+
name: "prompt_vocab_size"
|
294 |
+
data_type: TYPE_INT32
|
295 |
+
dims: [ 1 ]
|
296 |
+
reshape: { shape: [ ] }
|
297 |
+
optional: true
|
298 |
+
},
|
299 |
+
# the unique task ID for the given LoRA.
|
300 |
+
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
|
301 |
+
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
|
302 |
+
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
|
303 |
+
{
|
304 |
+
name: "lora_task_id"
|
305 |
+
data_type: TYPE_UINT64
|
306 |
+
dims: [ 1 ]
|
307 |
+
reshape: { shape: [ ] }
|
308 |
+
optional: true
|
309 |
+
},
|
310 |
+
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
|
311 |
+
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
|
312 |
+
# each of the in / out tensors are first flattened and then concatenated together in the format above.
|
313 |
+
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
|
314 |
+
{
|
315 |
+
name: "lora_weights"
|
316 |
+
data_type: TYPE_FP16
|
317 |
+
dims: [ -1, -1 ]
|
318 |
+
optional: true
|
319 |
+
allow_ragged_batch: true
|
320 |
+
},
|
321 |
+
# module identifier (same size a first dimension of lora_weights)
|
322 |
+
# See LoraModule::ModuleType for model id mapping
|
323 |
+
#
|
324 |
+
# "attn_qkv": 0 # compbined qkv adapter
|
325 |
+
# "attn_q": 1 # q adapter
|
326 |
+
# "attn_k": 2 # k adapter
|
327 |
+
# "attn_v": 3 # v adapter
|
328 |
+
# "attn_dense": 4 # adapter for the dense layer in attention
|
329 |
+
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
|
330 |
+
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
|
331 |
+
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
|
332 |
+
#
|
333 |
+
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
|
334 |
+
{
|
335 |
+
name: "lora_config"
|
336 |
+
data_type: TYPE_INT32
|
337 |
+
dims: [ -1, 3 ]
|
338 |
+
optional: true
|
339 |
+
allow_ragged_batch: true
|
340 |
+
}
|
341 |
+
]
|
342 |
+
output [
|
343 |
+
{
|
344 |
+
name: "output_ids"
|
345 |
+
data_type: TYPE_INT32
|
346 |
+
dims: [ -1, -1 ]
|
347 |
+
},
|
348 |
+
{
|
349 |
+
name: "sequence_length"
|
350 |
+
data_type: TYPE_INT32
|
351 |
+
dims: [ -1 ]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
name: "cum_log_probs"
|
355 |
+
data_type: TYPE_FP32
|
356 |
+
dims: [ -1 ]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
name: "output_log_probs"
|
360 |
+
data_type: TYPE_FP32
|
361 |
+
dims: [ -1, -1 ]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
name: "context_logits"
|
365 |
+
data_type: TYPE_FP32
|
366 |
+
dims: [ -1, -1 ]
|
367 |
+
},
|
368 |
+
{
|
369 |
+
name: "generation_logits"
|
370 |
+
data_type: TYPE_FP32
|
371 |
+
dims: [ -1, -1, -1 ]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
name: "batch_index"
|
375 |
+
data_type: TYPE_INT32
|
376 |
+
dims: [ 1 ]
|
377 |
+
}
|
378 |
+
]
|
379 |
+
instance_group [
|
380 |
+
{
|
381 |
+
count: 1
|
382 |
+
kind : KIND_CPU
|
383 |
+
}
|
384 |
+
]
|
385 |
+
parameters: {
|
386 |
+
key: "max_beam_width"
|
387 |
+
value: {
|
388 |
+
string_value: "${max_beam_width}"
|
389 |
+
}
|
390 |
+
}
|
391 |
+
parameters: {
|
392 |
+
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
|
393 |
+
value: {
|
394 |
+
string_value: "no"
|
395 |
+
}
|
396 |
+
}
|
397 |
+
parameters: {
|
398 |
+
key: "gpt_model_type"
|
399 |
+
value: {
|
400 |
+
string_value: "${batching_strategy}"
|
401 |
+
}
|
402 |
+
}
|
403 |
+
parameters: {
|
404 |
+
key: "gpt_model_path"
|
405 |
+
value: {
|
406 |
+
string_value: "${engine_dir}"
|
407 |
+
}
|
408 |
+
}
|
409 |
+
parameters: {
|
410 |
+
key: "encoder_model_path"
|
411 |
+
value: {
|
412 |
+
string_value: "${encoder_engine_dir}"
|
413 |
+
}
|
414 |
+
}
|
415 |
+
parameters: {
|
416 |
+
key: "max_tokens_in_paged_kv_cache"
|
417 |
+
value: {
|
418 |
+
string_value: "${max_tokens_in_paged_kv_cache}"
|
419 |
+
}
|
420 |
+
}
|
421 |
+
parameters: {
|
422 |
+
key: "max_attention_window_size"
|
423 |
+
value: {
|
424 |
+
string_value: "${max_attention_window_size}"
|
425 |
+
}
|
426 |
+
}
|
427 |
+
parameters: {
|
428 |
+
key: "sink_token_length"
|
429 |
+
value: {
|
430 |
+
string_value: "${sink_token_length}"
|
431 |
+
}
|
432 |
+
}
|
433 |
+
parameters: {
|
434 |
+
key: "batch_scheduler_policy"
|
435 |
+
value: {
|
436 |
+
string_value: "${batch_scheduler_policy}"
|
437 |
+
}
|
438 |
+
}
|
439 |
+
parameters: {
|
440 |
+
key: "kv_cache_free_gpu_mem_fraction"
|
441 |
+
value: {
|
442 |
+
string_value: "${kv_cache_free_gpu_mem_fraction}"
|
443 |
+
}
|
444 |
+
}
|
445 |
+
parameters: {
|
446 |
+
key: "kv_cache_host_memory_bytes"
|
447 |
+
value: {
|
448 |
+
string_value: "${kv_cache_host_memory_bytes}"
|
449 |
+
}
|
450 |
+
}
|
451 |
+
parameters: {
|
452 |
+
key: "kv_cache_onboard_blocks"
|
453 |
+
value: {
|
454 |
+
string_value: "${kv_cache_onboard_blocks}"
|
455 |
+
}
|
456 |
+
}
|
457 |
+
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
|
458 |
+
# parameters: {
|
459 |
+
# key: "enable_trt_overlap"
|
460 |
+
# value: {
|
461 |
+
# string_value: "${enable_trt_overlap}"
|
462 |
+
# }
|
463 |
+
# }
|
464 |
+
parameters: {
|
465 |
+
key: "exclude_input_in_output"
|
466 |
+
value: {
|
467 |
+
string_value: "${exclude_input_in_output}"
|
468 |
+
}
|
469 |
+
}
|
470 |
+
parameters: {
|
471 |
+
key: "cancellation_check_period_ms"
|
472 |
+
value: {
|
473 |
+
string_value: "${cancellation_check_period_ms}"
|
474 |
+
}
|
475 |
+
}
|
476 |
+
parameters: {
|
477 |
+
key: "stats_check_period_ms"
|
478 |
+
value: {
|
479 |
+
string_value: "${stats_check_period_ms}"
|
480 |
+
}
|
481 |
+
}
|
482 |
+
parameters: {
|
483 |
+
key: "iter_stats_max_iterations"
|
484 |
+
value: {
|
485 |
+
string_value: "${iter_stats_max_iterations}"
|
486 |
+
}
|
487 |
+
}
|
488 |
+
parameters: {
|
489 |
+
key: "request_stats_max_iterations"
|
490 |
+
value: {
|
491 |
+
string_value: "${request_stats_max_iterations}"
|
492 |
+
}
|
493 |
+
}
|
494 |
+
parameters: {
|
495 |
+
key: "enable_kv_cache_reuse"
|
496 |
+
value: {
|
497 |
+
string_value: "${enable_kv_cache_reuse}"
|
498 |
+
}
|
499 |
+
}
|
500 |
+
parameters: {
|
501 |
+
key: "normalize_log_probs"
|
502 |
+
value: {
|
503 |
+
string_value: "${normalize_log_probs}"
|
504 |
+
}
|
505 |
+
}
|
506 |
+
parameters: {
|
507 |
+
key: "enable_chunked_context"
|
508 |
+
value: {
|
509 |
+
string_value: "${enable_chunked_context}"
|
510 |
+
}
|
511 |
+
}
|
512 |
+
parameters: {
|
513 |
+
key: "gpu_device_ids"
|
514 |
+
value: {
|
515 |
+
string_value: "${gpu_device_ids}"
|
516 |
+
}
|
517 |
+
}
|
518 |
+
parameters: {
|
519 |
+
key: "lora_cache_optimal_adapter_size"
|
520 |
+
value: {
|
521 |
+
string_value: "${lora_cache_optimal_adapter_size}"
|
522 |
+
}
|
523 |
+
}
|
524 |
+
parameters: {
|
525 |
+
key: "lora_cache_max_adapter_size"
|
526 |
+
value: {
|
527 |
+
string_value: "${lora_cache_max_adapter_size}"
|
528 |
+
}
|
529 |
+
}
|
530 |
+
parameters: {
|
531 |
+
key: "lora_cache_gpu_memory_fraction"
|
532 |
+
value: {
|
533 |
+
string_value: "${lora_cache_gpu_memory_fraction}"
|
534 |
+
}
|
535 |
+
}
|
536 |
+
parameters: {
|
537 |
+
key: "lora_cache_host_memory_bytes"
|
538 |
+
value: {
|
539 |
+
string_value: "${lora_cache_host_memory_bytes}"
|
540 |
+
}
|
541 |
+
}
|
542 |
+
parameters: {
|
543 |
+
key: "decoding_mode"
|
544 |
+
value: {
|
545 |
+
string_value: "${decoding_mode}"
|
546 |
+
}
|
547 |
+
}
|
548 |
+
parameters: {
|
549 |
+
key: "executor_worker_path"
|
550 |
+
value: {
|
551 |
+
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
|
552 |
+
}
|
553 |
+
}
|
554 |
+
parameters: {
|
555 |
+
key: "medusa_choices"
|
556 |
+
value: {
|
557 |
+
string_value: "${medusa_choices}"
|
558 |
+
}
|
559 |
+
}
|
560 |
+
parameters: {
|
561 |
+
key: "gpu_weights_percent"
|
562 |
+
value: {
|
563 |
+
string_value: "${gpu_weights_percent}"
|
564 |
+
}
|
565 |
+
}
|
566 |
+
parameters: {
|
567 |
+
key: "enable_context_fmha_fp32_acc"
|
568 |
+
value: {
|
569 |
+
string_value: "${enable_context_fmha_fp32_acc}"
|
570 |
+
}
|
571 |
+
}
|
572 |
+
parameters: {
|
573 |
+
key: "multi_block_mode"
|
574 |
+
value: {
|
575 |
+
string_value: "${multi_block_mode}"
|
576 |
+
}
|
577 |
+
}
|
model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/fbank.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/model.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/whisper_trtllm.cpython-310.pyc
ADDED
Binary file (9.21 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/0/fbank.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from typing import Union
|
19 |
+
import os
|
20 |
+
|
21 |
+
def mel_filters(device, n_mels: int =128) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
24 |
+
Allows decoupling librosa dependency; saved using:
|
25 |
+
|
26 |
+
np.savez_compressed(
|
27 |
+
"mel_filters.npz",
|
28 |
+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
29 |
+
)
|
30 |
+
"""
|
31 |
+
assert n_mels == 80 or n_mels == 128 , f"Unsupported n_mels: {n_mels}"
|
32 |
+
with np.load(
|
33 |
+
os.path.join(os.path.dirname(__file__), "mel_filters.npz")
|
34 |
+
) as f:
|
35 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
36 |
+
|
37 |
+
|
38 |
+
def log_mel_spectrogram(
|
39 |
+
audio: Union[torch.Tensor],
|
40 |
+
filters: torch.Tensor,
|
41 |
+
n_mels: int = 128,
|
42 |
+
n_fft: int = 400,
|
43 |
+
hop_length: int = 160,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Compute the log-Mel spectrogram of
|
47 |
+
|
48 |
+
Parameters
|
49 |
+
----------
|
50 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
51 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
52 |
+
|
53 |
+
n_mels: int
|
54 |
+
The number of Mel-frequency filters, only 80 or 128 is supported
|
55 |
+
|
56 |
+
filters: torch.Tensor
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
torch.Tensor, shape = (128, n_frames)
|
61 |
+
A Tensor that contains the Mel spectrogram
|
62 |
+
"""
|
63 |
+
window = torch.hann_window(n_fft).to(audio.device)
|
64 |
+
stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
|
65 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
66 |
+
|
67 |
+
mel_spec = filters @ magnitudes
|
68 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
69 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
70 |
+
log_spec = (log_spec + 4.0) / 4.0
|
71 |
+
# cast to float 16
|
72 |
+
log_spec = log_spec.half()
|
73 |
+
return log_spec
|
74 |
+
|
75 |
+
class FeatureExtractor(torch.nn.Module):
|
76 |
+
"""Your Python model must use the same class name. Every Python model
|
77 |
+
that is created must have "TritonPythonModel" as the class name.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, n_mels: int = 128):
|
81 |
+
self.device = torch.device("cuda")
|
82 |
+
self.n_mels = n_mels
|
83 |
+
self.filters = mel_filters(self.device, n_mels=self.n_mels)
|
84 |
+
|
85 |
+
def compute_feature(self, wav, target: int = 3000):
|
86 |
+
mel = log_mel_spectrogram(wav, self.filters)
|
87 |
+
assert mel.shape[1] <= target, f"{mel.shape[1]} > {target}, audio is too long"
|
88 |
+
if mel.shape[1] < target:
|
89 |
+
mel = F.pad(mel, (0, target - mel.shape[1]), mode='constant')
|
90 |
+
mel = mel.unsqueeze(0)
|
91 |
+
return mel
|
model_repo_whisper_qwen_trtllm/whisper/0/mel_filters.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
|
3 |
+
size 4271
|
model_repo_whisper_qwen_trtllm/whisper/0/model.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import triton_python_backend_utils as pb_utils
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
7 |
+
import re
|
8 |
+
import transformers
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from typing import Dict
|
11 |
+
from pathlib import Path
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from .whisper_trtllm import WhisperTRTLLM
|
15 |
+
from .fbank import FeatureExtractor
|
16 |
+
|
17 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
18 |
+
def preprocess(
|
19 |
+
messages,
|
20 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
21 |
+
max_len: int = 128,
|
22 |
+
) -> Dict:
|
23 |
+
"""Preprocesses the data for supervised fine-tuning."""
|
24 |
+
texts = []
|
25 |
+
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
26 |
+
for i, msg in enumerate(messages):
|
27 |
+
texts.append(
|
28 |
+
tokenizer.apply_chat_template(
|
29 |
+
msg,
|
30 |
+
tokenize=True,
|
31 |
+
add_generation_prompt=False,
|
32 |
+
chat_template=TEMPLATE,
|
33 |
+
padding="longest",
|
34 |
+
max_length=max_len,
|
35 |
+
truncation=True,
|
36 |
+
)
|
37 |
+
)
|
38 |
+
max_len_texts = max([len(text) for text in texts])
|
39 |
+
if tokenizer.padding_side == "right":
|
40 |
+
texts = [
|
41 |
+
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
42 |
+
for text in texts
|
43 |
+
]
|
44 |
+
else:
|
45 |
+
texts = [
|
46 |
+
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
47 |
+
for text in texts
|
48 |
+
]
|
49 |
+
|
50 |
+
input_ids = torch.tensor(texts, dtype=torch.int)
|
51 |
+
|
52 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
53 |
+
|
54 |
+
return input_ids, attention_mask
|
55 |
+
|
56 |
+
class TritonPythonModel:
|
57 |
+
"""Your Python model must use the same class name. Every Python model
|
58 |
+
that is created must have "TritonPythonModel" as the class name.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def initialize(self, args):
|
62 |
+
"""`initialize` is called only once when the model is being loaded.
|
63 |
+
Implementing `initialize` function is optional. This function allows
|
64 |
+
the model to initialize any state associated with this model.
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
args : dict
|
69 |
+
Both keys and values are strings. The dictionary keys and values are:
|
70 |
+
* model_config: A JSON string containing the model configuration
|
71 |
+
* model_instance_kind: A string containing model instance kind
|
72 |
+
* model_instance_device_id: A string containing model instance device ID
|
73 |
+
* model_repository: Model repository path
|
74 |
+
* model_version: Model version
|
75 |
+
* model_name: Model name
|
76 |
+
"""
|
77 |
+
self.model_config = model_config = json.loads(args['model_config'])
|
78 |
+
|
79 |
+
# Get OUTPUT0 configuration
|
80 |
+
output0_config = pb_utils.get_output_config_by_name(
|
81 |
+
model_config, "TRANSCRIPTS")
|
82 |
+
# Convert Triton types to numpy types
|
83 |
+
self.out0_dtype = pb_utils.triton_string_to_numpy(
|
84 |
+
output0_config['data_type'])
|
85 |
+
|
86 |
+
#self.tokenizer = get_tokenizer(num_languages=100)
|
87 |
+
#self.blank = self.tokenizer.encode(" ", allowed_special=self.tokenizer.special_tokens_set)[0]
|
88 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
89 |
+
tokenizer.padding_side = "left"
|
90 |
+
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
91 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
92 |
+
self.tokenizer = tokenizer
|
93 |
+
self.eos = self.tokenizer.eos_token_id
|
94 |
+
self.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
95 |
+
DEFAULT_SPEECH_TOKEN
|
96 |
+
)
|
97 |
+
self.vocab_size = 151936
|
98 |
+
# self.vocab_size = 500000
|
99 |
+
# self.vocab_size = 160000
|
100 |
+
|
101 |
+
self.device = torch.device("cuda")
|
102 |
+
self.decoupled = False
|
103 |
+
self.logger = pb_utils.Logger
|
104 |
+
self.init_model(self.model_config['parameters'])
|
105 |
+
|
106 |
+
def init_model(self, parameters):
|
107 |
+
for key,value in parameters.items():
|
108 |
+
parameters[key] = value["string_value"]
|
109 |
+
engine_dir = parameters["engine_dir"]
|
110 |
+
n_mels = int(parameters["n_mels"])
|
111 |
+
adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt"
|
112 |
+
checkpoint = torch.load(
|
113 |
+
adapter_dir, map_location="cpu"
|
114 |
+
)
|
115 |
+
self.model = WhisperTRTLLM(engine_dir)
|
116 |
+
missing_keys, _ = self.model.load_state_dict(checkpoint, strict=False)
|
117 |
+
# print(f"Missing keys: {missing_keys}")
|
118 |
+
self.feature_extractor = FeatureExtractor(n_mels=n_mels)
|
119 |
+
|
120 |
+
def _tokenize(self, prompt=None, num_speech_tokens=187):
|
121 |
+
if prompt is None:
|
122 |
+
prompts = [
|
123 |
+
[
|
124 |
+
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
125 |
+
{"role": "assistant", "content": ""},
|
126 |
+
]
|
127 |
+
]
|
128 |
+
# prompts = [
|
129 |
+
# [
|
130 |
+
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}你好,你是谁?"},
|
131 |
+
# {"role": "assistant", "content": ""},
|
132 |
+
# ]
|
133 |
+
# ]
|
134 |
+
|
135 |
+
input_ids, _ = preprocess(prompts, self.tokenizer, max_len=128)
|
136 |
+
input_ids = input_ids.tolist()[0]
|
137 |
+
speech_token_index = input_ids.index(self.default_speech_token_id)
|
138 |
+
# replace 151646 with list(range(self.vocab_size, self.vocab_size + num_speech_tokens))
|
139 |
+
prompt_ids = input_ids[:speech_token_index] + list(range(self.vocab_size, self.vocab_size + num_speech_tokens)) + input_ids[speech_token_index + 1:]
|
140 |
+
# prompt_ids = input_ids[:speech_token_index] + input_ids[speech_token_index + 1:]
|
141 |
+
return prompt_ids
|
142 |
+
|
143 |
+
def _prepare_inputs(self, request, speech_embeddings, input_ids):
|
144 |
+
"""
|
145 |
+
Prepares inputs for the language model based on the parameters in the
|
146 |
+
request, image features, and prompt. It tokenizes prompt,
|
147 |
+
extracts and processes additional parameters from the request:
|
148 |
+
- max_tokens: Maximum number of tokens to generate (default: 50)
|
149 |
+
- temperature: Controls randomness in generation (default: 0.5)
|
150 |
+
- top_k: Top K sampling parameter (default: 1)
|
151 |
+
- frequency_penalty: Penalizes frequent tokens (default: 0.7)
|
152 |
+
- seed: Random seed for generation (default: 10)
|
153 |
+
|
154 |
+
Final llm input dictionary is combined out of all processed parameters,
|
155 |
+
prompt's tokens and image features. The latter will be passed to llm
|
156 |
+
through `prompt_embedding_table`.
|
157 |
+
|
158 |
+
Parameters
|
159 |
+
----------
|
160 |
+
- request: The original request object containing additional parameters.
|
161 |
+
- image_features (list): A list containing image feature tensors.
|
162 |
+
- prompt (str): The text prompt to be processed.
|
163 |
+
|
164 |
+
Returns
|
165 |
+
-------
|
166 |
+
- dict: A dictionary containing all the prepared inputs for the language model.
|
167 |
+
"""
|
168 |
+
input_ids = np.array(input_ids, dtype=np.int32)
|
169 |
+
max_tokens = 200
|
170 |
+
input_len = input_ids.shape[0]
|
171 |
+
|
172 |
+
assert speech_embeddings.shape[1] == 187, "Only support 187 speech tokens"
|
173 |
+
embedding_args = {
|
174 |
+
"prompt_vocab_size": np.array(
|
175 |
+
[[speech_embeddings.shape[1]]], dtype=np.int32
|
176 |
+
),
|
177 |
+
"prompt_embedding_table": speech_embeddings.detach().cpu().numpy(),
|
178 |
+
}
|
179 |
+
# TODO: 加不加这个出来的结果一样??? input_ids 超过最大 vocab 也不会报错???
|
180 |
+
input_dict = {
|
181 |
+
"input_ids": np.expand_dims(input_ids, 0),
|
182 |
+
"input_lengths": np.array([[input_len]], dtype=np.int32),
|
183 |
+
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
184 |
+
"runtime_top_k": np.array([[1]], dtype=np.int32),
|
185 |
+
"end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
|
186 |
+
"pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
|
187 |
+
"streaming": np.array([[0]], dtype=np.bool_),
|
188 |
+
**embedding_args,
|
189 |
+
}
|
190 |
+
|
191 |
+
print(input_ids)
|
192 |
+
for key, value in input_dict.items():
|
193 |
+
print(key, value.shape)
|
194 |
+
|
195 |
+
input_tensor_list = [pb_utils.Tensor(k, v) for k, v in input_dict.items()]
|
196 |
+
return input_tensor_list
|
197 |
+
|
198 |
+
def _prepare_llm_response(self, llm_request_inputs):
|
199 |
+
"""
|
200 |
+
Prepares the response from the language model based on the provided
|
201 |
+
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
202 |
+
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
203 |
+
For each response from the language model:
|
204 |
+
- Checks for errors and raise an exception if any are found.
|
205 |
+
- Extracts the "output_ids" tensor from the response.
|
206 |
+
- Determines the finish reason based on the presence of the
|
207 |
+
end-of-sequence token or reaching the maximum length.
|
208 |
+
- Appends the generated token IDs to `output_ids`.
|
209 |
+
- If the finish reason is determined, decodes the output IDs to text
|
210 |
+
and prepares the final response.
|
211 |
+
|
212 |
+
The final response includes the generated text, finish reason,
|
213 |
+
completion tokens, prompt tokens, and total tokens.
|
214 |
+
|
215 |
+
Parameters
|
216 |
+
----------
|
217 |
+
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
218 |
+
|
219 |
+
Returns
|
220 |
+
-------
|
221 |
+
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
222 |
+
"""
|
223 |
+
|
224 |
+
llm_request = pb_utils.InferenceRequest(
|
225 |
+
model_name="tensorrt_llm",
|
226 |
+
requested_output_names=["output_ids", "sequence_length"],
|
227 |
+
inputs=llm_request_inputs,
|
228 |
+
)
|
229 |
+
output_ids, output_len = [], 0
|
230 |
+
responses = llm_request.exec(decoupled=False)
|
231 |
+
responses = [responses]
|
232 |
+
for llm_response in responses:
|
233 |
+
if llm_response.has_error():
|
234 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
235 |
+
stream_output_ids = (
|
236 |
+
pb_utils.get_output_tensor_by_name(llm_response, "output_ids")
|
237 |
+
.as_numpy()
|
238 |
+
.flatten()
|
239 |
+
.tolist()
|
240 |
+
)
|
241 |
+
finish_reason = "test"
|
242 |
+
if len(stream_output_ids) == 0 or (
|
243 |
+
len(stream_output_ids) != 0
|
244 |
+
and stream_output_ids[-1] == self.eos
|
245 |
+
):
|
246 |
+
finish_reason = "stop"
|
247 |
+
|
248 |
+
output_ids += stream_output_ids
|
249 |
+
|
250 |
+
last_response = finish_reason != ""
|
251 |
+
output_len = len(output_ids)
|
252 |
+
if last_response:
|
253 |
+
print("final_output_ids", output_ids)
|
254 |
+
output_text = self.tokenizer.decode(output_ids).strip()
|
255 |
+
# print(output_text)
|
256 |
+
# output_text = re.sub(r'<\|.*?\|>', '', output_text)
|
257 |
+
response = pb_utils.InferenceResponse(
|
258 |
+
output_tensors=[
|
259 |
+
pb_utils.Tensor("TRANSCRIPTS", np.array([output_text], np.object_)),
|
260 |
+
]
|
261 |
+
)
|
262 |
+
yield response
|
263 |
+
|
264 |
+
def _extract_speech_embeddings(self, mel):
|
265 |
+
return self.model.process_batch(mel)
|
266 |
+
|
267 |
+
|
268 |
+
def execute(self, requests):
|
269 |
+
|
270 |
+
responses = []
|
271 |
+
|
272 |
+
for request in requests:
|
273 |
+
wav = pb_utils.get_input_tensor_by_name(request, "WAV").as_numpy()
|
274 |
+
assert wav.shape[0] == 1, "Only support batch size 1"
|
275 |
+
# To support batch > 1
|
276 |
+
# cat mel,text_prompt, also, need to increase decoder_input_len as a triton input
|
277 |
+
wav = torch.from_numpy(wav[0]).to(self.device)
|
278 |
+
# mel shape [1, 80, 3000] for remove_input_padding=False
|
279 |
+
mel = self.feature_extractor.compute_feature(wav)
|
280 |
+
print("==========================================================")
|
281 |
+
messages = [
|
282 |
+
[
|
283 |
+
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
284 |
+
{"role": "assistant", "content": ""},
|
285 |
+
]
|
286 |
+
] * len(mel)
|
287 |
+
|
288 |
+
input_ids, attention_mask = preprocess(messages, self.tokenizer, max_len=128)
|
289 |
+
|
290 |
+
generated_ids = self.model.decode(
|
291 |
+
mel, input_ids.to(self.device, dtype=torch.long), attention_mask.to(self.device)
|
292 |
+
)
|
293 |
+
print("pytorch model", generated_ids)
|
294 |
+
print("--------------------------------------------------------------------------")
|
295 |
+
|
296 |
+
|
297 |
+
speech_embeddings = self._extract_speech_embeddings(mel)
|
298 |
+
input_ids = self._tokenize()
|
299 |
+
|
300 |
+
|
301 |
+
if self.decoupled:
|
302 |
+
response_sender = request.get_response_sender()
|
303 |
+
try:
|
304 |
+
|
305 |
+
llm_request_inputs = self._prepare_inputs(
|
306 |
+
request, speech_embeddings, input_ids
|
307 |
+
)
|
308 |
+
if isinstance(llm_request_inputs, pb_utils.TritonError):
|
309 |
+
error = pb_utils.InferenceResponse(error=llm_request_inputs)
|
310 |
+
if self.decoupled:
|
311 |
+
response_sender.send(
|
312 |
+
error, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
responses.append(error)
|
316 |
+
llm_responses = self._prepare_llm_response(llm_request_inputs)
|
317 |
+
|
318 |
+
for triton_response in llm_responses:
|
319 |
+
if self.decoupled:
|
320 |
+
response_sender.send(triton_response)
|
321 |
+
else:
|
322 |
+
responses.append(triton_response)
|
323 |
+
|
324 |
+
if self.decoupled:
|
325 |
+
response_sender.send(
|
326 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
327 |
+
|
328 |
+
except Exception:
|
329 |
+
self.logger.log_error(traceback.format_exc())
|
330 |
+
# If encountering an error, send a response with err msg
|
331 |
+
error_response = pb_utils.InferenceResponse(
|
332 |
+
output_tensors=[],
|
333 |
+
error=pb_utils.TritonError(traceback.format_exc()))
|
334 |
+
|
335 |
+
if self.decoupled:
|
336 |
+
response_sender.send(error_response)
|
337 |
+
response_sender.send(
|
338 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
339 |
+
else:
|
340 |
+
responses.append(error_response)
|
341 |
+
|
342 |
+
if self.decoupled:
|
343 |
+
return None
|
344 |
+
else:
|
345 |
+
assert len(responses) == len(requests)
|
346 |
+
return responses
|
model_repo_whisper_qwen_trtllm/whisper/0/whisper_trtllm.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import json
|
16 |
+
from collections import OrderedDict
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import tensorrt_llm
|
23 |
+
import tensorrt_llm.logger as logger
|
24 |
+
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
|
25 |
+
trt_dtype_to_torch)
|
26 |
+
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
|
27 |
+
from tensorrt_llm.runtime.session import Session, TensorInfo
|
28 |
+
from transformers.trainer_pt_utils import LabelSmoother
|
29 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
30 |
+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
31 |
+
|
32 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
33 |
+
def remove_tensor_padding(input_tensor, input_tensor_lengths=None, pad_value=0):
|
34 |
+
if input_tensor.dim() == 2:
|
35 |
+
# Text tensor case: batch, seq_len
|
36 |
+
assert torch.all(
|
37 |
+
input_tensor[:, 0] != pad_value
|
38 |
+
), "First token in each sequence should not be pad_value"
|
39 |
+
assert input_tensor_lengths is None
|
40 |
+
|
41 |
+
# Create a mask for all non-pad tokens
|
42 |
+
mask = input_tensor != pad_value
|
43 |
+
|
44 |
+
# Apply the mask to input_tensor to remove pad tokens
|
45 |
+
output_tensor = input_tensor[mask].view(1, -1)
|
46 |
+
|
47 |
+
elif input_tensor.dim() == 3:
|
48 |
+
# Audio tensor case: batch, seq_len, feature_len
|
49 |
+
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
|
50 |
+
batch_size, seq_len, feature_len = input_tensor.shape
|
51 |
+
|
52 |
+
# Initialize a list to collect valid sequences
|
53 |
+
valid_sequences = []
|
54 |
+
|
55 |
+
for i in range(batch_size):
|
56 |
+
valid_length = input_tensor_lengths[i]
|
57 |
+
valid_sequences.append(input_tensor[i, :valid_length, :])
|
58 |
+
|
59 |
+
# Concatenate all valid sequences along the batch dimension
|
60 |
+
output_tensor = torch.cat(valid_sequences, dim=0)
|
61 |
+
|
62 |
+
else:
|
63 |
+
raise ValueError("Input tensor must have 2 or 3 dimensions")
|
64 |
+
|
65 |
+
return output_tensor
|
66 |
+
|
67 |
+
def read_config(component, engine_dir):
|
68 |
+
config_path = engine_dir / component / 'config.json'
|
69 |
+
with open(config_path, 'r') as f:
|
70 |
+
config = json.load(f)
|
71 |
+
model_config = OrderedDict()
|
72 |
+
model_config.update(config['pretrained_config'])
|
73 |
+
model_config.update(config['build_config'])
|
74 |
+
return model_config
|
75 |
+
|
76 |
+
class WhisperEncoding:
|
77 |
+
def __init__(self, engine_dir):
|
78 |
+
self.session = self.get_session(engine_dir)
|
79 |
+
config = read_config('encoder', engine_dir)
|
80 |
+
self.n_mels = config['n_mels']
|
81 |
+
self.dtype = config['dtype']
|
82 |
+
self.num_languages = config['num_languages']
|
83 |
+
self.encoder_config = config
|
84 |
+
|
85 |
+
def get_session(self, engine_dir):
|
86 |
+
serialize_path = engine_dir / 'encoder' / 'rank0.engine'
|
87 |
+
with open(serialize_path, 'rb') as f:
|
88 |
+
session = Session.from_serialized_engine(f.read())
|
89 |
+
return session
|
90 |
+
|
91 |
+
def get_audio_features(self,
|
92 |
+
mel):
|
93 |
+
mel_input_lengths = torch.tensor(
|
94 |
+
[mel.shape[2] for _ in range(mel.shape[0])],
|
95 |
+
dtype=torch.int32,
|
96 |
+
device=mel.device)
|
97 |
+
if self.encoder_config['plugin_config']['remove_input_padding']:
|
98 |
+
# mel B,D,T -> B,T,D -> BxT, D
|
99 |
+
mel = mel.transpose(1, 2)
|
100 |
+
mel = remove_tensor_padding(mel, mel_input_lengths)
|
101 |
+
|
102 |
+
inputs = OrderedDict()
|
103 |
+
inputs['input_features'] = mel
|
104 |
+
inputs['input_lengths'] = mel_input_lengths
|
105 |
+
|
106 |
+
output_list = [
|
107 |
+
TensorInfo('input_features', str_dtype_to_trt(self.dtype),
|
108 |
+
mel.shape),
|
109 |
+
TensorInfo('input_lengths', str_dtype_to_trt('int32'),
|
110 |
+
mel_input_lengths.shape)
|
111 |
+
]
|
112 |
+
|
113 |
+
output_info = (self.session).infer_shapes(output_list)
|
114 |
+
|
115 |
+
logger.debug(f'output info {output_info}')
|
116 |
+
outputs = {
|
117 |
+
t.name: torch.empty(tuple(t.shape),
|
118 |
+
dtype=trt_dtype_to_torch(t.dtype),
|
119 |
+
device='cuda')
|
120 |
+
for t in output_info
|
121 |
+
}
|
122 |
+
stream = torch.cuda.current_stream()
|
123 |
+
ok = self.session.run(inputs=inputs,
|
124 |
+
outputs=outputs,
|
125 |
+
stream=stream.cuda_stream)
|
126 |
+
assert ok, 'Engine execution failed'
|
127 |
+
stream.synchronize()
|
128 |
+
encoder_output = outputs['encoder_output']
|
129 |
+
encoder_output_lengths = mel_input_lengths // 2
|
130 |
+
|
131 |
+
return encoder_output
|
132 |
+
|
133 |
+
class EncoderProjector(torch.nn.Module):
|
134 |
+
"""
|
135 |
+
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
136 |
+
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
|
137 |
+
Args:
|
138 |
+
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
|
139 |
+
llm_dim (:obj:`int`): The dimension of the language model.
|
140 |
+
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, encoder_dim=1280, llm_dim=1536, downsample_rate=8):
|
144 |
+
super().__init__()
|
145 |
+
self.downsample_rate = downsample_rate
|
146 |
+
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
147 |
+
self.relu = nn.ReLU()
|
148 |
+
self.linear2 = nn.Linear(llm_dim, llm_dim)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
|
152 |
+
batch_size, seq_len, feat_dim = x.size()
|
153 |
+
num_frames_to_discard = seq_len % self.downsample_rate
|
154 |
+
if num_frames_to_discard > 0:
|
155 |
+
x = x[:, :-num_frames_to_discard, :]
|
156 |
+
seq_len = x.size(1)
|
157 |
+
|
158 |
+
x = x.contiguous()
|
159 |
+
x = x.view(
|
160 |
+
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
|
161 |
+
)
|
162 |
+
|
163 |
+
x = self.linear1(x)
|
164 |
+
x = self.relu(x)
|
165 |
+
x = self.linear2(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
class SPEECH_LLM(nn.Module):
|
169 |
+
"""
|
170 |
+
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
|
171 |
+
The encoder is used to extract speech features from the input speech signal.
|
172 |
+
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
|
173 |
+
The language model is used to generate the text from the speech features.
|
174 |
+
Args:
|
175 |
+
encoder (:obj:`nn.Module`): The encoder module.
|
176 |
+
llm (:obj:`nn.Module`): The language model module.
|
177 |
+
encoder_projector (:obj:`nn.Module`): The encoder projector module.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
encoder: nn.Module,
|
183 |
+
llm: nn.Module,
|
184 |
+
encoder_projector: nn.Module,
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
self.encoder = encoder
|
188 |
+
self.llm = llm
|
189 |
+
self.encoder_projector = encoder_projector
|
190 |
+
|
191 |
+
class WhisperTRTLLM(nn.Module):
|
192 |
+
|
193 |
+
def __init__(self, engine_dir):
|
194 |
+
super().__init__()
|
195 |
+
world_size = 1
|
196 |
+
runtime_rank = tensorrt_llm.mpi_rank()
|
197 |
+
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
|
198 |
+
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
199 |
+
engine_dir = Path(engine_dir)
|
200 |
+
|
201 |
+
self.encoder = WhisperEncoding(engine_dir)
|
202 |
+
self.encoder_projector = EncoderProjector()
|
203 |
+
self.encoder_projector = self.encoder_projector.half().to("cuda")
|
204 |
+
|
205 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
206 |
+
"/home/scratch.yuekaiz_wwfo_1/Qwen2_1.5B_merged",
|
207 |
+
attn_implementation="flash_attention_2",
|
208 |
+
torch_dtype=torch.float16,
|
209 |
+
)
|
210 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
211 |
+
tokenizer.padding_side = "left"
|
212 |
+
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
213 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
214 |
+
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
215 |
+
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
216 |
+
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
217 |
+
|
218 |
+
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
219 |
+
DEFAULT_SPEECH_TOKEN
|
220 |
+
)
|
221 |
+
self.llm = llm.half().to("cuda")
|
222 |
+
# print llm embedding layer shape
|
223 |
+
print("llm embedding layer shape", self.llm.get_input_embeddings().weight.shape)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
def process_batch(
|
228 |
+
self,
|
229 |
+
mel,
|
230 |
+
decoder_input_ids=None,
|
231 |
+
eot_id=50257,
|
232 |
+
max_new_tokens=96,
|
233 |
+
num_beams=1):
|
234 |
+
encoder_outputs = self.encoder.get_audio_features(mel)
|
235 |
+
speech_features = self.encoder_projector(encoder_outputs)
|
236 |
+
speech_features = speech_features.to(torch.float16)
|
237 |
+
# [1,187,1536]
|
238 |
+
return speech_features
|
239 |
+
|
240 |
+
|
241 |
+
def decode(
|
242 |
+
self,
|
243 |
+
fbank: torch.Tensor = None,
|
244 |
+
input_ids: torch.LongTensor = None,
|
245 |
+
attention_mask: torch.Tensor = None,
|
246 |
+
**kwargs,
|
247 |
+
):
|
248 |
+
|
249 |
+
encoder_outs = self.encoder.get_audio_features(fbank)
|
250 |
+
speech_features = self.encoder_projector(encoder_outs)
|
251 |
+
speech_features = speech_features.to(torch.float16)
|
252 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
253 |
+
speech_token_index = input_ids.tolist()[0].index(151646)
|
254 |
+
print("speech_token_index", speech_token_index, "speech_features_shape", speech_features.shape, "input_ids_shape", input_ids.shape, "inputs_embeds_shape", inputs_embeds.shape)
|
255 |
+
|
256 |
+
new_length = inputs_embeds.shape[1] + speech_features.shape[1] - 1
|
257 |
+
new_inputs_embeds = torch.zeros(1, new_length, 1536).to(inputs_embeds.device).half()
|
258 |
+
new_inputs_embeds[:, :3, :] = inputs_embeds[:, :3, :]
|
259 |
+
new_inputs_embeds[:, 3:3 + 187, :] = speech_features
|
260 |
+
new_inputs_embeds[:, 3 + 187:, :] = inputs_embeds[:, 4:, :]
|
261 |
+
|
262 |
+
inputs_embeds = new_inputs_embeds
|
263 |
+
generated_ids = self.llm.generate(
|
264 |
+
inputs_embeds=inputs_embeds,
|
265 |
+
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
266 |
+
num_beams=kwargs.get("num_beams", 1),
|
267 |
+
do_sample=kwargs.get("do_sample", False),
|
268 |
+
min_length=kwargs.get("min_length", 1),
|
269 |
+
top_p=kwargs.get("top_p", 1.0),
|
270 |
+
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
271 |
+
length_penalty=kwargs.get("length_penalty", 1.0),
|
272 |
+
temperature=kwargs.get("temperature", 1.0),
|
273 |
+
bos_token_id=self.llm.config.bos_token_id,
|
274 |
+
eos_token_id=self.llm.config.eos_token_id,
|
275 |
+
pad_token_id=self.llm.config.pad_token_id,
|
276 |
+
)
|
277 |
+
|
278 |
+
return generated_ids
|
model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/fbank.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/model.cpython-310.pyc
ADDED
Binary file (10.4 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/whisper_trtllm.cpython-310.pyc
ADDED
Binary file (6.2 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/1/fbank.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from typing import Union
|
19 |
+
import os
|
20 |
+
|
21 |
+
def mel_filters(device, n_mels: int =128) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
24 |
+
Allows decoupling librosa dependency; saved using:
|
25 |
+
|
26 |
+
np.savez_compressed(
|
27 |
+
"mel_filters.npz",
|
28 |
+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
29 |
+
)
|
30 |
+
"""
|
31 |
+
assert n_mels == 80 or n_mels == 128 , f"Unsupported n_mels: {n_mels}"
|
32 |
+
with np.load(
|
33 |
+
os.path.join(os.path.dirname(__file__), "mel_filters.npz")
|
34 |
+
) as f:
|
35 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
36 |
+
|
37 |
+
|
38 |
+
def log_mel_spectrogram(
|
39 |
+
audio: Union[torch.Tensor],
|
40 |
+
filters: torch.Tensor,
|
41 |
+
n_mels: int = 128,
|
42 |
+
n_fft: int = 400,
|
43 |
+
hop_length: int = 160,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Compute the log-Mel spectrogram of
|
47 |
+
|
48 |
+
Parameters
|
49 |
+
----------
|
50 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
51 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
52 |
+
|
53 |
+
n_mels: int
|
54 |
+
The number of Mel-frequency filters, only 80 or 128 is supported
|
55 |
+
|
56 |
+
filters: torch.Tensor
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
torch.Tensor, shape = (128, n_frames)
|
61 |
+
A Tensor that contains the Mel spectrogram
|
62 |
+
"""
|
63 |
+
window = torch.hann_window(n_fft).to(audio.device)
|
64 |
+
stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
|
65 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
66 |
+
|
67 |
+
mel_spec = filters @ magnitudes
|
68 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
69 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
70 |
+
log_spec = (log_spec + 4.0) / 4.0
|
71 |
+
# cast to float 16
|
72 |
+
log_spec = log_spec.half()
|
73 |
+
return log_spec
|
74 |
+
|
75 |
+
class FeatureExtractor(torch.nn.Module):
|
76 |
+
"""Your Python model must use the same class name. Every Python model
|
77 |
+
that is created must have "TritonPythonModel" as the class name.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, n_mels: int = 128):
|
81 |
+
self.device = torch.device("cuda")
|
82 |
+
self.n_mels = n_mels
|
83 |
+
self.filters = mel_filters(self.device, n_mels=self.n_mels)
|
84 |
+
|
85 |
+
def compute_feature(self, wav, target: int = 3000):
|
86 |
+
mel = log_mel_spectrogram(wav, self.filters)
|
87 |
+
assert mel.shape[1] <= target, f"{mel.shape[1]} > {target}, audio is too long"
|
88 |
+
if mel.shape[1] < target:
|
89 |
+
mel = F.pad(mel, (0, target - mel.shape[1]), mode='constant')
|
90 |
+
mel = mel.unsqueeze(0)
|
91 |
+
return mel
|
model_repo_whisper_qwen_trtllm/whisper/1/mel_filters.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
|
3 |
+
size 4271
|
model_repo_whisper_qwen_trtllm/whisper/1/model.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import triton_python_backend_utils as pb_utils
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
7 |
+
import re
|
8 |
+
import transformers
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from typing import Dict
|
11 |
+
from pathlib import Path
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from .whisper_trtllm import WhisperTRTLLM
|
15 |
+
from .fbank import FeatureExtractor
|
16 |
+
|
17 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
18 |
+
def preprocess(
|
19 |
+
messages,
|
20 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
21 |
+
max_len: int = 128,
|
22 |
+
) -> Dict:
|
23 |
+
"""Preprocesses the data for supervised fine-tuning."""
|
24 |
+
texts = []
|
25 |
+
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
26 |
+
for i, msg in enumerate(messages):
|
27 |
+
texts.append(
|
28 |
+
tokenizer.apply_chat_template(
|
29 |
+
msg,
|
30 |
+
tokenize=True,
|
31 |
+
add_generation_prompt=False,
|
32 |
+
chat_template=TEMPLATE,
|
33 |
+
padding="longest",
|
34 |
+
max_length=max_len,
|
35 |
+
truncation=True,
|
36 |
+
)
|
37 |
+
)
|
38 |
+
max_len_texts = max([len(text) for text in texts])
|
39 |
+
if tokenizer.padding_side == "right":
|
40 |
+
texts = [
|
41 |
+
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
42 |
+
for text in texts
|
43 |
+
]
|
44 |
+
else:
|
45 |
+
texts = [
|
46 |
+
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
47 |
+
for text in texts
|
48 |
+
]
|
49 |
+
|
50 |
+
input_ids = torch.tensor(texts, dtype=torch.int)
|
51 |
+
|
52 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
53 |
+
|
54 |
+
return input_ids, attention_mask
|
55 |
+
|
56 |
+
class TritonPythonModel:
|
57 |
+
"""Your Python model must use the same class name. Every Python model
|
58 |
+
that is created must have "TritonPythonModel" as the class name.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def initialize(self, args):
|
62 |
+
"""`initialize` is called only once when the model is being loaded.
|
63 |
+
Implementing `initialize` function is optional. This function allows
|
64 |
+
the model to initialize any state associated with this model.
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
args : dict
|
69 |
+
Both keys and values are strings. The dictionary keys and values are:
|
70 |
+
* model_config: A JSON string containing the model configuration
|
71 |
+
* model_instance_kind: A string containing model instance kind
|
72 |
+
* model_instance_device_id: A string containing model instance device ID
|
73 |
+
* model_repository: Model repository path
|
74 |
+
* model_version: Model version
|
75 |
+
* model_name: Model name
|
76 |
+
"""
|
77 |
+
self.model_config = model_config = json.loads(args['model_config'])
|
78 |
+
|
79 |
+
# Get OUTPUT0 configuration
|
80 |
+
output0_config = pb_utils.get_output_config_by_name(
|
81 |
+
model_config, "TRANSCRIPTS")
|
82 |
+
# Convert Triton types to numpy types
|
83 |
+
self.out0_dtype = pb_utils.triton_string_to_numpy(
|
84 |
+
output0_config['data_type'])
|
85 |
+
|
86 |
+
#self.tokenizer = get_tokenizer(num_languages=100)
|
87 |
+
#self.blank = self.tokenizer.encode(" ", allowed_special=self.tokenizer.special_tokens_set)[0]
|
88 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
89 |
+
tokenizer.padding_side = "left"
|
90 |
+
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
91 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
92 |
+
self.tokenizer = tokenizer
|
93 |
+
self.eos = self.tokenizer.eos_token_id
|
94 |
+
self.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
95 |
+
DEFAULT_SPEECH_TOKEN
|
96 |
+
)
|
97 |
+
self.vocab_size = 151936
|
98 |
+
|
99 |
+
self.device = torch.device("cuda")
|
100 |
+
self.decoupled = False
|
101 |
+
self.logger = pb_utils.Logger
|
102 |
+
self.init_model(self.model_config['parameters'])
|
103 |
+
|
104 |
+
def init_model(self, parameters):
|
105 |
+
for key,value in parameters.items():
|
106 |
+
parameters[key] = value["string_value"]
|
107 |
+
engine_dir = parameters["engine_dir"]
|
108 |
+
n_mels = int(parameters["n_mels"])
|
109 |
+
adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt"
|
110 |
+
checkpoint = torch.load(
|
111 |
+
adapter_dir, map_location="cpu"
|
112 |
+
)
|
113 |
+
self.model = WhisperTRTLLM(engine_dir)
|
114 |
+
missing_keys, _ = self.model.load_state_dict(checkpoint, strict=False)
|
115 |
+
print(f"Missing keys: {missing_keys}")
|
116 |
+
self.feature_extractor = FeatureExtractor(n_mels=n_mels)
|
117 |
+
|
118 |
+
def _tokenize(self, prompt=None, num_speech_tokens=187):
|
119 |
+
if prompt is None:
|
120 |
+
prompts = [
|
121 |
+
[
|
122 |
+
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
123 |
+
{"role": "assistant", "content": ""},
|
124 |
+
]
|
125 |
+
]
|
126 |
+
|
127 |
+
input_ids, _ = preprocess(prompts, self.tokenizer, max_len=128)
|
128 |
+
print(444444444444444, input_ids)
|
129 |
+
input_ids = input_ids.tolist()[0]
|
130 |
+
speech_token_index = input_ids.index(self.default_speech_token_id)
|
131 |
+
# replace 151646 with list(range(self.vocab_size, self.vocab_size + num_speech_tokens))
|
132 |
+
prompt_ids = input_ids[:speech_token_index] + list(range(self.vocab_size, self.vocab_size + num_speech_tokens)) + input_ids[speech_token_index + 1:]
|
133 |
+
print(prompt_ids)
|
134 |
+
return prompt_ids
|
135 |
+
|
136 |
+
def _prepare_inputs(self, request, speech_embeddings, input_ids):
|
137 |
+
"""
|
138 |
+
Prepares inputs for the language model based on the parameters in the
|
139 |
+
request, image features, and prompt. It tokenizes prompt,
|
140 |
+
extracts and processes additional parameters from the request:
|
141 |
+
- max_tokens: Maximum number of tokens to generate (default: 50)
|
142 |
+
- temperature: Controls randomness in generation (default: 0.5)
|
143 |
+
- top_k: Top K sampling parameter (default: 1)
|
144 |
+
- frequency_penalty: Penalizes frequent tokens (default: 0.7)
|
145 |
+
- seed: Random seed for generation (default: 10)
|
146 |
+
|
147 |
+
Final llm input dictionary is combined out of all processed parameters,
|
148 |
+
prompt's tokens and image features. The latter will be passed to llm
|
149 |
+
through `prompt_embedding_table`.
|
150 |
+
|
151 |
+
Parameters
|
152 |
+
----------
|
153 |
+
- request: The original request object containing additional parameters.
|
154 |
+
- image_features (list): A list containing image feature tensors.
|
155 |
+
- prompt (str): The text prompt to be processed.
|
156 |
+
|
157 |
+
Returns
|
158 |
+
-------
|
159 |
+
- dict: A dictionary containing all the prepared inputs for the language model.
|
160 |
+
"""
|
161 |
+
input_ids = np.array(input_ids, dtype=np.int32)
|
162 |
+
max_tokens = 50
|
163 |
+
input_len = input_ids.shape[0]
|
164 |
+
print(4555555555, speech_embeddings.shape)
|
165 |
+
assert speech_embeddings.shape[1] == 187, "Only support 187 speech tokens"
|
166 |
+
embedding_args = {
|
167 |
+
"prompt_vocab_size": np.array(
|
168 |
+
[[speech_embeddings.shape[1]]], dtype=np.int32
|
169 |
+
),
|
170 |
+
"prompt_embedding_table": speech_embeddings.detach().cpu().numpy(),
|
171 |
+
}
|
172 |
+
|
173 |
+
input_dict = {
|
174 |
+
"input_ids": np.expand_dims(input_ids, 0),
|
175 |
+
"input_lengths": np.array([[input_len]], dtype=np.int32),
|
176 |
+
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
177 |
+
"end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
|
178 |
+
"streaming": np.array([[0]], dtype=np.bool_),
|
179 |
+
**embedding_args,
|
180 |
+
}
|
181 |
+
|
182 |
+
input_tensor_list = [pb_utils.Tensor(k, v) for k, v in input_dict.items()]
|
183 |
+
return input_tensor_list
|
184 |
+
|
185 |
+
def _prepare_llm_response(self, llm_request_inputs):
|
186 |
+
"""
|
187 |
+
Prepares the response from the language model based on the provided
|
188 |
+
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
189 |
+
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
190 |
+
For each response from the language model:
|
191 |
+
- Checks for errors and raise an exception if any are found.
|
192 |
+
- Extracts the "output_ids" tensor from the response.
|
193 |
+
- Determines the finish reason based on the presence of the
|
194 |
+
end-of-sequence token or reaching the maximum length.
|
195 |
+
- Appends the generated token IDs to `output_ids`.
|
196 |
+
- If the finish reason is determined, decodes the output IDs to text
|
197 |
+
and prepares the final response.
|
198 |
+
|
199 |
+
The final response includes the generated text, finish reason,
|
200 |
+
completion tokens, prompt tokens, and total tokens.
|
201 |
+
|
202 |
+
Parameters
|
203 |
+
----------
|
204 |
+
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
205 |
+
|
206 |
+
Returns
|
207 |
+
-------
|
208 |
+
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
209 |
+
"""
|
210 |
+
|
211 |
+
llm_request = pb_utils.InferenceRequest(
|
212 |
+
model_name="tensorrt_llm",
|
213 |
+
requested_output_names=["output_ids", "sequence_length"],
|
214 |
+
inputs=llm_request_inputs,
|
215 |
+
)
|
216 |
+
output_ids, output_len = [], 0
|
217 |
+
responses = llm_request.exec(decoupled=False)
|
218 |
+
responses = [responses]
|
219 |
+
for llm_response in responses:
|
220 |
+
if llm_response.has_error():
|
221 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
222 |
+
stream_output_ids = (
|
223 |
+
pb_utils.get_output_tensor_by_name(llm_response, "output_ids")
|
224 |
+
.as_numpy()
|
225 |
+
.flatten()
|
226 |
+
.tolist()
|
227 |
+
)
|
228 |
+
finish_reason = "test"
|
229 |
+
if len(stream_output_ids) == 0 or (
|
230 |
+
len(stream_output_ids) != 0
|
231 |
+
and stream_output_ids[-1] == self.eos
|
232 |
+
):
|
233 |
+
finish_reason = "stop"
|
234 |
+
|
235 |
+
output_ids += stream_output_ids
|
236 |
+
|
237 |
+
last_response = finish_reason != ""
|
238 |
+
output_len = len(output_ids)
|
239 |
+
if last_response:
|
240 |
+
print(output_ids)
|
241 |
+
output_text = self.tokenizer.decode(output_ids).strip()
|
242 |
+
# print(output_text)
|
243 |
+
# output_text = re.sub(r'<\|.*?\|>', '', output_text)
|
244 |
+
response = pb_utils.InferenceResponse(
|
245 |
+
output_tensors=[
|
246 |
+
pb_utils.Tensor("TRANSCRIPTS", np.array([output_text], np.object_)),
|
247 |
+
]
|
248 |
+
)
|
249 |
+
yield response
|
250 |
+
|
251 |
+
def _extract_speech_embeddings(self, mel):
|
252 |
+
return self.model.process_batch(mel)
|
253 |
+
|
254 |
+
|
255 |
+
def execute(self, requests):
|
256 |
+
|
257 |
+
responses = []
|
258 |
+
|
259 |
+
for request in requests:
|
260 |
+
wav = pb_utils.get_input_tensor_by_name(request, "WAV").as_numpy()
|
261 |
+
assert wav.shape[0] == 1, "Only support batch size 1"
|
262 |
+
# To support batch > 1
|
263 |
+
# cat mel,text_prompt, also, need to increase decoder_input_len as a triton input
|
264 |
+
wav = torch.from_numpy(wav[0]).to(self.device)
|
265 |
+
# mel shape [1, 80, 3000] for remove_input_padding=False
|
266 |
+
mel = self.feature_extractor.compute_feature(wav)
|
267 |
+
|
268 |
+
speech_embeddings = self._extract_speech_embeddings(mel)
|
269 |
+
print(speech_embeddings.shape)
|
270 |
+
input_ids = self._tokenize()
|
271 |
+
print(input_ids)
|
272 |
+
|
273 |
+
if self.decoupled:
|
274 |
+
response_sender = request.get_response_sender()
|
275 |
+
try:
|
276 |
+
|
277 |
+
llm_request_inputs = self._prepare_inputs(
|
278 |
+
request, speech_embeddings, input_ids
|
279 |
+
)
|
280 |
+
if isinstance(llm_request_inputs, pb_utils.TritonError):
|
281 |
+
error = pb_utils.InferenceResponse(error=llm_request_inputs)
|
282 |
+
if self.decoupled:
|
283 |
+
response_sender.send(
|
284 |
+
error, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
|
285 |
+
)
|
286 |
+
else:
|
287 |
+
responses.append(error)
|
288 |
+
llm_responses = self._prepare_llm_response(llm_request_inputs)
|
289 |
+
|
290 |
+
for triton_response in llm_responses:
|
291 |
+
if self.decoupled:
|
292 |
+
response_sender.send(triton_response)
|
293 |
+
else:
|
294 |
+
responses.append(triton_response)
|
295 |
+
|
296 |
+
if self.decoupled:
|
297 |
+
response_sender.send(
|
298 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
299 |
+
|
300 |
+
except Exception:
|
301 |
+
self.logger.log_error(traceback.format_exc())
|
302 |
+
# If encountering an error, send a response with err msg
|
303 |
+
error_response = pb_utils.InferenceResponse(
|
304 |
+
output_tensors=[],
|
305 |
+
error=pb_utils.TritonError(traceback.format_exc()))
|
306 |
+
|
307 |
+
if self.decoupled:
|
308 |
+
response_sender.send(error_response)
|
309 |
+
response_sender.send(
|
310 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
311 |
+
else:
|
312 |
+
responses.append(error_response)
|
313 |
+
|
314 |
+
if self.decoupled:
|
315 |
+
return None
|
316 |
+
else:
|
317 |
+
assert len(responses) == len(requests)
|
318 |
+
return responses
|
model_repo_whisper_qwen_trtllm/whisper/1/whisper_trtllm.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import json
|
16 |
+
from collections import OrderedDict
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import tensorrt_llm
|
23 |
+
import tensorrt_llm.logger as logger
|
24 |
+
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
|
25 |
+
trt_dtype_to_torch)
|
26 |
+
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
|
27 |
+
from tensorrt_llm.runtime.session import Session, TensorInfo
|
28 |
+
|
29 |
+
def remove_tensor_padding(input_tensor, input_tensor_lengths=None, pad_value=0):
|
30 |
+
if input_tensor.dim() == 2:
|
31 |
+
# Text tensor case: batch, seq_len
|
32 |
+
assert torch.all(
|
33 |
+
input_tensor[:, 0] != pad_value
|
34 |
+
), "First token in each sequence should not be pad_value"
|
35 |
+
assert input_tensor_lengths is None
|
36 |
+
|
37 |
+
# Create a mask for all non-pad tokens
|
38 |
+
mask = input_tensor != pad_value
|
39 |
+
|
40 |
+
# Apply the mask to input_tensor to remove pad tokens
|
41 |
+
output_tensor = input_tensor[mask].view(1, -1)
|
42 |
+
|
43 |
+
elif input_tensor.dim() == 3:
|
44 |
+
# Audio tensor case: batch, seq_len, feature_len
|
45 |
+
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
|
46 |
+
batch_size, seq_len, feature_len = input_tensor.shape
|
47 |
+
|
48 |
+
# Initialize a list to collect valid sequences
|
49 |
+
valid_sequences = []
|
50 |
+
|
51 |
+
for i in range(batch_size):
|
52 |
+
valid_length = input_tensor_lengths[i]
|
53 |
+
valid_sequences.append(input_tensor[i, :valid_length, :])
|
54 |
+
|
55 |
+
# Concatenate all valid sequences along the batch dimension
|
56 |
+
output_tensor = torch.cat(valid_sequences, dim=0)
|
57 |
+
|
58 |
+
else:
|
59 |
+
raise ValueError("Input tensor must have 2 or 3 dimensions")
|
60 |
+
|
61 |
+
return output_tensor
|
62 |
+
|
63 |
+
def read_config(component, engine_dir):
|
64 |
+
config_path = engine_dir / component / 'config.json'
|
65 |
+
with open(config_path, 'r') as f:
|
66 |
+
config = json.load(f)
|
67 |
+
model_config = OrderedDict()
|
68 |
+
model_config.update(config['pretrained_config'])
|
69 |
+
model_config.update(config['build_config'])
|
70 |
+
return model_config
|
71 |
+
|
72 |
+
class WhisperEncoding:
|
73 |
+
def __init__(self, engine_dir):
|
74 |
+
self.session = self.get_session(engine_dir)
|
75 |
+
config = read_config('encoder', engine_dir)
|
76 |
+
self.n_mels = config['n_mels']
|
77 |
+
self.dtype = config['dtype']
|
78 |
+
self.num_languages = config['num_languages']
|
79 |
+
self.encoder_config = config
|
80 |
+
|
81 |
+
def get_session(self, engine_dir):
|
82 |
+
serialize_path = engine_dir / 'encoder' / 'rank0.engine'
|
83 |
+
with open(serialize_path, 'rb') as f:
|
84 |
+
session = Session.from_serialized_engine(f.read())
|
85 |
+
return session
|
86 |
+
|
87 |
+
def get_audio_features(self,
|
88 |
+
mel):
|
89 |
+
mel_input_lengths = torch.tensor(
|
90 |
+
[mel.shape[2] for _ in range(mel.shape[0])],
|
91 |
+
dtype=torch.int32,
|
92 |
+
device=mel.device)
|
93 |
+
if self.encoder_config['plugin_config']['remove_input_padding']:
|
94 |
+
# mel B,D,T -> B,T,D -> BxT, D
|
95 |
+
mel = mel.transpose(1, 2)
|
96 |
+
mel = remove_tensor_padding(mel, mel_input_lengths)
|
97 |
+
|
98 |
+
inputs = OrderedDict()
|
99 |
+
inputs['input_features'] = mel
|
100 |
+
inputs['input_lengths'] = mel_input_lengths
|
101 |
+
|
102 |
+
output_list = [
|
103 |
+
TensorInfo('input_features', str_dtype_to_trt(self.dtype),
|
104 |
+
mel.shape),
|
105 |
+
TensorInfo('input_lengths', str_dtype_to_trt('int32'),
|
106 |
+
mel_input_lengths.shape)
|
107 |
+
]
|
108 |
+
|
109 |
+
output_info = (self.session).infer_shapes(output_list)
|
110 |
+
|
111 |
+
logger.debug(f'output info {output_info}')
|
112 |
+
outputs = {
|
113 |
+
t.name: torch.empty(tuple(t.shape),
|
114 |
+
dtype=trt_dtype_to_torch(t.dtype),
|
115 |
+
device='cuda')
|
116 |
+
for t in output_info
|
117 |
+
}
|
118 |
+
stream = torch.cuda.current_stream()
|
119 |
+
ok = self.session.run(inputs=inputs,
|
120 |
+
outputs=outputs,
|
121 |
+
stream=stream.cuda_stream)
|
122 |
+
assert ok, 'Engine execution failed'
|
123 |
+
stream.synchronize()
|
124 |
+
encoder_output = outputs['encoder_output']
|
125 |
+
encoder_output_lengths = mel_input_lengths // 2
|
126 |
+
|
127 |
+
return encoder_output
|
128 |
+
|
129 |
+
class EncoderProjector(torch.nn.Module):
|
130 |
+
"""
|
131 |
+
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
132 |
+
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
|
133 |
+
Args:
|
134 |
+
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
|
135 |
+
llm_dim (:obj:`int`): The dimension of the language model.
|
136 |
+
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __init__(self, encoder_dim=1280, llm_dim=1536, downsample_rate=8):
|
140 |
+
super().__init__()
|
141 |
+
self.downsample_rate = downsample_rate
|
142 |
+
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
143 |
+
self.relu = nn.ReLU()
|
144 |
+
self.linear2 = nn.Linear(llm_dim, llm_dim)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
|
148 |
+
batch_size, seq_len, feat_dim = x.size()
|
149 |
+
num_frames_to_discard = seq_len % self.downsample_rate
|
150 |
+
if num_frames_to_discard > 0:
|
151 |
+
x = x[:, :-num_frames_to_discard, :]
|
152 |
+
seq_len = x.size(1)
|
153 |
+
|
154 |
+
x = x.contiguous()
|
155 |
+
x = x.view(
|
156 |
+
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
|
157 |
+
)
|
158 |
+
|
159 |
+
x = self.linear1(x)
|
160 |
+
x = self.relu(x)
|
161 |
+
x = self.linear2(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
# class SPEECH_LLM(nn.Module):
|
165 |
+
# """
|
166 |
+
# The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
|
167 |
+
# The encoder is used to extract speech features from the input speech signal.
|
168 |
+
# The encoder projector is used to project the encoder outputs to the same dimension as the language model.
|
169 |
+
# The language model is used to generate the text from the speech features.
|
170 |
+
# Args:
|
171 |
+
# encoder (:obj:`nn.Module`): The encoder module.
|
172 |
+
# llm (:obj:`nn.Module`): The language model module.
|
173 |
+
# encoder_projector (:obj:`nn.Module`): The encoder projector module.
|
174 |
+
# """
|
175 |
+
|
176 |
+
# def __init__(
|
177 |
+
# self,
|
178 |
+
# encoder: nn.Module = None,
|
179 |
+
# llm: nn.Module = None,
|
180 |
+
# encoder_projector: nn.Module = None,
|
181 |
+
# ):
|
182 |
+
# super().__init__()
|
183 |
+
# self.encoder = encoder
|
184 |
+
# self.llm = llm
|
185 |
+
# self.encoder_projector = encoder_projector
|
186 |
+
|
187 |
+
class WhisperTRTLLM(nn.Module):
|
188 |
+
|
189 |
+
def __init__(self, engine_dir):
|
190 |
+
super().__init__()
|
191 |
+
world_size = 1
|
192 |
+
runtime_rank = tensorrt_llm.mpi_rank()
|
193 |
+
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
|
194 |
+
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
195 |
+
engine_dir = Path(engine_dir)
|
196 |
+
|
197 |
+
self.encoder = WhisperEncoding(engine_dir)
|
198 |
+
self.encoder_projector = EncoderProjector()
|
199 |
+
self.encoder_projector = self.encoder_projector.half().to("cuda")
|
200 |
+
|
201 |
+
def process_batch(
|
202 |
+
self,
|
203 |
+
mel,
|
204 |
+
decoder_input_ids=None,
|
205 |
+
eot_id=50257,
|
206 |
+
max_new_tokens=96,
|
207 |
+
num_beams=1):
|
208 |
+
encoder_outputs = self.encoder.get_audio_features(mel)
|
209 |
+
speech_features = self.encoder_projector(encoder_outputs)
|
210 |
+
speech_features = speech_features.to(torch.float16)
|
211 |
+
print(2333333333333, speech_features.shape)
|
212 |
+
return speech_features
|
model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/fbank.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/model.cpython-310.pyc
ADDED
Binary file (10.4 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/whisper_trtllm.cpython-310.pyc
ADDED
Binary file (7.37 kB). View file
|
|
model_repo_whisper_qwen_trtllm/whisper/2/fbank.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from typing import Union
|
19 |
+
import os
|
20 |
+
|
21 |
+
def mel_filters(device, n_mels: int =128) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
24 |
+
Allows decoupling librosa dependency; saved using:
|
25 |
+
|
26 |
+
np.savez_compressed(
|
27 |
+
"mel_filters.npz",
|
28 |
+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
29 |
+
)
|
30 |
+
"""
|
31 |
+
assert n_mels == 80 or n_mels == 128 , f"Unsupported n_mels: {n_mels}"
|
32 |
+
with np.load(
|
33 |
+
os.path.join(os.path.dirname(__file__), "mel_filters.npz")
|
34 |
+
) as f:
|
35 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
36 |
+
|
37 |
+
|
38 |
+
def log_mel_spectrogram(
|
39 |
+
audio: Union[torch.Tensor],
|
40 |
+
filters: torch.Tensor,
|
41 |
+
n_mels: int = 128,
|
42 |
+
n_fft: int = 400,
|
43 |
+
hop_length: int = 160,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Compute the log-Mel spectrogram of
|
47 |
+
|
48 |
+
Parameters
|
49 |
+
----------
|
50 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
51 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
52 |
+
|
53 |
+
n_mels: int
|
54 |
+
The number of Mel-frequency filters, only 80 or 128 is supported
|
55 |
+
|
56 |
+
filters: torch.Tensor
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
torch.Tensor, shape = (128, n_frames)
|
61 |
+
A Tensor that contains the Mel spectrogram
|
62 |
+
"""
|
63 |
+
window = torch.hann_window(n_fft).to(audio.device)
|
64 |
+
stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
|
65 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
66 |
+
|
67 |
+
mel_spec = filters @ magnitudes
|
68 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
69 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
70 |
+
log_spec = (log_spec + 4.0) / 4.0
|
71 |
+
# cast to float 16
|
72 |
+
log_spec = log_spec.half()
|
73 |
+
return log_spec
|
74 |
+
|
75 |
+
class FeatureExtractor(torch.nn.Module):
|
76 |
+
"""Your Python model must use the same class name. Every Python model
|
77 |
+
that is created must have "TritonPythonModel" as the class name.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, n_mels: int = 128):
|
81 |
+
self.device = torch.device("cuda")
|
82 |
+
self.n_mels = n_mels
|
83 |
+
self.filters = mel_filters(self.device, n_mels=self.n_mels)
|
84 |
+
|
85 |
+
def compute_feature(self, wav, target: int = 3000):
|
86 |
+
mel = log_mel_spectrogram(wav, self.filters)
|
87 |
+
assert mel.shape[1] <= target, f"{mel.shape[1]} > {target}, audio is too long"
|
88 |
+
if mel.shape[1] < target:
|
89 |
+
mel = F.pad(mel, (0, target - mel.shape[1]), mode='constant')
|
90 |
+
mel = mel.unsqueeze(0)
|
91 |
+
return mel
|
model_repo_whisper_qwen_trtllm/whisper/2/mel_filters.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
|
3 |
+
size 4271
|
model_repo_whisper_qwen_trtllm/whisper/2/model.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import triton_python_backend_utils as pb_utils
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
7 |
+
import re
|
8 |
+
import transformers
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from typing import Dict
|
11 |
+
from pathlib import Path
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from .whisper_trtllm import WhisperTRTLLM
|
15 |
+
from .fbank import FeatureExtractor
|
16 |
+
|
17 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
18 |
+
def preprocess(
|
19 |
+
messages,
|
20 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
21 |
+
max_len: int = 128,
|
22 |
+
) -> Dict:
|
23 |
+
"""Preprocesses the data for supervised fine-tuning."""
|
24 |
+
texts = []
|
25 |
+
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
26 |
+
for i, msg in enumerate(messages):
|
27 |
+
texts.append(
|
28 |
+
tokenizer.apply_chat_template(
|
29 |
+
msg,
|
30 |
+
tokenize=True,
|
31 |
+
add_generation_prompt=False,
|
32 |
+
chat_template=TEMPLATE,
|
33 |
+
padding="longest",
|
34 |
+
max_length=max_len,
|
35 |
+
truncation=True,
|
36 |
+
)
|
37 |
+
)
|
38 |
+
max_len_texts = max([len(text) for text in texts])
|
39 |
+
if tokenizer.padding_side == "right":
|
40 |
+
texts = [
|
41 |
+
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
42 |
+
for text in texts
|
43 |
+
]
|
44 |
+
else:
|
45 |
+
texts = [
|
46 |
+
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
47 |
+
for text in texts
|
48 |
+
]
|
49 |
+
|
50 |
+
input_ids = torch.tensor(texts, dtype=torch.int)
|
51 |
+
|
52 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
53 |
+
|
54 |
+
return input_ids, attention_mask
|
55 |
+
|
56 |
+
class TritonPythonModel:
|
57 |
+
"""Your Python model must use the same class name. Every Python model
|
58 |
+
that is created must have "TritonPythonModel" as the class name.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def initialize(self, args):
|
62 |
+
"""`initialize` is called only once when the model is being loaded.
|
63 |
+
Implementing `initialize` function is optional. This function allows
|
64 |
+
the model to initialize any state associated with this model.
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
args : dict
|
69 |
+
Both keys and values are strings. The dictionary keys and values are:
|
70 |
+
* model_config: A JSON string containing the model configuration
|
71 |
+
* model_instance_kind: A string containing model instance kind
|
72 |
+
* model_instance_device_id: A string containing model instance device ID
|
73 |
+
* model_repository: Model repository path
|
74 |
+
* model_version: Model version
|
75 |
+
* model_name: Model name
|
76 |
+
"""
|
77 |
+
self.model_config = model_config = json.loads(args['model_config'])
|
78 |
+
|
79 |
+
# Get OUTPUT0 configuration
|
80 |
+
output0_config = pb_utils.get_output_config_by_name(
|
81 |
+
model_config, "TRANSCRIPTS")
|
82 |
+
# Convert Triton types to numpy types
|
83 |
+
self.out0_dtype = pb_utils.triton_string_to_numpy(
|
84 |
+
output0_config['data_type'])
|
85 |
+
|
86 |
+
#self.tokenizer = get_tokenizer(num_languages=100)
|
87 |
+
#self.blank = self.tokenizer.encode(" ", allowed_special=self.tokenizer.special_tokens_set)[0]
|
88 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
89 |
+
tokenizer.padding_side = "left"
|
90 |
+
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
91 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
92 |
+
self.tokenizer = tokenizer
|
93 |
+
self.eos = self.tokenizer.eos_token_id
|
94 |
+
self.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
95 |
+
DEFAULT_SPEECH_TOKEN
|
96 |
+
)
|
97 |
+
self.vocab_size = 151936
|
98 |
+
# self.vocab_size = 500000
|
99 |
+
# self.vocab_size = 160000
|
100 |
+
|
101 |
+
self.device = torch.device("cuda")
|
102 |
+
self.decoupled = False
|
103 |
+
self.logger = pb_utils.Logger
|
104 |
+
self.init_model(self.model_config['parameters'])
|
105 |
+
|
106 |
+
def init_model(self, parameters):
|
107 |
+
for key,value in parameters.items():
|
108 |
+
parameters[key] = value["string_value"]
|
109 |
+
engine_dir = parameters["engine_dir"]
|
110 |
+
n_mels = int(parameters["n_mels"])
|
111 |
+
adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt"
|
112 |
+
checkpoint = torch.load(
|
113 |
+
adapter_dir, map_location="cpu"
|
114 |
+
)
|
115 |
+
self.model = WhisperTRTLLM(engine_dir)
|
116 |
+
missing_keys, _ = self.model.load_state_dict(checkpoint, strict=False)
|
117 |
+
# print(f"Missing keys: {missing_keys}")
|
118 |
+
self.feature_extractor = FeatureExtractor(n_mels=n_mels)
|
119 |
+
|
120 |
+
def _tokenize(self, prompt=None, num_speech_tokens=187):
|
121 |
+
if prompt is None:
|
122 |
+
prompts = [
|
123 |
+
[
|
124 |
+
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
125 |
+
{"role": "assistant", "content": ""},
|
126 |
+
]
|
127 |
+
]
|
128 |
+
# prompts = [
|
129 |
+
# [
|
130 |
+
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}你好,你是谁?"},
|
131 |
+
# {"role": "assistant", "content": ""},
|
132 |
+
# ]
|
133 |
+
# ]
|
134 |
+
|
135 |
+
input_ids, _ = preprocess(prompts, self.tokenizer, max_len=128)
|
136 |
+
input_ids = input_ids.tolist()[0]
|
137 |
+
speech_token_index = input_ids.index(self.default_speech_token_id)
|
138 |
+
# replace 151646 with list(range(self.vocab_size, self.vocab_size + num_speech_tokens))
|
139 |
+
prompt_ids = input_ids[:speech_token_index] + list(range(self.vocab_size, self.vocab_size + num_speech_tokens)) + input_ids[speech_token_index + 1:]
|
140 |
+
# prompt_ids = input_ids[:speech_token_index] + input_ids[speech_token_index + 1:]
|
141 |
+
return prompt_ids
|
142 |
+
|
143 |
+
def _prepare_inputs(self, request, speech_embeddings, input_ids):
|
144 |
+
"""
|
145 |
+
Prepares inputs for the language model based on the parameters in the
|
146 |
+
request, image features, and prompt. It tokenizes prompt,
|
147 |
+
extracts and processes additional parameters from the request:
|
148 |
+
- max_tokens: Maximum number of tokens to generate (default: 50)
|
149 |
+
- temperature: Controls randomness in generation (default: 0.5)
|
150 |
+
- top_k: Top K sampling parameter (default: 1)
|
151 |
+
- frequency_penalty: Penalizes frequent tokens (default: 0.7)
|
152 |
+
- seed: Random seed for generation (default: 10)
|
153 |
+
|
154 |
+
Final llm input dictionary is combined out of all processed parameters,
|
155 |
+
prompt's tokens and image features. The latter will be passed to llm
|
156 |
+
through `prompt_embedding_table`.
|
157 |
+
|
158 |
+
Parameters
|
159 |
+
----------
|
160 |
+
- request: The original request object containing additional parameters.
|
161 |
+
- image_features (list): A list containing image feature tensors.
|
162 |
+
- prompt (str): The text prompt to be processed.
|
163 |
+
|
164 |
+
Returns
|
165 |
+
-------
|
166 |
+
- dict: A dictionary containing all the prepared inputs for the language model.
|
167 |
+
"""
|
168 |
+
input_ids = np.array(input_ids, dtype=np.int32)
|
169 |
+
max_tokens = 200
|
170 |
+
input_len = input_ids.shape[0]
|
171 |
+
|
172 |
+
assert speech_embeddings.shape[1] == 187, "Only support 187 speech tokens"
|
173 |
+
embedding_args = {
|
174 |
+
"prompt_vocab_size": np.array(
|
175 |
+
[[speech_embeddings.shape[1]]], dtype=np.int32
|
176 |
+
),
|
177 |
+
"prompt_embedding_table": speech_embeddings.detach().cpu().numpy(),
|
178 |
+
}
|
179 |
+
# TODO: 加不加这个出来的结果一样??? input_ids 超过最大 vocab 也不会报错???
|
180 |
+
input_dict = {
|
181 |
+
"input_ids": np.expand_dims(input_ids, 0),
|
182 |
+
"input_lengths": np.array([[input_len]], dtype=np.int32),
|
183 |
+
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
184 |
+
"runtime_top_k": np.array([[1]], dtype=np.int32),
|
185 |
+
"end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
|
186 |
+
"pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
|
187 |
+
"streaming": np.array([[0]], dtype=np.bool_),
|
188 |
+
**embedding_args,
|
189 |
+
}
|
190 |
+
|
191 |
+
# print(input_ids)
|
192 |
+
# for key, value in input_dict.items():
|
193 |
+
# print(key, value.shape)
|
194 |
+
|
195 |
+
input_tensor_list = [pb_utils.Tensor(k, v) for k, v in input_dict.items()]
|
196 |
+
return input_tensor_list
|
197 |
+
|
198 |
+
def _prepare_llm_response(self, llm_request_inputs):
|
199 |
+
"""
|
200 |
+
Prepares the response from the language model based on the provided
|
201 |
+
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
202 |
+
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
203 |
+
For each response from the language model:
|
204 |
+
- Checks for errors and raise an exception if any are found.
|
205 |
+
- Extracts the "output_ids" tensor from the response.
|
206 |
+
- Determines the finish reason based on the presence of the
|
207 |
+
end-of-sequence token or reaching the maximum length.
|
208 |
+
- Appends the generated token IDs to `output_ids`.
|
209 |
+
- If the finish reason is determined, decodes the output IDs to text
|
210 |
+
and prepares the final response.
|
211 |
+
|
212 |
+
The final response includes the generated text, finish reason,
|
213 |
+
completion tokens, prompt tokens, and total tokens.
|
214 |
+
|
215 |
+
Parameters
|
216 |
+
----------
|
217 |
+
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
218 |
+
|
219 |
+
Returns
|
220 |
+
-------
|
221 |
+
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
222 |
+
"""
|
223 |
+
|
224 |
+
llm_request = pb_utils.InferenceRequest(
|
225 |
+
model_name="tensorrt_llm",
|
226 |
+
requested_output_names=["output_ids", "sequence_length"],
|
227 |
+
inputs=llm_request_inputs,
|
228 |
+
)
|
229 |
+
output_ids, output_len = [], 0
|
230 |
+
responses = llm_request.exec(decoupled=False)
|
231 |
+
responses = [responses]
|
232 |
+
for llm_response in responses:
|
233 |
+
if llm_response.has_error():
|
234 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
235 |
+
stream_output_ids = (
|
236 |
+
pb_utils.get_output_tensor_by_name(llm_response, "output_ids")
|
237 |
+
.as_numpy()
|
238 |
+
.flatten()
|
239 |
+
.tolist()
|
240 |
+
)
|
241 |
+
finish_reason = "test"
|
242 |
+
if len(stream_output_ids) == 0 or (
|
243 |
+
len(stream_output_ids) != 0
|
244 |
+
and stream_output_ids[-1] == self.eos
|
245 |
+
):
|
246 |
+
finish_reason = "stop"
|
247 |
+
|
248 |
+
output_ids += stream_output_ids
|
249 |
+
|
250 |
+
last_response = finish_reason != ""
|
251 |
+
output_len = len(output_ids)
|
252 |
+
if last_response:
|
253 |
+
print("final_output_ids", output_ids)
|
254 |
+
output_text = self.tokenizer.decode(output_ids).strip()
|
255 |
+
# print(output_text)
|
256 |
+
# output_text = re.sub(r'<\|.*?\|>', '', output_text)
|
257 |
+
response = pb_utils.InferenceResponse(
|
258 |
+
output_tensors=[
|
259 |
+
pb_utils.Tensor("TRANSCRIPTS", np.array([output_text], np.object_)),
|
260 |
+
]
|
261 |
+
)
|
262 |
+
yield response
|
263 |
+
|
264 |
+
def _extract_speech_embeddings(self, mel):
|
265 |
+
return self.model.process_batch(mel)
|
266 |
+
|
267 |
+
|
268 |
+
def execute(self, requests):
|
269 |
+
|
270 |
+
responses = []
|
271 |
+
|
272 |
+
for request in requests:
|
273 |
+
wav = pb_utils.get_input_tensor_by_name(request, "WAV").as_numpy()
|
274 |
+
assert wav.shape[0] == 1, "Only support batch size 1"
|
275 |
+
# To support batch > 1
|
276 |
+
# cat mel,text_prompt, also, need to increase decoder_input_len as a triton input
|
277 |
+
wav = torch.from_numpy(wav[0]).to(self.device)
|
278 |
+
# mel shape [1, 80, 3000] for remove_input_padding=False
|
279 |
+
mel = self.feature_extractor.compute_feature(wav)
|
280 |
+
# print("==========================================================")
|
281 |
+
# messages = [
|
282 |
+
# [
|
283 |
+
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
284 |
+
# {"role": "assistant", "content": ""},
|
285 |
+
# ]
|
286 |
+
# ] * len(mel)
|
287 |
+
|
288 |
+
# input_ids, attention_mask = preprocess(messages, self.tokenizer, max_len=128)
|
289 |
+
|
290 |
+
# generated_ids = self.model.decode(
|
291 |
+
# mel, input_ids.to(self.device, dtype=torch.long), attention_mask.to(self.device)
|
292 |
+
# )
|
293 |
+
# print("pytorch model", generated_ids)
|
294 |
+
# print("--------------------------------------------------------------------------")
|
295 |
+
|
296 |
+
|
297 |
+
speech_embeddings = self._extract_speech_embeddings(mel)
|
298 |
+
input_ids = self._tokenize()
|
299 |
+
|
300 |
+
|
301 |
+
if self.decoupled:
|
302 |
+
response_sender = request.get_response_sender()
|
303 |
+
try:
|
304 |
+
|
305 |
+
llm_request_inputs = self._prepare_inputs(
|
306 |
+
request, speech_embeddings, input_ids
|
307 |
+
)
|
308 |
+
if isinstance(llm_request_inputs, pb_utils.TritonError):
|
309 |
+
error = pb_utils.InferenceResponse(error=llm_request_inputs)
|
310 |
+
if self.decoupled:
|
311 |
+
response_sender.send(
|
312 |
+
error, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
responses.append(error)
|
316 |
+
llm_responses = self._prepare_llm_response(llm_request_inputs)
|
317 |
+
|
318 |
+
for triton_response in llm_responses:
|
319 |
+
if self.decoupled:
|
320 |
+
response_sender.send(triton_response)
|
321 |
+
else:
|
322 |
+
responses.append(triton_response)
|
323 |
+
|
324 |
+
if self.decoupled:
|
325 |
+
response_sender.send(
|
326 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
327 |
+
|
328 |
+
except Exception:
|
329 |
+
self.logger.log_error(traceback.format_exc())
|
330 |
+
# If encountering an error, send a response with err msg
|
331 |
+
error_response = pb_utils.InferenceResponse(
|
332 |
+
output_tensors=[],
|
333 |
+
error=pb_utils.TritonError(traceback.format_exc()))
|
334 |
+
|
335 |
+
if self.decoupled:
|
336 |
+
response_sender.send(error_response)
|
337 |
+
response_sender.send(
|
338 |
+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
339 |
+
else:
|
340 |
+
responses.append(error_response)
|
341 |
+
|
342 |
+
if self.decoupled:
|
343 |
+
return None
|
344 |
+
else:
|
345 |
+
assert len(responses) == len(requests)
|
346 |
+
return responses
|
model_repo_whisper_qwen_trtllm/whisper/2/whisper_trtllm.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import json
|
16 |
+
from collections import OrderedDict
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import tensorrt_llm
|
23 |
+
import tensorrt_llm.logger as logger
|
24 |
+
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
|
25 |
+
trt_dtype_to_torch)
|
26 |
+
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
|
27 |
+
from tensorrt_llm.runtime.session import Session, TensorInfo
|
28 |
+
from transformers.trainer_pt_utils import LabelSmoother
|
29 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
30 |
+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
31 |
+
|
32 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
33 |
+
def remove_tensor_padding(input_tensor, input_tensor_lengths=None, pad_value=0):
|
34 |
+
if input_tensor.dim() == 2:
|
35 |
+
# Text tensor case: batch, seq_len
|
36 |
+
assert torch.all(
|
37 |
+
input_tensor[:, 0] != pad_value
|
38 |
+
), "First token in each sequence should not be pad_value"
|
39 |
+
assert input_tensor_lengths is None
|
40 |
+
|
41 |
+
# Create a mask for all non-pad tokens
|
42 |
+
mask = input_tensor != pad_value
|
43 |
+
|
44 |
+
# Apply the mask to input_tensor to remove pad tokens
|
45 |
+
output_tensor = input_tensor[mask].view(1, -1)
|
46 |
+
|
47 |
+
elif input_tensor.dim() == 3:
|
48 |
+
# Audio tensor case: batch, seq_len, feature_len
|
49 |
+
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
|
50 |
+
batch_size, seq_len, feature_len = input_tensor.shape
|
51 |
+
|
52 |
+
# Initialize a list to collect valid sequences
|
53 |
+
valid_sequences = []
|
54 |
+
|
55 |
+
for i in range(batch_size):
|
56 |
+
valid_length = input_tensor_lengths[i]
|
57 |
+
valid_sequences.append(input_tensor[i, :valid_length, :])
|
58 |
+
|
59 |
+
# Concatenate all valid sequences along the batch dimension
|
60 |
+
output_tensor = torch.cat(valid_sequences, dim=0)
|
61 |
+
|
62 |
+
else:
|
63 |
+
raise ValueError("Input tensor must have 2 or 3 dimensions")
|
64 |
+
|
65 |
+
return output_tensor
|
66 |
+
|
67 |
+
def read_config(component, engine_dir):
|
68 |
+
config_path = engine_dir / component / 'config.json'
|
69 |
+
with open(config_path, 'r') as f:
|
70 |
+
config = json.load(f)
|
71 |
+
model_config = OrderedDict()
|
72 |
+
model_config.update(config['pretrained_config'])
|
73 |
+
model_config.update(config['build_config'])
|
74 |
+
return model_config
|
75 |
+
|
76 |
+
class WhisperEncoding:
|
77 |
+
def __init__(self, engine_dir):
|
78 |
+
self.session = self.get_session(engine_dir)
|
79 |
+
config = read_config('encoder', engine_dir)
|
80 |
+
self.n_mels = config['n_mels']
|
81 |
+
self.dtype = config['dtype']
|
82 |
+
self.num_languages = config['num_languages']
|
83 |
+
self.encoder_config = config
|
84 |
+
|
85 |
+
def get_session(self, engine_dir):
|
86 |
+
serialize_path = engine_dir / 'encoder' / 'rank0.engine'
|
87 |
+
with open(serialize_path, 'rb') as f:
|
88 |
+
session = Session.from_serialized_engine(f.read())
|
89 |
+
return session
|
90 |
+
|
91 |
+
def get_audio_features(self,
|
92 |
+
mel):
|
93 |
+
mel_input_lengths = torch.tensor(
|
94 |
+
[mel.shape[2] for _ in range(mel.shape[0])],
|
95 |
+
dtype=torch.int32,
|
96 |
+
device=mel.device)
|
97 |
+
if self.encoder_config['plugin_config']['remove_input_padding']:
|
98 |
+
# mel B,D,T -> B,T,D -> BxT, D
|
99 |
+
mel = mel.transpose(1, 2)
|
100 |
+
mel = remove_tensor_padding(mel, mel_input_lengths)
|
101 |
+
|
102 |
+
inputs = OrderedDict()
|
103 |
+
inputs['input_features'] = mel
|
104 |
+
inputs['input_lengths'] = mel_input_lengths
|
105 |
+
|
106 |
+
output_list = [
|
107 |
+
TensorInfo('input_features', str_dtype_to_trt(self.dtype),
|
108 |
+
mel.shape),
|
109 |
+
TensorInfo('input_lengths', str_dtype_to_trt('int32'),
|
110 |
+
mel_input_lengths.shape)
|
111 |
+
]
|
112 |
+
|
113 |
+
output_info = (self.session).infer_shapes(output_list)
|
114 |
+
|
115 |
+
logger.debug(f'output info {output_info}')
|
116 |
+
outputs = {
|
117 |
+
t.name: torch.empty(tuple(t.shape),
|
118 |
+
dtype=trt_dtype_to_torch(t.dtype),
|
119 |
+
device='cuda')
|
120 |
+
for t in output_info
|
121 |
+
}
|
122 |
+
stream = torch.cuda.current_stream()
|
123 |
+
ok = self.session.run(inputs=inputs,
|
124 |
+
outputs=outputs,
|
125 |
+
stream=stream.cuda_stream)
|
126 |
+
assert ok, 'Engine execution failed'
|
127 |
+
stream.synchronize()
|
128 |
+
encoder_output = outputs['encoder_output']
|
129 |
+
encoder_output_lengths = mel_input_lengths // 2
|
130 |
+
|
131 |
+
return encoder_output
|
132 |
+
|
133 |
+
class EncoderProjector(torch.nn.Module):
|
134 |
+
"""
|
135 |
+
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
136 |
+
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
|
137 |
+
Args:
|
138 |
+
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
|
139 |
+
llm_dim (:obj:`int`): The dimension of the language model.
|
140 |
+
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, encoder_dim=1280, llm_dim=1536, downsample_rate=8):
|
144 |
+
super().__init__()
|
145 |
+
self.downsample_rate = downsample_rate
|
146 |
+
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
147 |
+
self.relu = nn.ReLU()
|
148 |
+
self.linear2 = nn.Linear(llm_dim, llm_dim)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
|
152 |
+
batch_size, seq_len, feat_dim = x.size()
|
153 |
+
num_frames_to_discard = seq_len % self.downsample_rate
|
154 |
+
if num_frames_to_discard > 0:
|
155 |
+
x = x[:, :-num_frames_to_discard, :]
|
156 |
+
seq_len = x.size(1)
|
157 |
+
|
158 |
+
x = x.contiguous()
|
159 |
+
x = x.view(
|
160 |
+
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
|
161 |
+
)
|
162 |
+
|
163 |
+
x = self.linear1(x)
|
164 |
+
x = self.relu(x)
|
165 |
+
x = self.linear2(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
class SPEECH_LLM(nn.Module):
|
169 |
+
"""
|
170 |
+
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
|
171 |
+
The encoder is used to extract speech features from the input speech signal.
|
172 |
+
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
|
173 |
+
The language model is used to generate the text from the speech features.
|
174 |
+
Args:
|
175 |
+
encoder (:obj:`nn.Module`): The encoder module.
|
176 |
+
llm (:obj:`nn.Module`): The language model module.
|
177 |
+
encoder_projector (:obj:`nn.Module`): The encoder projector module.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
encoder: nn.Module,
|
183 |
+
llm: nn.Module,
|
184 |
+
encoder_projector: nn.Module,
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
self.encoder = encoder
|
188 |
+
self.llm = llm
|
189 |
+
self.encoder_projector = encoder_projector
|
190 |
+
|
191 |
+
class WhisperTRTLLM(nn.Module):
|
192 |
+
|
193 |
+
def __init__(self, engine_dir):
|
194 |
+
super().__init__()
|
195 |
+
world_size = 1
|
196 |
+
runtime_rank = tensorrt_llm.mpi_rank()
|
197 |
+
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
|
198 |
+
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
199 |
+
engine_dir = Path(engine_dir)
|
200 |
+
|
201 |
+
self.encoder = WhisperEncoding(engine_dir)
|
202 |
+
self.encoder_projector = EncoderProjector()
|
203 |
+
self.encoder_projector = self.encoder_projector.half().to("cuda")
|
204 |
+
|
205 |
+
# llm = AutoModelForCausalLM.from_pretrained(
|
206 |
+
# "/home/scratch.yuekaiz_wwfo_1/Qwen2_1.5B_merged",
|
207 |
+
# attn_implementation="flash_attention_2",
|
208 |
+
# torch_dtype=torch.float16,
|
209 |
+
# )
|
210 |
+
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
211 |
+
# tokenizer.padding_side = "left"
|
212 |
+
# special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
213 |
+
# tokenizer.add_special_tokens(special_tokens_dict)
|
214 |
+
# llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
215 |
+
# llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
216 |
+
# llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
217 |
+
|
218 |
+
# llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
219 |
+
# DEFAULT_SPEECH_TOKEN
|
220 |
+
# )
|
221 |
+
# self.llm = llm.half().to("cuda")
|
222 |
+
# # print llm embedding layer shape
|
223 |
+
# print("llm embedding layer shape", self.llm.get_input_embeddings().weight.shape)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
def process_batch(
|
228 |
+
self,
|
229 |
+
mel,
|
230 |
+
decoder_input_ids=None,
|
231 |
+
eot_id=50257,
|
232 |
+
max_new_tokens=96,
|
233 |
+
num_beams=1):
|
234 |
+
encoder_outputs = self.encoder.get_audio_features(mel)
|
235 |
+
speech_features = self.encoder_projector(encoder_outputs)
|
236 |
+
speech_features = speech_features.to(torch.float16)
|
237 |
+
# [1,187,1536]
|
238 |
+
return speech_features
|
239 |
+
|
240 |
+
|
241 |
+
# def decode(
|
242 |
+
# self,
|
243 |
+
# fbank: torch.Tensor = None,
|
244 |
+
# input_ids: torch.LongTensor = None,
|
245 |
+
# attention_mask: torch.Tensor = None,
|
246 |
+
# **kwargs,
|
247 |
+
# ):
|
248 |
+
|
249 |
+
# encoder_outs = self.encoder.get_audio_features(fbank)
|
250 |
+
# speech_features = self.encoder_projector(encoder_outs)
|
251 |
+
# speech_features = speech_features.to(torch.float16)
|
252 |
+
# inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
253 |
+
# speech_token_index = input_ids.tolist()[0].index(151646)
|
254 |
+
# print("speech_token_index", speech_token_index, "speech_features_shape", speech_features.shape, "input_ids_shape", input_ids.shape, "inputs_embeds_shape", inputs_embeds.shape)
|
255 |
+
|
256 |
+
# new_length = inputs_embeds.shape[1] + speech_features.shape[1] - 1
|
257 |
+
# new_inputs_embeds = torch.zeros(1, new_length, 1536).to(inputs_embeds.device).half()
|
258 |
+
# new_inputs_embeds[:, :3, :] = inputs_embeds[:, :3, :]
|
259 |
+
# new_inputs_embeds[:, 3:3 + 187, :] = speech_features
|
260 |
+
# new_inputs_embeds[:, 3 + 187:, :] = inputs_embeds[:, 4:, :]
|
261 |
+
|
262 |
+
# inputs_embeds = new_inputs_embeds
|
263 |
+
# generated_ids = self.llm.generate(
|
264 |
+
# inputs_embeds=inputs_embeds,
|
265 |
+
# max_new_tokens=kwargs.get("max_new_tokens", 200),
|
266 |
+
# num_beams=kwargs.get("num_beams", 1),
|
267 |
+
# do_sample=kwargs.get("do_sample", False),
|
268 |
+
# min_length=kwargs.get("min_length", 1),
|
269 |
+
# top_p=kwargs.get("top_p", 1.0),
|
270 |
+
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
271 |
+
# length_penalty=kwargs.get("length_penalty", 1.0),
|
272 |
+
# temperature=kwargs.get("temperature", 1.0),
|
273 |
+
# bos_token_id=self.llm.config.bos_token_id,
|
274 |
+
# eos_token_id=self.llm.config.eos_token_id,
|
275 |
+
# pad_token_id=self.llm.config.pad_token_id,
|
276 |
+
# )
|
277 |
+
|
278 |
+
# return generated_ids
|
model_repo_whisper_qwen_trtllm/whisper/config.pbtxt
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "whisper"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: 8
|
18 |
+
|
19 |
+
parameters [
|
20 |
+
{
|
21 |
+
key: "n_mels",
|
22 |
+
value: {string_value:"80"} # 128 dim for large-v3, 80 dim for large-v2
|
23 |
+
},
|
24 |
+
{
|
25 |
+
key: "engine_dir"
|
26 |
+
value: { string_value: "/home/scratch.yuekaiz_wwfo_1/tekit/examples/whisper/whisper_multi_zh"}
|
27 |
+
}
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
input [
|
32 |
+
{
|
33 |
+
name: "TEXT_PREFIX"
|
34 |
+
data_type: TYPE_STRING
|
35 |
+
dims: [1]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
name: "WAV"
|
39 |
+
data_type: TYPE_FP32
|
40 |
+
dims: [-1]
|
41 |
+
}
|
42 |
+
]
|
43 |
+
|
44 |
+
output [
|
45 |
+
{
|
46 |
+
name: "TRANSCRIPTS"
|
47 |
+
data_type: TYPE_STRING
|
48 |
+
dims: [1]
|
49 |
+
}
|
50 |
+
]
|
51 |
+
|
52 |
+
dynamic_batching {
|
53 |
+
preferred_batch_size: [ 4, 8]
|
54 |
+
max_queue_delay_microseconds: 1000
|
55 |
+
}
|
56 |
+
instance_group [
|
57 |
+
{
|
58 |
+
count: 1
|
59 |
+
kind: KIND_CPU
|
60 |
+
}
|
61 |
+
]
|