# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py # also released under the MIT license. import numpy as np import argparse from concurrent.futures import ProcessPoolExecutor import os from pathlib import Path import subprocess as sp from tempfile import NamedTemporaryFile import time import typing as tp import warnings import torch import gradio as gr from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write from audiocraft.models import MusicGen MODEL = None # Last used model IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '') MAX_BATCH_SIZE = 6 BATCHED_DURATION = 15 INTERRUPTING = False # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform _old_call = sp.call def _call_nostderr(*args, **kwargs): # Avoid ffmpeg vomitting on the logs. kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocating the pool of processes. pool = ProcessPoolExecutor(3) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True class FileCleaner: def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: tp.Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break file_cleaner = FileCleaner() def make_waveform(*args, **kwargs): # Further remove some warnings. be = time.time() with warnings.catch_warnings(): warnings.simplefilter('ignore') out = gr.make_waveform(*args, **kwargs) print("Make a video took", time.time() - be) print("Returning from make_waveform") return out def load_model(version='melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: MODEL = MusicGen.get_pretrained(version) def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): MODEL.set_generation_params(duration=duration, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() processed_melodies = [] target_sr = 32000 target_ac = 1 for melody in melodies: if melody is None: processed_melodies.append(None) else: sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() if melody.dim() == 1: melody = melody[None] melody = melody.to(target_ac).to(MODEL.device).float() if melody.size(0) != target_sr: melody = convert_audio(melody, sr, target_sr) processed_melodies.append(melody[None]) try: outputs, infos = MODEL.generate_multiple(texts, processed_melodies, progress=progress) except RuntimeError as e: print(f'Runtime error in _do_predictions: {e}') return [] print(f'Generation took {time.time() - be} seconds.') return outputs, infos def _postprocess(output): be = time.time() audio_path = NamedTemporaryFile(delete=False, suffix=".mp3").name file_cleaner.add(audio_path) audio_write(output, audio_path) print(f'Audio write took {time.time() - be} seconds.') print("Returning from _postprocess") return audio_path def _predict_single(text: str, melody: tp.Tuple[tp.Optional[int], tp.Optional[np.ndarray]], duration: float, **gen_kwargs): load_model() print(f'_predict_single called with text: {text}, melody: {melody}, duration: {duration}, gen_kwargs: {gen_kwargs}') outputs, infos = _do_predictions([text], [melody], duration, **gen_kwargs) if not outputs: print("No outputs in _predict_single") return None output = outputs[0] return _postprocess(output) def _predict_batch(texts: tp.List[str], melodies: tp.List[tp.Tuple[tp.Optional[int], tp.Optional[np.ndarray]]], duration: float, **gen_kwargs): load_model() print(f'_predict_batch called with texts: {texts}, melodies: {melodies}, duration: {duration}, gen_kwargs: {gen_kwargs}') outputs, infos = _do_predictions(texts, melodies, duration, **gen_kwargs) if not outputs: print("No outputs in _predict_batch") return [None] * len(texts) return [_postprocess(output) for output in outputs] def launch_app(): launch_kwargs = dict(verbose=False, debug=False, inline=False) if 'PORT' in os.environ: launch_kwargs['port'] = os.environ['PORT'] if 'HOST' in os.environ: launch_kwargs['host'] = os.environ['HOST'] # lastly launch the app with the set parameters if IS_BATCHED: print("Launching batched UI.") ui_batched(launch_kwargs) else: print("Launching full UI.") ui_full(launch_kwargs)