nicolaus625 commited on
Commit
6c06494
1 Parent(s): 709655d

update readme.md with 1 sample inference code

Browse files
Files changed (1) hide show
  1. README.md +144 -39
README.md CHANGED
@@ -42,6 +42,115 @@ from transformers import Wav2Vec2FeatureExtractor
42
  from transformers import StoppingCriteria, StoppingCriteriaList
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  class StoppingCriteriaSub(StoppingCriteria):
47
  def __init__(self, stops=[], encounters=1):
@@ -53,27 +162,36 @@ class StoppingCriteriaSub(StoppingCriteria):
53
  return True
54
  return False
55
 
56
- def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
57
- repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
58
- audio = samples["audio"].cuda()
59
- audio_embeds, atts_audio = self.encode_audio(audio)
60
- if 'instruction_input' in samples: # instruction dataset
61
- #print('Instruction Batch')
62
- instruction_prompt = []
63
- for instruction in samples['instruction_input']:
64
- prompt = '<Audio><AudioHere></Audio> ' + instruction
65
- instruction_prompt.append(self.prompt_template.format(prompt))
66
- audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
67
- self.llama_tokenizer.padding_side = "right"
 
 
 
 
 
 
 
 
 
68
  batch_size = audio_embeds.shape[0]
69
  bos = torch.ones([batch_size, 1],
70
  dtype=torch.long,
71
- device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
72
- bos_embeds = self.llama_model.model.embed_tokens(bos)
73
- atts_bos = atts_audio[:, :1]
74
  inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
75
- attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
76
- outputs = self.llama_model.generate(
77
  inputs_embeds=inputs_embeds,
78
  max_new_tokens=max_new_tokens,
79
  stopping_criteria=stopping,
@@ -90,34 +208,21 @@ def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=
90
  output_token = output_token[1:]
91
  if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
92
  output_token = output_token[1:]
93
- output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
94
  output_text = output_text.split('###')[0] # remove the stop sign '###'
95
  output_text = output_text.split('Assistant:')[-1].strip()
96
  return output_text
97
 
98
- processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
99
- ds = MusicQADataset(processor, f'{path}/data/music_data', 'Eval')
100
- dl = DataLoader(
101
- ds,
102
- batch_size=1,
103
- num_workers=0,
104
- pin_memory=True,
105
- shuffle=False,
106
- drop_last=True,
107
- collate_fn=ds.collater
108
- )
109
 
 
 
110
  stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
111
- torch.tensor([2277, 29937]).cuda()])])
112
-
113
- from transformers import AutoModel
114
- model_musicqa = AutoModel.from_pretrained("m-a-p/MusiLingo-musicqa-v1")
115
-
116
- for idx, sample in tqdm(enumerate(dl)):
117
- ans = answer(Musilingo_musicqa.model, sample, stopping, length_penalty=100, temperature=0.1)
118
- txt = sample['text_input'][0]
119
- print(txt)
120
- print(and)
121
  ```
122
 
123
  # Citing This Work
 
42
  from transformers import StoppingCriteria, StoppingCriteriaList
43
 
44
 
45
+ def load_audio(
46
+ file_path,
47
+ target_sr,
48
+ is_mono=True,
49
+ is_normalize=False,
50
+ crop_to_length_in_sec=None,
51
+ crop_to_length_in_sample_points=None,
52
+ crop_randomly=False,
53
+ pad=False,
54
+ return_start=False,
55
+ device=torch.device('cpu')
56
+ ):
57
+ """Load audio file and convert to target sample rate.
58
+ Supports cropping and padding.
59
+
60
+ Args:
61
+ file_path (str): path to audio file
62
+ target_sr (int): target sample rate, if not equal to sample rate of audio file, resample to target_sr
63
+ is_mono (bool, optional): convert to mono. Defaults to True.
64
+ is_normalize (bool, optional): normalize to [-1, 1]. Defaults to False.
65
+ crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None.
66
+ crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None. Note that the crop length in sample points is calculated before resampling.
67
+ crop_randomly (bool, optional): crop randomly. Defaults to False.
68
+ pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False.
69
+ device (torch.device, optional): device to use for resampling. Defaults to torch.device('cpu').
70
+
71
+ Returns:
72
+ torch.Tensor: waveform of shape (1, n_sample)
73
+ """
74
+ # TODO: deal with target_depth
75
+ try:
76
+ waveform, sample_rate = torchaudio.load(file_path)
77
+ except Exception as e:
78
+ waveform, sample_rate = torchaudio.backend.soundfile_backend.load(file_path)
79
+ if waveform.shape[0] > 1:
80
+ if is_mono:
81
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
82
+
83
+ if is_normalize:
84
+ waveform = waveform / waveform.abs().max()
85
+
86
+ waveform, start = crop_audio(
87
+ waveform,
88
+ sample_rate,
89
+ crop_to_length_in_sec=crop_to_length_in_sec,
90
+ crop_to_length_in_sample_points=crop_to_length_in_sample_points,
91
+ crop_randomly=crop_randomly,
92
+ pad=pad,
93
+ )
94
+
95
+ if sample_rate != target_sr:
96
+ resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
97
+ waveform = waveform.to(device)
98
+ resampler = resampler.to(device)
99
+ waveform = resampler(waveform)
100
+
101
+ if return_start:
102
+ return waveform, start
103
+ return waveform
104
+
105
+
106
+ def crop_audio(
107
+ waveform,
108
+ sample_rate,
109
+ crop_to_length_in_sec=None,
110
+ crop_to_length_in_sample_points=None,
111
+ crop_randomly=False,
112
+ pad=False,
113
+ ):
114
+ """Crop waveform to specified length in seconds or sample points.
115
+ Supports random cropping and padding.
116
+
117
+ Args:
118
+ waveform (torch.Tensor): waveform of shape (1, n_sample)
119
+ sample_rate (int): sample rate of waveform
120
+ crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None.
121
+ crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None.
122
+ crop_randomly (bool, optional): crop randomly. Defaults to False.
123
+ pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False.
124
+
125
+ Returns:
126
+ torch.Tensor: cropped waveform
127
+ int: start index of cropped waveform in original waveform
128
+ """
129
+ assert crop_to_length_in_sec is None or crop_to_length_in_sample_points is None, \
130
+ "Only one of crop_to_length_in_sec and crop_to_length_in_sample_points can be specified"
131
+
132
+ # convert crop length to sample points
133
+ crop_duration_in_sample = None
134
+ if crop_to_length_in_sec:
135
+ crop_duration_in_sample = int(sample_rate * crop_to_length_in_sec)
136
+ elif crop_to_length_in_sample_points:
137
+ crop_duration_in_sample = crop_to_length_in_sample_points
138
+
139
+ # crop
140
+ start = 0
141
+ if crop_duration_in_sample:
142
+ if waveform.shape[-1] > crop_duration_in_sample:
143
+ if crop_randomly:
144
+ start = random.randint(0, waveform.shape[-1] - crop_duration_in_sample)
145
+ waveform = waveform[..., start:start + crop_duration_in_sample]
146
+
147
+ elif waveform.shape[-1] < crop_duration_in_sample:
148
+ if pad:
149
+ waveform = torch.nn.functional.pad(waveform, (0, crop_duration_in_sample - waveform.shape[-1]))
150
+
151
+ return waveform, start
152
+
153
+
154
 
155
  class StoppingCriteriaSub(StoppingCriteria):
156
  def __init__(self, stops=[], encounters=1):
 
162
  return True
163
  return False
164
 
165
+ def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
166
+ max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):
167
+
168
+ audio = load_audio(audio_path, target_sr=24000,
169
+ is_mono=True,
170
+ is_normalize=False,
171
+ crop_to_length_in_sample_points=int(30*16000)+1,
172
+ crop_randomly=True,
173
+ pad=False).cuda()
174
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
175
+ audio = processor(audio,
176
+ sampling_rate=24000,
177
+ return_tensors="pt")['input_values'][0].cuda()
178
+
179
+ audio_embeds, atts_audio = model.encode_audio(audio)
180
+
181
+ prompt = '<Audio><AudioHere></Audio> ' + text
182
+ instruction_prompt = [model.prompt_template.format(prompt)]
183
+ audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
184
+
185
+ model.llama_tokenizer.padding_side = "right"
186
  batch_size = audio_embeds.shape[0]
187
  bos = torch.ones([batch_size, 1],
188
  dtype=torch.long,
189
+ device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
190
+ bos_embeds = model.llama_model.model.embed_tokens(bos)
191
+ # atts_bos = atts_audio[:, :1]
192
  inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
193
+ # attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
194
+ outputs = model.llama_model.generate(
195
  inputs_embeds=inputs_embeds,
196
  max_new_tokens=max_new_tokens,
197
  stopping_criteria=stopping,
 
208
  output_token = output_token[1:]
209
  if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
210
  output_token = output_token[1:]
211
+ output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
212
  output_text = output_text.split('###')[0] # remove the stop sign '###'
213
  output_text = output_text.split('Assistant:')[-1].strip()
214
  return output_text
215
 
216
+ musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-musicqa-v1", trust_remote_code=True)
217
+ musilingo.to("cuda")
218
+ musilingo.eval()
 
 
 
 
 
 
 
 
219
 
220
+ prompt = "this is the task instruction and input question for MusiLingo model"
221
+ audio = "/path/to/the/24kHz-audio"
222
  stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
223
+ torch.tensor([2277, 29937]).cuda()])])
224
+ response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)
225
+
 
 
 
 
 
 
 
226
  ```
227
 
228
  # Citing This Work