--- language: - en license: cc-by-nc-4.0 library_name: transformers tags: - music - art --- --- license: cc-by-4.0 language: - en tags: - music - art --- # Model Card for Model ID ## Model Details ### Model Description The model consists of a music encoder ```MERT-v1-300M```, a natural language decoder ```vicuna-7b-delta-v0```, and a linear projection laer between the two. This checkpoint of MusiLingo is developed on the MusicQA and can answer instructions with music raw audio, such as querying about the tempo, emotion, genre, tags or subjective feelings etc. You can use the MusicQA dataset for the following demo. For the implementation of MusicQA, please refer to our [Github repo](https://github.com/zihaod/MusiLingo/blob/main/musilingo/datasets/datasets/musicqa_dataset.py). ### Model Sources [optional] - **Repository:** [GitHub repo](https://github.com/zihaod/MusiLingo) - **Paper [optional]:** __[MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response](https://arxiv.org/abs/2309.08730)__ ## Getting Start ``` from tqdm.auto import tqdm import torch from torch.utils.data import DataLoader from transformers import Wav2Vec2FeatureExtractor from transformers import StoppingCriteria, StoppingCriteriaList def load_audio( file_path, target_sr, is_mono=True, is_normalize=False, crop_to_length_in_sec=None, crop_to_length_in_sample_points=None, crop_randomly=False, pad=False, return_start=False, device=torch.device('cpu') ): """Load audio file and convert to target sample rate. Supports cropping and padding. Args: file_path (str): path to audio file target_sr (int): target sample rate, if not equal to sample rate of audio file, resample to target_sr is_mono (bool, optional): convert to mono. Defaults to True. is_normalize (bool, optional): normalize to [-1, 1]. Defaults to False. crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None. crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None. Note that the crop length in sample points is calculated before resampling. crop_randomly (bool, optional): crop randomly. Defaults to False. pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False. device (torch.device, optional): device to use for resampling. Defaults to torch.device('cpu'). Returns: torch.Tensor: waveform of shape (1, n_sample) """ # TODO: deal with target_depth try: waveform, sample_rate = torchaudio.load(file_path) except Exception as e: waveform, sample_rate = torchaudio.backend.soundfile_backend.load(file_path) if waveform.shape[0] > 1: if is_mono: waveform = torch.mean(waveform, dim=0, keepdim=True) if is_normalize: waveform = waveform / waveform.abs().max() waveform, start = crop_audio( waveform, sample_rate, crop_to_length_in_sec=crop_to_length_in_sec, crop_to_length_in_sample_points=crop_to_length_in_sample_points, crop_randomly=crop_randomly, pad=pad, ) if sample_rate != target_sr: resampler = torchaudio.transforms.Resample(sample_rate, target_sr) waveform = waveform.to(device) resampler = resampler.to(device) waveform = resampler(waveform) if return_start: return waveform, start return waveform def crop_audio( waveform, sample_rate, crop_to_length_in_sec=None, crop_to_length_in_sample_points=None, crop_randomly=False, pad=False, ): """Crop waveform to specified length in seconds or sample points. Supports random cropping and padding. Args: waveform (torch.Tensor): waveform of shape (1, n_sample) sample_rate (int): sample rate of waveform crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None. crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None. crop_randomly (bool, optional): crop randomly. Defaults to False. pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False. Returns: torch.Tensor: cropped waveform int: start index of cropped waveform in original waveform """ assert crop_to_length_in_sec is None or crop_to_length_in_sample_points is None, \ "Only one of crop_to_length_in_sec and crop_to_length_in_sample_points can be specified" # convert crop length to sample points crop_duration_in_sample = None if crop_to_length_in_sec: crop_duration_in_sample = int(sample_rate * crop_to_length_in_sec) elif crop_to_length_in_sample_points: crop_duration_in_sample = crop_to_length_in_sample_points # crop start = 0 if crop_duration_in_sample: if waveform.shape[-1] > crop_duration_in_sample: if crop_randomly: start = random.randint(0, waveform.shape[-1] - crop_duration_in_sample) waveform = waveform[..., start:start + crop_duration_in_sample] elif waveform.shape[-1] < crop_duration_in_sample: if pad: waveform = torch.nn.functional.pad(waveform, (0, crop_duration_in_sample - waveform.shape[-1])) return waveform, start class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all((stop == input_ids[0][-len(stop):])).item(): return True return False def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0): audio = load_audio(audio_path, target_sr=24000, is_mono=True, is_normalize=False, crop_to_length_in_sample_points=int(30*16000)+1, crop_randomly=True, pad=False).cuda() processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True) audio = processor(audio, sampling_rate=24000, return_tensors="pt")['input_values'][0].cuda() audio_embeds, atts_audio = model.encode_audio(audio) prompt = ' ' + text instruction_prompt = [model.prompt_template.format(prompt)] audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt) model.llama_tokenizer.padding_side = "right" batch_size = audio_embeds.shape[0] bos = torch.ones([batch_size, 1], dtype=torch.long, device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id bos_embeds = model.llama_model.model.embed_tokens(bos) # atts_bos = atts_audio[:, :1] inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1) # attention_mask = torch.cat([atts_bos, atts_audio], dim=1) outputs = model.llama_model.generate( inputs_embeds=inputs_embeds, max_new_tokens=max_new_tokens, stopping_criteria=stopping, num_beams=num_beams, do_sample=True, min_length=min_length, top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, ) output_token = outputs[0] if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it output_token = output_token[1:] if output_token[0] == 1: # if there is a start token at the beginning. remove it output_token = output_token[1:] output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False) output_text = output_text.split('###')[0] # remove the stop sign '###' output_text = output_text.split('Assistant:')[-1].strip() return output_text musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-musicqa-v1", trust_remote_code=True) musilingo.to("cuda") musilingo.eval() prompt = "this is the task instruction and input question for MusiLingo model" audio = "/path/to/the/24kHz-audio" stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(), torch.tensor([2277, 29937]).cuda()])]) response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1) ``` # Citing This Work If you find the work useful for your research, please consider citing it using the following BibTeX entry: ``` @inproceedings{deng2024musilingo, title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response}, author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil}, booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)}, year={2024}, organization={Association for Computational Linguistics} } ```