Feature Extraction
Transformers
Safetensors
diva
custom_code
Helw150 commited on
Commit
49f38f9
1 Parent(s): b79a517

Add Batch Support

Browse files
Files changed (2) hide show
  1. modeling_diva.py +58 -26
  2. test.py +28 -0
modeling_diva.py CHANGED
@@ -44,7 +44,7 @@ class WhisperConnector(nn.Module):
44
 
45
  class DiVAModel(PreTrainedModel):
46
  config_class = DiVAConfig
47
-
48
  def __init__(
49
  self, via_path=None, config_dict={}, device_map=None, speech_encoder_device=None
50
  ):
@@ -105,10 +105,9 @@ class DiVAModel(PreTrainedModel):
105
  )
106
  self.speech_encoder_device = speech_encoder_device
107
 
108
-
109
- def can_generate(cls):
110
  return False
111
-
112
  @classmethod
113
  def from_pretrained(
114
  cls,
@@ -182,8 +181,14 @@ class DiVAModel(PreTrainedModel):
182
 
183
  return outputs
184
 
 
185
  def generate(
186
- self, audio, text_prompt, do_sample=False, logits_processor=None, max_new_tokens=128
 
 
 
 
 
187
  ):
188
  inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
189
  input_features = inputs.input_features.to(self.speech_encoder_device)
@@ -193,29 +198,45 @@ class DiVAModel(PreTrainedModel):
193
  virt_tokens = self.connector(
194
  hidden_states,
195
  output_device=self.llama_decoder.model.embed_tokens.weight.device,
196
- ).squeeze()
 
197
 
198
  if text_prompt != None and text_prompt != "":
199
  user_prompt_text = torch.tensor(
200
- self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"],
 
 
 
 
 
201
  device=self.pre_user_suffix.device,
202
  )
203
  prefix = torch.cat(
204
- [self.pre_user_suffix, user_prompt_text, self.prefix], axis=0
 
 
 
 
 
 
 
 
 
 
 
205
  )
206
  else:
207
  prefix = self.prefix
208
- prefix_embed = self.llama_decoder.model.embed_tokens(prefix)
209
  suffix = self.final_header
210
- suffix_embed = self.llama_decoder.model.embed_tokens(suffix)
211
- inputs_embeds = torch.cat(
212
- [prefix_embed, virt_tokens, suffix_embed], axis=0
213
- ).unsqueeze(0)
214
- outs = []
215
  outputs = None
216
  greedy = 1
217
  i = 0
218
- while greedy != 128009 and len(outs) < max_new_tokens:
219
  past_key_values = outputs.past_key_values if outputs else None
220
  outputs = self.llama_decoder(
221
  inputs_embeds=inputs_embeds.to(
@@ -225,7 +246,7 @@ class DiVAModel(PreTrainedModel):
225
  output_hidden_states=True,
226
  past_key_values=past_key_values,
227
  )
228
- next_token_logits = outputs.logits[-1, -1, :]
229
 
230
  if logits_processor:
231
  local_outs = torch.tensor(outs) if outs != [] else suffix
@@ -240,16 +261,23 @@ class DiVAModel(PreTrainedModel):
240
  probs = F.softmax(logits, dim=-1)
241
  greedy = torch.multinomial(probs, num_samples=1)[0]
242
  else:
243
- greedy = next_token_logits.argmax()
244
- outs.append(greedy)
245
- next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1))
 
 
 
 
246
  inputs_embeds = next_embed
247
- return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
248
- "<|eot_id|>", ""
249
- )
250
 
251
  def generate_stream(
252
- self, audio, text_prompt, do_sample=False, logits_processor=None, max_new_tokens=128
 
 
 
 
 
253
  ):
254
  inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
255
  input_features = inputs.input_features.to(self.whisper_encoder.device)
@@ -284,7 +312,7 @@ class DiVAModel(PreTrainedModel):
284
  while greedy != 128009 and len(outs) < max_new_tokens:
285
  past_key_values = outputs.past_key_values if outputs else None
286
  outputs = self.llama_decoder(
287
- inputs_embeds=inputs_embeds.to(
288
  self.llama_decoder.model.embed_tokens.weight.device
289
  ).half(),
290
  return_dict=True,
@@ -310,5 +338,9 @@ class DiVAModel(PreTrainedModel):
310
  outs.append(greedy)
311
  next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1))
312
  inputs_embeds = next_embed
313
- yield self.tokenizer.decode(outs, skip_special_tokens=True).replace("<|eot_id|>", "")
314
- return self.tokenizer.decode(outs, skip_special_tokens=True).replace("<|eot_id|>", "")
 
 
 
 
 
44
 
45
  class DiVAModel(PreTrainedModel):
46
  config_class = DiVAConfig
47
+
48
  def __init__(
49
  self, via_path=None, config_dict={}, device_map=None, speech_encoder_device=None
50
  ):
 
105
  )
106
  self.speech_encoder_device = speech_encoder_device
107
 
108
+ def can_generate(cls):
 
109
  return False
