anzorq commited on
Commit
3df612c
1 Parent(s): 3523177
Files changed (1) hide show
  1. app.py +266 -25
app.py CHANGED
@@ -1,45 +1,286 @@
1
- from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
2
  import torch
3
- import gradio as gr
4
  import os
5
  import numpy as np
6
  from scipy.io.wavfile import read
 
 
7
 
8
  os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion')
9
 
 
 
 
 
 
 
 
10
  repo_id = "riffusion/riffusion-model-v1"
11
- pipe = DiffusionPipeline.from_pretrained(repo_id)
12
 
13
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  if torch.cuda.is_available():
15
- pipe = pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def infer(prompt, steps):
 
 
18
 
19
- from riffusion.riffusion import audio
20
 
21
- mel_spectr = pipe(prompt, num_inference_steps=steps).images[0]
22
- wav_bytes, duration_s = audio.wav_bytes_from_spectrogram_image(mel_spectr)
23
 
24
- return read(wav_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  with gr.Blocks() as app:
27
- with gr.Row():
28
- with gr.Column():
29
- prompt = gr.Textbox(lines=1, label="Prompt")
30
- steps = gr.Slider(minimum=1, maximum=100, value=25, label="Steps")
31
- btn_generate = gr.Button(value="Generate")
32
- with gr.Column():
33
- audio = gr.Audio(label="Audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- inputs = [prompt, steps]
36
- outputs = [audio]
37
 
38
- prompt.submit(infer, inputs, outputs)
39
- btn_generate.click(infer, inputs, outputs)
 
 
40
 
41
- examples = gr.Examples(
42
- examples=[["rap battle freestyle"], ["techno club banger"], ["acoustic folk ballad"], ["blues guitar riff"], ["jazzy trumpet solo"], ["classical symphony orchestra"], ["rock and roll power chord"], ["soulful R&B love song"], ["reggae dub beat"], ["country western twangy guitar"], ["all 25 steps"]],
43
- inputs=[prompt])
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- app.launch()
 
1
+ from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline
2
  import torch
3
+ from PIL import Image, ImageDraw
4
  import os
5
  import numpy as np
6
  from scipy.io.wavfile import read
7
+ import gradio as gr
8
+
9
 
10
  os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion')
11
 
12
+ from riffusion.riffusion.riffusion_pipeline import RiffusionPipeline
13
+ from riffusion.riffusion.datatypes import PromptInput, InferenceInput
14
+ from riffusion.riffusion.audio import wav_bytes_from_spectrogram_image
15
+ from PIL import Image
16
+ import struct
17
+ import random
18
+
19
  repo_id = "riffusion/riffusion-model-v1"
 
20
 
21
+ model = RiffusionPipeline.from_pretrained(
22
+ repo_id,
23
+ revision="main",
24
+ torch_dtype=torch.float16,
25
+ safety_checker=lambda images, **kwargs: (images, False),
26
+ )
27
+
28
+ if torch.cuda.is_available():
29
+ model.to("cuda")
30
+
31
+ pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, safety_checker=lambda images, **kwargs: (images, False),)
32
+ pipe_inpaint.scheduler = DPMSolverMultistepScheduler.from_config(pipe_inpaint.scheduler.config)
33
+
34
+ # pipe_inpaint.enable_xformers_memory_efficient_attention()
35
+
36
  if torch.cuda.is_available():
37
+ pipe_inpaint = pipe_inpaint.to("cuda")
38
+
39
+
40
+ def get_init_image(image, overlap, feel):
41
+
42
+ width, height = image.size
43
+ init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
44
+ # Crop the right side of the original image with `overlap_width`
45
+ cropped_img = image.crop((width - int(width*overlap), 0, width, height))
46
+ init_image.paste(cropped_img, (0, 0))
47
+
48
+ return init_image
49
+
50
+ def get_mask(image, overlap):
51
+
52
+ width, height = image.size
53
+
54
+ mask = Image.new("RGB", (width, height), color="white")
55
+ draw = ImageDraw.Draw(mask)
56
+ draw.rectangle((0, 0, int(overlap * width), height), fill="black")
57
+ return mask
58
+
59
+ def i2i(prompt, steps, feel, seed):
60
+ # return pipe_i2i(
61
+ # prompt,
62
+ # num_inference_steps=steps,
63
+ # image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB"),
64
+ # ).images[0]
65
+
66
+ prompt_input_start = PromptInput(prompt=prompt, seed=seed)
67
+ prompt_input_end = PromptInput(prompt=prompt, seed=seed)
68
+
69
+ return model.riffuse(
70
+ inputs=InferenceInput(
71
+ start=prompt_input_start,
72
+ end=prompt_input_end,
73
+ alpha=1.0,
74
+ num_inference_steps=steps),
75
+ init_image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
76
+ )
77
+
78
+ def outpaint(prompt, init_image, mask, steps):
79
+ return pipe_inpaint(
80
+ prompt,
81
+ num_inference_steps=steps,
82
+ image=init_image,
83
+ mask_image=mask,
84
+ ).images[0]
85
+
86
+
87
+ def generate(prompt, steps, num_iterations, feel, seed):
88
+
89
+ if seed == 0:
90
+ seed = random.randint(0,4294967295)
91
+
92
+ num_images = num_iterations
93
+ overlap = 0.5
94
+ image_width, image_height = 512, 512 # dimensions of each output image
95
+ total_width = num_images * image_width - (num_images - 1) * int(overlap * image_width) # total width of the stitched image
96
+
97
+ # Create a blank image with the desired dimensions
98
+ stitched_image = Image.new("RGB", (total_width, image_height), color="white")
99
+
100
+ # Initialize the x position for pasting the next image
101
+ x_pos = 0
102
+
103
+ image = i2i(prompt, steps, feel, seed)
104
+
105
+ for i in range(num_images):
106
+ # Generate the prompt, initial image, and mask for this iteration
107
+ init_image = get_init_image(image, overlap, feel)
108
+ mask = get_mask(init_image, overlap)
109
+
110
+ # Run the outpaint function to generate the output image
111
+ steps = 25
112
+ image = outpaint(prompt, init_image, mask, steps)
113
+
114
+ # Paste the output image onto the stitched image
115
+ stitched_image.paste(image, (x_pos, 0))
116
+
117
+ # Update the x position for the next iteration
118
+ x_pos += int((1 - overlap) * image_width)
119
+
120
+ wav_bytes, duration_s = wav_bytes_from_spectrogram_image(stitched_image)
121
+
122
+ mask = Image.new("RGB", (512, 512), color="white")
123
+ bg_image = outpaint(prompt, init_image, mask, steps)
124
+ bg_image.save("bg_image.png")
125
 
126
+ # return read(wav_bytes)
127
+ with open("output.wav", "wb") as f:
128
+ f.write(wav_bytes.read())
129
 
130
+ return gr.make_waveform("output.wav", bg_image="bg_image.png", bar_count=int(duration_s*25))
131
 
 
 
132
 
133
+ ###############################################
134
+
135
+ def riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start=0.75, guidance_start=7.0, prompt_end=None, seed_end=None, denoising_end=0.75, guidance_end=7.0, alpha=0.5):
136
+
137
+ prompt_input_start = PromptInput(prompt=prompt_start, seed=seed_start, denoising=denoising_start, guidance=guidance_start)
138
+
139
+ prompt_input_end = PromptInput(prompt=prompt_end, seed=seed_end, denoising=denoising_end, guidance=guidance_end)
140
+
141
+ input = InferenceInput(
142
+ start=prompt_input_start,
143
+ end=prompt_input_end,
144
+ alpha=alpha,
145
+ num_inference_steps=steps,
146
+ seed_image_id=feel,
147
+ # mask_image_id="mask_beat_lines_80.png"
148
+ )
149
+
150
+ image = model.riffuse(inputs=input, init_image=init_image)
151
+
152
+ wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
153
+
154
+ return wav_bytes, image
155
+
156
+ def generate_riffuse(prompt_start, steps, num_iterations, feel, prompt_end=None, seed_start=None, seed_end=None, denoising_start=0.75, denoising_end=0.75, guidance_start=7.0, guidance_end=7.0):
157
+ """Generate a WAV file of length seconds using the Riffusion model.
158
+
159
+ Args:
160
+ length (int): Length of the WAV file in seconds, must be divisible by 5.
161
+ prompt_start (str): Prompt to start with.
162
+ prompt_end (str, optional): Prompt to end with. Defaults to prompt_start.
163
+ overlap (float, optional): Overlap between audio clips as a fraction of the image size. Defaults to 0.2.
164
+ """
165
+
166
+ # open the initial image and convert it to RGB
167
+ init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
168
+
169
+ if prompt_end is None:
170
+ prompt_end = prompt_start
171
+ if seed_start is None:
172
+ seed_start = random.randint(0,4294967295)
173
+ if seed_end is None:
174
+ seed_end = seed_start
175
+
176
+ # one riffuse() generates 5 seconds of audio
177
+ wav_list = []
178
+
179
+ for i in range(int(num_iterations)):
180
+
181
+ alpha = i / (num_iterations - 1)
182
+ print(alpha)
183
+ wav_bytes, image = riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start, guidance_start, prompt_end, seed_end, denoising_end, guidance_end, alpha=alpha)
184
+ wav_list.append(wav_bytes)
185
+
186
+ init_image = image
187
+
188
+ seed_start = seed_end
189
+ seed_end = seed_start + 1
190
+
191
+ # return read(wav_bytes)
192
+
193
+ mask = Image.new("RGB", (512, 512), color="white")
194
+ bg_image = outpaint(f"{prompt_start} and {prompt_end}", init_image, mask, steps)
195
+ bg_image.save("bg_image.png")
196
+
197
+ with open("output.wav", "wb") as f:
198
+ f.write(wav_bytes.read())
199
+
200
+ return gr.make_waveform("output.wav", bg_image="bg_image.png")
201
+
202
+
203
+ def wav_list_to_wav(wav_list):
204
+
205
+ # remove headers from the WAV files
206
+ data = [wav.read()[44:] for wav in wav_list]
207
+
208
+ # concatenate the data
209
+ concatenated_data = b"".join(data)
210
+
211
+ # create a new RIFF header
212
+ channels = 1
213
+ sample_rate = 44100
214
+ bytes_per_second = channels * sample_rate
215
+ new_header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", len(concatenated_data) + 44 - 8, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, bytes_per_second, 2, 16, b"data", len(concatenated_data))
216
+
217
+ # combine the header and data to create the final WAV file
218
+ final_wav = new_header + concatenated_data
219
+ return final_wav
220
+
221
+ ###############################################
222
+
223
+ def on_submit(prompt_1, prompt_2, steps, num_iterations, feel, seed):
224
+ if prompt_1 == "":
225
+ return None, gr.update(value="First prompt is required.")
226
+ if prompt_2 == "":
227
+ return generate(prompt_1, steps, num_iterations, feel, seed), None
228
+ else:
229
+ return generate_riffuse(prompt_1, steps, num_iterations, feel, prompt_end=prompt_2, seed_start=seed), None
230
+
231
+
232
+ def on_num_iterations_change(n, prompt_2):
233
+ if n is None:
234
+ return gr.update(value="")
235
+ x = 5 if prompt_2 != "" else 2.5
236
+ total_length = x + x * n
237
+ return gr.update(value=f"Total length: {total_length:.2f} seconds")
238
 
