yuekai commited on
Commit
6fdf6ff
1 Parent(s): a063960

Upload folder using huggingface_hub

Browse files
Files changed (26) hide show
  1. model_repo_whisper_qwen_trtllm/tensorrt_llm/1/.gitkeep +0 -0
  2. model_repo_whisper_qwen_trtllm/tensorrt_llm/1/model.py +947 -0
  3. model_repo_whisper_qwen_trtllm/tensorrt_llm/config.pbtxt +577 -0
  4. model_repo_whisper_qwen_trtllm/tensorrt_llm/config.template +577 -0
  5. model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/fbank.cpython-310.pyc +0 -0
  6. model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/model.cpython-310.pyc +0 -0
  7. model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/whisper_trtllm.cpython-310.pyc +0 -0
  8. model_repo_whisper_qwen_trtllm/whisper/0/fbank.py +91 -0
  9. model_repo_whisper_qwen_trtllm/whisper/0/mel_filters.npz +3 -0
  10. model_repo_whisper_qwen_trtllm/whisper/0/model.py +346 -0
  11. model_repo_whisper_qwen_trtllm/whisper/0/whisper_trtllm.py +278 -0
  12. model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/fbank.cpython-310.pyc +0 -0
  13. model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/model.cpython-310.pyc +0 -0
  14. model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/whisper_trtllm.cpython-310.pyc +0 -0
  15. model_repo_whisper_qwen_trtllm/whisper/1/fbank.py +91 -0
  16. model_repo_whisper_qwen_trtllm/whisper/1/mel_filters.npz +3 -0
  17. model_repo_whisper_qwen_trtllm/whisper/1/model.py +318 -0
  18. model_repo_whisper_qwen_trtllm/whisper/1/whisper_trtllm.py +212 -0
  19. model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/fbank.cpython-310.pyc +0 -0
  20. model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/model.cpython-310.pyc +0 -0
  21. model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/whisper_trtllm.cpython-310.pyc +0 -0
  22. model_repo_whisper_qwen_trtllm/whisper/2/fbank.py +91 -0
  23. model_repo_whisper_qwen_trtllm/whisper/2/mel_filters.npz +3 -0
  24. model_repo_whisper_qwen_trtllm/whisper/2/model.py +346 -0
  25. model_repo_whisper_qwen_trtllm/whisper/2/whisper_trtllm.py +278 -0
  26. model_repo_whisper_qwen_trtllm/whisper/config.pbtxt +61 -0
