Spaces:
Configuration error
Configuration error
app v2
Browse files
app.py
CHANGED
@@ -1,45 +1,286 @@
|
|
1 |
-
from diffusers import
|
2 |
import torch
|
3 |
-
import
|
4 |
import os
|
5 |
import numpy as np
|
6 |
from scipy.io.wavfile import read
|
|
|
|
|
7 |
|
8 |
os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion')
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
repo_id = "riffusion/riffusion-model-v1"
|
11 |
-
pipe = DiffusionPipeline.from_pretrained(repo_id)
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
if torch.cuda.is_available():
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
|
|
|
|
18 |
|
19 |
-
|
20 |
|
21 |
-
mel_spectr = pipe(prompt, num_inference_steps=steps).images[0]
|
22 |
-
wav_bytes, duration_s = audio.wav_bytes_from_spectrogram_image(mel_spectr)
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
with gr.Blocks() as app:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
with gr.
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
|
38 |
-
|
39 |
-
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
app.launch()
|
|
|
1 |
+
from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline
|
2 |
import torch
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
import os
|
5 |
import numpy as np
|
6 |
from scipy.io.wavfile import read
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
|
10 |
os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion')
|
11 |
|
12 |
+
from riffusion.riffusion.riffusion_pipeline import RiffusionPipeline
|
13 |
+
from riffusion.riffusion.datatypes import PromptInput, InferenceInput
|
14 |
+
from riffusion.riffusion.audio import wav_bytes_from_spectrogram_image
|
15 |
+
from PIL import Image
|
16 |
+
import struct
|
17 |
+
import random
|
18 |
+
|
19 |
repo_id = "riffusion/riffusion-model-v1"
|
|
|
20 |
|
21 |
+
model = RiffusionPipeline.from_pretrained(
|
22 |
+
repo_id,
|
23 |
+
revision="main",
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
safety_checker=lambda images, **kwargs: (images, False),
|
26 |
+
)
|
27 |
+
|
28 |
+
if torch.cuda.is_available():
|
29 |
+
model.to("cuda")
|
30 |
+
|
31 |
+
pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, safety_checker=lambda images, **kwargs: (images, False),)
|
32 |
+
pipe_inpaint.scheduler = DPMSolverMultistepScheduler.from_config(pipe_inpaint.scheduler.config)
|
33 |
+
|
34 |
+
# pipe_inpaint.enable_xformers_memory_efficient_attention()
|
35 |
+
|
36 |
if torch.cuda.is_available():
|
37 |
+
pipe_inpaint = pipe_inpaint.to("cuda")
|
38 |
+
|
39 |
+
|
40 |
+
def get_init_image(image, overlap, feel):
|
41 |
+
|
42 |
+
width, height = image.size
|
43 |
+
init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
|
44 |
+
# Crop the right side of the original image with `overlap_width`
|
45 |
+
cropped_img = image.crop((width - int(width*overlap), 0, width, height))
|
46 |
+
init_image.paste(cropped_img, (0, 0))
|
47 |
+
|
48 |
+
return init_image
|
49 |
+
|
50 |
+
def get_mask(image, overlap):
|
51 |
+
|
52 |
+
width, height = image.size
|
53 |
+
|
54 |
+
mask = Image.new("RGB", (width, height), color="white")
|
55 |
+
draw = ImageDraw.Draw(mask)
|
56 |
+
draw.rectangle((0, 0, int(overlap * width), height), fill="black")
|
57 |
+
return mask
|
58 |
+
|
59 |
+
def i2i(prompt, steps, feel, seed):
|
60 |
+
# return pipe_i2i(
|
61 |
+
# prompt,
|
62 |
+
# num_inference_steps=steps,
|
63 |
+
# image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB"),
|
64 |
+
# ).images[0]
|
65 |
+
|
66 |
+
prompt_input_start = PromptInput(prompt=prompt, seed=seed)
|
67 |
+
prompt_input_end = PromptInput(prompt=prompt, seed=seed)
|
68 |
+
|
69 |
+
return model.riffuse(
|
70 |
+
inputs=InferenceInput(
|
71 |
+
start=prompt_input_start,
|
72 |
+
end=prompt_input_end,
|
73 |
+
alpha=1.0,
|
74 |
+
num_inference_steps=steps),
|
75 |
+
init_image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
|
76 |
+
)
|
77 |
+
|
78 |
+
def outpaint(prompt, init_image, mask, steps):
|
79 |
+
return pipe_inpaint(
|
80 |
+
prompt,
|
81 |
+
num_inference_steps=steps,
|
82 |
+
image=init_image,
|
83 |
+
mask_image=mask,
|
84 |
+
).images[0]
|
85 |
+
|
86 |
+
|
87 |
+
def generate(prompt, steps, num_iterations, feel, seed):
|
88 |
+
|
89 |
+
if seed == 0:
|
90 |
+
seed = random.randint(0,4294967295)
|
91 |
+
|
92 |
+
num_images = num_iterations
|
93 |
+
overlap = 0.5
|
94 |
+
image_width, image_height = 512, 512 # dimensions of each output image
|
95 |
+
total_width = num_images * image_width - (num_images - 1) * int(overlap * image_width) # total width of the stitched image
|
96 |
+
|
97 |
+
# Create a blank image with the desired dimensions
|
98 |
+
stitched_image = Image.new("RGB", (total_width, image_height), color="white")
|
99 |
+
|
100 |
+
# Initialize the x position for pasting the next image
|
101 |
+
x_pos = 0
|
102 |
+
|
103 |
+
image = i2i(prompt, steps, feel, seed)
|
104 |
+
|
105 |
+
for i in range(num_images):
|
106 |
+
# Generate the prompt, initial image, and mask for this iteration
|
107 |
+
init_image = get_init_image(image, overlap, feel)
|
108 |
+
mask = get_mask(init_image, overlap)
|
109 |
+
|
110 |
+
# Run the outpaint function to generate the output image
|
111 |
+
steps = 25
|
112 |
+
image = outpaint(prompt, init_image, mask, steps)
|
113 |
+
|
114 |
+
# Paste the output image onto the stitched image
|
115 |
+
stitched_image.paste(image, (x_pos, 0))
|
116 |
+
|
117 |
+
# Update the x position for the next iteration
|
118 |
+
x_pos += int((1 - overlap) * image_width)
|
119 |
+
|
120 |
+
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(stitched_image)
|
121 |
+
|
122 |
+
mask = Image.new("RGB", (512, 512), color="white")
|
123 |
+
bg_image = outpaint(prompt, init_image, mask, steps)
|
124 |
+
bg_image.save("bg_image.png")
|
125 |
|
126 |
+
# return read(wav_bytes)
|
127 |
+
with open("output.wav", "wb") as f:
|
128 |
+
f.write(wav_bytes.read())
|
129 |
|
130 |
+
return gr.make_waveform("output.wav", bg_image="bg_image.png", bar_count=int(duration_s*25))
|
131 |
|
|
|
|
|
132 |
|
133 |
+
###############################################
|
134 |
+
|
135 |
+
def riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start=0.75, guidance_start=7.0, prompt_end=None, seed_end=None, denoising_end=0.75, guidance_end=7.0, alpha=0.5):
|
136 |
+
|
137 |
+
prompt_input_start = PromptInput(prompt=prompt_start, seed=seed_start, denoising=denoising_start, guidance=guidance_start)
|
138 |
+
|
139 |
+
prompt_input_end = PromptInput(prompt=prompt_end, seed=seed_end, denoising=denoising_end, guidance=guidance_end)
|
140 |
+
|
141 |
+
input = InferenceInput(
|
142 |
+
start=prompt_input_start,
|
143 |
+
end=prompt_input_end,
|
144 |
+
alpha=alpha,
|
145 |
+
num_inference_steps=steps,
|
146 |
+
seed_image_id=feel,
|
147 |
+
# mask_image_id="mask_beat_lines_80.png"
|
148 |
+
)
|
149 |
+
|
150 |
+
image = model.riffuse(inputs=input, init_image=init_image)
|
151 |
+
|
152 |
+
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
153 |
+
|
154 |
+
return wav_bytes, image
|
155 |
+
|
156 |
+
def generate_riffuse(prompt_start, steps, num_iterations, feel, prompt_end=None, seed_start=None, seed_end=None, denoising_start=0.75, denoising_end=0.75, guidance_start=7.0, guidance_end=7.0):
|
157 |
+
"""Generate a WAV file of length seconds using the Riffusion model.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
length (int): Length of the WAV file in seconds, must be divisible by 5.
|
161 |
+
prompt_start (str): Prompt to start with.
|
162 |
+
prompt_end (str, optional): Prompt to end with. Defaults to prompt_start.
|
163 |
+
overlap (float, optional): Overlap between audio clips as a fraction of the image size. Defaults to 0.2.
|
164 |
+
"""
|
165 |
+
|
166 |
+
# open the initial image and convert it to RGB
|
167 |
+
init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
|
168 |
+
|
169 |
+
if prompt_end is None:
|
170 |
+
prompt_end = prompt_start
|
171 |
+
if seed_start is None:
|
172 |
+
seed_start = random.randint(0,4294967295)
|
173 |
+
if seed_end is None:
|
174 |
+
seed_end = seed_start
|
175 |
+
|
176 |
+
# one riffuse() generates 5 seconds of audio
|
177 |
+
wav_list = []
|
178 |
+
|
179 |
+
for i in range(int(num_iterations)):
|
180 |
+
|
181 |
+
alpha = i / (num_iterations - 1)
|
182 |
+
print(alpha)
|
183 |
+
wav_bytes, image = riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start, guidance_start, prompt_end, seed_end, denoising_end, guidance_end, alpha=alpha)
|
184 |
+
wav_list.append(wav_bytes)
|
185 |
+
|
186 |
+
init_image = image
|
187 |
+
|
188 |
+
seed_start = seed_end
|
189 |
+
seed_end = seed_start + 1
|
190 |
+
|
191 |
+
# return read(wav_bytes)
|
192 |
+
|
193 |
+
mask = Image.new("RGB", (512, 512), color="white")
|
194 |
+
bg_image = outpaint(f"{prompt_start} and {prompt_end}", init_image, mask, steps)
|
195 |
+
bg_image.save("bg_image.png")
|
196 |
+
|
197 |
+
with open("output.wav", "wb") as f:
|
198 |
+
f.write(wav_bytes.read())
|
199 |
+
|
200 |
+
return gr.make_waveform("output.wav", bg_image="bg_image.png")
|
201 |
+
|
202 |
+
|
203 |
+
def wav_list_to_wav(wav_list):
|
204 |
+
|
205 |
+
# remove headers from the WAV files
|
206 |
+
data = [wav.read()[44:] for wav in wav_list]
|
207 |
+
|
208 |
+
# concatenate the data
|
209 |
+
concatenated_data = b"".join(data)
|
210 |
+
|
211 |
+
# create a new RIFF header
|
212 |
+
channels = 1
|
213 |
+
sample_rate = 44100
|
214 |
+
bytes_per_second = channels * sample_rate
|
215 |
+
new_header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", len(concatenated_data) + 44 - 8, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, bytes_per_second, 2, 16, b"data", len(concatenated_data))
|
216 |
+
|
217 |
+
# combine the header and data to create the final WAV file
|
218 |
+
final_wav = new_header + concatenated_data
|
219 |
+
return final_wav
|
220 |
+
|
221 |
+
###############################################
|
222 |
+
|
223 |
+
def on_submit(prompt_1, prompt_2, steps, num_iterations, feel, seed):
|
224 |
+
if prompt_1 == "":
|
225 |
+
return None, gr.update(value="First prompt is required.")
|
226 |
+
if prompt_2 == "":
|
227 |
+
return generate(prompt_1, steps, num_iterations, feel, seed), None
|
228 |
+
else:
|
229 |
+
return generate_riffuse(prompt_1, steps, num_iterations, feel, prompt_end=prompt_2, seed_start=seed), None
|
230 |
+
|
231 |
+
|
232 |
+
def on_num_iterations_change(n, prompt_2):
|
233 |
+
if n is None:
|
234 |
+
return gr.update(value="")
|
235 |
+
x = 5 if prompt_2 != "" else 2.5
|
236 |
+
total_length = x + x * n
|
237 |
+
return gr.update(value=f"Total length: {total_length:.2f} seconds")
|
238 |
|
239 |
with gr.Blocks() as app:
|
240 |
+
gr.Markdown("## Riffusion")
|
241 |
+
gr.Markdown("""Generate audio using the [Riffusion](https://huggingface.co/riffusion/riffusion-model-v1) model.<br>
|
242 |
+
In single prompt mode you can generate up to ~1 minute of audio with smooth transitions between sections. (beta)<br>
|
243 |
+
Bi-prompt mode interpolates between two prompts. It can generate up to ~2 minutes of audio, but the transitions between sections are more abrupt.""")
|
244 |
+
|
245 |
+
with gr.Row():
|
246 |
+
with gr.Group():
|
247 |
+
with gr.Row():
|
248 |
+
prompt_1 = gr.Textbox(lines=1, label="Start from", placeholder="Starting prompt")
|
249 |
+
prompt_2 = gr.Textbox(lines=1, label="End with (optional)", placeholder="Prompt to shift towards at the end")
|
250 |
+
with gr.Row():
|
251 |
+
steps = gr.Slider(minimum=1, maximum=100, value=25, label="Steps per section")
|
252 |
+
num_iterations = gr.Slider(minimum=2, maximum=25, value=2, step=1, label="Number of sections")
|
253 |
+
with gr.Row():
|
254 |
+
feel = gr.Dropdown(["og_beat", "agile", "vibes", "motorway", "marim"], value="og_beat", label="Feel")
|
255 |
+
seed = gr.Slider(minimum=0, maximum=4294967295, value=0, step=1, label="Seed (0 for random)")
|
256 |
+
|
257 |
+
info = gr.Markdown()
|
258 |
+
btn_generate = gr.Button(value="Generate")
|
259 |
+
with gr.Column():
|
260 |
+
video = gr.Video()
|
261 |
|
262 |
+
inputs = [prompt_1, prompt_2, steps, num_iterations, feel, seed]
|
263 |
+
outputs = [video, info]
|
264 |
|
265 |
+
num_iterations.change(on_num_iterations_change, [num_iterations, prompt_2], [info])
|
266 |
+
prompt_1.submit(on_submit, inputs, outputs)
|
267 |
+
prompt_2.submit(on_submit, inputs, outputs)
|
268 |
+
btn_generate.click(on_submit, inputs, outputs)
|
269 |
|
270 |
+
examples = gr.Examples(
|
271 |
+
examples=[
|
272 |
+
["typing", "dance beat", "og_beat", 10],
|
273 |
+
["synthwave", "jazz", "agile", 10],
|
274 |
+
["rap battle freestyle", "", "og_beat", 10],
|
275 |
+
["techno club banger", "", "og_beat", 10],
|
276 |
+
["acoustic folk ballad", "", "agile", 10],
|
277 |
+
["blues guitar riff", "", "agile", 5],
|
278 |
+
["jazzy trumpet solo", "", "og_beat", 5],
|
279 |
+
["classical symphony orchestra", "", "vibes", 10],
|
280 |
+
["rock and roll power chord", "", "motorway", 5],
|
281 |
+
["soulful R&B love song", "", "marim", 10],
|
282 |
+
["reggae dub beat", "sunset chill", "og_beat", 10],
|
283 |
+
["country western twangy guitar", "", "agile", 10]],
|
284 |
+
inputs=[prompt_1, prompt_2, feel, num_iterations])
|
285 |
|
286 |
+
app.launch()
|