239
  with gr.Blocks() as app:
240
+ gr.Markdown("## Riffusion")
241
+ gr.Markdown("""Generate audio using the [Riffusion](https://huggingface.co/riffusion/riffusion-model-v1) model.<br>
242
+ In single prompt mode you can generate up to ~1 minute of audio with smooth transitions between sections. (beta)<br>
243
+ Bi-prompt mode interpolates between two prompts. It can generate up to ~2 minutes of audio, but the transitions between sections are more abrupt.""")
244
+
245
+ with gr.Row():
246
+ with gr.Group():
247
+ with gr.Row():
248
+ prompt_1 = gr.Textbox(lines=1, label="Start from", placeholder="Starting prompt")
249
+ prompt_2 = gr.Textbox(lines=1, label="End with (optional)", placeholder="Prompt to shift towards at the end")
250
+ with gr.Row():
251
+ steps = gr.Slider(minimum=1, maximum=100, value=25, label="Steps per section")
252
+ num_iterations = gr.Slider(minimum=2, maximum=25, value=2, step=1, label="Number of sections")
253
+ with gr.Row():
254
+ feel = gr.Dropdown(["og_beat", "agile", "vibes", "motorway", "marim"], value="og_beat", label="Feel")
255
+ seed = gr.Slider(minimum=0, maximum=4294967295, value=0, step=1, label="Seed (0 for random)")
256
+
257
+ info = gr.Markdown()
258
+ btn_generate = gr.Button(value="Generate")
259
+ with gr.Column():
260
+ video = gr.Video()
261
 
