nicolaus625's picture
update readme.md with 1 sample inference code
6c06494 verified
|
raw
history blame
9.63 kB
metadata
language:
  - en
license: cc-by-nc-4.0
library_name: transformers
tags:
  - music
  - art

license: cc-by-4.0 language: - en tags: - music - art

Model Card for Model ID

Model Details

Model Description

The model consists of a music encoder MERT-v1-300M, a natural language decoder vicuna-7b-delta-v0, and a linear projection laer between the two.

This checkpoint of MusiLingo is developed on the MusicQA and can answer instructions with music raw audio, such as querying about the tempo, emotion, genre, tags or subjective feelings etc. You can use the MusicQA dataset for the following demo. For the implementation of MusicQA, please refer to our Github repo.

Model Sources [optional]

Getting Start

from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList


def load_audio(
    file_path, 
    target_sr, 
    is_mono=True, 
    is_normalize=False,
    crop_to_length_in_sec=None, 
    crop_to_length_in_sample_points=None, 
    crop_randomly=False, 
    pad=False,
    return_start=False,
    device=torch.device('cpu')
):
    """Load audio file and convert to target sample rate.
    Supports cropping and padding.

    Args:
        file_path (str): path to audio file
        target_sr (int): target sample rate, if not equal to sample rate of audio file, resample to target_sr
        is_mono (bool, optional): convert to mono. Defaults to True.
        is_normalize (bool, optional): normalize to [-1, 1]. Defaults to False.
        crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None.
        crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None. Note that the crop length in sample points is calculated before resampling.
        crop_randomly (bool, optional): crop randomly. Defaults to False.
        pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False.
        device (torch.device, optional): device to use for resampling. Defaults to torch.device('cpu').
    
    Returns:
        torch.Tensor: waveform of shape (1, n_sample)
    """
    # TODO: deal with target_depth
    try:
        waveform, sample_rate = torchaudio.load(file_path)
    except Exception as e:
        waveform, sample_rate = torchaudio.backend.soundfile_backend.load(file_path)
    if waveform.shape[0] > 1:
        if is_mono:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    if is_normalize:
        waveform = waveform / waveform.abs().max()
    
    waveform, start = crop_audio(
        waveform, 
        sample_rate, 
        crop_to_length_in_sec=crop_to_length_in_sec, 
        crop_to_length_in_sample_points=crop_to_length_in_sample_points, 
        crop_randomly=crop_randomly, 
        pad=pad,
    )
    
    if sample_rate != target_sr:
        resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
        waveform = waveform.to(device)
        resampler = resampler.to(device)
        waveform = resampler(waveform)
    
    if return_start:
        return waveform, start
    return waveform


def crop_audio(
    waveform, 
    sample_rate, 
    crop_to_length_in_sec=None, 
    crop_to_length_in_sample_points=None, 
    crop_randomly=False, 
    pad=False,
):
    """Crop waveform to specified length in seconds or sample points.
    Supports random cropping and padding.

    Args:
        waveform (torch.Tensor): waveform of shape (1, n_sample)
        sample_rate (int): sample rate of waveform
        crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None.
        crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None.
        crop_randomly (bool, optional): crop randomly. Defaults to False.
        pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False.

    Returns:
        torch.Tensor: cropped waveform
        int: start index of cropped waveform in original waveform
    """
    assert crop_to_length_in_sec is None or crop_to_length_in_sample_points is None, \
    "Only one of crop_to_length_in_sec and crop_to_length_in_sample_points can be specified"

    # convert crop length to sample points
    crop_duration_in_sample = None
    if crop_to_length_in_sec:
        crop_duration_in_sample = int(sample_rate * crop_to_length_in_sec)
    elif crop_to_length_in_sample_points:
        crop_duration_in_sample = crop_to_length_in_sample_points

    # crop
    start = 0
    if crop_duration_in_sample:
        if waveform.shape[-1] > crop_duration_in_sample:
            if crop_randomly:
                start = random.randint(0, waveform.shape[-1] - crop_duration_in_sample)
            waveform = waveform[..., start:start + crop_duration_in_sample]

        elif waveform.shape[-1] < crop_duration_in_sample:
            if pad:
                waveform = torch.nn.functional.pad(waveform, (0, crop_duration_in_sample - waveform.shape[-1]))
    
    return waveform, start



class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
    max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):

    audio = load_audio(audio_path, target_sr=24000,
                        is_mono=True,
                        is_normalize=False,
                        crop_to_length_in_sample_points=int(30*16000)+1,
                        crop_randomly=True, 
                        pad=False).cuda()    
    processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True) 
    audio = processor(audio, 
                    sampling_rate=24000, 
                    return_tensors="pt")['input_values'][0].cuda() 
        
    audio_embeds, atts_audio = model.encode_audio(audio)
        
    prompt = '<Audio><AudioHere></Audio> ' + text
    instruction_prompt = [model.prompt_template.format(prompt)]
    audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
    
    model.llama_tokenizer.padding_side = "right"
    batch_size = audio_embeds.shape[0]
    bos = torch.ones([batch_size, 1],
                    dtype=torch.long,
                    device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
    bos_embeds = model.llama_model.model.embed_tokens(bos)
    # atts_bos = atts_audio[:, :1]
    inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
    # attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
    outputs = model.llama_model.generate(
        inputs_embeds=inputs_embeds,
        max_new_tokens=max_new_tokens,
        stopping_criteria=stopping,
        num_beams=num_beams,
        do_sample=True,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
        output_token = output_token[1:]
    if output_token[0] == 1:  # if there is a start token <s> at the beginning. remove it
        output_token = output_token[1:]
    output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    return output_text

musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-musicqa-v1", trust_remote_code=True)
musilingo.to("cuda")
musilingo.eval()

prompt = "this is the task instruction and input question for MusiLingo model"
audio = "/path/to/the/24kHz-audio"
stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
                                  torch.tensor([2277, 29937]).cuda()])])
response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)
        

Citing This Work

If you find the work useful for your research, please consider citing it using the following BibTeX entry:

@inproceedings{deng2024musilingo,
  title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
  author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
  booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
  year={2024},
  organization={Association for Computational Linguistics}
}