"use client"; import Image from "next/image"; import { AutoTokenizer, MusicgenForConditionalGeneration, BaseStreamer, } from "@xenova/transformers"; import { Prompt } from "@/components/prompt"; import { Length } from "@/components/length"; import { Styles } from "@/components/styles"; import { Moods } from "@/components/moods"; import { useGeneration } from "@/components/hooks/useGeneration"; import { useState, useRef, useEffect } from "react"; import { encodeWAV, MODEL_ID } from "@/utils"; import classNames from "classnames"; class CallbackStreamer extends BaseStreamer { [x: string]: any; constructor(callback_fn: any) { super(); this.callback_fn = callback_fn; } put(value: any) { return this.callback_fn(value); } end() { return this.callback_fn(); } } export const Form = ({ children }: { children: React.ReactNode }) => { const [modelLoaded, setModelLoaded] = useState(false); const [progress, setProgress] = useState(0); const [statusText, setStatusText] = useState("Loading model (656MB)..."); const [loadProgress, setLoadProgress] = useState({}); const [track, setTrack] = useState(""); const { form, setForm, formattedPrompt, generate, results, loading, setResults, } = useGeneration(); const modelPromise = useRef(null); const tokenizerPromise = useRef(null); useEffect(() => { modelPromise.current ??= MusicgenForConditionalGeneration.from_pretrained( MODEL_ID, { progress_callback: (data: any) => { if (data.status !== "progress") return; setLoadProgress((prev) => ({ ...prev, [data.file]: data })); }, dtype: { text_encoder: "q8", decoder_model_merged: "q8", encodec_decode: "fp32", }, device: "wasm", } ); //@ts-ignore tokenizerPromise.current ??= AutoTokenizer.from_pretrained(MODEL_ID); }, []); useEffect(() => { const items = Object.values(loadProgress); if (items.length !== 5) return; // 5 files to load let loaded = 0; let total = 0; for (const data of Object.values(loadProgress)) { // @ts-ignore loaded += data.loaded; // @ts-ignore total += data.total; } const progress = loaded / total; setProgress(progress); setStatusText( progress === 1 ? "Ready!" : `Loading model (${(progress * 100).toFixed()}% of 656MB)...` ); if (progress === 1) { setTimeout(() => setModelLoaded(true), 1500); } }, [loadProgress]); const generateMusic = async () => { const tokenizer: any = await tokenizerPromise.current; const model: any = await modelPromise.current; if (!tokenizer || !model) return null; const max_length = Math.min( Math.max(Math.floor(form.length * 50), 1) + 4, model?.generation_config?.max_length ?? 1500 ); const streamer = new CallbackStreamer((value: string) => { const percent = value === undefined ? 1 : value[0].length / max_length; setStatusText(`Generating (${(percent * 100).toFixed()}%)...`); setProgress(percent); }); const inputs = tokenizer(formattedPrompt); const audio_values = await model.generate({ ...inputs, max_length, streamer, }); setStatusText("Encoding audio..."); const sampling_rate = model.config.audio_encoder.sampling_rate; const wav = encodeWAV(audio_values.data, sampling_rate); const blob = new Blob([wav], { type: "audio/wav" }); setTrack(URL.createObjectURL(blob)); setStatusText("Done!"); }; console.log("track is", track); return (
{children} setForm({ ...form, prompt: value })} /> setForm({ ...form, length: value })} /> setForm({ ...form, style: value })} /> setForm({ ...form, mood: value })} />

Generated prompt

"{formattedPrompt}"

{(loading || results?.title || results?.cover) && (
{results.cover ? ( Cover art ) : (
)} {results.title ? (

{results.title}

) : (
)} {modelLoaded && (track !== "" ? (
) : (

{statusText}

))}
)}

{statusText}

); };