262
+ inputs = [prompt_1, prompt_2, steps, num_iterations, feel, seed]
263
+ outputs = [video, info]
264
 
265
+ num_iterations.change(on_num_iterations_change, [num_iterations, prompt_2], [info])
266
+ prompt_1.submit(on_submit, inputs, outputs)
267
+ prompt_2.submit(on_submit, inputs, outputs)
268
+ btn_generate.click(on_submit, inputs, outputs)
269
 
270
+ examples = gr.Examples(
271
+ examples=[
272
+ ["typing", "dance beat", "og_beat", 10],
273
+ ["synthwave", "jazz", "agile", 10],
274
+ ["rap battle freestyle", "", "og_beat", 10],
275
+ ["techno club banger", "", "og_beat", 10],
276
+ ["acoustic folk ballad", "", "agile", 10],
277
+ ["blues guitar riff", "", "agile", 5],
278
+ ["jazzy trumpet solo", "", "og_beat", 5],
279
+ ["classical symphony orchestra", "", "vibes", 10],
280
+ ["rock and roll power chord", "", "motorway", 5],
281
+ ["soulful R&B love song", "", "marim", 10],
282
+ ["reggae dub beat", "sunset chill", "og_beat", 10],
283
+ ["country western twangy guitar", "", "agile", 10]],
284
+ inputs=[prompt_1, prompt_2, feel, num_iterations])
285
 
286
+ app.launch()