model_repo_whisper_qwen_trtllm/tensorrt_llm/1/.gitkeep ADDED
File without changes
model_repo_whisper_qwen_trtllm/tensorrt_llm/1/model.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+ import sys
5
+ import time
6
+ from random import randint
7
+ from threading import Lock, Thread
8
+
9
+ import numpy as np
10
+ import torch
11
+ import triton_python_backend_utils as pb_utils
12
+ from torch import from_numpy
13
+ from torch.utils.dlpack import from_dlpack
14
+
15
+ import tensorrt_llm.bindings.executor as trtllm
16
+
17
+
18
+ def get_input_tensor_by_name(request,
19
+ name,
20
+ expected_batch_size=None,
21
+ batch_index=None):
22
+ tensor = pb_utils.get_input_tensor_by_name(request, name)
23
+ if tensor is None:
24
+ return None
25
+
26
+ if tensor.is_cpu():
27
+ tensor = tensor.as_numpy()
28
+ else:
29
+ tensor = from_dlpack(tensor.to_dlpack())
30
+
31
+ if expected_batch_size is not None and tensor.shape[
32
+ 0] != expected_batch_size:
33
+ raise pb_utils.TritonModelException(
34
+ f"Expected batch size doesn't match batch size for tensor {name}. Expected {expected_batch_size} got {tensor.shape[0]}"
35
+ )
36
+
37
+ if batch_index is not None and expected_batch_size is not None and batch_index >= expected_batch_size:
38
+ raise pb_utils.TritonModelException(
39
+ f"Invalid batch index in get_input_tensor_by_name for {name}")
40
+
41
+ if batch_index is not None:
42
+ # Add leading 1 batch dimension
43
+ if isinstance(tensor, np.ndarray):
44
+ return np.expand_dims(tensor[batch_index], axis=0)
45
+ elif isinstance(tensor, torch.Tensor):
46
+ return torch.unsqueeze(tensor[batch_index], dim=0)
47
+ else:
48
+ return tensor
49
+
50
+
51
+ def get_input_scalar_by_name(request,
52
+ name,
53
+ expected_batch_size=1,
54
+ batch_index=0):
55
+ tensor = pb_utils.get_input_tensor_by_name(request, name)
56
+ if tensor is None:
57
+ return None
58
+ tensor = tensor.as_numpy()
59
+
60
+ if tensor.size != expected_batch_size:
61
+ raise pb_utils.TritonModelException(
62
+ f"Expected a scalar tensor for tensor {name}")
63
+
64
+ return tensor.item(batch_index)
65
+
66
+
67
+ def read_parameter_as_type(value, name, pytype=str):
68
+ if value == "":
69
+ return None
70
+ if value.startswith("${") and value.endswith("}"):
71
+ return None
72
+ if pytype is bool:
73
+ return value.lower() in ["1", "true"]
74
+ try:
75
+ result = pytype(value)
76
+ return result
77
+ except:
78
+ pb_utils.Logger.log_warning(
79
+ f"Could not read parameter '{name}' with value '{value}', will use default."
80
+ )
81
+ return None
82
+
83
+
84
+ def get_parameter(model_config, name, pytype=str):
85
+ if name not in model_config['parameters']:
86
+ return None
87
+ return read_parameter_as_type(
88
+ model_config['parameters'][name]['string_value'], name, pytype)
89
+
90
+
91
+ def convert_word_list(word_list):
92
+ if word_list is None:
93
+ return None
94
+ word_list = word_list.tolist()
95
+ if len(word_list) == 0 or len(word_list[0]) != 2:
96
+ raise pb_utils.TritonModelException(f"Invalid format for word list.")
97
+ words, indices = word_list[0]
98
+ result = []
99
+ current_index = 0
100
+ for i in indices:
101
+ if i == -1:
102
+ continue
103
+ if i > len(words):
104
+ raise pb_utils.TritonModelException(
105
+ f"Invalid format for word list.")
106
+ current_word = []
107
+ while current_index < i:
108
+ current_word.append(words[current_index])
109
+ current_index += 1
110
+ result.append(current_word)
111
+ return result
112
+
113
+
114
+ def parse_medusa_choices(medusa_choices):
115
+ if medusa_choices is None:
116
+ return None
117
+ try:
118
+ result = json.loads(
119
+ "[" + medusa_choices.replace("{", "[").replace("}", "]") + "]")
120
+ assert isinstance(result, list) and len(result) > 0
121
+ assert all([isinstance(x, list) for x in result])
122
+ assert all([isinstance(y, int) for x in result for y in x])
123
+ except Exception:
124
+ raise pb_utils.TritonModelException(
125
+ "Invalid format for medusa_choices")
126
+ return result
127
+
128
+
129
+ def get_sampling_config_from_request(request, batch_size=1, batch_index=0):
130
+ kwargs = {}
131
+ kwargs['beam_width'] = get_input_scalar_by_name(
132
+ request, 'beam_width', batch_size, batch_index) or 1
133
+ kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k',
134
+ batch_size, batch_index)
135
+ kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p',
136
+ batch_size, batch_index)
137
+ kwargs['top_p'] = None if kwargs['top_p'] is None or kwargs[
138
+ 'top_p'] <= 0 else kwargs['top_p']
139
+ kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed',
140
+ batch_size, batch_index)
141
+ kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature',
142
+ batch_size, batch_index)
143
+ kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length',
144
+ batch_size, batch_index)
145
+ kwargs['repetition_penalty'] = get_input_scalar_by_name(
146
+ request, 'repetition_penalty', batch_size, batch_index)
147
+ kwargs['presence_penalty'] = get_input_scalar_by_name(
148
+ request, 'presence_penalty', batch_size, batch_index)
149
+ kwargs['frequency_penalty'] = get_input_scalar_by_name(
150
+ request, 'frequency_penalty', batch_size, batch_index)
151
+ kwargs['length_penalty'] = get_input_scalar_by_name(
152
+ request, 'len_penalty', batch_size, batch_index)
153
+ kwargs['top_p_min'] = get_input_scalar_by_name(request,
154
+ 'runtime_top_p_min',
155
+ batch_size, batch_index)
156
+ kwargs['top_p_reset_ids'] = get_input_scalar_by_name(
157
+ request, 'runtime_top_p_reset_ids', batch_size, batch_index)
158
+ kwargs['top_p_decay'] = get_input_scalar_by_name(request,
159
+ 'runtime_top_p_decay',
160
+ batch_size, batch_index)
161
+ kwargs['beam_search_diversity_rate'] = get_input_scalar_by_name(
162
+ request, 'beam_search_diversity_rate', batch_size, batch_index)
163
+ kwargs['early_stopping'] = get_input_scalar_by_name(
164
+ request, 'early_stopping', batch_size, batch_index)
165
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
166
+ return trtllm.SamplingConfig(**kwargs)
167
+
168
+
169
+ def get_output_config_from_request(request,
170
+ exclude_input_from_output,
171
+ batch_size=1,
172
+ batch_index=0):
173
+ kwargs = {}
174
+ kwargs["return_log_probs"] = get_input_scalar_by_name(
175
+ request, 'return_log_probs', batch_size, batch_index)
176
+ kwargs["return_context_logits"] = get_input_scalar_by_name(
177
+ request, 'return_context_logits', batch_size, batch_index)
178
+ kwargs["return_generation_logits"] = get_input_scalar_by_name(
179
+ request, 'return_generation_logits', batch_size, batch_index)
180
+ kwargs["exclude_input_from_output"] = exclude_input_from_output
181
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
182
+ return trtllm.OutputConfig(**kwargs)
183
+
184
+
185
+ def get_external_draft_tokens_config_from_request(request,
186
+ batch_size=1,
187
+ batch_index=0):
188
+ kwargs = {}
189
+ draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids',
190
+ batch_size, batch_index)
191
+ if draft_input_ids is not None:
192
+ kwargs['tokens'] = draft_input_ids[0].tolist()
193
+ draft_logits = get_input_tensor_by_name(request, 'draft_logits',
194
+ batch_size, batch_index)
195
+ if draft_logits is not None:
196
+ kwargs['logits'] = from_numpy(draft_logits).squeeze()
197
+ kwargs['acceptance_threshold'] = get_input_scalar_by_name(
198
+ request, 'draft_acceptance_threshold', batch_size, batch_index)
199
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
200
+ if len(kwargs) > 0:
201
+ return trtllm.ExternalDraftTokensConfig(**kwargs)
202
+ return None
203
+
204
+
205
+ def get_prompt_tuning_config_from_request(request,
206
+ batch_size=1,
207
+ batch_index=0):
208
+ # prompt_vocab_size is unused by executor.
209
+ kwargs = {}
210
+ prompt_embedding_table = get_input_tensor_by_name(
211
+ request, 'prompt_embedding_table', batch_size, batch_index)
212
+ if prompt_embedding_table is not None:
213
+ if isinstance(prompt_embedding_table, np.ndarray):
214
+ kwargs["embedding_table"] = from_numpy(
215
+ prompt_embedding_table).squeeze()
216
+ elif isinstance(prompt_embedding_table, torch.Tensor):
217
+ kwargs["embedding_table"] = from_dlpack(
218
+ prompt_embedding_table.to_dlpack()).squeeze(dim=0)
219
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
220
+ if len(kwargs) > 0:
221
+ return trtllm.PromptTuningConfig(**kwargs)
222
+ return None
223
+
224
+
225
+ def get_lora_config_from_request(request, batch_size=1, batch_index=0):
226
+ kwargs = {}
227
+ kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id',
228
+ batch_size, batch_index)
229
+ lora_weights = get_input_tensor_by_name(request, 'lora_weights',
230
+ batch_size, batch_index)
231
+ if lora_weights is not None:
232
+ kwargs["weights"] = from_numpy(lora_weights).squeeze()
233
+ lora_config = get_input_tensor_by_name(request, 'lora_config', batch_size,
234
+ batch_index)
235
+ if lora_config is not None:
236
+ kwargs["config"] = from_numpy(lora_config).squeeze()
237
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
238
+ if len(kwargs) > 0:
239
+ return trtllm.LoraConfig(**kwargs)
240
+ return None
241
+
242
+
243
+ def convert_request(request, exclude_input_from_output, decoupled):
244
+ inputs = {}
245
+ input_token_ids = get_input_tensor_by_name(request, 'input_ids')
246
+ if input_token_ids is None:
247
+ raise pb_utils.TritonModelException(
248
+ "A value is required for input_ids")
249
+ if len(input_token_ids.shape) != 2:
250
+ raise pb_utils.TritonModelException(f"Invalid format for input_ids")
251
+ batch_size = input_token_ids.shape[0]
252
+ requests = []
253
+ for batch_index in range(0, batch_size):
254
+ input_token_ids = get_input_tensor_by_name(request, 'input_ids',
255
+ batch_size, batch_index)[0]
256
+ if input_token_ids is None:
257
+ raise pb_utils.TritonModelException(
258
+ "A value is required for input_ids")
259
+ input_token_ids = input_token_ids.tolist()
260
+ if len(input_token_ids) == 0:
261
+ raise pb_utils.TritonModelException(
262
+ f"Invalid format for input_ids")
263
+
264
+ input_length = get_input_scalar_by_name(request, 'input_lengths',
265
+ batch_size, batch_index)
266
+ if input_length is None:
267
+ input_length = len(input_token_ids)
268
+ # Trim input token ids with input_lengths
269
+ inputs['input_token_ids'] = input_token_ids[0:input_length]
270
+
271
+ inputs['max_new_tokens'] = get_input_scalar_by_name(
272
+ request, 'request_output_len', batch_size, batch_index)
273
+ if inputs['max_new_tokens'] is None:
274
+ raise pb_utils.TritonModelException(
275
+ "A value is required for request_output_len")
276
+ inputs['streaming'] = get_input_scalar_by_name(request, 'streaming',
277
+ batch_size, batch_index)
278
+ if inputs['streaming'] and not decoupled:
279
+ raise pb_utils.TritonModelException(
280
+ "Streaming is only supported in decoupled mode.")
281
+ inputs['end_id'] = get_input_scalar_by_name(request, 'end_id',
282
+ batch_size, batch_index)
283
+ inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id',
284
+ batch_size, batch_index)
285
+ inputs['stop_words'] = convert_word_list(
286
+ get_input_tensor_by_name(request, 'stop_words_list', batch_size,
287
+ batch_index))
288
+ inputs['bad_words'] = convert_word_list(
289
+ get_input_tensor_by_name(request, 'bad_words_list', batch_size,
290
+ batch_index))
291
+ embedding_bias = get_input_tensor_by_name(request, 'embedding_bias',
292
+ batch_size, batch_index)
293
+ if embedding_bias is not None and embedding_bias.size != 0:
294
+ inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze()
295
+
296
+ sampling_config = get_sampling_config_from_request(
297
+ request, batch_size, batch_index)
298
+ output_config = get_output_config_from_request(
299
+ request, exclude_input_from_output, batch_size, batch_index)
300
+ external_draft_tokens_config = get_external_draft_tokens_config_from_request(
301
+ request, batch_size, batch_index)
302
+ prompt_tuning_config = get_prompt_tuning_config_from_request(
303
+ request, batch_size, batch_index)
304
+ lora_config = get_lora_config_from_request(request, batch_size,
305
+ batch_index)
306
+
307
+ requests.append(
308
+ trtllm.Request(
309
+ **inputs,
310
+ sampling_config=sampling_config,
311
+ output_config=output_config,
312
+ external_draft_tokens_config=external_draft_tokens_config,
313
+ prompt_tuning_config=prompt_tuning_config,
314
+ lora_config=lora_config,
315
+ ))
316
+ return requests
317
+
318
+
319
+ def convert_response(response, batch_index):
320
+ if response.has_error():
321
+ return pb_utils.InferenceResponse(output_tensors=[],
322
+ error=pb_utils.TritonError(
323
+ response.error_msg)), True
324
+ result = response.result
325
+ beam_lengths = np.expand_dims(
326
+ np.array([len(beam) for beam in result.output_token_ids], np.int32), 0)
327
+ max_beam_length = max([len(beam) for beam in result.output_token_ids])
328
+ output_ids = np.full((1, len(result.output_token_ids), max_beam_length),
329
+ -1, np.int32)
330
+ for idx, beam in enumerate(result.output_token_ids):
331
+ output_ids[0, idx, :len(beam)] = beam
332
+ output_tensors = [
333
+ pb_utils.Tensor("output_ids", output_ids),
334
+ pb_utils.Tensor("sequence_length", beam_lengths),
335
+ ]
336
+ output_tensors.append(
337
+ pb_utils.Tensor(
338
+ "cum_log_probs",
339
+ np.expand_dims(np.array(result.cum_log_probs, np.float32), 0)
340
+ if result.cum_log_probs is not None else np.zeros(
341
+ (1, 1), np.float32)))
342
+ output_tensors.append(
343
+ pb_utils.Tensor(
344
+ "output_log_probs",
345
+ np.expand_dims(np.array(result.log_probs, np.float32), 0) if
346
+ result.log_probs is not None else np.zeros((1, 1, 1), np.float32)))
347
+ output_tensors.append(
348
+ pb_utils.Tensor(
349
+ "context_logits",
350
+ np.expand_dims(np.array(result.context_logits, np.float32), 0)
351
+ if result.context_logits is not None else np.zeros(
352
+ (1, 1, 1), np.float32)))
353
+ output_tensors.append(
354
+ pb_utils.Tensor(
355
+ "generation_logits",
356
+ np.expand_dims(np.array(result.generation_logits, np.float32), 0)
357
+ if result.generation_logits is not None else np.zeros(
358
+ (1, 1, 1, 1), np.float32)))
359
+ output_tensors.append(
360
+ pb_utils.Tensor("batch_index",
361
+ np.expand_dims(np.array([batch_index], np.int32), 0)))
362
+
363
+ return pb_utils.InferenceResponse(output_tensors), result.is_final
364
+
365
+
366
+ def convert_scheduler_policy(batch_scheduler_policy: str):
367
+ if batch_scheduler_policy.lower() == "max_utilization":
368
+ return trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
369
+ elif batch_scheduler_policy.lower() == "guaranteed_no_evict":
370
+ return trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
371
+ raise pb_utils.TritonModelException(
372
+ f"batch_scheduler_policy value of '{batch_scheduler_policy}' is not supported."
373
+ )
374
+
375
+
376
+ def convert_batching_type(gpt_model_type: str):
377
+ if gpt_model_type is None:
378
+ return None
379
+ if gpt_model_type.lower(
380
+ ) == "inflight_fused_batching" or gpt_model_type.lower(
381
+ ) == "inflight_batching":
382
+ return trtllm.BatchingType.INFLIGHT
383
+ elif gpt_model_type.lower() == "v1":
384
+ return trtllm.BatchingType.STATIC
385
+ raise pb_utils.TritonModelException(
386
+ f"gpt_model_type value of '{gpt_model_type}' is not supported.")
387
+
388
+
389
+ def convert_decoding_mode(decoding_mode: str):
390
+ if decoding_mode is None:
391
+ return None
392
+ elif decoding_mode == "auto":
393
+ return trtllm.DecodingMode.Auto()
394
+ elif decoding_mode == "top_k":
395
+ return trtllm.DecodingMode.TopK()
396
+ elif decoding_mode == "top_p":
397
+ return trtllm.DecodingMode.TopP()
398
+ elif decoding_mode == "top_k_top_p":
399
+ return trtllm.DecodingMode.TopKTopP()
400
+ elif decoding_mode == "beam_search":
401
+ return trtllm.DecodingMode.BeamSearch()
402
+ elif decoding_mode == "medusa":
403
+ return trtllm.DecodingMode.Medusa()
404
+ raise pb_utils.TritonModelException(
405
+ f"decoding_mode value of '{decoding_mode}' is not supported.")
406
+
407
+
408
+ def convert_timestamp_to_seconds(timestamp: str):
409
+ return int(
410
+ datetime.datetime.strptime(timestamp,
411
+ "%m-%d-%Y %H:%M:%S.%f").timestamp())
412
+
413
+
414
+ class TritonPythonModel:
415
+ """Your Python model must use the same class name. Every Python model
416
+ that is created must have "TritonPythonModel" as the class name.
417
+ """
418
+
419
+ def get_scheduler_config(self, model_config):
420
+ batch_scheduler_policy = get_parameter(model_config,
421
+ "batch_scheduler_policy")
422
+ if batch_scheduler_policy is None:
423
+ return trtllm.SchedulerConfig()
424
+ return trtllm.SchedulerConfig(
425
+ convert_scheduler_policy(batch_scheduler_policy))
426
+
427
+ def get_kv_cache_config(self, model_config):
428
+ kwargs = {
429
+ "enable_block_reuse":
430
+ get_parameter(model_config, "enable_kv_cache_reuse", bool),
431
+ "max_tokens":
432
+ get_parameter(model_config, "max_tokens_in_paged_kv_cache", int),
433
+ "sink_token_length":
434
+ get_parameter(model_config, "sink_token_length", int),
435
+ "free_gpu_memory_fraction":
436
+ get_parameter(model_config, "kv_cache_free_gpu_mem_fraction",
437
+ float),
438
+ "host_cache_size":
439
+ get_parameter(model_config, "kv_cache_host_memory_bytes", int),
440
+ "onboard_blocks":
441
+ get_parameter(model_config, "kv_cache_onboard_blocks", bool),
442
+ }
443
+ max_attention_window_size = get_parameter(model_config,
444
+ "max_attention_window_size")
445
+ if max_attention_window_size:
446
+ kwargs["max_attention_window"] = [
447
+ int(x) for x in max_attention_window_size.split(",")
448
+ ]
449
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
450
+ return trtllm.KvCacheConfig(**kwargs)
451
+
452
+ def get_parallel_config(self, model_config):
453
+ kwargs = {}
454
+ gpu_device_ids = get_parameter(model_config, "gpu_device_ids")
455
+ if gpu_device_ids:
456
+ kwargs["device_ids"] = [int(x) for x in gpu_device_ids.split(",")]
457
+ self.use_orchestrator_mode = os.environ.get("TRTLLM_ORCHESTRATOR",
458
+ "0") == "1"
459
+ if self.use_orchestrator_mode:
460
+ kwargs[
461
+ "communication_mode"] = trtllm.CommunicationMode.ORCHESTRATOR
462
+ worker_path = get_parameter(model_config, "worker_path")
463
+ if worker_path is not None:
464
+ raise pb_utils.TritonModelException(
465
+ "worker_path parameter is specified, but this is no longer supported. Please specify executor_worker_path instead to specify the location of the trtllmExecutorWorker executable."
466
+ )
467
+ executor_worker_path = get_parameter(model_config,
468
+ "executor_worker_path")
469
+ kwargs["orchestrator_config"] = trtllm.OrchestratorConfig(
470
+ True, executor_worker_path)
471
+ if len(kwargs) > 0:
472
+ return trtllm.ParallelConfig(**kwargs)
473
+ return None
474
+
475
+ def get_peft_cache_config(self, model_config):
476
+ kwargs = {
477
+ "optimal_adapter_size":
478
+ get_parameter(model_config, "lora_cache_optimal_adapter_size",
479
+ int),
480
+ "max_adapter_size":
481
+ get_parameter(model_config, "lora_cache_max_adapter_size", int),
482
+ "device_cache_percent":
483
+ get_parameter(model_config, "lora_cache_gpu_memory_fraction",
484
+ float),
485
+ "host_cache_size":
486
+ get_parameter(model_config, "lora_cache_host_memory_bytes", int),
487
+ }
488
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
489
+ return trtllm.PeftCacheConfig(**kwargs)
490
+
491
+ def get_decoding_config(self, model_config):
492
+ kwargs = {
493
+ "medusa_choices":
494
+ parse_medusa_choices(get_parameter(model_config,
495
+ "medusa_choices")),
496
+ "decoding_mode":
497
+ convert_decoding_mode(get_parameter(model_config,
498
+ "decoding_mode")),
499
+ }
500
+ print(kwargs)
501
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
502
+ return trtllm.DecodingConfig(**kwargs)
503
+
504
+ def get_extended_runtime_perf_knob_config(self, model_config):
505
+ kwargs = {
506
+ "multi_block_mode":
507
+ get_parameter(model_config, "multi_block_mode", bool),
508
+ "enable_context_fmha_fp32_acc":
509
+ get_parameter(model_config, "enable_context_fmha_fp32_acc", bool)
510
+ }
511
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
512
+ return trtllm.ExtendedRuntimePerfKnobConfig(**kwargs)
513
+
514
+ def get_executor_config(self, model_config):
515
+ kwargs = {
516
+ "max_beam_width":
517
+ get_parameter(model_config, "max_beam_width", int),
518
+ "scheduler_config":
519
+ self.get_scheduler_config(model_config),
520
+ "kv_cache_config":
521
+ self.get_kv_cache_config(model_config),
522
+ "enable_chunked_context":
523
+ get_parameter(model_config, "enable_chunked_context", bool),
524
+ "normalize_log_probs":
525
+ get_parameter(model_config, "normalize_log_probs", bool),
526
+ "batching_type":
527
+ convert_batching_type(get_parameter(model_config,
528
+ "gpt_model_type")),
529
+ "parallel_config":
530
+ self.get_parallel_config(model_config),
531
+ "peft_cache_config":
532
+ self.get_peft_cache_config(model_config),
533
+ "decoding_config":
534
+ self.get_decoding_config(model_config),
535
+ "max_queue_size":
536
+ model_config.get(
537
+ "dynamic_batching",
538
+ {},
539
+ ).get(
540
+ "default_queue_policy",
541
+ {},
542
+ ).get("max_queue_size"),
543
+ "extended_runtime_perf_knob_config":
544
+ self.get_extended_runtime_perf_knob_config(model_config)
545
+ }
546
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
547
+ return trtllm.ExecutorConfig(**kwargs)
548
+
549
+ def create_metrics(self, model: str, version: str, is_v1_model: bool):
550
+ self.request_metric_family = pb_utils.MetricFamily(
551
+ name="nv_trt_llm_request_metrics",
552
+ description="TRT LLM request metrics",
553
+ kind=pb_utils.MetricFamily.GAUGE,
554
+ )
555
+ self.runtime_memory_metric_family = pb_utils.MetricFamily(
556
+ name="nv_trt_llm_runtime_memory_metrics",
557
+ description="TRT LLM runtime memory metrics",
558
+ kind=pb_utils.MetricFamily.GAUGE,
559
+ )
560
+ self.kv_cache_metric_family = pb_utils.MetricFamily(
561
+ name="nv_trt_llm_kv_cache_block_metrics",
562
+ description="TRT LLM KV cache block metrics",
563
+ kind=pb_utils.MetricFamily.GAUGE,
564
+ )
565
+ model_type = "v1" if is_v1_model else "inflight_batcher"
566
+ self.model_type_metric_family = pb_utils.MetricFamily(
567
+ name=f"nv_trt_llm_{model_type}_metrics",
568
+ description=f"TRT LLM {model_type}-specific metrics",
569
+ kind=pb_utils.MetricFamily.GAUGE,
570
+ )
571
+ self.general_metric_family = pb_utils.MetricFamily(
572
+ name="nv_trt_llm_general_metrics",
573
+ description="General TRT LLM metrics",
574
+ kind=pb_utils.MetricFamily.GAUGE,
575
+ )
576
+ common_labels = {"model": model, "version": version}
577
+ self.all_metrics = {
578
+ # Request metrics
579
+ "num_active_requests":
580
+ self.request_metric_family.Metric(labels={
581
+ "request_type": "active",
582
+ **common_labels
583
+ }),
584
+ "max_num_active_requests":
585
+ self.request_metric_family.Metric(labels={
586
+ "request_type": "max",
587
+ **common_labels
588
+ }),
589
+ "num_scheduled_requests":
590
+ self.request_metric_family.Metric(labels={
591
+ "request_type": "scheduled",
592
+ **common_labels
593
+ }),
594
+ "num_context_requests":
595
+ self.request_metric_family.Metric(labels={
596
+ "request_type": "context",
597
+ **common_labels
598
+ }),
599
+ # Runtime metrics
600
+ "cpu_mem_usage":
601
+ self.runtime_memory_metric_family.Metric(labels={
602
+ "memory_type": "cpu",
603
+ **common_labels
604
+ }),
605
+ "gpu_mem_usage":
606
+ self.runtime_memory_metric_family.Metric(labels={
607
+ "memory_type": "gpu",
608
+ **common_labels
609
+ }),
610
+ "pinned_mem_usage":
611
+ self.runtime_memory_metric_family.Metric(labels={
612
+ "memory_type": "pinned",
613
+ **common_labels
614
+ }),
615
+ # KV cache metrics
616
+ "max_num_blocks":
617
+ self.kv_cache_metric_family.Metric(labels={
618
+ "kv_cache_block_type": "max",
619
+ **common_labels
620
+ }),
621
+ "free_num_blocks":
622
+ self.kv_cache_metric_family.Metric(labels={
623
+ "kv_cache_block_type": "free",
624
+ **common_labels
625
+ }),
626
+ "used_num_blocks":
627
+ self.kv_cache_metric_family.Metric(labels={
628
+ "kv_cache_block_type": "used",
629
+ **common_labels
630
+ }),
631
+ "tokens_per_block":
632
+ self.kv_cache_metric_family.Metric(labels={
633
+ "kv_cache_block_type": "tokens_per",
634
+ **common_labels
635
+ }),
636
+ # General metrics
637
+ "timestamp":
638
+ self.general_metric_family.Metric(labels={
639
+ "general_type": "timestamp",
640
+ **common_labels
641
+ }),
642
+ "iter":
643
+ self.general_metric_family.Metric(labels={
644
+ "general_type": "iteration_counter",
645
+ **common_labels
646
+ }),
647
+ }
648
+ if is_v1_model:
649
+ self.all_metrics.update({
650
+ "num_ctx_tokens":
651
+ self.model_type_metric_family.Metric(labels={
652
+ "v1_specific_metric": "total_context_tokens",
653
+ **common_labels
654
+ }),
655
+ "num_gen_tokens":
656
+ self.model_type_metric_family.Metric(
657
+ labels={
658
+ "v1_specific_metric": "total_generation_tokens",
659
+ **common_labels
660
+ }),
661
+ "empty_gen_slots":
662
+ self.model_type_metric_family.Metric(
663
+ labels={
664
+ "v1_specific_metric": "empty_generation_slots",
665
+ **common_labels
666
+ }),
667
+ })
668
+ else:
669
+ self.all_metrics.update({
670
+ "num_ctx_tokens":
671
+ self.model_type_metric_family.Metric(
672
+ labels={
673
+ "inflight_batcher_specific_metric":
674
+ "total_context_tokens",
675
+ **common_labels
676
+ }),
677
+ "num_gen_requests":
678
+ self.model_type_metric_family.Metric(
679
+ labels={
680
+ "inflight_batcher_specific_metric":
681
+ "generation_requests",
682
+ **common_labels
683
+ }),
684
+ "micro_batch_id":
685
+ self.model_type_metric_family.Metric(
686
+ labels={
687
+ "inflight_batcher_specific_metric": "micro_batch_id",
688
+ **common_labels
689
+ }),
690
+ "num_paused_requests":
691
+ self.model_type_metric_family.Metric(
692
+ labels={
693
+ "inflight_batcher_specific_metric": "paused_requests",
694
+ **common_labels
695
+ }),
696
+ })
697
+
698
+ def initialize(self, args):
699
+ """`initialize` is called only once when the model is being loaded.
700
+ Implementing `initialize` function is optional. This function allows
701
+ the model to initialize any state associated with this model.
702
+
703
+ Parameters
704
+ ----------
705
+ args : dict
706
+ Both keys and values are strings. The dictionary keys and values are:
707
+ * model_config: A JSON string containing the model configuration
708
+ * model_instance_kind: A string containing model instance kind
709
+ * model_instance_device_id: A string containing model instance device ID
710
+ * model_repository: Model repository path
711
+ * model_version: Model version
712
+ * model_name: Model name
713
+ """
714
+ model_config = json.loads(args['model_config'])
715
+ gpt_model_path = get_parameter(model_config, "gpt_model_path")
716
+ if get_parameter(model_config, "enable_trt_overlap", bool):
717
+ raise pb_utils.TritonModelException(
718
+ f"enable_trt_overlap=true is not supported.")
719
+ self.exclude_input_from_output = get_parameter(
720
+ model_config, "exclude_input_in_output", bool)
721
+ executor_config = self.get_executor_config(model_config)
722
+ self.executor = trtllm.Executor(gpt_model_path,
723
+ trtllm.ModelType.DECODER_ONLY,
724
+ executor_config)
725
+ self.decoupled = pb_utils.using_decoupled_model_transaction_policy(
726
+ model_config)
727
+ self.cancellation_check_period_ms = get_parameter(
728
+ model_config, "cancellation_check_period_ms", int) or 100
729
+ self.stats_check_period_ms = get_parameter(
730
+ model_config, "stats_check_period_ms", int) or 100
731
+
732
+ if not self.decoupled:
733
+ raise pb_utils.TritonModelException(
734
+ "Please enable decoupled transaction policy in the model configuration to serve this model"
735
+ )
736
+
737
+ self.create_metrics(args["model_name"],
738
+ args["model_version"],
739
+ is_v1_model=executor_config.batching_type ==
740
+ trtllm.BatchingType.STATIC)
741
+ self.triton_user_id_to_req_ids = {}
742
+ self.triton_req_id_to_req_ids = {}
743
+ self.req_id_to_request_data = {}
744
+ self.lock = Lock()
745
+ self.running = False
746
+ self.awaiter_thread = Thread(target=self.awaiter_loop)
747
+ self.cancellation_thread = Thread(target=self.cancellation_loop)
748
+ self.metrics_thread = Thread(target=self.metrics_loop)
749
+ if self.executor.can_enqueue_requests():
750
+ self.running = True
751
+ self.awaiter_thread.start()
752
+ self.cancellation_thread.start()
753
+ self.metrics_thread.start()
754
+ else:
755
+ # In leader mode, worker ranks will wait here until leader is done.
756
+ self.executor.shutdown()
757
+
758
+ def handle_stop_request(self, triton_user_id, response_sender):
759
+ if triton_user_id is None or triton_user_id == "":
760
+ response_sender.send(
761
+ pb_utils.InferenceResponse(error=pb_utils.TritonError(
762
+ "A request id must be provided for request cancellation")),
763
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
764
+ return
765
+
766
+ with self.lock:
767
+ if triton_user_id in self.triton_user_id_to_req_ids:
768
+ req_ids = self.triton_user_id_to_req_ids[triton_user_id]
769
+ for req_id in req_ids:
770
+ self.executor.cancel_request(req_id)
771
+
772
+ response_sender.send(
773
+ pb_utils.InferenceResponse(),
774
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
775
+
776
+ def execute(self, requests):
777
+ """`execute` must be implemented in every Python model. `execute`
778
+ function receives a list of pb_utils.InferenceRequest as the only
779
+ argument. This function is called when an inference is requested
780
+ for this model.
781
+
782
+ Parameters
783
+ ----------
784
+ requests : list
785
+ A list of pb_utils.InferenceRequest
786
+
787
+ Returns
788
+ -------
789
+ list
790
+ A list of pb_utils.InferenceResponse. The length of this list must
791
+ be the same as `requests`
792
+ """
793
+ if not self.executor.can_enqueue_requests():
794
+ return
795
+
796
+ # Convert to executor requests.
797
+
798
+ triton_requests = []
799
+ executor_requests = []
800
+ batch_indices = []
801
+ triton_user_ids = []
802
+ triton_req_ids = []
803
+
804
+ for request in requests:
805
+
806
+ triton_user_id = request.request_id()
807
+
808
+ response_sender = request.get_response_sender()
809
+ stop = get_input_scalar_by_name(request, 'stop')
810
+
811
+ if stop:
812
+ self.handle_stop_request(triton_user_id, response_sender)
813
+ else:
814
+ #Unique request id used to identify each triton request
815
+ triton_req_id = str(randint(0, sys.maxsize))
816
+ self.triton_req_id_to_req_ids[triton_req_id] = set()
817
+ if triton_user_id is not None and triton_user_id != "":
818
+ self.triton_user_id_to_req_ids[triton_user_id] = set()
819
+
820
+ try:
821
+ converted_reqs = convert_request(
822
+ request, self.exclude_input_from_output,
823
+ self.decoupled)
824
+ except Exception as e:
825
+ response_sender.send(
826
+ pb_utils.InferenceResponse(error=pb_utils.TritonError(
827
+ f"An error occurred when processing the input values for request id {request.request_id()}, the error was '{e}'"
828
+ )),
829
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
830
+ else:
831
+ for batch_index, converted_req in enumerate(
832
+ converted_reqs):
833
+ triton_requests.append(request)
834
+ executor_requests.append(converted_req)
835
+ triton_user_ids.append(triton_user_id)
836
+ triton_req_ids.append(triton_req_id)
837
+ batch_indices.append(batch_index)
838
+
839
+ with self.lock:
840
+ request_ids = self.executor.enqueue_requests(executor_requests)
841
+ for req_id, triton_req_id, triton_user_id, triton_request, batch_index in zip(
842
+ request_ids, triton_req_ids, triton_user_ids,
843
+ triton_requests, batch_indices):
844
+ self.req_id_to_request_data[
845
+ req_id] = triton_req_id, triton_user_id, batch_index, triton_request.get_response_sender(
846
+ )
847
+ self.triton_req_id_to_req_ids[triton_req_id].add(req_id)
848
+ if triton_user_id is not None and triton_user_id != "":
849
+ self.triton_user_id_to_req_ids[triton_user_id].add(req_id)
850
+
851
+ return None
852
+
853
+ def awaiter_loop(self):
854
+ """Gets responses from executor and returns the results."""
855
+ while self.running:
856
+ for response in self.executor.await_responses(
857
+ timeout=datetime.timedelta(milliseconds=1)):
858
+ req_id = response.request_id
859
+ with self.lock:
860
+ if req_id not in self.req_id_to_request_data:
861
+ continue
862
+ triton_req_id, triton_user_id, batch_index, response_sender = self.req_id_to_request_data[
863
+ req_id]
864
+
865
+ triton_response, is_final = convert_response(
866
+ response, batch_index)
867
+
868
+ triton_request_final = False
869
+ if is_final:
870
+ with self.lock:
871
+ # Check if all executor requests part of that triton request are finished
872
+ self.triton_req_id_to_req_ids[triton_req_id].remove(
873
+ req_id)
874
+ if len(self.triton_req_id_to_req_ids[triton_req_id]
875
+ ) == 0:
876
+ pb_utils.Logger.log_info(
877
+ f"DELETING Req id {req_id}, triton_req_id {triton_req_id} "
878
+ )
879
+ triton_request_final = True
880
+ del self.triton_req_id_to_req_ids[triton_req_id]
881
+ if triton_user_id is not None and triton_user_id != "":
882
+ del self.triton_user_id_to_req_ids[
883
+ triton_user_id]
884
+ del self.req_id_to_request_data[req_id]
885
+
886
+ response_sender.send(
887
+ triton_response,
888
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
889
+ if triton_request_final else 0)
890
+
891
+ # Remove local reference so response_sender can be cleaned properly.
892
+ del response_sender
893
+
894
+ def cancellation_loop(self):
895
+ """Checks if any pending requests have been cancelled."""
896
+ while self.running:
897
+ time.sleep(self.cancellation_check_period_ms / 1000.0)
898
+ with self.lock:
899
+ for req_id, (triton_req_id, triton_user_id, batch_index,
900
+ response_sender
901
+ ) in self.req_id_to_request_data.items():
902
+ if response_sender.is_cancelled():
903
+ self.executor.cancel_request(req_id)
904
+ # Remove local reference so response_sender can be cleaned properly.
905
+ del response_sender
906
+
907
+ def metrics_loop(self):
908
+ """Updates triton metrics using stats from the executor."""
909
+ while self.running:
910
+ time.sleep(self.stats_check_period_ms / 1000.0)
911
+ for stat in self.executor.get_latest_iteration_stats():
912
+ try:
913
+ for key, metric in self.all_metrics.items():
914
+ value = None
915
+ if hasattr(stat, key):
916
+ value = getattr(stat, key)
917
+ elif stat.kv_cache_stats is not None and hasattr(
918
+ stat.kv_cache_stats, key):
919
+ value = getattr(stat.kv_cache_stats, key)
920
+ elif stat.static_batching_stats is not None and hasattr(
921
+ stat.static_batching_stats, key):
922
+ value = getattr(stat.static_batching_stats, key)
923
+ elif stat.inflight_batching_stats is not None and hasattr(
924
+ stat.inflight_batching_stats, key):
925
+ value = getattr(stat.inflight_batching_stats, key)
926
+ if value is not None:
927
+ if key == "timestamp":
928
+ value = convert_timestamp_to_seconds(value)
929
+ metric.set(value)
930
+ else:
931
+ pb_utils.Logger.log_warn(
932
+ f"Metric \"{key}\" not found.")
933
+ except Exception as e:
934
+ pb_utils.Logger.log_warn(
935
+ f"Error while processing metrics: {e}")
936
+
937
+ def finalize(self):
938
+ """`finalize` is called only once when the model is being unloaded.
939
+ Implementing `finalize` function is optional. This function allows
940
+ the model to perform any necessary clean ups before exit.
941
+ """
942
+ if self.executor.can_enqueue_requests():
943
+ self.running = False
944
+ self.awaiter_thread.join()
945
+ self.cancellation_thread.join()
946
+ self.metrics_thread.join()
947
+ self.executor.shutdown()
model_repo_whisper_qwen_trtllm/tensorrt_llm/config.pbtxt ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ name: "tensorrt_llm"
28
+ backend: "tensorrtllm"
29
+ max_batch_size: 8
30
+
31
+ model_transaction_policy {
32
+ decoupled: false
33
+ }
34
+
35
+ dynamic_batching {
36
+ preferred_batch_size: [ 8 ]
37
+ max_queue_delay_microseconds: 0
38
+ default_queue_policy: { max_queue_size: 0 }
39
+ }
40
+
41
+ input [
42
+ {
43
+ name: "input_ids"
44
+ data_type: TYPE_INT32
45
+ dims: [ -1 ]
46
+ allow_ragged_batch: true
47
+ optional: true
48
+ },
49
+ {
50
+ name: "encoder_input_features"
51
+ data_type: TYPE_FP16
52
+ dims: [ -1, -1 ]
53
+ allow_ragged_batch: true
54
+ optional: true
55
+ },
56
+ {
57
+ name: "encoder_output_lengths"
58
+ data_type: TYPE_INT32
59
+ dims: [ 1 ]
60
+ reshape: { shape: [ ] }
61
+ optional: true
62
+ },
63
+ {
64
+ name: "input_lengths"
65
+ data_type: TYPE_INT32
66
+ dims: [ 1 ]
67
+ reshape: { shape: [ ] }
68
+ },
69
+ {
70
+ name: "request_output_len"
71
+ data_type: TYPE_INT32
72
+ dims: [ 1 ]
73
+ reshape: { shape: [ ] }
74
+ },
75
+ {
76
+ name: "draft_input_ids"
77
+ data_type: TYPE_INT32
78
+ dims: [ -1 ]
79
+ optional: true
80
+ allow_ragged_batch: true
81
+ },
82
+ {
83
+ name: "decoder_input_ids"
84
+ data_type: TYPE_INT32
85
+ dims: [ -1 ]
86
+ optional: true
87
+ allow_ragged_batch: true
88
+ },
89
+ {
90
+ name: "decoder_input_lengths"
91
+ data_type: TYPE_INT32
92
+ dims: [ 1 ]
93
+ optional: true
94
+ reshape: { shape: [ ] }
95
+ },
96
+ {
97
+ name: "draft_logits"
98
+ data_type: TYPE_FP32
99
+ dims: [ -1, -1 ]
100
+ optional: true
101
+ allow_ragged_batch: true
102
+ },
103
+ {
104
+ name: "draft_acceptance_threshold"
105
+ data_type: TYPE_FP32
106
+ dims: [ 1 ]
107
+ reshape: { shape: [ ] }
108
+ optional: true
109
+ },
110
+ {
111
+ name: "end_id"
112
+ data_type: TYPE_INT32
113
+ dims: [ 1 ]
114
+ reshape: { shape: [ ] }
115
+ optional: true
116
+ },
117
+ {
118
+ name: "pad_id"
119
+ data_type: TYPE_INT32
120
+ dims: [ 1 ]
121
+ reshape: { shape: [ ] }
122
+ optional: true
123
+ },
124
+ {
125
+ name: "stop_words_list"
126
+ data_type: TYPE_INT32
127
+ dims: [ 2, -1 ]
128
+ optional: true
129
+ allow_ragged_batch: true
130
+ },
131
+ {
132
+ name: "bad_words_list"
133
+ data_type: TYPE_INT32
134
+ dims: [ 2, -1 ]
135
+ optional: true
136
+ allow_ragged_batch: true
137
+ },
138
+ {
139
+ name: "embedding_bias"
140
+ data_type: TYPE_FP32
141
+ dims: [ -1 ]
142
+ optional: true
143
+ allow_ragged_batch: true
144
+ },
145
+ {
146
+ name: "beam_width"
147
+ data_type: TYPE_INT32
148
+ dims: [ 1 ]
149
+ reshape: { shape: [ ] }
150
+ optional: true
151
+ },
152
+ {
153
+ name: "temperature"
154
+ data_type: TYPE_FP32
155
+ dims: [ 1 ]
156
+ reshape: { shape: [ ] }
157
+ optional: true
158
+ },
159
+ {
160
+ name: "runtime_top_k"
161
+ data_type: TYPE_INT32
162
+ dims: [ 1 ]
163
+ reshape: { shape: [ ] }
164
+ optional: true
165
+ },
166
+ {
167
+ name: "runtime_top_p"
168
+ data_type: TYPE_FP32
169
+ dims: [ 1 ]
170
+ reshape: { shape: [ ] }
171
+ optional: true
172
+ },
173
+ {
174
+ name: "runtime_top_p_min"
175
+ data_type: TYPE_FP32
176
+ dims: [ 1 ]
177
+ reshape: { shape: [ ] }
178
+ optional: true
179
+ },
180
+ {
181
+ name: "runtime_top_p_decay"
182
+ data_type: TYPE_FP32
183
+ dims: [ 1 ]
184
+ reshape: { shape: [ ] }
185
+ optional: true
186
+ },
187
+ {
188
+ name: "runtime_top_p_reset_ids"
189
+ data_type: TYPE_INT32
190
+ dims: [ 1 ]
191
+ reshape: { shape: [ ] }
192
+ optional: true
193
+ },
194
+ {
195
+ name: "len_penalty"
196
+ data_type: TYPE_FP32
197
+ dims: [ 1 ]
198
+ reshape: { shape: [ ] }
199
+ optional: true
200
+ },
201
+ {
202
+ name: "early_stopping"
203
+ data_type: TYPE_BOOL
204
+ dims: [ 1 ]
205
+ reshape: { shape: [ ] }
206
+ optional: true
207
+ },
208
+ {
209
+ name: "repetition_penalty"
210
+ data_type: TYPE_FP32
211
+ dims: [ 1 ]
212
+ reshape: { shape: [ ] }
213
+ optional: true
214
+ },
215
+ {
216
+ name: "min_length"
217
+ data_type: TYPE_INT32
218
+ dims: [ 1 ]
219
+ reshape: { shape: [ ] }
220
+ optional: true
221
+ },
222
+ {
223
+ name: "beam_search_diversity_rate"
224
+ data_type: TYPE_FP32
225
+ dims: [ 1 ]
226
+ reshape: { shape: [ ] }
227
+ optional: true
228
+ },
229
+ {
230
+ name: "presence_penalty"
231
+ data_type: TYPE_FP32
232
+ dims: [ 1 ]
233
+ reshape: { shape: [ ] }
234
+ optional: true
235
+ },
236
+ {
237
+ name: "frequency_penalty"
238
+ data_type: TYPE_FP32
239
+ dims: [ 1 ]
240
+ reshape: { shape: [ ] }
241
+ optional: true
242
+ },
243
+ {
244
+ name: "random_seed"
245
+ data_type: TYPE_UINT64
246
+ dims: [ 1 ]
247
+ reshape: { shape: [ ] }
248
+ optional: true
249
+ },
250
+ {
251
+ name: "return_log_probs"
252
+ data_type: TYPE_BOOL
253
+ dims: [ 1 ]
254
+ reshape: { shape: [ ] }
255
+ optional: true
256
+ },
257
+ {
258
+ name: "return_context_logits"
259
+ data_type: TYPE_BOOL
260
+ dims: [ 1 ]
261
+ reshape: { shape: [ ] }
262
+ optional: true
263
+ },
264
+ {
265
+ name: "return_generation_logits"
266
+ data_type: TYPE_BOOL
267
+ dims: [ 1 ]
268
+ reshape: { shape: [ ] }
269
+ optional: true
270
+ },
271
+ {
272
+ name: "stop"
273
+ data_type: TYPE_BOOL
274
+ dims: [ 1 ]
275
+ reshape: { shape: [ ] }
276
+ optional: true
277
+ },
278
+ {
279
+ name: "streaming"
280
+ data_type: TYPE_BOOL
281
+ dims: [ 1 ]
282
+ reshape: { shape: [ ] }
283
+ optional: true
284
+ },
285
+ {
286
+ name: "prompt_embedding_table"
287
+ data_type: TYPE_FP16
288
+ dims: [ -1, -1 ]
289
+ optional: true
290
+ allow_ragged_batch: true
291
+ },
292
+ {
293
+ name: "prompt_vocab_size"
294
+ data_type: TYPE_INT32
295
+ dims: [ 1 ]
296
+ reshape: { shape: [ ] }
297
+ optional: true
298
+ },
299
+ # the unique task ID for the given LoRA.
300
+ # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
301
+ # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
302
+ # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
303
+ {
304
+ name: "lora_task_id"
305
+ data_type: TYPE_UINT64
306
+ dims: [ 1 ]
307
+ reshape: { shape: [ ] }
308
+ optional: true
309
+ },
310
+ # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
311
+ # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
312
+ # each of the in / out tensors are first flattened and then concatenated together in the format above.
313
+ # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
314
+ {
315
+ name: "lora_weights"
316
+ data_type: TYPE_FP16
317
+ dims: [ -1, -1 ]
318
+ optional: true
319
+ allow_ragged_batch: true
320
+ },
321
+ # module identifier (same size a first dimension of lora_weights)
322
+ # See LoraModule::ModuleType for model id mapping
323
+ #
324
+ # "attn_qkv": 0 # compbined qkv adapter
325
+ # "attn_q": 1 # q adapter
326
+ # "attn_k": 2 # k adapter
327
+ # "attn_v": 3 # v adapter
328
+ # "attn_dense": 4 # adapter for the dense layer in attention
329
+ # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
330
+ # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
331
+ # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
332
+ #
333
+ # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
334
+ {
335
+ name: "lora_config"
336
+ data_type: TYPE_INT32
337
+ dims: [ -1, 3 ]
338
+ optional: true
339
+ allow_ragged_batch: true
340
+ }
341
+ ]
342
+ output [
343
+ {
344
+ name: "output_ids"
345
+ data_type: TYPE_INT32
346
+ dims: [ -1, -1 ]
347
+ },
348
+ {
349
+ name: "sequence_length"
350
+ data_type: TYPE_INT32
351
+ dims: [ -1 ]
352
+ },
353
+ {
354
+ name: "cum_log_probs"
355
+ data_type: TYPE_FP32
356
+ dims: [ -1 ]
357
+ },
358
+ {
359
+ name: "output_log_probs"
360
+ data_type: TYPE_FP32
361
+ dims: [ -1, -1 ]
362
+ },
363
+ {
364
+ name: "context_logits"
365
+ data_type: TYPE_FP32
366
+ dims: [ -1, -1 ]
367
+ },
368
+ {
369
+ name: "generation_logits"
370
+ data_type: TYPE_FP32
371
+ dims: [ -1, -1, -1 ]
372
+ },
373
+ {
374
+ name: "batch_index"
375
+ data_type: TYPE_INT32
376
+ dims: [ 1 ]
377
+ }
378
+ ]
379
+ instance_group [
380
+ {
381
+ count: 1
382
+ kind : KIND_CPU
383
+ }
384
+ ]
385
+ parameters: {
386
+ key: "max_beam_width"
387
+ value: {
388
+ string_value: "1"
389
+ }
390
+ }
391
+ parameters: {
392
+ key: "FORCE_CPU_ONLY_INPUT_TENSORS"
393
+ value: {
394
+ string_value: "no"
395
+ }
396
+ }
397
+ parameters: {
398
+ key: "gpt_model_type"
399
+ value: {
400
+ string_value: "inflight_fused_batching"
401
+ }
402
+ }
403
+ parameters: {
404
+ key: "gpt_model_path"
405
+ value: {
406
+ string_value: "/home/scratch.yuekaiz_wwfo_1/tekit/examples/qwen/qwen2_1.5B_instruct_fp16_merged_max_prompt_embedding_table_size_256"
407
+ }
408
+ }
409
+ parameters: {
410
+ key: "encoder_model_path"
411
+ value: {
412
+ string_value: "${encoder_engine_dir}"
413
+ }
414
+ }
415
+ parameters: {
416
+ key: "max_tokens_in_paged_kv_cache"
417
+ value: {
418
+ string_value: "2560"
419
+ }
420
+ }
421
+ parameters: {
422
+ key: "max_attention_window_size"
423
+ value: {
424
+ string_value: "2000"
425
+ }
426
+ }
427
+ parameters: {
428
+ key: "sink_token_length"
429
+ value: {
430
+ string_value: "${sink_token_length}"
431
+ }
432
+ }
433
+ parameters: {
434
+ key: "batch_scheduler_policy"
435
+ value: {
436
+ string_value: "${batch_scheduler_policy}"
437
+ }
438
+ }
439
+ parameters: {
440
+ key: "kv_cache_free_gpu_mem_fraction"
441
+ value: {
442
+ string_value: "0.5"
443
+ }
444
+ }
445
+ parameters: {
446
+ key: "kv_cache_host_memory_bytes"
447
+ value: {
448
+ string_value: "${kv_cache_host_memory_bytes}"
449
+ }
450
+ }
451
+ parameters: {
452
+ key: "kv_cache_onboard_blocks"
453
+ value: {
454
+ string_value: "${kv_cache_onboard_blocks}"
455
+ }
456
+ }
457
+ # enable_trt_overlap is deprecated and doesn't have any effect on the runtime
458
+ # parameters: {
459
+ # key: "enable_trt_overlap"
460
+ # value: {
461
+ # string_value: "${enable_trt_overlap}"
462
+ # }
463
+ # }
464
+ parameters: {
465
+ key: "exclude_input_in_output"
466
+ value: {
467
+ string_value: "True"
468
+ }
469
+ }
470
+ parameters: {
471
+ key: "cancellation_check_period_ms"
472
+ value: {
473
+ string_value: "${cancellation_check_period_ms}"
474
+ }
475
+ }
476
+ parameters: {
477
+ key: "stats_check_period_ms"
478
+ value: {
479
+ string_value: "${stats_check_period_ms}"
480
+ }
481
+ }
482
+ parameters: {
483
+ key: "iter_stats_max_iterations"
484
+ value: {
485
+ string_value: "${iter_stats_max_iterations}"
486
+ }
487
+ }
488
+ parameters: {
489
+ key: "request_stats_max_iterations"
490
+ value: {
491
+ string_value: "${request_stats_max_iterations}"
492
+ }
493
+ }
494
+ parameters: {
495
+ key: "enable_kv_cache_reuse"
496
+ value: {
497
+ string_value: "False"
498
+ }
499
+ }
500
+ parameters: {
501
+ key: "normalize_log_probs"
502
+ value: {
503
+ string_value: "${normalize_log_probs}"
504
+ }
505
+ }
506
+ parameters: {
507
+ key: "enable_chunked_context"
508
+ value: {
509
+ string_value: "${enable_chunked_context}"
510
+ }
511
+ }
512
+ parameters: {
513
+ key: "gpu_device_ids"
514
+ value: {
515
+ string_value: "${gpu_device_ids}"
516
+ }
517
+ }
518
+ parameters: {
519
+ key: "lora_cache_optimal_adapter_size"
520
+ value: {
521
+ string_value: "${lora_cache_optimal_adapter_size}"
522
+ }
523
+ }
524
+ parameters: {
525
+ key: "lora_cache_max_adapter_size"
526
+ value: {
527
+ string_value: "${lora_cache_max_adapter_size}"
528
+ }
529
+ }
530
+ parameters: {
531
+ key: "lora_cache_gpu_memory_fraction"
532
+ value: {
533
+ string_value: "${lora_cache_gpu_memory_fraction}"
534
+ }
535
+ }
536
+ parameters: {
537
+ key: "lora_cache_host_memory_bytes"
538
+ value: {
539
+ string_value: "${lora_cache_host_memory_bytes}"
540
+ }
541
+ }
542
+ parameters: {
543
+ key: "decoding_mode"
544
+ value: {
545
+ string_value: "${decoding_mode}"
546
+ }
547
+ }
548
+ parameters: {
549
+ key: "executor_worker_path"
550
+ value: {
551
+ string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
552
+ }
553
+ }
554
+ parameters: {
555
+ key: "medusa_choices"
556
+ value: {
557
+ string_value: "${medusa_choices}"
558
+ }
559
+ }
560
+ parameters: {
561
+ key: "gpu_weights_percent"
562
+ value: {
563
+ string_value: "${gpu_weights_percent}"
564
+ }
565
+ }
566
+ parameters: {
567
+ key: "enable_context_fmha_fp32_acc"
568
+ value: {
569
+ string_value: "${enable_context_fmha_fp32_acc}"
570
+ }
571
+ }
572
+ parameters: {
573
+ key: "multi_block_mode"
574
+ value: {
575
+ string_value: "${multi_block_mode}"
576
+ }
577
+ }
model_repo_whisper_qwen_trtllm/tensorrt_llm/config.template ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ name: "tensorrt_llm"
28
+ backend: "${triton_backend}"
29
+ max_batch_size: ${triton_max_batch_size}
30
+
31
+ model_transaction_policy {
32
+ decoupled: ${decoupled_mode}
33
+ }
34
+
35
+ dynamic_batching {
36
+ preferred_batch_size: [ ${triton_max_batch_size} ]
37
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
38
+ default_queue_policy: { max_queue_size: ${max_queue_size} }
39
+ }
40
+
41
+ input [
42
+ {
43
+ name: "input_ids"
44
+ data_type: TYPE_INT32
45
+ dims: [ -1 ]
46
+ allow_ragged_batch: true
47
+ optional: true
48
+ },
49
+ {
50
+ name: "encoder_input_features"
51
+ data_type: TYPE_FP16
52
+ dims: [ -1, -1 ]
53
+ allow_ragged_batch: true
54
+ optional: true
55
+ },
56
+ {
57
+ name: "encoder_output_lengths"
58
+ data_type: TYPE_INT32
59
+ dims: [ 1 ]
60
+ reshape: { shape: [ ] }
61
+ optional: true
62
+ },
63
+ {
64
+ name: "input_lengths"
65
+ data_type: TYPE_INT32
66
+ dims: [ 1 ]
67
+ reshape: { shape: [ ] }
68
+ },
69
+ {
70
+ name: "request_output_len"
71
+ data_type: TYPE_INT32
72
+ dims: [ 1 ]
73
+ reshape: { shape: [ ] }
74
+ },
75
+ {
76
+ name: "draft_input_ids"
77
+ data_type: TYPE_INT32
78
+ dims: [ -1 ]
79
+ optional: true
80
+ allow_ragged_batch: true
81
+ },
82
+ {
83
+ name: "decoder_input_ids"
84
+ data_type: TYPE_INT32
85
+ dims: [ -1 ]
86
+ optional: true
87
+ allow_ragged_batch: true
88
+ },
89
+ {
90
+ name: "decoder_input_lengths"
91
+ data_type: TYPE_INT32
92
+ dims: [ 1 ]
93
+ optional: true
94
+ reshape: { shape: [ ] }
95
+ },
96
+ {
97
+ name: "draft_logits"
98
+ data_type: TYPE_FP32
99
+ dims: [ -1, -1 ]
100
+ optional: true
101
+ allow_ragged_batch: true
102
+ },
103
+ {
104
+ name: "draft_acceptance_threshold"
105
+ data_type: TYPE_FP32
106
+ dims: [ 1 ]
107
+ reshape: { shape: [ ] }
108
+ optional: true
109
+ },
110
+ {
111
+ name: "end_id"
112
+ data_type: TYPE_INT32
113
+ dims: [ 1 ]
114
+ reshape: { shape: [ ] }
115
+ optional: true
116
+ },
117
+ {
118
+ name: "pad_id"
119
+ data_type: TYPE_INT32
120
+ dims: [ 1 ]
121
+ reshape: { shape: [ ] }
122
+ optional: true
123
+ },
124
+ {
125
+ name: "stop_words_list"
126
+ data_type: TYPE_INT32
127
+ dims: [ 2, -1 ]
128
+ optional: true
129
+ allow_ragged_batch: true
130
+ },
131
+ {
132
+ name: "bad_words_list"
133
+ data_type: TYPE_INT32
134
+ dims: [ 2, -1 ]
135
+ optional: true
136
+ allow_ragged_batch: true
137
+ },
138
+ {
139
+ name: "embedding_bias"
140
+ data_type: TYPE_FP32
141
+ dims: [ -1 ]
142
+ optional: true
143
+ allow_ragged_batch: true
144
+ },
145
+ {
146
+ name: "beam_width"
147
+ data_type: TYPE_INT32
148
+ dims: [ 1 ]
149
+ reshape: { shape: [ ] }
150
+ optional: true
151
+ },
152
+ {
153
+ name: "temperature"
154
+ data_type: TYPE_FP32
155
+ dims: [ 1 ]
156
+ reshape: { shape: [ ] }
157
+ optional: true
158
+ },
159
+ {
160
+ name: "runtime_top_k"
161
+ data_type: TYPE_INT32
162
+ dims: [ 1 ]
163
+ reshape: { shape: [ ] }
164
+ optional: true
165
+ },
166
+ {
167
+ name: "runtime_top_p"
168
+ data_type: TYPE_FP32
169
+ dims: [ 1 ]
170
+ reshape: { shape: [ ] }
171
+ optional: true
172
+ },
173
+ {
174
+ name: "runtime_top_p_min"
175
+ data_type: TYPE_FP32
176
+ dims: [ 1 ]
177
+ reshape: { shape: [ ] }
178
+ optional: true
179
+ },
180
+ {
181
+ name: "runtime_top_p_decay"
182
+ data_type: TYPE_FP32
183
+ dims: [ 1 ]
184
+ reshape: { shape: [ ] }
185
+ optional: true
186
+ },
187
+ {
188
+ name: "runtime_top_p_reset_ids"
189
+ data_type: TYPE_INT32
190
+ dims: [ 1 ]
191
+ reshape: { shape: [ ] }
192
+ optional: true
193
+ },
194
+ {
195
+ name: "len_penalty"
196
+ data_type: TYPE_FP32
197
+ dims: [ 1 ]
198
+ reshape: { shape: [ ] }
199
+ optional: true
200
+ },
201
+ {
202
+ name: "early_stopping"
203
+ data_type: TYPE_BOOL
204
+ dims: [ 1 ]
205
+ reshape: { shape: [ ] }
206
+ optional: true
207
+ },
208
+ {
209
+ name: "repetition_penalty"
210
+ data_type: TYPE_FP32
211
+ dims: [ 1 ]
212
+ reshape: { shape: [ ] }
213
+ optional: true
214
+ },
215
+ {
216
+ name: "min_length"
217
+ data_type: TYPE_INT32
218
+ dims: [ 1 ]
219
+ reshape: { shape: [ ] }
220
+ optional: true
221
+ },
222
+ {
223
+ name: "beam_search_diversity_rate"
224
+ data_type: TYPE_FP32
225
+ dims: [ 1 ]
226
+ reshape: { shape: [ ] }
227
+ optional: true
228
+ },
229
+ {
230
+ name: "presence_penalty"
231
+ data_type: TYPE_FP32
232
+ dims: [ 1 ]
233
+ reshape: { shape: [ ] }
234
+ optional: true
235
+ },
236
+ {
237
+ name: "frequency_penalty"
238
+ data_type: TYPE_FP32
239
+ dims: [ 1 ]
240
+ reshape: { shape: [ ] }
241
+ optional: true
242
+ },
243
+ {
244
+ name: "random_seed"
245
+ data_type: TYPE_UINT64
246
+ dims: [ 1 ]
247
+ reshape: { shape: [ ] }
248
+ optional: true
249
+ },
250
+ {
251
+ name: "return_log_probs"
252
+ data_type: TYPE_BOOL
253
+ dims: [ 1 ]
254
+ reshape: { shape: [ ] }
255
+ optional: true
256
+ },
257
+ {
258
+ name: "return_context_logits"
259
+ data_type: TYPE_BOOL
260
+ dims: [ 1 ]
261
+ reshape: { shape: [ ] }
262
+ optional: true
263
+ },
264
+ {
265
+ name: "return_generation_logits"
266
+ data_type: TYPE_BOOL
267
+ dims: [ 1 ]
268
+ reshape: { shape: [ ] }
269
+ optional: true
270
+ },
271
+ {
272
+ name: "stop"
273
+ data_type: TYPE_BOOL
274
+ dims: [ 1 ]
275
+ reshape: { shape: [ ] }
276
+ optional: true
277
+ },
278
+ {
279
+ name: "streaming"
280
+ data_type: TYPE_BOOL
281
+ dims: [ 1 ]
282
+ reshape: { shape: [ ] }
283
+ optional: true
284
+ },
285
+ {
286
+ name: "prompt_embedding_table"
287
+ data_type: TYPE_FP16
288
+ dims: [ -1, -1 ]
289
+ optional: true
290
+ allow_ragged_batch: true
291
+ },
292
+ {
293
+ name: "prompt_vocab_size"
294
+ data_type: TYPE_INT32
295
+ dims: [ 1 ]
296
+ reshape: { shape: [ ] }
297
+ optional: true
298
+ },
299
+ # the unique task ID for the given LoRA.
300
+ # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
301
+ # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
302
+ # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
303
+ {
304
+ name: "lora_task_id"
305
+ data_type: TYPE_UINT64
306
+ dims: [ 1 ]
307
+ reshape: { shape: [ ] }
308
+ optional: true
309
+ },
310
+ # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
311
+ # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
312
+ # each of the in / out tensors are first flattened and then concatenated together in the format above.
313
+ # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
314
+ {
315
+ name: "lora_weights"
316
+ data_type: TYPE_FP16
317
+ dims: [ -1, -1 ]
318
+ optional: true
319
+ allow_ragged_batch: true
320
+ },
321
+ # module identifier (same size a first dimension of lora_weights)
322
+ # See LoraModule::ModuleType for model id mapping
323
+ #
324
+ # "attn_qkv": 0 # compbined qkv adapter
325
+ # "attn_q": 1 # q adapter
326
+ # "attn_k": 2 # k adapter
327
+ # "attn_v": 3 # v adapter
328
+ # "attn_dense": 4 # adapter for the dense layer in attention
329
+ # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
330
+ # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
331
+ # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
332
+ #
333
+ # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
334
+ {
335
+ name: "lora_config"
336
+ data_type: TYPE_INT32
337
+ dims: [ -1, 3 ]
338
+ optional: true
339
+ allow_ragged_batch: true
340
+ }
341
+ ]
342
+ output [
343
+ {
344
+ name: "output_ids"
345
+ data_type: TYPE_INT32
346
+ dims: [ -1, -1 ]
347
+ },
348
+ {
349
+ name: "sequence_length"
350
+ data_type: TYPE_INT32
351
+ dims: [ -1 ]
352
+ },
353
+ {
354
+ name: "cum_log_probs"
355
+ data_type: TYPE_FP32
356
+ dims: [ -1 ]
357
+ },
358
+ {
359
+ name: "output_log_probs"
360
+ data_type: TYPE_FP32
361
+ dims: [ -1, -1 ]
362
+ },
363
+ {
364
+ name: "context_logits"
365
+ data_type: TYPE_FP32
366
+ dims: [ -1, -1 ]
367
+ },
368
+ {
369
+ name: "generation_logits"
370
+ data_type: TYPE_FP32
371
+ dims: [ -1, -1, -1 ]
372
+ },
373
+ {
374
+ name: "batch_index"
375
+ data_type: TYPE_INT32
376
+ dims: [ 1 ]
377
+ }
378
+ ]
379
+ instance_group [
380
+ {
381
+ count: 1
382
+ kind : KIND_CPU
383
+ }
384
+ ]
385
+ parameters: {
386
+ key: "max_beam_width"
387
+ value: {
388
+ string_value: "${max_beam_width}"
389
+ }
390
+ }
391
+ parameters: {
392
+ key: "FORCE_CPU_ONLY_INPUT_TENSORS"
393
+ value: {
394
+ string_value: "no"
395
+ }
396
+ }
397
+ parameters: {
398
+ key: "gpt_model_type"
399
+ value: {
400
+ string_value: "${batching_strategy}"
401
+ }
402
+ }
403
+ parameters: {
404
+ key: "gpt_model_path"
405
+ value: {
406
+ string_value: "${engine_dir}"
407
+ }
408
+ }
409
+ parameters: {
410
+ key: "encoder_model_path"
411
+ value: {
412
+ string_value: "${encoder_engine_dir}"
413
+ }
414
+ }
415
+ parameters: {
416
+ key: "max_tokens_in_paged_kv_cache"
417
+ value: {
418
+ string_value: "${max_tokens_in_paged_kv_cache}"
419
+ }
420
+ }
421
+ parameters: {
422
+ key: "max_attention_window_size"
423
+ value: {
424
+ string_value: "${max_attention_window_size}"
425
+ }
426
+ }
427
+ parameters: {
428
+ key: "sink_token_length"
429
+ value: {
430
+ string_value: "${sink_token_length}"
431
+ }
432
+ }
433
+ parameters: {
434
+ key: "batch_scheduler_policy"
435
+ value: {
436
+ string_value: "${batch_scheduler_policy}"
437
+ }
438
+ }
439
+ parameters: {
440
+ key: "kv_cache_free_gpu_mem_fraction"
441
+ value: {
442
+ string_value: "${kv_cache_free_gpu_mem_fraction}"
443
+ }
444
+ }
445
+ parameters: {
446
+ key: "kv_cache_host_memory_bytes"
447
+ value: {
448
+ string_value: "${kv_cache_host_memory_bytes}"
449
+ }
450
+ }
451
+ parameters: {
452
+ key: "kv_cache_onboard_blocks"
453
+ value: {
454
+ string_value: "${kv_cache_onboard_blocks}"
455
+ }
456
+ }
457
+ # enable_trt_overlap is deprecated and doesn't have any effect on the runtime
458
+ # parameters: {
459
+ # key: "enable_trt_overlap"
460
+ # value: {
461
+ # string_value: "${enable_trt_overlap}"
462
+ # }
463
+ # }
464
+ parameters: {
465
+ key: "exclude_input_in_output"
466
+ value: {
467
+ string_value: "${exclude_input_in_output}"
468
+ }
469
+ }
470
+ parameters: {
471
+ key: "cancellation_check_period_ms"
472
+ value: {
473
+ string_value: "${cancellation_check_period_ms}"
474
+ }
475
+ }
476
+ parameters: {
477
+ key: "stats_check_period_ms"
478
+ value: {
479
+ string_value: "${stats_check_period_ms}"
480
+ }
481
+ }
482
+ parameters: {
483
+ key: "iter_stats_max_iterations"
484
+ value: {
485
+ string_value: "${iter_stats_max_iterations}"
486
+ }
487
+ }
488
+ parameters: {
489
+ key: "request_stats_max_iterations"
490
+ value: {
491
+ string_value: "${request_stats_max_iterations}"
492
+ }
493
+ }
494
+ parameters: {
495
+ key: "enable_kv_cache_reuse"
496
+ value: {
497
+ string_value: "${enable_kv_cache_reuse}"
498
+ }
499
+ }
500
+ parameters: {
501
+ key: "normalize_log_probs"
502
+ value: {
503
+ string_value: "${normalize_log_probs}"
504
+ }
505
+ }
506
+ parameters: {
507
+ key: "enable_chunked_context"
508
+ value: {
509
+ string_value: "${enable_chunked_context}"
510
+ }
511
+ }
512
+ parameters: {
513
+ key: "gpu_device_ids"
514
+ value: {
515
+ string_value: "${gpu_device_ids}"
516
+ }
517
+ }
518
+ parameters: {
519
+ key: "lora_cache_optimal_adapter_size"
520
+ value: {
521
+ string_value: "${lora_cache_optimal_adapter_size}"
522
+ }
523
+ }
524
+ parameters: {
525
+ key: "lora_cache_max_adapter_size"
526
+ value: {
527
+ string_value: "${lora_cache_max_adapter_size}"
528
+ }
529
+ }
530
+ parameters: {
531
+ key: "lora_cache_gpu_memory_fraction"
532
+ value: {
533
+ string_value: "${lora_cache_gpu_memory_fraction}"
534
+ }
535
+ }
536
+ parameters: {
537
+ key: "lora_cache_host_memory_bytes"
538
+ value: {
539
+ string_value: "${lora_cache_host_memory_bytes}"
540
+ }
541
+ }
542
+ parameters: {
543
+ key: "decoding_mode"
544
+ value: {
545
+ string_value: "${decoding_mode}"
546
+ }
547
+ }
548
+ parameters: {
549
+ key: "executor_worker_path"
550
+ value: {
551
+ string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
552
+ }
553
+ }
554
+ parameters: {
555
+ key: "medusa_choices"
556
+ value: {
557
+ string_value: "${medusa_choices}"
558
+ }
559
+ }
560
+ parameters: {
561
+ key: "gpu_weights_percent"
562
+ value: {
563
+ string_value: "${gpu_weights_percent}"
564
+ }
565
+ }
566
+ parameters: {
567
+ key: "enable_context_fmha_fp32_acc"
568
+ value: {
569
+ string_value: "${enable_context_fmha_fp32_acc}"
570
+ }
571
+ }
572
+ parameters: {
573
+ key: "multi_block_mode"
574
+ value: {
575
+ string_value: "${multi_block_mode}"
576
+ }
577
+ }
model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/fbank.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/model.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/0/__pycache__/whisper_trtllm.cpython-310.pyc ADDED
Binary file (9.21 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/0/fbank.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from typing import Union
19
+ import os
20
+
21
+ def mel_filters(device, n_mels: int =128) -> torch.Tensor:
22
+ """
23
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
24
+ Allows decoupling librosa dependency; saved using:
25
+
26
+ np.savez_compressed(
27
+ "mel_filters.npz",
28
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
29
+ )
30
+ """
31
+ assert n_mels == 80 or n_mels == 128 , f"Unsupported n_mels: {n_mels}"
32
+ with np.load(
33
+ os.path.join(os.path.dirname(__file__), "mel_filters.npz")
34
+ ) as f:
35
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
36
+
37
+
38
+ def log_mel_spectrogram(
39
+ audio: Union[torch.Tensor],
40
+ filters: torch.Tensor,
41
+ n_mels: int = 128,
42
+ n_fft: int = 400,
43
+ hop_length: int = 160,
44
+ ):
45
+ """
46
+ Compute the log-Mel spectrogram of
47
+
48
+ Parameters
49
+ ----------
50
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
51
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
52
+
53
+ n_mels: int
54
+ The number of Mel-frequency filters, only 80 or 128 is supported
55
+
56
+ filters: torch.Tensor
57
+
58
+ Returns
59
+ -------
60
+ torch.Tensor, shape = (128, n_frames)
61
+ A Tensor that contains the Mel spectrogram
62
+ """
63
+ window = torch.hann_window(n_fft).to(audio.device)
64
+ stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
65
+ magnitudes = stft[..., :-1].abs() ** 2
66
+
67
+ mel_spec = filters @ magnitudes
68
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
69
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
70
+ log_spec = (log_spec + 4.0) / 4.0
71
+ # cast to float 16
72
+ log_spec = log_spec.half()
73
+ return log_spec
74
+
75
+ class FeatureExtractor(torch.nn.Module):
76
+ """Your Python model must use the same class name. Every Python model
77
+ that is created must have "TritonPythonModel" as the class name.
78
+ """
79
+
80
+ def __init__(self, n_mels: int = 128):
81
+ self.device = torch.device("cuda")
82
+ self.n_mels = n_mels
83
+ self.filters = mel_filters(self.device, n_mels=self.n_mels)
84
+
85
+ def compute_feature(self, wav, target: int = 3000):
86
+ mel = log_mel_spectrogram(wav, self.filters)
87
+ assert mel.shape[1] <= target, f"{mel.shape[1]} > {target}, audio is too long"
88
+ if mel.shape[1] < target:
89
+ mel = F.pad(mel, (0, target - mel.shape[1]), mode='constant')
90
+ mel = mel.unsqueeze(0)
91
+ return mel
model_repo_whisper_qwen_trtllm/whisper/0/mel_filters.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
3
+ size 4271
model_repo_whisper_qwen_trtllm/whisper/0/model.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import triton_python_backend_utils as pb_utils
3
+ import numpy as np
4
+ import json
5
+ import torch
6
+ from torch.utils.dlpack import from_dlpack, to_dlpack
7
+ import re
8
+ import transformers
9
+ from transformers import AutoTokenizer
10
+ from typing import Dict
11
+ from pathlib import Path
12
+ import traceback
13
+
14
+ from .whisper_trtllm import WhisperTRTLLM
15
+ from .fbank import FeatureExtractor
16
+
17
+ DEFAULT_SPEECH_TOKEN = "<speech>"
18
+ def preprocess(
19
+ messages,
20
+ tokenizer: transformers.PreTrainedTokenizer,
21
+ max_len: int = 128,
22
+ ) -> Dict:
23
+ """Preprocesses the data for supervised fine-tuning."""
24
+ texts = []
25
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
26
+ for i, msg in enumerate(messages):
27
+ texts.append(
28
+ tokenizer.apply_chat_template(
29
+ msg,
30
+ tokenize=True,
31
+ add_generation_prompt=False,
32
+ chat_template=TEMPLATE,
33
+ padding="longest",
34
+ max_length=max_len,
35
+ truncation=True,
36
+ )
37
+ )
38
+ max_len_texts = max([len(text) for text in texts])
39
+ if tokenizer.padding_side == "right":
40
+ texts = [
41
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
42
+ for text in texts
43
+ ]
44
+ else:
45
+ texts = [
46
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
47
+ for text in texts
48
+ ]
49
+
50
+ input_ids = torch.tensor(texts, dtype=torch.int)
51
+
52
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
53
+
54
+ return input_ids, attention_mask
55
+
56
+ class TritonPythonModel:
57
+ """Your Python model must use the same class name. Every Python model
58
+ that is created must have "TritonPythonModel" as the class name.
59
+ """
60
+
61
+ def initialize(self, args):
62
+ """`initialize` is called only once when the model is being loaded.
63
+ Implementing `initialize` function is optional. This function allows
64
+ the model to initialize any state associated with this model.
65
+
66
+ Parameters
67
+ ----------
68
+ args : dict
69
+ Both keys and values are strings. The dictionary keys and values are:
70
+ * model_config: A JSON string containing the model configuration
71
+ * model_instance_kind: A string containing model instance kind
72
+ * model_instance_device_id: A string containing model instance device ID
73
+ * model_repository: Model repository path
74
+ * model_version: Model version
75
+ * model_name: Model name
76
+ """
77
+ self.model_config = model_config = json.loads(args['model_config'])
78
+
79
+ # Get OUTPUT0 configuration
80
+ output0_config = pb_utils.get_output_config_by_name(
81
+ model_config, "TRANSCRIPTS")
82
+ # Convert Triton types to numpy types
83
+ self.out0_dtype = pb_utils.triton_string_to_numpy(
84
+ output0_config['data_type'])
85
+
86
+ #self.tokenizer = get_tokenizer(num_languages=100)
87
+ #self.blank = self.tokenizer.encode(" ", allowed_special=self.tokenizer.special_tokens_set)[0]
88
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
89
+ tokenizer.padding_side = "left"
90
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
91
+ tokenizer.add_special_tokens(special_tokens_dict)
92
+ self.tokenizer = tokenizer
93
+ self.eos = self.tokenizer.eos_token_id
94
+ self.default_speech_token_id = tokenizer.convert_tokens_to_ids(
95
+ DEFAULT_SPEECH_TOKEN
96
+ )
97
+ self.vocab_size = 151936
98
+ # self.vocab_size = 500000
99
+ # self.vocab_size = 160000
100
+
101
+ self.device = torch.device("cuda")
102
+ self.decoupled = False
103
+ self.logger = pb_utils.Logger
104
+ self.init_model(self.model_config['parameters'])
105
+
106
+ def init_model(self, parameters):
107
+ for key,value in parameters.items():
108
+ parameters[key] = value["string_value"]
109
+ engine_dir = parameters["engine_dir"]
110
+ n_mels = int(parameters["n_mels"])
111
+ adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt"
112
+ checkpoint = torch.load(
113
+ adapter_dir, map_location="cpu"
114
+ )
115
+ self.model = WhisperTRTLLM(engine_dir)
116
+ missing_keys, _ = self.model.load_state_dict(checkpoint, strict=False)
117
+ # print(f"Missing keys: {missing_keys}")
118
+ self.feature_extractor = FeatureExtractor(n_mels=n_mels)
119
+
120
+ def _tokenize(self, prompt=None, num_speech_tokens=187):
121
+ if prompt is None:
122
+ prompts = [
123
+ [
124
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
125
+ {"role": "assistant", "content": ""},
126
+ ]
127
+ ]
128
+ # prompts = [
129
+ # [
130
+ # {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}你好,你是谁?"},
131
+ # {"role": "assistant", "content": ""},
132
+ # ]
133
+ # ]
134
+
135
+ input_ids, _ = preprocess(prompts, self.tokenizer, max_len=128)
136
+ input_ids = input_ids.tolist()[0]
137
+ speech_token_index = input_ids.index(self.default_speech_token_id)
138
+ # replace 151646 with list(range(self.vocab_size, self.vocab_size + num_speech_tokens))
139
+ prompt_ids = input_ids[:speech_token_index] + list(range(self.vocab_size, self.vocab_size + num_speech_tokens)) + input_ids[speech_token_index + 1:]
140
+ # prompt_ids = input_ids[:speech_token_index] + input_ids[speech_token_index + 1:]
141
+ return prompt_ids
142
+
143
+ def _prepare_inputs(self, request, speech_embeddings, input_ids):
144
+ """
145
+ Prepares inputs for the language model based on the parameters in the
146
+ request, image features, and prompt. It tokenizes prompt,
147
+ extracts and processes additional parameters from the request:
148
+ - max_tokens: Maximum number of tokens to generate (default: 50)
149
+ - temperature: Controls randomness in generation (default: 0.5)
150
+ - top_k: Top K sampling parameter (default: 1)
151
+ - frequency_penalty: Penalizes frequent tokens (default: 0.7)
152
+ - seed: Random seed for generation (default: 10)
153
+
154
+ Final llm input dictionary is combined out of all processed parameters,
155
+ prompt's tokens and image features. The latter will be passed to llm
156
+ through `prompt_embedding_table`.
157
+
158
+ Parameters
159
+ ----------
160
+ - request: The original request object containing additional parameters.
161
+ - image_features (list): A list containing image feature tensors.
162
+ - prompt (str): The text prompt to be processed.
163
+
164
+ Returns
165
+ -------
166
+ - dict: A dictionary containing all the prepared inputs for the language model.
167
+ """
168
+ input_ids = np.array(input_ids, dtype=np.int32)
169
+ max_tokens = 200
170
+ input_len = input_ids.shape[0]
171
+
172
+ assert speech_embeddings.shape[1] == 187, "Only support 187 speech tokens"
173
+ embedding_args = {
174
+ "prompt_vocab_size": np.array(
175
+ [[speech_embeddings.shape[1]]], dtype=np.int32
176
+ ),
177
+ "prompt_embedding_table": speech_embeddings.detach().cpu().numpy(),
178
+ }
179
+ # TODO: 加不加这个出来的结果一样??? input_ids 超过最大 vocab 也不会报错???
180
+ input_dict = {
181
+ "input_ids": np.expand_dims(input_ids, 0),
182
+ "input_lengths": np.array([[input_len]], dtype=np.int32),
183
+ "request_output_len": np.array([[max_tokens]], dtype=np.int32),
184
+ "runtime_top_k": np.array([[1]], dtype=np.int32),
185
+ "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
186
+ "pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
187
+ "streaming": np.array([[0]], dtype=np.bool_),
188
+ **embedding_args,
189
+ }
190
+
191
+ print(input_ids)
192
+ for key, value in input_dict.items():
193
+ print(key, value.shape)
194
+
195
+ input_tensor_list = [pb_utils.Tensor(k, v) for k, v in input_dict.items()]
196
+ return input_tensor_list
197
+
198
+ def _prepare_llm_response(self, llm_request_inputs):
199
+ """
200
+ Prepares the response from the language model based on the provided
201
+ inputs. Creates a `pb_utils.InferenceRequest` object with passed
202
+ `llm_request_inputs` to send to a decoupled TensorRTLLM model.
203
+ For each response from the language model:
204
+ - Checks for errors and raise an exception if any are found.
205
+ - Extracts the "output_ids" tensor from the response.
206
+ - Determines the finish reason based on the presence of the
207
+ end-of-sequence token or reaching the maximum length.
208
+ - Appends the generated token IDs to `output_ids`.
209
+ - If the finish reason is determined, decodes the output IDs to text
210
+ and prepares the final response.
211
+
212
+ The final response includes the generated text, finish reason,
213
+ completion tokens, prompt tokens, and total tokens.
214
+
215
+ Parameters
216
+ ----------
217
+ - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
218
+
219
+ Returns
220
+ -------
221
+ - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
222
+ """
223
+
224
+ llm_request = pb_utils.InferenceRequest(
225
+ model_name="tensorrt_llm",
226
+ requested_output_names=["output_ids", "sequence_length"],
227
+ inputs=llm_request_inputs,
228
+ )
229
+ output_ids, output_len = [], 0
230
+ responses = llm_request.exec(decoupled=False)
231
+ responses = [responses]
232
+ for llm_response in responses:
233
+ if llm_response.has_error():
234
+ raise pb_utils.TritonModelException(llm_response.error().message())
235
+ stream_output_ids = (
236
+ pb_utils.get_output_tensor_by_name(llm_response, "output_ids")
237
+ .as_numpy()
238
+ .flatten()
239
+ .tolist()
240
+ )
241
+ finish_reason = "test"
242
+ if len(stream_output_ids) == 0 or (
243
+ len(stream_output_ids) != 0
244
+ and stream_output_ids[-1] == self.eos
245
+ ):
246
+ finish_reason = "stop"
247
+
248
+ output_ids += stream_output_ids
249
+
250
+ last_response = finish_reason != ""
251
+ output_len = len(output_ids)
252
+ if last_response:
253
+ print("final_output_ids", output_ids)
254
+ output_text = self.tokenizer.decode(output_ids).strip()
255
+ # print(output_text)
256
+ # output_text = re.sub(r'<\|.*?\|>', '', output_text)
257
+ response = pb_utils.InferenceResponse(
258
+ output_tensors=[
259
+ pb_utils.Tensor("TRANSCRIPTS", np.array([output_text], np.object_)),
260
+ ]
261
+ )
262
+ yield response
263
+
264
+ def _extract_speech_embeddings(self, mel):
265
+ return self.model.process_batch(mel)
266
+
267
+
268
+ def execute(self, requests):
269
+
270
+ responses = []
271
+
272
+ for request in requests:
273
+ wav = pb_utils.get_input_tensor_by_name(request, "WAV").as_numpy()
274
+ assert wav.shape[0] == 1, "Only support batch size 1"
275
+ # To support batch > 1
276
+ # cat mel,text_prompt, also, need to increase decoder_input_len as a triton input
277
+ wav = torch.from_numpy(wav[0]).to(self.device)
278
+ # mel shape [1, 80, 3000] for remove_input_padding=False
279
+ mel = self.feature_extractor.compute_feature(wav)
280
+ print("==========================================================")
281
+ messages = [
282
+ [
283
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
284
+ {"role": "assistant", "content": ""},
285
+ ]
286
+ ] * len(mel)
287
+
288
+ input_ids, attention_mask = preprocess(messages, self.tokenizer, max_len=128)
289
+
290
+ generated_ids = self.model.decode(
291
+ mel, input_ids.to(self.device, dtype=torch.long), attention_mask.to(self.device)
292
+ )
293
+ print("pytorch model", generated_ids)
294
+ print("--------------------------------------------------------------------------")
295
+
296
+
297
+ speech_embeddings = self._extract_speech_embeddings(mel)
298
+ input_ids = self._tokenize()
299
+
300
+
301
+ if self.decoupled:
302
+ response_sender = request.get_response_sender()
303
+ try:
304
+
305
+ llm_request_inputs = self._prepare_inputs(
306
+ request, speech_embeddings, input_ids
307
+ )
308
+ if isinstance(llm_request_inputs, pb_utils.TritonError):
309
+ error = pb_utils.InferenceResponse(error=llm_request_inputs)
310
+ if self.decoupled:
311
+ response_sender.send(
312
+ error, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
313
+ )
314
+ else:
315
+ responses.append(error)
316
+ llm_responses = self._prepare_llm_response(llm_request_inputs)
317
+
318
+ for triton_response in llm_responses:
319
+ if self.decoupled:
320
+ response_sender.send(triton_response)
321
+ else:
322
+ responses.append(triton_response)
323
+
324
+ if self.decoupled:
325
+ response_sender.send(
326
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
327
+
328
+ except Exception:
329
+ self.logger.log_error(traceback.format_exc())
330
+ # If encountering an error, send a response with err msg
331
+ error_response = pb_utils.InferenceResponse(
332
+ output_tensors=[],
333
+ error=pb_utils.TritonError(traceback.format_exc()))
334
+
335
+ if self.decoupled:
336
+ response_sender.send(error_response)
337
+ response_sender.send(
338
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
339
+ else:
340
+ responses.append(error_response)
341
+
342
+ if self.decoupled:
343
+ return None
344
+ else:
345
+ assert len(responses) == len(requests)
346
+ return responses
model_repo_whisper_qwen_trtllm/whisper/0/whisper_trtllm.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ from collections import OrderedDict
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import tensorrt_llm
23
+ import tensorrt_llm.logger as logger
24
+ from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
25
+ trt_dtype_to_torch)
26
+ from tensorrt_llm.runtime import ModelConfig, SamplingConfig
27
+ from tensorrt_llm.runtime.session import Session, TensorInfo
28
+ from transformers.trainer_pt_utils import LabelSmoother
29
+ from transformers import AutoModelForCausalLM, AutoTokenizer
30
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
31
+
32
+ DEFAULT_SPEECH_TOKEN = "<speech>"
33
+ def remove_tensor_padding(input_tensor, input_tensor_lengths=None, pad_value=0):
34
+ if input_tensor.dim() == 2:
35
+ # Text tensor case: batch, seq_len
36
+ assert torch.all(
37
+ input_tensor[:, 0] != pad_value
38
+ ), "First token in each sequence should not be pad_value"
39
+ assert input_tensor_lengths is None
40
+
41
+ # Create a mask for all non-pad tokens
42
+ mask = input_tensor != pad_value
43
+
44
+ # Apply the mask to input_tensor to remove pad tokens
45
+ output_tensor = input_tensor[mask].view(1, -1)
46
+
47
+ elif input_tensor.dim() == 3:
48
+ # Audio tensor case: batch, seq_len, feature_len
49
+ assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
50
+ batch_size, seq_len, feature_len = input_tensor.shape
51
+
52
+ # Initialize a list to collect valid sequences
53
+ valid_sequences = []
54
+
55
+ for i in range(batch_size):
56
+ valid_length = input_tensor_lengths[i]
57
+ valid_sequences.append(input_tensor[i, :valid_length, :])
58
+
59
+ # Concatenate all valid sequences along the batch dimension
60
+ output_tensor = torch.cat(valid_sequences, dim=0)
61
+
62
+ else:
63
+ raise ValueError("Input tensor must have 2 or 3 dimensions")
64
+
65
+ return output_tensor
66
+
67
+ def read_config(component, engine_dir):
68
+ config_path = engine_dir / component / 'config.json'
69
+ with open(config_path, 'r') as f:
70
+ config = json.load(f)
71
+ model_config = OrderedDict()
72
+ model_config.update(config['pretrained_config'])
73
+ model_config.update(config['build_config'])
74
+ return model_config
75
+
76
+ class WhisperEncoding:
77
+ def __init__(self, engine_dir):
78
+ self.session = self.get_session(engine_dir)
79
+ config = read_config('encoder', engine_dir)
80
+ self.n_mels = config['n_mels']
81
+ self.dtype = config['dtype']
82
+ self.num_languages = config['num_languages']
83
+ self.encoder_config = config
84
+
85
+ def get_session(self, engine_dir):
86
+ serialize_path = engine_dir / 'encoder' / 'rank0.engine'
87
+ with open(serialize_path, 'rb') as f:
88
+ session = Session.from_serialized_engine(f.read())
89
+ return session
90
+
91
+ def get_audio_features(self,
92
+ mel):
93
+ mel_input_lengths = torch.tensor(
94
+ [mel.shape[2] for _ in range(mel.shape[0])],
95
+ dtype=torch.int32,
96
+ device=mel.device)
97
+ if self.encoder_config['plugin_config']['remove_input_padding']:
98
+ # mel B,D,T -> B,T,D -> BxT, D
99
+ mel = mel.transpose(1, 2)
100
+ mel = remove_tensor_padding(mel, mel_input_lengths)
101
+
102
+ inputs = OrderedDict()
103
+ inputs['input_features'] = mel
104
+ inputs['input_lengths'] = mel_input_lengths
105
+
106
+ output_list = [
107
+ TensorInfo('input_features', str_dtype_to_trt(self.dtype),
108
+ mel.shape),
109
+ TensorInfo('input_lengths', str_dtype_to_trt('int32'),
110
+ mel_input_lengths.shape)
111
+ ]
112
+
113
+ output_info = (self.session).infer_shapes(output_list)
114
+
115
+ logger.debug(f'output info {output_info}')
116
+ outputs = {
117
+ t.name: torch.empty(tuple(t.shape),
118
+ dtype=trt_dtype_to_torch(t.dtype),
119
+ device='cuda')
120
+ for t in output_info
121
+ }
122
+ stream = torch.cuda.current_stream()
123
+ ok = self.session.run(inputs=inputs,
124
+ outputs=outputs,
125
+ stream=stream.cuda_stream)
126
+ assert ok, 'Engine execution failed'
127
+ stream.synchronize()
128
+ encoder_output = outputs['encoder_output']
129
+ encoder_output_lengths = mel_input_lengths // 2
130
+
131
+ return encoder_output
132
+
133
+ class EncoderProjector(torch.nn.Module):
134
+ """
135
+ The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
136
+ Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
137
+ Args:
138
+ encoder_dim (:obj:`int`): The dimension of the encoder outputs.
139
+ llm_dim (:obj:`int`): The dimension of the language model.
140
+ downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
141
+ """
142
+
143
+ def __init__(self, encoder_dim=1280, llm_dim=1536, downsample_rate=8):
144
+ super().__init__()
145
+ self.downsample_rate = downsample_rate
146
+ self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
147
+ self.relu = nn.ReLU()
148
+ self.linear2 = nn.Linear(llm_dim, llm_dim)
149
+
150
+ def forward(self, x):
151
+
152
+ batch_size, seq_len, feat_dim = x.size()
153
+ num_frames_to_discard = seq_len % self.downsample_rate
154
+ if num_frames_to_discard > 0:
155
+ x = x[:, :-num_frames_to_discard, :]
156
+ seq_len = x.size(1)
157
+
158
+ x = x.contiguous()
159
+ x = x.view(
160
+ batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
161
+ )
162
+
163
+ x = self.linear1(x)
164
+ x = self.relu(x)
165
+ x = self.linear2(x)
166
+ return x
167
+
168
+ class SPEECH_LLM(nn.Module):
169
+ """
170
+ The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
171
+ The encoder is used to extract speech features from the input speech signal.
172
+ The encoder projector is used to project the encoder outputs to the same dimension as the language model.
173
+ The language model is used to generate the text from the speech features.
174
+ Args:
175
+ encoder (:obj:`nn.Module`): The encoder module.
176
+ llm (:obj:`nn.Module`): The language model module.
177
+ encoder_projector (:obj:`nn.Module`): The encoder projector module.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ encoder: nn.Module,
183
+ llm: nn.Module,
184
+ encoder_projector: nn.Module,
185
+ ):
186
+ super().__init__()
187
+ self.encoder = encoder
188
+ self.llm = llm
189
+ self.encoder_projector = encoder_projector
190
+
191
+ class WhisperTRTLLM(nn.Module):
192
+
193
+ def __init__(self, engine_dir):
194
+ super().__init__()
195
+ world_size = 1
196
+ runtime_rank = tensorrt_llm.mpi_rank()
197
+ runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
198
+ torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
199
+ engine_dir = Path(engine_dir)
200
+
201
+ self.encoder = WhisperEncoding(engine_dir)
202
+ self.encoder_projector = EncoderProjector()
203
+ self.encoder_projector = self.encoder_projector.half().to("cuda")
204
+
205
+ llm = AutoModelForCausalLM.from_pretrained(
206
+ "/home/scratch.yuekaiz_wwfo_1/Qwen2_1.5B_merged",
207
+ attn_implementation="flash_attention_2",
208
+ torch_dtype=torch.float16,
209
+ )
210
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
211
+ tokenizer.padding_side = "left"
212
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
213
+ tokenizer.add_special_tokens(special_tokens_dict)
214
+ llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
215
+ llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
216
+ llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
217
+
218
+ llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
219
+ DEFAULT_SPEECH_TOKEN
220
+ )
221
+ self.llm = llm.half().to("cuda")
222
+ # print llm embedding layer shape
223
+ print("llm embedding layer shape", self.llm.get_input_embeddings().weight.shape)
224
+
225
+
226
+
227
+ def process_batch(
228
+ self,
229
+ mel,
230
+ decoder_input_ids=None,
231
+ eot_id=50257,
232
+ max_new_tokens=96,
233
+ num_beams=1):
234
+ encoder_outputs = self.encoder.get_audio_features(mel)
235
+ speech_features = self.encoder_projector(encoder_outputs)
236
+ speech_features = speech_features.to(torch.float16)
237
+ # [1,187,1536]
238
+ return speech_features
239
+
240
+
241
+ def decode(
242
+ self,
243
+ fbank: torch.Tensor = None,
244
+ input_ids: torch.LongTensor = None,
245
+ attention_mask: torch.Tensor = None,
246
+ **kwargs,
247
+ ):
248
+
249
+ encoder_outs = self.encoder.get_audio_features(fbank)
250
+ speech_features = self.encoder_projector(encoder_outs)
251
+ speech_features = speech_features.to(torch.float16)
252
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
253
+ speech_token_index = input_ids.tolist()[0].index(151646)
254
+ print("speech_token_index", speech_token_index, "speech_features_shape", speech_features.shape, "input_ids_shape", input_ids.shape, "inputs_embeds_shape", inputs_embeds.shape)
255
+
256
+ new_length = inputs_embeds.shape[1] + speech_features.shape[1] - 1
257
+ new_inputs_embeds = torch.zeros(1, new_length, 1536).to(inputs_embeds.device).half()
258
+ new_inputs_embeds[:, :3, :] = inputs_embeds[:, :3, :]
259
+ new_inputs_embeds[:, 3:3 + 187, :] = speech_features
260
+ new_inputs_embeds[:, 3 + 187:, :] = inputs_embeds[:, 4:, :]
261
+
262
+ inputs_embeds = new_inputs_embeds
263
+ generated_ids = self.llm.generate(
264
+ inputs_embeds=inputs_embeds,
265
+ max_new_tokens=kwargs.get("max_new_tokens", 200),
266
+ num_beams=kwargs.get("num_beams", 1),
267
+ do_sample=kwargs.get("do_sample", False),
268
+ min_length=kwargs.get("min_length", 1),
269
+ top_p=kwargs.get("top_p", 1.0),
270
+ repetition_penalty=kwargs.get("repetition_penalty", 1.0),
271
+ length_penalty=kwargs.get("length_penalty", 1.0),
272
+ temperature=kwargs.get("temperature", 1.0),
273
+ bos_token_id=self.llm.config.bos_token_id,
274
+ eos_token_id=self.llm.config.eos_token_id,
275
+ pad_token_id=self.llm.config.pad_token_id,
276
+ )
277
+
278
+ return generated_ids
model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/fbank.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/model.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/1/__pycache__/whisper_trtllm.cpython-310.pyc ADDED
Binary file (6.2 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/1/fbank.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from typing import Union
19
+ import os
20
+
21
+ def mel_filters(device, n_mels: int =128) -> torch.Tensor:
22
+ """
23
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
24
+ Allows decoupling librosa dependency; saved using:
25
+
26
+ np.savez_compressed(
27
+ "mel_filters.npz",
28
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
29
+ )
30
+ """
31
+ assert n_mels == 80 or n_mels == 128 , f"Unsupported n_mels: {n_mels}"
32
+ with np.load(
33
+ os.path.join(os.path.dirname(__file__), "mel_filters.npz")
34
+ ) as f:
35
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
36
+
37
+
38
+ def log_mel_spectrogram(
39
+ audio: Union[torch.Tensor],
40
+ filters: torch.Tensor,
41
+ n_mels: int = 128,
42
+ n_fft: int = 400,
43
+ hop_length: int = 160,
44
+ ):
45
+ """
46
+ Compute the log-Mel spectrogram of
47
+
48
+ Parameters
49
+ ----------
50
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
51
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
52
+
53
+ n_mels: int
54
+ The number of Mel-frequency filters, only 80 or 128 is supported
55
+
56
+ filters: torch.Tensor
57
+
58
+ Returns
59
+ -------
60
+ torch.Tensor, shape = (128, n_frames)
61
+ A Tensor that contains the Mel spectrogram
62
+ """
63
+ window = torch.hann_window(n_fft).to(audio.device)
64
+ stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
65
+ magnitudes = stft[..., :-1].abs() ** 2
66
+
67
+ mel_spec = filters @ magnitudes
68
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
69
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
70
+ log_spec = (log_spec + 4.0) / 4.0
71
+ # cast to float 16
72
+ log_spec = log_spec.half()
73
+ return log_spec
74
+
75
+ class FeatureExtractor(torch.nn.Module):
76
+ """Your Python model must use the same class name. Every Python model
77
+ that is created must have "TritonPythonModel" as the class name.
78
+ """
79
+
80
+ def __init__(self, n_mels: int = 128):
81
+ self.device = torch.device("cuda")
82
+ self.n_mels = n_mels
83
+ self.filters = mel_filters(self.device, n_mels=self.n_mels)
84
+
85
+ def compute_feature(self, wav, target: int = 3000):
86
+ mel = log_mel_spectrogram(wav, self.filters)
87
+ assert mel.shape[1] <= target, f"{mel.shape[1]} > {target}, audio is too long"
88
+ if mel.shape[1] < target:
89
+ mel = F.pad(mel, (0, target - mel.shape[1]), mode='constant')
90
+ mel = mel.unsqueeze(0)
91
+ return mel
model_repo_whisper_qwen_trtllm/whisper/1/mel_filters.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
3
+ size 4271
model_repo_whisper_qwen_trtllm/whisper/1/model.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import triton_python_backend_utils as pb_utils
3
+ import numpy as np
4
+ import json
5
+ import torch
6
+ from torch.utils.dlpack import from_dlpack, to_dlpack
7
+ import re
8
+ import transformers
9
+ from transformers import AutoTokenizer
10
+ from typing import Dict
11
+ from pathlib import Path
12
+ import traceback
13
+
14
+ from .whisper_trtllm import WhisperTRTLLM
15
+ from .fbank import FeatureExtractor
16
+
17
+ DEFAULT_SPEECH_TOKEN = "<speech>"
18
+ def preprocess(
19
+ messages,
20
+ tokenizer: transformers.PreTrainedTokenizer,
21
+ max_len: int = 128,
22
+ ) -> Dict:
23
+ """Preprocesses the data for supervised fine-tuning."""
24
+ texts = []
25
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
26
+ for i, msg in enumerate(messages):
27
+ texts.append(
28
+ tokenizer.apply_chat_template(
29
+ msg,
30
+ tokenize=True,
31
+ add_generation_prompt=False,
32
+ chat_template=TEMPLATE,
33
+ padding="longest",
34
+ max_length=max_len,
35
+ truncation=True,
36
+ )
37
+ )
38
+ max_len_texts = max([len(text) for text in texts])
39
+ if tokenizer.padding_side == "right":
40
+ texts = [
41
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
42
+ for text in texts
43
+ ]
44
+ else:
45
+ texts = [
46
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
47
+ for text in texts
48
+ ]
49
+
50
+ input_ids = torch.tensor(texts, dtype=torch.int)
51
+
52
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
53
+
54
+ return input_ids, attention_mask
55
+
56
+ class TritonPythonModel:
57
+ """Your Python model must use the same class name. Every Python model
58
+ that is created must have "TritonPythonModel" as the class name.
59
+ """
60
+
61
+ def initialize(self, args):
62
+ """`initialize` is called only once when the model is being loaded.
63
+ Implementing `initialize` function is optional. This function allows
64
+ the model to initialize any state associated with this model.
65
+
66
+ Parameters
67
+ ----------
68
+ args : dict
69
+ Both keys and values are strings. The dictionary keys and values are:
70
+ * model_config: A JSON string containing the model configuration
71
+ * model_instance_kind: A string containing model instance kind
72
+ * model_instance_device_id: A string containing model instance device ID
73
+ * model_repository: Model repository path
74
+ * model_version: Model version
75
+ * model_name: Model name
76
+ """
77
+ self.model_config = model_config = json.loads(args['model_config'])
78
+
79
+ # Get OUTPUT0 configuration
80
+ output0_config = pb_utils.get_output_config_by_name(
81
+ model_config, "TRANSCRIPTS")
82
+ # Convert Triton types to numpy types
83
+ self.out0_dtype = pb_utils.triton_string_to_numpy(
84
+ output0_config['data_type'])
85
+
86
+ #self.tokenizer = get_tokenizer(num_languages=100)
87
+ #self.blank = self.tokenizer.encode(" ", allowed_special=self.tokenizer.special_tokens_set)[0]
88
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
89
+ tokenizer.padding_side = "left"
90
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
91
+ tokenizer.add_special_tokens(special_tokens_dict)
92
+ self.tokenizer = tokenizer
93
+ self.eos = self.tokenizer.eos_token_id
94
+ self.default_speech_token_id = tokenizer.convert_tokens_to_ids(
95
+ DEFAULT_SPEECH_TOKEN
96
+ )
97
+ self.vocab_size = 151936
98
+
99
+ self.device = torch.device("cuda")
100
+ self.decoupled = False
101
+ self.logger = pb_utils.Logger
102
+ self.init_model(self.model_config['parameters'])
103
+
104
+ def init_model(self, parameters):
105
+ for key,value in parameters.items():
106
+ parameters[key] = value["string_value"]
107
+ engine_dir = parameters["engine_dir"]
108
+ n_mels = int(parameters["n_mels"])
109
+ adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt"
110
+ checkpoint = torch.load(
111
+ adapter_dir, map_location="cpu"
112
+ )
113
+ self.model = WhisperTRTLLM(engine_dir)
114
+ missing_keys, _ = self.model.load_state_dict(checkpoint, strict=False)
115
+ print(f"Missing keys: {missing_keys}")
116
+ self.feature_extractor = FeatureExtractor(n_mels=n_mels)
117
+
118
+ def _tokenize(self, prompt=None, num_speech_tokens=187):
119
+ if prompt is None:
120
+ prompts = [
121
+ [
122
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
123
+ {"role": "assistant", "content": ""},
124
+ ]
125
+ ]
126
+
127
+ input_ids, _ = preprocess(prompts, self.tokenizer, max_len=128)
128
+ print(444444444444444, input_ids)
129
+ input_ids = input_ids.tolist()[0]
130
+ speech_token_index = input_ids.index(self.default_speech_token_id)
131
+ # replace 151646 with list(range(self.vocab_size, self.vocab_size + num_speech_tokens))
132
+ prompt_ids = input_ids[:speech_token_index] + list(range(self.vocab_size, self.vocab_size + num_speech_tokens)) + input_ids[speech_token_index + 1:]
133
+ print(prompt_ids)
134
+ return prompt_ids
135
+
136
+ def _prepare_inputs(self, request, speech_embeddings, input_ids):
137
+ """
138
+ Prepares inputs for the language model based on the parameters in the
139
+ request, image features, and prompt. It tokenizes prompt,
140
+ extracts and processes additional parameters from the request:
141
+ - max_tokens: Maximum number of tokens to generate (default: 50)
142
+ - temperature: Controls randomness in generation (default: 0.5)
143
+ - top_k: Top K sampling parameter (default: 1)
144
+ - frequency_penalty: Penalizes frequent tokens (default: 0.7)
145
+ - seed: Random seed for generation (default: 10)
146
+
147
+ Final llm input dictionary is combined out of all processed parameters,
148
+ prompt's tokens and image features. The latter will be passed to llm
149
+ through `prompt_embedding_table`.
150
+
151
+ Parameters
152
+ ----------
153
+ - request: The original request object containing additional parameters.
154
+ - image_features (list): A list containing image feature tensors.
155
+ - prompt (str): The text prompt to be processed.
156
+
157
+ Returns
158
+ -------
159
+ - dict: A dictionary containing all the prepared inputs for the language model.
160
+ """
161
+ input_ids = np.array(input_ids, dtype=np.int32)
162
+ max_tokens = 50
163
+ input_len = input_ids.shape[0]
164
+ print(4555555555, speech_embeddings.shape)
165
+ assert speech_embeddings.shape[1] == 187, "Only support 187 speech tokens"
166
+ embedding_args = {
167
+ "prompt_vocab_size": np.array(
168
+ [[speech_embeddings.shape[1]]], dtype=np.int32
169
+ ),
170
+ "prompt_embedding_table": speech_embeddings.detach().cpu().numpy(),
171
+ }
172
+
173
+ input_dict = {
174
+ "input_ids": np.expand_dims(input_ids, 0),
175
+ "input_lengths": np.array([[input_len]], dtype=np.int32),
176
+ "request_output_len": np.array([[max_tokens]], dtype=np.int32),
177
+ "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
178
+ "streaming": np.array([[0]], dtype=np.bool_),
179
+ **embedding_args,
180
+ }
181
+
182
+ input_tensor_list = [pb_utils.Tensor(k, v) for k, v in input_dict.items()]
183
+ return input_tensor_list
184
+
185
+ def _prepare_llm_response(self, llm_request_inputs):
186
+ """
187
+ Prepares the response from the language model based on the provided
188
+ inputs. Creates a `pb_utils.InferenceRequest` object with passed
189
+ `llm_request_inputs` to send to a decoupled TensorRTLLM model.
190
+ For each response from the language model:
191
+ - Checks for errors and raise an exception if any are found.
192
+ - Extracts the "output_ids" tensor from the response.
193
+ - Determines the finish reason based on the presence of the
194
+ end-of-sequence token or reaching the maximum length.
195
+ - Appends the generated token IDs to `output_ids`.
196
+ - If the finish reason is determined, decodes the output IDs to text
197
+ and prepares the final response.
198
+
199
+ The final response includes the generated text, finish reason,
200
+ completion tokens, prompt tokens, and total tokens.
201
+
202
+ Parameters
203
+ ----------
204
+ - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
205
+
206
+ Returns
207
+ -------
208
+ - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
209
+ """
210
+
211
+ llm_request = pb_utils.InferenceRequest(
212
+ model_name="tensorrt_llm",
213
+ requested_output_names=["output_ids", "sequence_length"],
214
+ inputs=llm_request_inputs,
215
+ )
216
+ output_ids, output_len = [], 0
217
+ responses = llm_request.exec(decoupled=False)
218
+ responses = [responses]
219
+ for llm_response in responses:
220
+ if llm_response.has_error():
221
+ raise pb_utils.TritonModelException(llm_response.error().message())
222
+ stream_output_ids = (
223
+ pb_utils.get_output_tensor_by_name(llm_response, "output_ids")
224
+ .as_numpy()
225
+ .flatten()
226
+ .tolist()
227
+ )
228
+ finish_reason = "test"
229
+ if len(stream_output_ids) == 0 or (
230
+ len(stream_output_ids) != 0
231
+ and stream_output_ids[-1] == self.eos
232
+ ):
233
+ finish_reason = "stop"
234
+
235
+ output_ids += stream_output_ids
236
+
237
+ last_response = finish_reason != ""
238
+ output_len = len(output_ids)
239
+ if last_response:
240
+ print(output_ids)
241
+ output_text = self.tokenizer.decode(output_ids).strip()
242
+ # print(output_text)
243
+ # output_text = re.sub(r'<\|.*?\|>', '', output_text)
244
+ response = pb_utils.InferenceResponse(
245
+ output_tensors=[
246
+ pb_utils.Tensor("TRANSCRIPTS", np.array([output_text], np.object_)),
247
+ ]
248
+ )
249
+ yield response
250
+
251
+ def _extract_speech_embeddings(self, mel):
252
+ return self.model.process_batch(mel)
253
+
254
+
255
+ def execute(self, requests):
256
+
257
+ responses = []
258
+
259
+ for request in requests:
260
+ wav = pb_utils.get_input_tensor_by_name(request, "WAV").as_numpy()
261
+ assert wav.shape[0] == 1, "Only support batch size 1"
262
+ # To support batch > 1
263
+ # cat mel,text_prompt, also, need to increase decoder_input_len as a triton input
264
+ wav = torch.from_numpy(wav[0]).to(self.device)
265
+ # mel shape [1, 80, 3000] for remove_input_padding=False
266
+ mel = self.feature_extractor.compute_feature(wav)
267
+
268
+ speech_embeddings = self._extract_speech_embeddings(mel)
269
+ print(speech_embeddings.shape)
270
+ input_ids = self._tokenize()
271
+ print(input_ids)
272
+
273
+ if self.decoupled:
274
+ response_sender = request.get_response_sender()
275
+ try:
276
+
277
+ llm_request_inputs = self._prepare_inputs(
278
+ request, speech_embeddings, input_ids
279
+ )
280
+ if isinstance(llm_request_inputs, pb_utils.TritonError):
281
+ error = pb_utils.InferenceResponse(error=llm_request_inputs)
282
+ if self.decoupled:
283
+ response_sender.send(
284
+ error, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
285
+ )
286
+ else:
287
+ responses.append(error)
288
+ llm_responses = self._prepare_llm_response(llm_request_inputs)
289
+
290
+ for triton_response in llm_responses:
291
+ if self.decoupled:
292
+ response_sender.send(triton_response)
293
+ else:
294
+ responses.append(triton_response)
295
+
296
+ if self.decoupled:
297
+ response_sender.send(
298
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
299
+
300
+ except Exception:
301
+ self.logger.log_error(traceback.format_exc())
302
+ # If encountering an error, send a response with err msg
303
+ error_response = pb_utils.InferenceResponse(
304
+ output_tensors=[],
305
+ error=pb_utils.TritonError(traceback.format_exc()))
306
+
307
+ if self.decoupled:
308
+ response_sender.send(error_response)
309
+ response_sender.send(
310
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
311
+ else:
312
+ responses.append(error_response)
313
+
314
+ if self.decoupled:
315
+ return None
316
+ else:
317
+ assert len(responses) == len(requests)
318
+ return responses
model_repo_whisper_qwen_trtllm/whisper/1/whisper_trtllm.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ from collections import OrderedDict
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import tensorrt_llm
23
+ import tensorrt_llm.logger as logger
24
+ from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
25
+ trt_dtype_to_torch)
26
+ from tensorrt_llm.runtime import ModelConfig, SamplingConfig
27
+ from tensorrt_llm.runtime.session import Session, TensorInfo
28
+
29
+ def remove_tensor_padding(input_tensor, input_tensor_lengths=None, pad_value=0):
30
+ if input_tensor.dim() == 2:
31
+ # Text tensor case: batch, seq_len
32
+ assert torch.all(
33
+ input_tensor[:, 0] != pad_value
34
+ ), "First token in each sequence should not be pad_value"
35
+ assert input_tensor_lengths is None
36
+
37
+ # Create a mask for all non-pad tokens
38
+ mask = input_tensor != pad_value
39
+
40
+ # Apply the mask to input_tensor to remove pad tokens
41
+ output_tensor = input_tensor[mask].view(1, -1)
42
+
43
+ elif input_tensor.dim() == 3:
44
+ # Audio tensor case: batch, seq_len, feature_len
45
+ assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
46
+ batch_size, seq_len, feature_len = input_tensor.shape
47
+
48
+ # Initialize a list to collect valid sequences
49
+ valid_sequences = []
50
+
51
+ for i in range(batch_size):
52
+ valid_length = input_tensor_lengths[i]
53
+ valid_sequences.append(input_tensor[i, :valid_length, :])
54
+
55
+ # Concatenate all valid sequences along the batch dimension
56
+ output_tensor = torch.cat(valid_sequences, dim=0)
57
+
58
+ else:
59
+ raise ValueError("Input tensor must have 2 or 3 dimensions")
60
+
61
+ return output_tensor
62
+
63
+ def read_config(component, engine_dir):
64
+ config_path = engine_dir / component / 'config.json'
65
+ with open(config_path, 'r') as f:
66
+ config = json.load(f)
67
+ model_config = OrderedDict()
68
+ model_config.update(config['pretrained_config'])
69
+ model_config.update(config['build_config'])
70
+ return model_config
71
+
72
+ class WhisperEncoding:
73
+ def __init__(self, engine_dir):
74
+ self.session = self.get_session(engine_dir)
75
+ config = read_config('encoder', engine_dir)
76
+ self.n_mels = config['n_mels']
77
+ self.dtype = config['dtype']
78
+ self.num_languages = config['num_languages']
79
+ self.encoder_config = config
80
+
81
+ def get_session(self, engine_dir):
82
+ serialize_path = engine_dir / 'encoder' / 'rank0.engine'
83
+ with open(serialize_path, 'rb') as f:
84
+ session = Session.from_serialized_engine(f.read())
85
+ return session
86
+
87
+ def get_audio_features(self,
88
+ mel):
89
+ mel_input_lengths = torch.tensor(
90
+ [mel.shape[2] for _ in range(mel.shape[0])],
91
+ dtype=torch.int32,
92
+ device=mel.device)
93
+ if self.encoder_config['plugin_config']['remove_input_padding']:
94
+ # mel B,D,T -> B,T,D -> BxT, D
95
+ mel = mel.transpose(1, 2)
96
+ mel = remove_tensor_padding(mel, mel_input_lengths)
97
+
98
+ inputs = OrderedDict()
99
+ inputs['input_features'] = mel
100
+ inputs['input_lengths'] = mel_input_lengths
101
+
102
+ output_list = [
103
+ TensorInfo('input_features', str_dtype_to_trt(self.dtype),
104
+ mel.shape),
105
+ TensorInfo('input_lengths', str_dtype_to_trt('int32'),
106
+ mel_input_lengths.shape)
107
+ ]
108
+
109
+ output_info = (self.session).infer_shapes(output_list)
110
+
111
+ logger.debug(f'output info {output_info}')
112
+ outputs = {
113
+ t.name: torch.empty(tuple(t.shape),
114
+ dtype=trt_dtype_to_torch(t.dtype),
115
+ device='cuda')
116
+ for t in output_info
117
+ }
118
+ stream = torch.cuda.current_stream()
119
+ ok = self.session.run(inputs=inputs,
120
+ outputs=outputs,
121
+ stream=stream.cuda_stream)
122
+ assert ok, 'Engine execution failed'
123
+ stream.synchronize()
124
+ encoder_output = outputs['encoder_output']
125
+ encoder_output_lengths = mel_input_lengths // 2
126
+
127
+ return encoder_output
128
+
129
+ class EncoderProjector(torch.nn.Module):
130
+ """
131
+ The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
132
+ Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
133
+ Args:
134
+ encoder_dim (:obj:`int`): The dimension of the encoder outputs.
135
+ llm_dim (:obj:`int`): The dimension of the language model.
136
+ downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
137
+ """
138
+
139
+ def __init__(self, encoder_dim=1280, llm_dim=1536, downsample_rate=8):
140
+ super().__init__()
141
+ self.downsample_rate = downsample_rate
142
+ self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
143
+ self.relu = nn.ReLU()
144
+ self.linear2 = nn.Linear(llm_dim, llm_dim)
145
+
146
+ def forward(self, x):
147
+
148
+ batch_size, seq_len, feat_dim = x.size()
149
+ num_frames_to_discard = seq_len % self.downsample_rate
150
+ if num_frames_to_discard > 0:
151
+ x = x[:, :-num_frames_to_discard, :]
152
+ seq_len = x.size(1)
153
+
154
+ x = x.contiguous()
155
+ x = x.view(
156
+ batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
157
+ )
158
+
159
+ x = self.linear1(x)
160
+ x = self.relu(x)
161
+ x = self.linear2(x)
162
+ return x
163
+
164
+ # class SPEECH_LLM(nn.Module):
165
+ # """
166
+ # The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
167
+ # The encoder is used to extract speech features from the input speech signal.
168
+ # The encoder projector is used to project the encoder outputs to the same dimension as the language model.
169
+ # The language model is used to generate the text from the speech features.
170
+ # Args:
171
+ # encoder (:obj:`nn.Module`): The encoder module.
172
+ # llm (:obj:`nn.Module`): The language model module.
173
+ # encoder_projector (:obj:`nn.Module`): The encoder projector module.
174
+ # """
175
+
176
+ # def __init__(
177
+ # self,
178
+ # encoder: nn.Module = None,
179
+ # llm: nn.Module = None,
180
+ # encoder_projector: nn.Module = None,
181
+ # ):
182
+ # super().__init__()
183
+ # self.encoder = encoder
184
+ # self.llm = llm
185
+ # self.encoder_projector = encoder_projector
186
+
187
+ class WhisperTRTLLM(nn.Module):
188
+
189
+ def __init__(self, engine_dir):
190
+ super().__init__()
191
+ world_size = 1
192
+ runtime_rank = tensorrt_llm.mpi_rank()
193
+ runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
194
+ torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
195
+ engine_dir = Path(engine_dir)
196
+
197
+ self.encoder = WhisperEncoding(engine_dir)
198
+ self.encoder_projector = EncoderProjector()
199
+ self.encoder_projector = self.encoder_projector.half().to("cuda")
200
+
201
+ def process_batch(
202
+ self,
203
+ mel,
204
+ decoder_input_ids=None,
205
+ eot_id=50257,
206
+ max_new_tokens=96,
207
+ num_beams=1):
208
+ encoder_outputs = self.encoder.get_audio_features(mel)
209
+ speech_features = self.encoder_projector(encoder_outputs)
210
+ speech_features = speech_features.to(torch.float16)
211
+ print(2333333333333, speech_features.shape)
212
+ return speech_features
model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/fbank.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/model.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/2/__pycache__/whisper_trtllm.cpython-310.pyc ADDED
Binary file (7.37 kB). View file
 
model_repo_whisper_qwen_trtllm/whisper/2/fbank.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from typing import Union
19
+ import os
20
+
21
+ def mel_filters(device, n_mels: int =128) -> torch.Tensor:
22
+ """
23
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
24
+ Allows decoupling librosa dependency; saved using:
25
+
26
+ np.savez_compressed(
27
+ "mel_filters.npz",
28
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
29
+ )
30
+ """
31
+ assert n_mels == 80 or n_mels == 128 , f"Unsupported n_mels: {n_mels}"
32
+ with np.load(
33
+ os.path.join(os.path.dirname(__file__), "mel_filters.npz")
34
+ ) as f:
35
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
36
+
37
+
38
+ def log_mel_spectrogram(
39
+ audio: Union[torch.Tensor],
40
+ filters: torch.Tensor,
41
+ n_mels: int = 128,
42
+ n_fft: int = 400,
43
+ hop_length: int = 160,
44
+ ):
45
+ """
46
+ Compute the log-Mel spectrogram of
47
+
48
+ Parameters
49
+ ----------
50
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
51
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
52
+
53
+ n_mels: int
54
+ The number of Mel-frequency filters, only 80 or 128 is supported
55
+
56
+ filters: torch.Tensor
57
+
58
+ Returns
59
+ -------
60
+ torch.Tensor, shape = (128, n_frames)
61
+ A Tensor that contains the Mel spectrogram
62
+ """
63
+ window = torch.hann_window(n_fft).to(audio.device)
64
+ stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
65
+ magnitudes = stft[..., :-1].abs() ** 2
66
+
67
+ mel_spec = filters @ magnitudes
68
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
69
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
70
+ log_spec = (log_spec + 4.0) / 4.0
71
+ # cast to float 16
72
+ log_spec = log_spec.half()
73
+ return log_spec
74
+
75
+ class FeatureExtractor(torch.nn.Module):
76
+ """Your Python model must use the same class name. Every Python model
77
+ that is created must have "TritonPythonModel" as the class name.
78
+ """
79
+
80
+ def __init__(self, n_mels: int = 128):
81
+ self.device = torch.device("cuda")
82
+ self.n_mels = n_mels
83
+ self.filters = mel_filters(self.device, n_mels=self.n_mels)
84
+
85
+ def compute_feature(self, wav, target: int = 3000):
86
+ mel = log_mel_spectrogram(wav, self.filters)
87
+ assert mel.shape[1] <= target, f"{mel.shape[1]} > {target}, audio is too long"
88
+ if mel.shape[1] < target:
89
+ mel = F.pad(mel, (0, target - mel.shape[1]), mode='constant')
90
+ mel = mel.unsqueeze(0)
91
+ return mel
model_repo_whisper_qwen_trtllm/whisper/2/mel_filters.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
3
+ size 4271
model_repo_whisper_qwen_trtllm/whisper/2/model.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import triton_python_backend_utils as pb_utils
3
+ import numpy as np
4
+ import json
5
+ import torch
6
+ from torch.utils.dlpack import from_dlpack, to_dlpack
7
+ import re
8
+ import transformers
9
+ from transformers import AutoTokenizer
10
+ from typing import Dict
11
+ from pathlib import Path
12
+ import traceback
13
+
14
+ from .whisper_trtllm import WhisperTRTLLM
15
+ from .fbank import FeatureExtractor
16
+
17
+ DEFAULT_SPEECH_TOKEN = "<speech>"
18
+ def preprocess(
19
+ messages,
20
+ tokenizer: transformers.PreTrainedTokenizer,
21
+ max_len: int = 128,
22
+ ) -> Dict:
23
+ """Preprocesses the data for supervised fine-tuning."""
24
+ texts = []
25
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
26
+ for i, msg in enumerate(messages):
27
+ texts.append(
28
+ tokenizer.apply_chat_template(
29
+ msg,
30
+ tokenize=True,
31
+ add_generation_prompt=False,
32
+ chat_template=TEMPLATE,
33
+ padding="longest",
34
+ max_length=max_len,
35
+ truncation=True,
36
+ )
37
+ )
38
+ max_len_texts = max([len(text) for text in texts])
39
+ if tokenizer.padding_side == "right":
40
+ texts = [
41
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
42
+ for text in texts
43
+ ]
44
+ else:
45
+ texts = [
46
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
47
+ for text in texts
48
+ ]
49
+
50
+ input_ids = torch.tensor(texts, dtype=torch.int)
51
+
52
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
53
+
54
+ return input_ids, attention_mask
55
+
56
+ class TritonPythonModel:
57
+ """Your Python model must use the same class name. Every Python model
58
+ that is created must have "TritonPythonModel" as the class name.
59
+ """
60
+
61
+ def initialize(self, args):
62
+ """`initialize` is called only once when the model is being loaded.
63
+ Implementing `initialize` function is optional. This function allows
64
+ the model to initialize any state associated with this model.
65
+
66
+ Parameters
67
+ ----------
68
+ args : dict
69
+ Both keys and values are strings. The dictionary keys and values are:
70
+ * model_config: A JSON string containing the model configuration
71
+ * model_instance_kind: A string containing model instance kind
72
+ * model_instance_device_id: A string containing model instance device ID
73
+ * model_repository: Model repository path
74
+ * model_version: Model version
75
+ * model_name: Model name
76
+ """
77
+ self.model_config = model_config = json.loads(args['model_config'])
78
+
79
+ # Get OUTPUT0 configuration
80
+ output0_config = pb_utils.get_output_config_by_name(
81
+ model_config, "TRANSCRIPTS")
82
+ # Convert Triton types to numpy types
83
+ self.out0_dtype = pb_utils.triton_string_to_numpy(
84
+ output0_config['data_type'])
85
+
86
+ #self.tokenizer = get_tokenizer(num_languages=100)
87
+ #self.blank = self.tokenizer.encode(" ", allowed_special=self.tokenizer.special_tokens_set)[0]
88
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
89
+ tokenizer.padding_side = "left"
90
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
91
+ tokenizer.add_special_tokens(special_tokens_dict)
92
+ self.tokenizer = tokenizer
93
+ self.eos = self.tokenizer.eos_token_id
94
+ self.default_speech_token_id = tokenizer.convert_tokens_to_ids(
95
+ DEFAULT_SPEECH_TOKEN
96
+ )
97
+ self.vocab_size = 151936
98
+ # self.vocab_size = 500000
99
+ # self.vocab_size = 160000
100
+
101
+ self.device = torch.device("cuda")
102
+ self.decoupled = False
103
+ self.logger = pb_utils.Logger
104
+ self.init_model(self.model_config['parameters'])
105
+
106
+ def init_model(self, parameters):
107
+ for key,value in parameters.items():
108
+ parameters[key] = value["string_value"]
109
+ engine_dir = parameters["engine_dir"]
110
+ n_mels = int(parameters["n_mels"])
111
+ adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt"
112
+ checkpoint = torch.load(
113
+ adapter_dir, map_location="cpu"
114
+ )
115
+ self.model = WhisperTRTLLM(engine_dir)
116
+ missing_keys, _ = self.model.load_state_dict(checkpoint, strict=False)
117
+ # print(f"Missing keys: {missing_keys}")
118
+ self.feature_extractor = FeatureExtractor(n_mels=n_mels)
119
+
120
+ def _tokenize(self, prompt=None, num_speech_tokens=187):
121
+ if prompt is None:
122
+ prompts = [
123
+ [
124
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
125
+ {"role": "assistant", "content": ""},
126
+ ]
127
+ ]
128
+ # prompts = [
129
+ # [
130
+ # {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}你好,你是谁?"},
131
+ # {"role": "assistant", "content": ""},
132
+ # ]
133
+ # ]
134
+
135
+ input_ids, _ = preprocess(prompts, self.tokenizer, max_len=128)
136
+ input_ids = input_ids.tolist()[0]
137
+ speech_token_index = input_ids.index(self.default_speech_token_id)
138
+ # replace 151646 with list(range(self.vocab_size, self.vocab_size + num_speech_tokens))
139
+ prompt_ids = input_ids[:speech_token_index] + list(range(self.vocab_size, self.vocab_size + num_speech_tokens)) + input_ids[speech_token_index + 1:]
140
+ # prompt_ids = input_ids[:speech_token_index] + input_ids[speech_token_index + 1:]
141
+ return prompt_ids
142
+
143
+ def _prepare_inputs(self, request, speech_embeddings, input_ids):
144
+ """
145
+ Prepares inputs for the language model based on the parameters in the
146
+ request, image features, and prompt. It tokenizes prompt,
147
+ extracts and processes additional parameters from the request:
148
+ - max_tokens: Maximum number of tokens to generate (default: 50)
149
+ - temperature: Controls randomness in generation (default: 0.5)
150
+ - top_k: Top K sampling parameter (default: 1)
151
+ - frequency_penalty: Penalizes frequent tokens (default: 0.7)
152
+ - seed: Random seed for generation (default: 10)
153
+
154
+ Final llm input dictionary is combined out of all processed parameters,
155
+ prompt's tokens and image features. The latter will be passed to llm
156
+ through `prompt_embedding_table`.
157
+
158
+ Parameters
159
+ ----------
160
+ - request: The original request object containing additional parameters.
161
+ - image_features (list): A list containing image feature tensors.
162
+ - prompt (str): The text prompt to be processed.
163
+
164
+ Returns
165
+ -------
166
+ - dict: A dictionary containing all the prepared inputs for the language model.
167
+ """
168
+ input_ids = np.array(input_ids, dtype=np.int32)
169
+ max_tokens = 200
170
+ input_len = input_ids.shape[0]
171
+
172
+ assert speech_embeddings.shape[1] == 187, "Only support 187 speech tokens"
173
+ embedding_args = {
174
+ "prompt_vocab_size": np.array(
175
+ [[speech_embeddings.shape[1]]], dtype=np.int32
176
+ ),
177
+ "prompt_embedding_table": speech_embeddings.detach().cpu().numpy(),
178
+ }
179
+ # TODO: 加不加这个出来的结果一样??? input_ids 超过最大 vocab 也不会报错???
180
+ input_dict = {
181
+ "input_ids": np.expand_dims(input_ids, 0),
182
+ "input_lengths": np.array([[input_len]], dtype=np.int32),
183
+ "request_output_len": np.array([[max_tokens]], dtype=np.int32),
184
+ "runtime_top_k": np.array([[1]], dtype=np.int32),
185
+ "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
186
+ "pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
187
+ "streaming": np.array([[0]], dtype=np.bool_),
188
+ **embedding_args,
189
+ }
190
+
191
+ # print(input_ids)
192
+ # for key, value in input_dict.items():
193
+ # print(key, value.shape)
194
+
195
+ input_tensor_list = [pb_utils.Tensor(k, v) for k, v in input_dict.items()]
196
+ return input_tensor_list
197
+
198
+ def _prepare_llm_response(self, llm_request_inputs):
199
+ """
200
+ Prepares the response from the language model based on the provided
201
+ inputs. Creates a `pb_utils.InferenceRequest` object with passed
202
+ `llm_request_inputs` to send to a decoupled TensorRTLLM model.
203
+ For each response from the language model:
204
+ - Checks for errors and raise an exception if any are found.
205
+ - Extracts the "output_ids" tensor from the response.
206
+ - Determines the finish reason based on the presence of the
207
+ end-of-sequence token or reaching the maximum length.
208
+ - Appends the generated token IDs to `output_ids`.
209
+ - If the finish reason is determined, decodes the output IDs to text
210
+ and prepares the final response.
211
+
212
+ The final response includes the generated text, finish reason,
213
+ completion tokens, prompt tokens, and total tokens.
214
+
215
+ Parameters
216
+ ----------
217
+ - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
218
+
219
+ Returns
220
+ -------
221
+ - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
222
+ """
223
+
224
+ llm_request = pb_utils.InferenceRequest(
225
+ model_name="tensorrt_llm",
226
+ requested_output_names=["output_ids", "sequence_length"],
227
+ inputs=llm_request_inputs,
228
+ )
229
+ output_ids, output_len = [], 0
230
+ responses = llm_request.exec(decoupled=False)
231
+ responses = [responses]
232
+ for llm_response in responses:
233
+ if llm_response.has_error():
234
+ raise pb_utils.TritonModelException(llm_response.error().message())
235
+ stream_output_ids = (
236
+ pb_utils.get_output_tensor_by_name(llm_response, "output_ids")
237
+ .as_numpy()
238
+ .flatten()
239
+ .tolist()
240
+ )
241
+ finish_reason = "test"
242
+ if len(stream_output_ids) == 0 or (
243
+ len(stream_output_ids) != 0
244
+ and stream_output_ids[-1] == self.eos
245
+ ):
246
+ finish_reason = "stop"
247
+
248
+ output_ids += stream_output_ids
249
+
250
+ last_response = finish_reason != ""
251
+ output_len = len(output_ids)
252
+ if last_response:
253
+ print("final_output_ids", output_ids)
254
+ output_text = self.tokenizer.decode(output_ids).strip()
255
+ # print(output_text)
256
+ # output_text = re.sub(r'<\|.*?\|>', '', output_text)
257
+ response = pb_utils.InferenceResponse(
258
+ output_tensors=[
259
+ pb_utils.Tensor("TRANSCRIPTS", np.array([output_text], np.object_)),
260
+ ]
261
+ )
262
+ yield response
263
+
264
+ def _extract_speech_embeddings(self, mel):
265
+ return self.model.process_batch(mel)
266
+
267
+
268
+ def execute(self, requests):
269
+
270
+ responses = []
271
+
272
+ for request in requests:
273
+ wav = pb_utils.get_input_tensor_by_name(request, "WAV").as_numpy()
274
+ assert wav.shape[0] == 1, "Only support batch size 1"
275
+ # To support batch > 1
276
+ # cat mel,text_prompt, also, need to increase decoder_input_len as a triton input
277
+ wav = torch.from_numpy(wav[0]).to(self.device)
278
+ # mel shape [1, 80, 3000] for remove_input_padding=False
279
+ mel = self.feature_extractor.compute_feature(wav)
280
+ # print("==========================================================")
281
+ # messages = [
282
+ # [
283
+ # {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
284
+ # {"role": "assistant", "content": ""},
285
+ # ]
286
+ # ] * len(mel)
287
+
288
+ # input_ids, attention_mask = preprocess(messages, self.tokenizer, max_len=128)
289
+
290
+ # generated_ids = self.model.decode(
291
+ # mel, input_ids.to(self.device, dtype=torch.long), attention_mask.to(self.device)
292
+ # )
293
+ # print("pytorch model", generated_ids)
294
+ # print("--------------------------------------------------------------------------")
295
+
296
+
297
+ speech_embeddings = self._extract_speech_embeddings(mel)
298
+ input_ids = self._tokenize()
299
+
300
+
301
+ if self.decoupled:
302
+ response_sender = request.get_response_sender()
303
+ try:
304
+
305
+ llm_request_inputs = self._prepare_inputs(
306
+ request, speech_embeddings, input_ids
307
+ )
308
+ if isinstance(llm_request_inputs, pb_utils.TritonError):
309
+ error = pb_utils.InferenceResponse(error=llm_request_inputs)
310
+ if self.decoupled:
311
+ response_sender.send(
312
+ error, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
313
+ )
314
+ else:
315
+ responses.append(error)
316
+ llm_responses = self._prepare_llm_response(llm_request_inputs)
317
+
318
+ for triton_response in llm_responses:
319
+ if self.decoupled:
320
+ response_sender.send(triton_response)
321
+ else:
322
+ responses.append(triton_response)
323
+
324
+ if self.decoupled:
325
+ response_sender.send(
326
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
327
+
328
+ except Exception:
329
+ self.logger.log_error(traceback.format_exc())
330
+ # If encountering an error, send a response with err msg
331
+ error_response = pb_utils.InferenceResponse(
332
+ output_tensors=[],
333
+ error=pb_utils.TritonError(traceback.format_exc()))
334
+
335
+ if self.decoupled:
336
+ response_sender.send(error_response)
337
+ response_sender.send(
338
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
339
+ else:
340
+ responses.append(error_response)
341
+
342
+ if self.decoupled:
343
+ return None
344
+ else:
345
+ assert len(responses) == len(requests)
346
+ return responses
model_repo_whisper_qwen_trtllm/whisper/2/whisper_trtllm.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ from collections import OrderedDict
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import tensorrt_llm
23
+ import tensorrt_llm.logger as logger
24
+ from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
25
+ trt_dtype_to_torch)
26
+ from tensorrt_llm.runtime import ModelConfig, SamplingConfig
27
+ from tensorrt_llm.runtime.session import Session, TensorInfo
28
+ from transformers.trainer_pt_utils import LabelSmoother
29
+ from transformers import AutoModelForCausalLM, AutoTokenizer
30
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
31
+
32
+ DEFAULT_SPEECH_TOKEN = "<speech>"
33
+ def remove_tensor_padding(input_tensor, input_tensor_lengths=None, pad_value=0):
34
+ if input_tensor.dim() == 2:
35
+ # Text tensor case: batch, seq_len
36
+ assert torch.all(
37
+ input_tensor[:, 0] != pad_value
38
+ ), "First token in each sequence should not be pad_value"
39
+ assert input_tensor_lengths is None
40
+
41
+ # Create a mask for all non-pad tokens
42
+ mask = input_tensor != pad_value
43
+
44
+ # Apply the mask to input_tensor to remove pad tokens
45
+ output_tensor = input_tensor[mask].view(1, -1)
46
+
47
+ elif input_tensor.dim() == 3:
48
+ # Audio tensor case: batch, seq_len, feature_len
49
+ assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
50
+ batch_size, seq_len, feature_len = input_tensor.shape
51
+
52
+ # Initialize a list to collect valid sequences
53
+ valid_sequences = []
54
+
55
+ for i in range(batch_size):
56
+ valid_length = input_tensor_lengths[i]
57
+ valid_sequences.append(input_tensor[i, :valid_length, :])
58
+
59
+ # Concatenate all valid sequences along the batch dimension
60
+ output_tensor = torch.cat(valid_sequences, dim=0)
61
+
62
+ else:
63
+ raise ValueError("Input tensor must have 2 or 3 dimensions")
64
+
65
+ return output_tensor
66
+
67
+ def read_config(component, engine_dir):
68
+ config_path = engine_dir / component / 'config.json'
69
+ with open(config_path, 'r') as f:
70
+ config = json.load(f)
71
+ model_config = OrderedDict()
72
+ model_config.update(config['pretrained_config'])
73
+ model_config.update(config['build_config'])
74
+ return model_config
75
+
76
+ class WhisperEncoding:
77
+ def __init__(self, engine_dir):
78
+ self.session = self.get_session(engine_dir)
79
+ config = read_config('encoder', engine_dir)
80
+ self.n_mels = config['n_mels']
81
+ self.dtype = config['dtype']
82
+ self.num_languages = config['num_languages']
83
+ self.encoder_config = config
84
+
85
+ def get_session(self, engine_dir):
86
+ serialize_path = engine_dir / 'encoder' / 'rank0.engine'
87
+ with open(serialize_path, 'rb') as f:
88
+ session = Session.from_serialized_engine(f.read())
89
+ return session
90
+
91
+ def get_audio_features(self,
92
+ mel):
93
+ mel_input_lengths = torch.tensor(
94
+ [mel.shape[2] for _ in range(mel.shape[0])],
95
+ dtype=torch.int32,
96
+ device=mel.device)
97
+ if self.encoder_config['plugin_config']['remove_input_padding']:
98
+ # mel B,D,T -> B,T,D -> BxT, D
99
+ mel = mel.transpose(1, 2)
100
+ mel = remove_tensor_padding(mel, mel_input_lengths)
101
+
102
+ inputs = OrderedDict()
103
+ inputs['input_features'] = mel
104
+ inputs['input_lengths'] = mel_input_lengths
105
+
106
+ output_list = [
107
+ TensorInfo('input_features', str_dtype_to_trt(self.dtype),
108
+ mel.shape),
109
+ TensorInfo('input_lengths', str_dtype_to_trt('int32'),
110
+ mel_input_lengths.shape)
111
+ ]
112
+
113
+ output_info = (self.session).infer_shapes(output_list)
114
+
115
+ logger.debug(f'output info {output_info}')
116
+ outputs = {
117
+ t.name: torch.empty(tuple(t.shape),
118
+ dtype=trt_dtype_to_torch(t.dtype),
119
+ device='cuda')
120
+ for t in output_info
121
+ }
122
+ stream = torch.cuda.current_stream()
123
+ ok = self.session.run(inputs=inputs,
124
+ outputs=outputs,
125
+ stream=stream.cuda_stream)
126
+ assert ok, 'Engine execution failed'
127
+ stream.synchronize()
128
+ encoder_output = outputs['encoder_output']
129
+ encoder_output_lengths = mel_input_lengths // 2
130
+
131
+ return encoder_output
132
+
133
+ class EncoderProjector(torch.nn.Module):
134
+ """
135
+ The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
136
+ Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
137
+ Args:
138
+ encoder_dim (:obj:`int`): The dimension of the encoder outputs.
139
+ llm_dim (:obj:`int`): The dimension of the language model.
140
+ downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
141
+ """
142
+
143
+ def __init__(self, encoder_dim=1280, llm_dim=1536, downsample_rate=8):
144
+ super().__init__()
145
+ self.downsample_rate = downsample_rate
146
+ self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
147
+ self.relu = nn.ReLU()
148
+ self.linear2 = nn.Linear(llm_dim, llm_dim)
149
+
150
+ def forward(self, x):
151
+
152
+ batch_size, seq_len, feat_dim = x.size()
153
+ num_frames_to_discard = seq_len % self.downsample_rate
154
+ if num_frames_to_discard > 0:
155
+ x = x[:, :-num_frames_to_discard, :]
156
+ seq_len = x.size(1)
157
+
158
+ x = x.contiguous()
159
+ x = x.view(
160
+ batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
161
+ )
162
+
163
+ x = self.linear1(x)
164
+ x = self.relu(x)
165
+ x = self.linear2(x)
166
+ return x
167
+
168
+ class SPEECH_LLM(nn.Module):
169
+ """
170
+ The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
171
+ The encoder is used to extract speech features from the input speech signal.
172
+ The encoder projector is used to project the encoder outputs to the same dimension as the language model.
173
+ The language model is used to generate the text from the speech features.
174
+ Args:
175
+ encoder (:obj:`nn.Module`): The encoder module.
176
+ llm (:obj:`nn.Module`): The language model module.
177
+ encoder_projector (:obj:`nn.Module`): The encoder projector module.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ encoder: nn.Module,
183
+ llm: nn.Module,
184
+ encoder_projector: nn.Module,
185
+ ):
186
+ super().__init__()
187
+ self.encoder = encoder
188
+ self.llm = llm
189
+ self.encoder_projector = encoder_projector
190
+
191
+ class WhisperTRTLLM(nn.Module):
192
+
193
+ def __init__(self, engine_dir):
194
+ super().__init__()
195
+ world_size = 1
196
+ runtime_rank = tensorrt_llm.mpi_rank()
197
+ runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
198
+ torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
199
+ engine_dir = Path(engine_dir)
200
+
201
+ self.encoder = WhisperEncoding(engine_dir)
202
+ self.encoder_projector = EncoderProjector()
203
+ self.encoder_projector = self.encoder_projector.half().to("cuda")
204
+
205
+ # llm = AutoModelForCausalLM.from_pretrained(
206
+ # "/home/scratch.yuekaiz_wwfo_1/Qwen2_1.5B_merged",
207
+ # attn_implementation="flash_attention_2",
208
+ # torch_dtype=torch.float16,
209
+ # )
210
+ # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
211
+ # tokenizer.padding_side = "left"
212
+ # special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
213
+ # tokenizer.add_special_tokens(special_tokens_dict)
214
+ # llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
215
+ # llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
216
+ # llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
217
+
218
+ # llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
219
+ # DEFAULT_SPEECH_TOKEN
220
+ # )
221
+ # self.llm = llm.half().to("cuda")
222
+ # # print llm embedding layer shape
223
+ # print("llm embedding layer shape", self.llm.get_input_embeddings().weight.shape)
224
+
225
+
226
+
227
+ def process_batch(
228
+ self,
229
+ mel,
230
+ decoder_input_ids=None,
231
+ eot_id=50257,
232
+ max_new_tokens=96,
233
+ num_beams=1):
234
+ encoder_outputs = self.encoder.get_audio_features(mel)
235
+ speech_features = self.encoder_projector(encoder_outputs)
236
+ speech_features = speech_features.to(torch.float16)
237
+ # [1,187,1536]
238
+ return speech_features
239
+
240
+
241
+ # def decode(
242
+ # self,
243
+ # fbank: torch.Tensor = None,
244
+ # input_ids: torch.LongTensor = None,
245
+ # attention_mask: torch.Tensor = None,
246
+ # **kwargs,
247
+ # ):
248
+
249
+ # encoder_outs = self.encoder.get_audio_features(fbank)
250
+ # speech_features = self.encoder_projector(encoder_outs)
251
+ # speech_features = speech_features.to(torch.float16)
252
+ # inputs_embeds = self.llm.get_input_embeddings()(input_ids)
253
+ # speech_token_index = input_ids.tolist()[0].index(151646)
254
+ # print("speech_token_index", speech_token_index, "speech_features_shape", speech_features.shape, "input_ids_shape", input_ids.shape, "inputs_embeds_shape", inputs_embeds.shape)
255
+
256
+ # new_length = inputs_embeds.shape[1] + speech_features.shape[1] - 1
257
+ # new_inputs_embeds = torch.zeros(1, new_length, 1536).to(inputs_embeds.device).half()
258
+ # new_inputs_embeds[:, :3, :] = inputs_embeds[:, :3, :]
259
+ # new_inputs_embeds[:, 3:3 + 187, :] = speech_features
260
+ # new_inputs_embeds[:, 3 + 187:, :] = inputs_embeds[:, 4:, :]
261
+
262
+ # inputs_embeds = new_inputs_embeds
263
+ # generated_ids = self.llm.generate(
264
+ # inputs_embeds=inputs_embeds,
265
+ # max_new_tokens=kwargs.get("max_new_tokens", 200),
266
+ # num_beams=kwargs.get("num_beams", 1),
267
+ # do_sample=kwargs.get("do_sample", False),
268
+ # min_length=kwargs.get("min_length", 1),
269
+ # top_p=kwargs.get("top_p", 1.0),
270
+ # repetition_penalty=kwargs.get("repetition_penalty", 1.0),
271
+ # length_penalty=kwargs.get("length_penalty", 1.0),
272
+ # temperature=kwargs.get("temperature", 1.0),
273
+ # bos_token_id=self.llm.config.bos_token_id,
274
+ # eos_token_id=self.llm.config.eos_token_id,
275
+ # pad_token_id=self.llm.config.pad_token_id,
276
+ # )
277
+
278
+ # return generated_ids
model_repo_whisper_qwen_trtllm/whisper/config.pbtxt ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "whisper"
16
+ backend: "python"
17
+ max_batch_size: 8
18
+
19
+ parameters [
20
+ {
21
+ key: "n_mels",
22
+ value: {string_value:"80"} # 128 dim for large-v3, 80 dim for large-v2
23
+ },
24
+ {
25
+ key: "engine_dir"
26
+ value: { string_value: "/home/scratch.yuekaiz_wwfo_1/tekit/examples/whisper/whisper_multi_zh"}
27
+ }
28
+ ]
29
+
30
+
31
+ input [
32
+ {
33
+ name: "TEXT_PREFIX"
34
+ data_type: TYPE_STRING
35
+ dims: [1]
36
+ },
37
+ {
38
+ name: "WAV"
39
+ data_type: TYPE_FP32
40
+ dims: [-1]
41
+ }
42
+ ]
43
+
44
+ output [
45
+ {
46
+ name: "TRANSCRIPTS"
47
+ data_type: TYPE_STRING
48
+ dims: [1]
49
+ }
50
+ ]
51
+
52
+ dynamic_batching {
53
+ preferred_batch_size: [ 4, 8]
54
+ max_queue_delay_microseconds: 1000
55
+ }
56
+ instance_group [
57
+ {
58
+ count: 1
59
+ kind: KIND_CPU
60
+ }
61
+ ]