Spaces:
Sleeping
Sleeping
import streamlit as st | |
from engine import DescribeVideo, GenerateAudio | |
import os | |
from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip | |
from moviepy.audio.fx.volumex import volumex | |
from streamlit.runtime.scriptrunner import get_script_run_ctx | |
def get_session_id(): | |
session_id = get_script_run_ctx().session_id | |
session_id = session_id.replace("-", "_") | |
session_id = "_id_" + session_id | |
return session_id | |
user_session_id = get_session_id() | |
os.makedirs(user_session_id, exist_ok=True) | |
# Define model maps | |
video_model_map = { | |
"Fast": "flash", | |
"Quality": "pro", | |
} | |
music_model_map = { | |
"Fast": "musicgen-stereo-small", | |
"Balanced": "musicgen-stereo-medium", | |
"Quality": "musicgen-stereo-large", | |
} | |
# music_model_map = { | |
# "Fast": "facebook/musicgen-melody", | |
# "Quality": "facebook/musicgen-melody-large", | |
# } | |
genre_map = { | |
"None": None, | |
"Pop": "Pop", | |
"Rock": "Rock", | |
"Hip Hop": "Hip-Hop/Rap", | |
"Jazz": "Jazz", | |
"Classical": "Classical", | |
"Blues": "Blues", | |
"Country": "Country", | |
"EDM": "Electronic/Dance", | |
"Metal": "Metal", | |
"Disco": "Disco", | |
"Lo-Fi": "Lo-Fi", | |
} | |
# Streamlit page configuration | |
st.set_page_config( | |
page_title="VidTune: Where Videos Find Their Melody", | |
layout="centered", | |
page_icon="assets/favicon.png", | |
) | |
left_co, cent_co, last_co = st.columns(3) | |
with cent_co: | |
st.image("assets/VidTune-Logo-Without-BG.png", use_column_width=False, width=200) | |
# Title and Description | |
st.markdown( | |
""" | |
<style> | |
h2, p, div, img { | |
text-align: center; | |
} | |
</style> | |
<div style="font-size: 35px; font-weight: bold;">VidTune: Where Videos Find Their Melody</div> | |
<p>VidTune is a web application to effortlessly tailor perfect soundtracks for your videos with AI.</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Initialize session state for advanced settings and other inputs | |
if "show_advanced" not in st.session_state: | |
st.session_state.show_advanced = False | |
if "video_model" not in st.session_state: | |
st.session_state.video_model = "Fast" | |
if "music_model" not in st.session_state: | |
st.session_state.music_model = "Fast" | |
if "num_samples" not in st.session_state: | |
st.session_state.num_samples = 3 | |
if "music_genre" not in st.session_state: | |
st.session_state.music_genre = None | |
if "music_bpm" not in st.session_state: | |
st.session_state.music_bpm = 100 | |
if "user_keywords" not in st.session_state: | |
st.session_state.user_keywords = None | |
if "selected_audio" not in st.session_state: | |
st.session_state.selected_audio = "None" | |
if "audio_paths" not in st.session_state: | |
st.session_state.audio_paths = [] | |
if "selected_audio_path" not in st.session_state: | |
st.session_state.selected_audio_path = None | |
if "orig_audio_vol" not in st.session_state: | |
st.session_state.orig_audio_vol = 100 | |
if "generated_audio_vol" not in st.session_state: | |
st.session_state.generated_audio_vol = 100 | |
if "generate_button_flag" not in st.session_state: | |
st.session_state.generate_button_flag = False | |
if "video_description_content" not in st.session_state: | |
st.session_state.video_description_content = "" | |
if "music_prompt" not in st.session_state: | |
st.session_state.music_prompt = "" | |
if "audio_mix_flag" not in st.session_state: | |
st.session_state.audio_mix_flag = False | |
if "google_api_key" not in st.session_state: | |
st.session_state.google_api_key = "" | |
# Sidebar | |
st.sidebar.title("Configuration") | |
# Google API Key | |
st.session_state.google_api_key = st.sidebar.text_input( | |
"Enter your [Google API Key](https://ai.google.dev/gemini-api/docs/api-key) to get started :", | |
st.session_state.google_api_key, | |
type="password", | |
) | |
if not st.session_state.google_api_key: | |
st.warning("Please enter your Google API Key to proceed.") | |
st.stop() | |
# Basic Settings | |
st.session_state.video_model = st.sidebar.selectbox( | |
"Select Video Descriptor", | |
["Fast", "Quality"], | |
index=["Fast", "Quality"].index(st.session_state.video_model), | |
) | |
st.session_state.music_model = st.sidebar.selectbox( | |
"Select Music Generator", | |
["Fast", "Balanced", "Quality"], | |
index=["Fast", "Balanced", "Quality"].index(st.session_state.music_model), | |
) | |
st.session_state.num_samples = st.sidebar.slider( | |
"Number of samples", 1, 5, st.session_state.num_samples | |
) | |
# Sidebar for advanced settings | |
with st.sidebar: | |
# Create a placeholder for the advanced settings button | |
placeholder = st.empty() | |
# Button to toggle advanced settings | |
if placeholder.button("Advanced"): | |
st.session_state.show_advanced = not st.session_state.show_advanced | |
st.rerun() # Refresh the layout after button click | |
# Display advanced settings if enabled | |
if st.session_state.show_advanced: | |
# Advanced settings | |
st.session_state.music_bpm = st.sidebar.slider("Beats Per Minute", 35, 180, 100) | |
st.session_state.music_genre = st.sidebar.selectbox( | |
"Select Music Genre", | |
list(genre_map.keys()), | |
index=( | |
list(genre_map.keys()).index(st.session_state.music_genre) | |
if st.session_state.music_genre in genre_map.keys() | |
else 0 | |
), | |
) | |
st.session_state.user_keywords = st.sidebar.text_input( | |
"User Keywords", | |
value=st.session_state.user_keywords, | |
help="Enter keywords separated by commas.", | |
) | |
else: | |
st.session_state.music_genre = None | |
st.session_state.music_bpm = None | |
st.session_state.user_keywords = None | |
# Generate Button | |
generate_button = st.sidebar.button("Generate Music") | |
# Cache the model loading | |
def load_models(video_model_key, music_model_key, google_api_key): | |
video_descriptor = DescribeVideo( | |
model=video_model_map[video_model_key], google_api_key=google_api_key | |
) | |
audio_generator = GenerateAudio(model=music_model_map[music_model_key]) | |
if audio_generator.device == "cpu": | |
st.warning( | |
"The music generator model is running on CPU. For faster results, consider using a GPU." | |
) | |
return video_descriptor, audio_generator | |
# Load models | |
video_descriptor, audio_generator = load_models( | |
st.session_state.video_model, | |
st.session_state.music_model, | |
st.session_state.google_api_key, | |
) | |
# Video Uploader | |
uploaded_video = st.file_uploader("Upload Video", type=["mp4"]) | |
if uploaded_video is not None: | |
st.session_state.uploaded_video = uploaded_video | |
with open(f"{user_session_id}/temp.mp4", mode="wb") as w: | |
w.write(uploaded_video.getvalue()) | |
# Video Player | |
if os.path.exists(f"{user_session_id}/temp.mp4") and uploaded_video is not None: | |
st.video(uploaded_video) | |
# Submit button if video is not uploaded | |
if generate_button: | |
if uploaded_video is None: | |
st.error("Please upload a video before generating music.") | |
st.stop() | |
with st.spinner("Analyzing video..."): | |
video_description = video_descriptor.describe_video( | |
f"{user_session_id}/temp.mp4", | |
genre=st.session_state.music_genre, | |
bpm=st.session_state.music_bpm, | |
user_keywords=st.session_state.user_keywords, | |
) | |
video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration | |
st.session_state.video_description_content = video_description[ | |
"Content Description" | |
] | |
st.session_state.music_prompt = video_description["Music Prompt"] | |
st.success("Video description generated successfully.") | |
st.session_state.generate_button_flag = True | |
# Display Video Description and Music Prompt | |
if st.session_state.generate_button_flag: | |
st.text_area( | |
"Video Description", | |
st.session_state.video_description_content, | |
disabled=True, | |
height=120, | |
) | |
music_prompt = st.text_area( | |
"Music Prompt", | |
st.session_state.music_prompt, | |
disabled=True, | |
height=120, | |
) | |
if generate_button: | |
# Generate Music | |
with st.spinner("Generating music..."): | |
if video_duration > 30: | |
st.warning( | |
"Due to hardware limitations, the maximum music length is capped at 30 seconds." | |
) | |
music_prompt = [st.session_state.music_prompt] * st.session_state.num_samples | |
audio_generator.generate_audio(music_prompt, duration=video_duration) | |
st.session_state.audio_paths = audio_generator.save_audio() | |
st.success("Music generated successfully.") | |
st.balloons() | |
# Callback function for radio button selection change | |
def on_audio_selection_change(): | |
st.session_state.audio_mix_flag = False | |
selected_audio_index = st.session_state.selected_audio | |
if selected_audio_index > 0: | |
st.session_state.selected_audio_path = st.session_state.audio_paths[ | |
selected_audio_index - 1 | |
] | |
else: | |
st.session_state.selected_audio_path = None | |
if st.session_state.audio_paths: | |
# Dropdown to select one of the generated audio files | |
audio_options = ["None"] + [ | |
f"Generated Music {i+1}" for i in range(len(st.session_state.audio_paths)) | |
] | |
# Display the audio files | |
for i, audio_path in enumerate(st.session_state.audio_paths): | |
st.audio(audio_path, format="audio/wav") | |
selected_audio_index = st.selectbox( | |
"Select one of the generated audio files for further processing:", | |
range(len(audio_options)), | |
format_func=lambda x: audio_options[x], | |
index=0, | |
key="selected_audio", | |
on_change=on_audio_selection_change, | |
) | |
# Button to confirm the selection | |
if st.button("Add Generated Music to Video"): | |
st.session_state.audio_mix_flag = True | |
# Handle Audio Mixing and Export | |
if st.session_state.selected_audio_path is not None and st.session_state.audio_mix_flag: | |
with st.spinner("Mixing Audio..."): | |
orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4") | |
orig_clip_audio = orig_clip.audio | |
generated_audio = AudioFileClip(st.session_state.selected_audio_path) | |
st.session_state.orig_audio_vol = st.slider( | |
"Original Audio Volume", | |
0, | |
200, | |
st.session_state.orig_audio_vol, | |
format="%d%%", | |
) | |
st.session_state.generated_audio_vol = st.slider( | |
"Generated Music Volume", | |
0, | |
200, | |
st.session_state.generated_audio_vol, | |
format="%d%%", | |
) | |
orig_clip_audio = volumex( | |
orig_clip_audio, float(st.session_state.orig_audio_vol / 100) | |
) | |
generated_audio = volumex( | |
generated_audio, float(st.session_state.generated_audio_vol / 100) | |
) | |
orig_clip.audio = CompositeAudioClip([orig_clip_audio, generated_audio]) | |
final_video_path = f"{user_session_id}/out_tmp.mp4" | |
orig_clip.write_videofile(final_video_path) | |
orig_clip.close() | |
generated_audio.close() | |
st.session_state.final_video_path = final_video_path | |
st.video(final_video_path) | |
if st.session_state.final_video_path: | |
with open(st.session_state.final_video_path, "rb") as video_file: | |
st.download_button( | |
label="Download final video", | |
data=video_file, | |
file_name="final_video.mp4", | |
mime="video/mp4", | |
) | |