110
+
111
  @classmethod
112
  def from_pretrained(
113
  cls,
 
181
 
182
  return outputs
183
 
184
+ @torch.no_grad()
185
  def generate(
186
+ self,
187
+ audio,
188
+ text_prompt=None,
189
+ do_sample=False,
190
+ logits_processor=None,
191
+ max_new_tokens=128,
192
  ):
193
  inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
194
  input_features = inputs.input_features.to(self.speech_encoder_device)
 
198
  virt_tokens = self.connector(
199
  hidden_states,
200
  output_device=self.llama_decoder.model.embed_tokens.weight.device,
201
+ )
202
+ bsz = virt_tokens.shape[0]
203
 
204
  if text_prompt != None and text_prompt != "":
205
  user_prompt_text = torch.tensor(
206
+ self.tokenizer(
207
+ text_prompt,
208
+ add_special_tokens=False,
209
+ padding=True,
210
+ padding_side="right",
211
+ )["input_ids"],
212
  device=self.pre_user_suffix.device,
213
  )
214
  prefix = torch.cat(
215
+ [
216
+ self.pre_user_suffix.expand(
217
+ bsz,
218
+ -1,
219
+ ),
220
+ user_prompt_text,
221
+ self.prefix.expand(
222
+ bsz,
223
+ -1,
224
+ ),
225
+ ],
226
+ axis=1,
227
  )
228
  else:
229
  prefix = self.prefix
230
+ prefix_embed = self.llama_decoder.model.embed_tokens(prefix).expand(bsz, -1, -1)
231
  suffix = self.final_header
232
+ suffix_embed = self.llama_decoder.model.embed_tokens(suffix).expand(bsz, -1, -1)
233
+ inputs_embeds = torch.cat([prefix_embed, virt_tokens, suffix_embed], axis=1)
234
+ outs = [[] for i in range(bsz)]
235
+ complete = [False] * bsz
 
236
  outputs = None
237
  greedy = 1
238
  i = 0
239
+ while not all(complete) and len(outs[0]) < max_new_tokens:
240
  past_key_values = outputs.past_key_values if outputs else None
241
  outputs = self.llama_decoder(
242
  inputs_embeds=inputs_embeds.to(
 
246
  output_hidden_states=True,
247
  past_key_values=past_key_values,
248
  )
249
+ next_token_logits = outputs.logits[:, -1, :]
250
 
251
  if logits_processor:
252
  local_outs = torch.tensor(outs) if outs != [] else suffix
 
261
  probs = F.softmax(logits, dim=-1)
262
  greedy = torch.multinomial(probs, num_samples=1)[0]
263
  else:
264
+ greedy = next_token_logits.argmax(dim=-1)
265
+ for token_index, out in enumerate(greedy.flatten().tolist()):
266
+ outs[token_index].append(out)
267
+ if out == 128009:
268
+ complete[token_index] = True
269
+
270
+ next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(-1, 1))
271
  inputs_embeds = next_embed
272
+ return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
 
 
273
 
274
  def generate_stream(
275
+ self,
276
+ audio,
277
+ text_prompt,
278
+ do_sample=False,
279
+ logits_processor=None,
280
+ max_new_tokens=128,
281
  ):
282
  inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
283
  input_features = inputs.input_features.to(self.whisper_encoder.device)
 
312
  while greedy != 128009 and len(outs) < max_new_tokens:
313
  past_key_values = outputs.past_key_values if outputs else None
314
  outputs = self.llama_decoder(
315
+ inputs_embeds=inputs_embeds.to(
316
  self.llama_decoder.model.embed_tokens.weight.device
317
  ).half(),
318
  return_dict=True,
 
338
  outs.append(greedy)
339
  next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1))
340
  inputs_embeds = next_embed
341
+ yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
342
+ "<|eot_id|>", ""
343
+ )
344
+ return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
345
+ "<|eot_id|>", ""
346
+ )
test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel
2
+ import librosa
3
+ import wget
4
+ from modeling_diva import DiVAModel
5
+
6
+ filename = wget.download(
7
+ "https://github.com/ffaisal93/SD-QA/raw/refs/heads/master/dev/eng/irl/wav_eng/-1008642825401516622.wav"
8
+ )
9
+
10
+ speech_data, _ = librosa.load(filename, sr=16_000)
11
+
12
+ model = DiVAModel.from_pretrained("./")
13
+
14
+ print(model.generate([speech_data]))
15
+ print(model.generate([speech_data], ["Reply Briefly Like A Pirate"]))
16
+
17
+ filename = wget.download(
18
+ "https://github.com/ffaisal93/SD-QA/raw/refs/heads/master/dev/eng/irl/wav_eng/-2426554427049983479.wav"
19
+ )
20
+
21
+ speech_data2, _ = librosa.load(filename, sr=16_000)
22
+
23
+ print(
24
+ model.generate(
25
+ [speech_data, speech_data2],
26
+ ["Reply Briefly Like A Pirate", "Reply Briefly Like A New Yorker"],
27
+ )
28
+ )