Spaces:
Configuration error
Configuration error
from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline | |
import torch | |
from PIL import Image, ImageDraw | |
import os | |
import numpy as np | |
from scipy.io.wavfile import read | |
import gradio as gr | |
from share_btn import community_icon_html, loading_icon_html, share_js | |
os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion') | |
from riffusion.riffusion.riffusion_pipeline import RiffusionPipeline | |
from riffusion.riffusion.datatypes import PromptInput, InferenceInput | |
from riffusion.riffusion.audio import wav_bytes_from_spectrogram_image | |
from PIL import Image | |
import struct | |
import random | |
repo_id = "riffusion/riffusion-model-v1" | |
model = RiffusionPipeline.from_pretrained( | |
repo_id, | |
revision="main", | |
torch_dtype=torch.float16, | |
safety_checker=lambda images, **kwargs: (images, False), | |
) | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
model.enable_xformers_memory_efficient_attention() | |
pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, safety_checker=lambda images, **kwargs: (images, False),) | |
pipe_inpaint.scheduler = DPMSolverMultistepScheduler.from_config(pipe_inpaint.scheduler.config) | |
# pipe_inpaint.enable_xformers_memory_efficient_attention() | |
if torch.cuda.is_available(): | |
pipe_inpaint = pipe_inpaint.to("cuda") | |
pipe.enable_xformers_memory_efficient_attention() | |
def get_init_image(image, overlap, feel): | |
width, height = image.size | |
init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB") | |
# Crop the right side of the original image with `overlap_width` | |
cropped_img = image.crop((width - int(width*overlap), 0, width, height)) | |
init_image.paste(cropped_img, (0, 0)) | |
return init_image | |
def get_mask(image, overlap): | |
width, height = image.size | |
mask = Image.new("RGB", (width, height), color="white") | |
draw = ImageDraw.Draw(mask) | |
draw.rectangle((0, 0, int(overlap * width), height), fill="black") | |
return mask | |
def i2i(prompt, steps, feel, seed): | |
# return pipe_i2i( | |
# prompt, | |
# num_inference_steps=steps, | |
# image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB"), | |
# ).images[0] | |
prompt_input_start = PromptInput(prompt=prompt, seed=seed) | |
prompt_input_end = PromptInput(prompt=prompt, seed=seed) | |
return model.riffuse( | |
inputs=InferenceInput( | |
start=prompt_input_start, | |
end=prompt_input_end, | |
alpha=1.0, | |
num_inference_steps=steps), | |
init_image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB") | |
) | |
def outpaint(prompt, init_image, mask, steps): | |
return pipe_inpaint( | |
prompt, | |
num_inference_steps=steps, | |
image=init_image, | |
mask_image=mask, | |
).images[0] | |
def generate(prompt, steps, num_iterations, feel, seed): | |
if seed == 0: | |
seed = random.randint(0,4294967295) | |
num_images = num_iterations | |
overlap = 0.5 | |
image_width, image_height = 512, 512 # dimensions of each output image | |
total_width = num_images * image_width - (num_images - 1) * int(overlap * image_width) # total width of the stitched image | |
# Create a blank image with the desired dimensions | |
stitched_image = Image.new("RGB", (total_width, image_height), color="white") | |
# Initialize the x position for pasting the next image | |
x_pos = 0 | |
image = i2i(prompt, steps, feel, seed) | |
for i in range(num_images): | |
# Generate the prompt, initial image, and mask for this iteration | |
init_image = get_init_image(image, overlap, feel) | |
mask = get_mask(init_image, overlap) | |
# Run the outpaint function to generate the output image | |
steps = 25 | |
image = outpaint(prompt, init_image, mask, steps) | |
# Paste the output image onto the stitched image | |
stitched_image.paste(image, (x_pos, 0)) | |
# Update the x position for the next iteration | |
x_pos += int((1 - overlap) * image_width) | |
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(stitched_image) | |
# mask = Image.new("RGB", (512, 512), color="white") | |
# bg_image = outpaint(prompt, init_image, mask, steps) | |
# bg_image.save("bg_image.png") | |
init_image.save("bg_image.png") | |
# return read(wav_bytes) | |
with open("output.wav", "wb") as f: | |
f.write(wav_bytes.read()) | |
return gr.make_waveform("output.wav", bg_image="bg_image.png", bar_count=int(duration_s*25)) | |
############################################### | |
def riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start=0.75, guidance_start=7.0, prompt_end=None, seed_end=None, denoising_end=0.75, guidance_end=7.0, alpha=0.5): | |
prompt_input_start = PromptInput(prompt=prompt_start, seed=seed_start, denoising=denoising_start, guidance=guidance_start) | |
prompt_input_end = PromptInput(prompt=prompt_end, seed=seed_end, denoising=denoising_end, guidance=guidance_end) | |
input = InferenceInput( | |
start=prompt_input_start, | |
end=prompt_input_end, | |
alpha=alpha, | |
num_inference_steps=steps, | |
seed_image_id=feel, | |
# mask_image_id="mask_beat_lines_80.png" | |
) | |
image = model.riffuse(inputs=input, init_image=init_image) | |
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image) | |
return wav_bytes, image | |
def generate_riffuse(prompt_start, steps, num_iterations, feel, prompt_end=None, seed_start=None, seed_end=None, denoising_start=0.75, denoising_end=0.75, guidance_start=7.0, guidance_end=7.0): | |
"""Generate a WAV file of length seconds using the Riffusion model. | |
Args: | |
length (int): Length of the WAV file in seconds, must be divisible by 5. | |
prompt_start (str): Prompt to start with. | |
prompt_end (str, optional): Prompt to end with. Defaults to prompt_start. | |
overlap (float, optional): Overlap between audio clips as a fraction of the image size. Defaults to 0.2. | |
""" | |
# open the initial image and convert it to RGB | |
init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB") | |
if prompt_end is None: | |
prompt_end = prompt_start | |
if seed_start is 0: | |
seed_start = random.randint(0,4294967295) | |
if seed_end is None: | |
seed_end = seed_start | |
# one riffuse() generates 5 seconds of audio | |
wav_list = [] | |
for i in range(int(num_iterations)): | |
alpha = i / (num_iterations - 1) | |
print(alpha) | |
wav_bytes, image = riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start, guidance_start, prompt_end, seed_end, denoising_end, guidance_end, alpha=alpha) | |
wav_list.append(wav_bytes) | |
init_image = image | |
seed_start = seed_end | |
seed_end = seed_start + 1 | |
# return read(wav_bytes) | |
# return wav_list_to_wav(wav_list) | |
# mask = Image.new("RGB", (512, 512), color="white") | |
# bg_image = outpaint(f"{prompt_start} and {prompt_end}", init_image, mask, steps) | |
# bg_image.save("bg_image.png") | |
init_image.save("bg_image.png") | |
with open("output.wav", "wb") as f: | |
f.write(wav_list_to_wav(wav_list)) | |
return gr.make_waveform("output.wav", bg_image="bg_image.png") | |
def wav_list_to_wav(wav_list): | |
# remove headers from the WAV files | |
data = [wav.read()[44:] for wav in wav_list] | |
# concatenate the data | |
concatenated_data = b"".join(data) | |
# create a new RIFF header | |
channels = 1 | |
sample_rate = 44100 | |
bytes_per_second = channels * sample_rate | |
new_header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", len(concatenated_data) + 44 - 8, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, bytes_per_second, 2, 16, b"data", len(concatenated_data)) | |
# combine the header and data to create the final WAV file | |
final_wav = new_header + concatenated_data | |
return final_wav | |
############################################### | |
def on_submit(prompt_1, prompt_2, steps, num_iterations, feel, seed): | |
if prompt_1 == "": | |
return None, gr.update(value="First prompt is required."), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
if prompt_2 == "": | |
return generate(prompt_1, steps, num_iterations, feel, seed), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
else: | |
return generate_riffuse(prompt_1, steps, num_iterations, feel, prompt_end=prompt_2, seed_start=seed), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
def on_num_iterations_change(n, prompt_2): | |
if n is None: | |
return gr.update(value="") | |
if prompt_2 != "": | |
total_length = 5 * n | |
else: | |
total_length = 2.5 + 2.5 * n | |
return gr.update(value=f"Total length: {total_length:.2f} seconds") | |
css = ''' | |
#share-btn-container { | |
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; | |
} | |
#share-btn { | |
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; | |
} | |
#share-btn * { | |
all: unset; | |
} | |
#share-btn-container div:nth-child(-n+2){ | |
width: auto !important; | |
min-height: 0px !important; | |
} | |
#share-btn-container .wrap { | |
display: none !important; | |
} | |
''' | |
with gr.Blocks(css=css) as app: | |
gr.Markdown("## Riffusion Demo") | |
gr.Markdown("""Generate audio using the [Riffusion](https://huggingface.co/riffusion/riffusion-model-v1) model.<br> | |
In single prompt mode you can generate up to ~1 minute of audio with smooth transitions between sections. (beta)<br> | |
Bi-prompt mode interpolates between two prompts. It can generate up to ~2 minutes of audio, but transitions between sections are more abrupt.""") | |
gr.Markdown(f"""Running on {"**GPU 🔥**" if torch.cuda.is_available() else f"**CPU 🥶**. For faster inference it is recommended to **upgrade to GPU in space's Settings**"}<br> | |
[![Duplicate Space](https://bit.ly/3gLdBN6)](https://huggingface.co/spaces/$space_id?duplicate=true)""") | |
with gr.Row(): | |
with gr.Group(): | |
with gr.Row(): | |
prompt_1 = gr.Textbox(lines=1, label="Start from", placeholder="Starting prompt", elem_id="riff-prompt_1") | |
prompt_2 = gr.Textbox(lines=1, label="End with (optional)", placeholder="Prompt to shift towards at the end", elem_id="riff-prompt_2") | |
with gr.Row(): | |
steps = gr.Slider(minimum=1, maximum=100, value=25, label="Steps per section") | |
num_iterations = gr.Slider(minimum=2, maximum=25, value=2, step=1, label="Number of sections") | |
with gr.Row(): | |
feel = gr.Dropdown(["og_beat", "agile", "vibes", "motorway", "marim"], value="og_beat", label="Feel", elem_id="riff-feel") | |
seed = gr.Slider(minimum=0, maximum=4294967295, value=0, step=1, label="Seed (0 for random)", elem_id="riff-seed") | |
btn_generate = gr.Button(value="Generate").style(full_width=True) | |
info = gr.Markdown() | |
with gr.Column(): | |
video = gr.Video(elem_id="riff-video") | |
with gr.Group(elem_id="share-btn-container"): | |
community_icon = gr.HTML(community_icon_html, elem_id="share-btn-share-icon", visible=False) | |
loading_icon = gr.HTML(loading_icon_html, elem_id="share-btn-loading-icon", visible=False) | |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False) | |
inputs = [prompt_1, prompt_2, steps, num_iterations, feel, seed] | |
outputs = [video, info, community_icon, loading_icon, share_button] | |
num_iterations.change(on_num_iterations_change, [num_iterations, prompt_2], [info]) | |
prompt_1.submit(on_submit, inputs, outputs) | |
prompt_2.submit(on_submit, inputs, outputs) | |
btn_generate.click(on_submit, inputs, outputs) | |
share_button.click(None, [], [], _js=share_js) | |
examples = gr.Examples( | |
examples=[ | |
["typing", "dance beat", "og_beat", 10], | |
["synthwave", "jazz", "agile", 10], | |
["rap battle freestyle", "", "og_beat", 10], | |
# ["techno club banger", "", "og_beat", 10], | |
["reggae dub beat", "sunset chill", "og_beat", 10], | |
["acoustic folk ballad", "", "agile", 10], | |
["blues guitar riff", "", "agile", 5], | |
["jazzy trumpet solo", "", "og_beat", 5], | |
["classical symphony orchestra", "", "vibes", 10], | |
["rock and roll power chord", "", "motorway", 5], | |
["soulful R&B love song", "", "marim", 10], | |
["country western twangy guitar", "", "agile", 10]], | |
inputs=[prompt_1, prompt_2, feel, num_iterations]. | |
cache_examples=True) | |
gr.HTML(""" | |
<div style="border-top: 1px solid #303030;"> | |
<br> | |
<p>Space by:<br> | |
<a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a><br> | |
<a href="https://github.com/qunash"><img alt="GitHub followers" src="https://img.shields.io/github/followers/qunash?style=social" alt="Github Follow"></a></p><br> | |
<a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 24px !important;width: 81px !important;" ></a><br><br> | |
<p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.riffusion-demo" alt="visitors"></p> | |
</div> | |
""") | |
app.queue(max_size=250, concurrency_count=2).launch() | |