animikhaich commited on
Commit
ee4f393
1 Parent(s): 99896f8

Pre-Final Changes, Release Ready

Browse files
Files changed (3) hide show
  1. engine/audio_generator.py +33 -22
  2. engine/video_descriptor.py +23 -8
  3. main.py +94 -85
engine/audio_generator.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import warnings
 
3
 
4
  warnings.simplefilter("ignore")
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -9,13 +10,11 @@ import numpy as np
9
  from audiocraft.models import musicgen
10
  from scipy.io.wavfile import write as wav_write
11
 
12
- import logging
 
 
 
13
 
14
- FORMAT = "%(asctime)s: %(levelname)s: %(message)s"
15
- logging.basicConfig(filename='logs.log', level=logging.INFO, format=FORMAT)
16
- STDERRLOGGER = logging.StreamHandler()
17
- STDERRLOGGER.setFormatter(logging.Formatter(FORMAT))
18
- logging.getLogger().addHandler(STDERRLOGGER)
19
 
20
  class GenerateAudio:
21
  def __init__(self, model="musicgen-stereo-small"):
@@ -32,7 +31,9 @@ class GenerateAudio:
32
  logging.info(f"Loaded model: {model}")
33
  return model
34
  except Exception as e:
35
- logging.error(f"Failed to load model: {e}")
 
 
36
  raise ValueError(f"Failed to load model: {e}")
37
  return
38
 
@@ -41,14 +42,18 @@ class GenerateAudio:
41
  if model_name.startswith("facebook/"):
42
  return model_name
43
  return f"facebook/{model_name}"
44
-
45
  @staticmethod
46
  def duration_sanity_check(duration):
47
  if duration < 1:
48
- logging.warning("Duration is less than 1 second. Setting duration to 1 second.")
 
 
49
  return 1
50
  elif duration > 30:
51
- logging.warning("Duration is greater than 30 seconds. Setting duration to 30 seconds.")
 
 
52
  return 30
53
  return duration
54
 
@@ -62,16 +67,16 @@ class GenerateAudio:
62
  for prompt in prompts:
63
  if not isinstance(prompt, str):
64
  raise ValueError("Prompts should be a string or a list of strings.")
65
- if len(prompts) > 8: # Too many prompts will cause OOM error
66
  raise ValueError("Maximum number of prompts allowed is 8.")
67
  return prompts
68
-
69
 
70
  def generate_audio(self, prompts, duration=10):
71
  duration = self.duration_sanity_check(duration)
72
  prompts = self.prompts_sanity_check(prompts)
73
 
74
  try:
 
75
  if duration <= 30:
76
  self.model.set_generation_params(duration=duration)
77
  result = self.model.generate(prompts, progress=False)
@@ -79,18 +84,23 @@ class GenerateAudio:
79
  self.model.set_generation_params(duration=30)
80
  result = self.model.generate(prompts, progress=False)
81
  self.model.set_generation_params(duration=duration)
82
- result = self.model.generate_with_chroma(prompts, result, melody_sample_rate=self.sampling_rate, progress=False)
 
 
 
 
 
83
  self.result = result.cpu().numpy().T
84
  self.result = self.result.transpose((2, 0, 1))
85
- self.sampling_rate = self.model.sample_rate
86
 
87
  logging.info(
88
  f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
89
  )
90
- print(f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz")
91
  return self.sampling_rate, self.result
92
  except Exception as e:
93
- logging.error(f"Failed to generate audio: {e}")
 
 
94
  raise ValueError(f"Failed to generate audio: {e}")
95
 
96
  def save_audio(self, audio_dir="generated_audio"):
@@ -121,17 +131,18 @@ class GenerateAudio:
121
  buffers.append(buffer)
122
  return buffers
123
 
 
124
  if __name__ == "__main__":
125
  audio_gen = GenerateAudio()
126
  sample_rate, result = audio_gen.generate_audio(
127
  [
128
- "A piano playing a jazz melody",
129
- "A guitar playing a rock riff",
130
- "A LoFi music for coding"
131
- ],
132
- duration=10
133
  )
134
  paths = audio_gen.save_audio()
135
  print(f"Saved audio to: {paths}")
136
  buffers = audio_gen.get_audio_buffer()
137
- print(f"Audio buffers: {buffers}")
 
1
  import os
2
  import warnings
3
+ import traceback
4
 
5
  warnings.simplefilter("ignore")
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
 
10
  from audiocraft.models import musicgen
11
  from scipy.io.wavfile import write as wav_write
12
 
13
+ try:
14
+ from logger import logging
15
+ except:
16
+ import logging
17
 
 
 
 
 
 
18
 
19
  class GenerateAudio:
20
  def __init__(self, model="musicgen-stereo-small"):
 
31
  logging.info(f"Loaded model: {model}")
32
  return model
33
  except Exception as e:
34
+ logging.error(
35
+ f"Failed to load model: {e}, Traceback: {traceback.format_exc()}"
36
+ )
37
  raise ValueError(f"Failed to load model: {e}")
38
  return
39
 
 
42
  if model_name.startswith("facebook/"):
43
  return model_name
44
  return f"facebook/{model_name}"
45
+
46
  @staticmethod
47
  def duration_sanity_check(duration):
48
  if duration < 1:
49
+ logging.warning(
50
+ "Duration is less than 1 second. Setting duration to 1 second."
51
+ )
52
  return 1
53
  elif duration > 30:
54
+ logging.warning(
55
+ "Duration is greater than 30 seconds. Setting duration to 30 seconds."
56
+ )
57
  return 30
58
  return duration
59
 
 
67
  for prompt in prompts:
68
  if not isinstance(prompt, str):
69
  raise ValueError("Prompts should be a string or a list of strings.")
70
+ if len(prompts) > 8: # Too many prompts will cause OOM error
71
  raise ValueError("Maximum number of prompts allowed is 8.")
72
  return prompts
 
73
 
74
  def generate_audio(self, prompts, duration=10):
75
  duration = self.duration_sanity_check(duration)
76
  prompts = self.prompts_sanity_check(prompts)
77
 
78
  try:
79
+ self.sampling_rate = self.model.sample_rate
80
  if duration <= 30:
81
  self.model.set_generation_params(duration=duration)
82
  result = self.model.generate(prompts, progress=False)
 
84
  self.model.set_generation_params(duration=30)
85
  result = self.model.generate(prompts, progress=False)
86
  self.model.set_generation_params(duration=duration)
87
+ result = self.model.generate_with_chroma(
88
+ prompts,
89
+ result,
90
+ melody_sample_rate=self.sampling_rate,
91
+ progress=False,
92
+ )
93
  self.result = result.cpu().numpy().T
94
  self.result = self.result.transpose((2, 0, 1))
 
95
 
96
  logging.info(
97
  f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
98
  )
 
99
  return self.sampling_rate, self.result
100
  except Exception as e:
101
+ logging.error(
102
+ f"Failed to generate audio: {e}, Traceback: {traceback.format_exc()}"
103
+ )
104
  raise ValueError(f"Failed to generate audio: {e}")
105
 
106
  def save_audio(self, audio_dir="generated_audio"):
 
131
  buffers.append(buffer)
132
  return buffers
133
 
134
+
135
  if __name__ == "__main__":
136
  audio_gen = GenerateAudio()
137
  sample_rate, result = audio_gen.generate_audio(
138
  [
139
+ "A piano playing a jazz melody",
140
+ "A guitar playing a rock riff",
141
+ "A LoFi music for coding",
142
+ ],
143
+ duration=10,
144
  )
145
  paths = audio_gen.save_audio()
146
  print(f"Saved audio to: {paths}")
147
  buffers = audio_gen.get_audio_buffer()
148
+ print(f"Audio buffers: {buffers}")
engine/video_descriptor.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  from warnings import simplefilter
 
3
 
4
  simplefilter("ignore")
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -43,26 +44,38 @@ class DescribeVideo:
43
  self.safety_settings = self.get_safety_settings()
44
 
45
  genai.configure(api_key=__api_key)
46
- self.mllm_model = genai.GenerativeModel(self.model, system_instruction=gemini_instructions)
 
 
47
 
48
  logging.info(f"Initialized DescribeVideo with model: {self.model}")
49
 
50
  def describe_video(self, video_path, genre, bpm, user_keywords):
51
  video_file = genai.upload_file(video_path)
52
- logging.info(f"Uploaded video: {video_path}")
53
 
54
  while video_file.state.name == "PROCESSING":
55
  time.sleep(0.25)
56
  video_file = genai.get_file(video_file.name)
57
 
58
  if video_file.state.name == "FAILED":
59
- logging.error(f"Failed to upload video: {video_file.state.name}")
 
 
60
  raise ValueError(f"Failed to upload video: {video_file.state.name}")
61
-
62
- additional_keywords = ", ".join([genre, user_keywords, bpm]) + "bpm"
 
 
 
 
 
 
 
 
 
63
 
64
  response = self.mllm_model.generate_content(
65
- [video_file, f"Explain what is happening in this video. The following keywords are provided by the user for generating the music prompt: {additional_keywords}"],
66
  request_options={"timeout": 600},
67
  safety_settings=self.safety_settings,
68
  )
@@ -116,7 +129,9 @@ class DescribeVideo:
116
 
117
  api_key = creds.get("google_api_key", None)
118
  if api_key is None or not isinstance(api_key, str):
119
- logging.error(f"Google API key not found in {path}")
 
 
120
  raise ValueError(f"Gemini API key not found in {path}")
121
  return api_key
122
 
@@ -129,7 +144,7 @@ class DescribeVideo:
129
 
130
  if model not in models:
131
  logging.error(
132
- f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
133
  )
134
  raise ValueError(
135
  f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
 
1
  import os
2
  from warnings import simplefilter
3
+ import traceback
4
 
5
  simplefilter("ignore")
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
 
44
  self.safety_settings = self.get_safety_settings()
45
 
46
  genai.configure(api_key=__api_key)
47
+ self.mllm_model = genai.GenerativeModel(
48
+ self.model, system_instruction=gemini_instructions
49
+ )
50
 
51
  logging.info(f"Initialized DescribeVideo with model: {self.model}")
52
 
53
  def describe_video(self, video_path, genre, bpm, user_keywords):
54
  video_file = genai.upload_file(video_path)
 
55
 
56
  while video_file.state.name == "PROCESSING":
57
  time.sleep(0.25)
58
  video_file = genai.get_file(video_file.name)
59
 
60
  if video_file.state.name == "FAILED":
61
+ logging.error(
62
+ f"Failed to upload video: {video_file.state.name}, Traceback: {traceback.format_exc()}"
63
+ )
64
  raise ValueError(f"Failed to upload video: {video_file.state.name}")
65
+
66
+ additional_keywords = ", ".join(filter(None, [genre, user_keywords])) + (
67
+ f", {bpm} bpm" if bpm else ""
68
+ )
69
+
70
+ logging.info(f"Uploaded video: {video_path} and config: {additional_keywords}")
71
+
72
+ user_prompt = "Explain what is happening in this video."
73
+
74
+ if additional_keywords:
75
+ user_prompt += f" The following keywords are provided by the user for generating the music prompt: {additional_keywords}"
76
 
77
  response = self.mllm_model.generate_content(
78
+ [video_file, user_prompt],
79
  request_options={"timeout": 600},
80
  safety_settings=self.safety_settings,
81
  )
 
129
 
130
  api_key = creds.get("google_api_key", None)
131
  if api_key is None or not isinstance(api_key, str):
132
+ logging.error(
133
+ f"Google API key not found in {path}, Traceback: {traceback.format_exc()}"
134
+ )
135
  raise ValueError(f"Gemini API key not found in {path}")
136
  return api_key
137
 
 
144
 
145
  if model not in models:
146
  logging.error(
147
+ f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}, Traceback: {traceback.format_exc()}"
148
  )
149
  raise ValueError(
150
  f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
main.py CHANGED
@@ -1,10 +1,9 @@
1
  import streamlit as st
2
  from engine import DescribeVideo, GenerateAudio
3
  import os
4
- import time
5
  from moviepy.editor import VideoFileClip
6
 
7
-
8
  video_model_map = {
9
  "Fast": "flash",
10
  "Quality": "pro",
@@ -16,7 +15,13 @@ music_model_map = {
16
  "Quality": "musicgen-stereo-large",
17
  }
18
 
 
 
 
 
 
19
  genre_map = {
 
20
  "Pop": "Pop",
21
  "Rock": "Rock",
22
  "Hip Hop": "Hip-Hop/Rap",
@@ -30,7 +35,6 @@ genre_map = {
30
  "Lo-Fi": "Lo-Fi",
31
  }
32
 
33
-
34
  st.set_page_config(
35
  page_title="VidTune: Where Videos Find Their Melody", layout="centered"
36
  )
@@ -41,29 +45,90 @@ st.write(
41
  "VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video."
42
  )
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Sidebar
46
  st.sidebar.title("Settings")
47
- video_model = st.sidebar.selectbox(
48
- "Select Video Descriptor", ["Fast", "Quality"], index=0
 
 
 
 
 
 
 
 
 
49
  )
50
- music_model = st.sidebar.selectbox(
51
- "Select Music Generator", ["Fast", "Balanced", "Quality"], index=0
52
  )
53
- music_genre = st.sidebar.selectbox("Select Music Genre", list(genre_map.keys()))
54
- num_samples = st.sidebar.slider("Number of samples", 1, 5, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  generate_button = st.sidebar.button("Generate Music")
56
 
57
- video_descriptor = None
58
- audio_descriptor = None
59
- video_description = None
60
 
61
- # Initialize Video Descriptor and Audio Generator
62
- if video_descriptor is None or audio_descriptor is None:
63
- video_descriptor = DescribeVideo(model=video_model_map[video_model])
64
- audio_generator = GenerateAudio(model=music_model_map[music_model])
 
 
65
 
66
 
 
 
 
 
 
67
  # Video Uploader
68
  uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
69
  if uploaded_video is not None:
@@ -80,11 +145,15 @@ if generate_button and uploaded_video is None:
80
  st.error("Please upload a video before generating music.")
81
  st.stop()
82
 
83
-
84
  # Submit Button and music generation if video is uploaded
85
  if generate_button and uploaded_video is not None:
86
  with st.spinner("Analyzing video..."):
87
- video_description = video_descriptor.describe_video("temp.mp4")
 
 
 
 
 
88
  video_duration = VideoFileClip("temp.mp4").duration
89
  music_prompt = video_description["Music Prompt"]
90
 
@@ -106,75 +175,15 @@ if generate_button and uploaded_video is not None:
106
 
107
  # Generate Music
108
  with st.spinner("Generating music..."):
109
- music_prompt = [music_prompt] * num_samples
 
 
 
 
110
  audio_generator.generate_audio(music_prompt, duration=video_duration)
111
  audio_paths = audio_generator.save_audio()
112
  st.success("Music generated successfully.")
113
  for i, audio_path in enumerate(audio_paths):
114
  st.audio(audio_path, format="audio/wav")
115
- # st.download_button(
116
- # label=f"Download Music {i+1}",
117
- # data=open(audio_path, "rb"),
118
- # file_name=f"Generated Music {i+1}.wav",
119
- # mime="audio/wav",
120
- # )
121
-
122
-
123
- # # Main Page (Page 1)
124
- # if "page" not in st.session_state:
125
- # st.session_state.page = "main"
126
-
127
- # if st.session_state.page == "main":
128
- # st.header("Video to Music")
129
- # uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
130
-
131
- # if uploaded_video is not None:
132
- # st.session_state.uploaded_video = uploaded_video
133
- # with open("temp.mp4", mode="wb") as w:
134
- # w.write(uploaded_video.getvalue())
135
- # video_description = video_descriptor.describe_video("temp.mp4")
136
-
137
- # st.session_state.page = "video_to_music"
138
-
139
- # if st.session_state.page == "main":
140
- # st.header("Prompt to Music")
141
- # prompt = st.text_area("Prompt")
142
- # if generate_button:
143
- # st.session_state.prompt = prompt
144
- # st.session_state.page = "prompt_to_music"
145
-
146
- # # Page 2a (If the user uploads a video)
147
- # if st.session_state.page == "video_to_music":
148
- # st.video(st.session_state.uploaded_video)
149
-
150
- # st.text_area(
151
- # "Video Description", "This is a fixed video description", disabled=True
152
- # )
153
- # st.text_area("Music Description")
154
-
155
- # if generate_button:
156
- # st.session_state.page = "result"
157
- # st.session_state.device = device
158
- # st.session_state.num_samples = num_samples
159
-
160
- # # Page 2b (If user selects "Prompt to Music" in Page 1)
161
- # if st.session_state.page == "prompt_to_music":
162
- # st.sidebar.title("Settings")
163
- # device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
164
- # num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
165
-
166
- # if generate_button:
167
- # st.session_state.page = "result"
168
- # st.session_state.device = device
169
- # st.session_state.num_samples = num_samples
170
-
171
- # # Page 3 (Results Page)
172
- # if st.session_state.page == "result":
173
- # st.header("Generated Music")
174
- # for i in range(st.session_state.num_samples):
175
- # st.write(f"Music Sample {i+1}")
176
- # st.audio(f"Generated Music {i+1}.mp3", format="audio/mp3")
177
- # st.download_button(f"Download Music {i+1}", f"Generated Music {i+1}.mp3")
178
-
179
- # if st.button("Start Over"):
180
- # st.session_state.page = "main"
 
1
  import streamlit as st
2
  from engine import DescribeVideo, GenerateAudio
3
  import os
 
4
  from moviepy.editor import VideoFileClip
5
 
6
+ # Define model maps
7
  video_model_map = {
8
  "Fast": "flash",
9
  "Quality": "pro",
 
15
  "Quality": "musicgen-stereo-large",
16
  }
17
 
18
+ # music_model_map = {
19
+ # "Fast": "facebook/musicgen-melody",
20
+ # "Quality": "facebook/musicgen-melody-large",
21
+ # }
22
+
23
  genre_map = {
24
+ "None": None,
25
  "Pop": "Pop",
26
  "Rock": "Rock",
27
  "Hip Hop": "Hip-Hop/Rap",
 
35
  "Lo-Fi": "Lo-Fi",
36
  }
37
 
 
38
  st.set_page_config(
39
  page_title="VidTune: Where Videos Find Their Melody", layout="centered"
40
  )
 
45
  "VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video."
46
  )
47
 
48
+ # Initialize session state for advanced settings and other inputs
49
+ if "show_advanced" not in st.session_state:
50
+ st.session_state.show_advanced = False
51
+ if "video_model" not in st.session_state:
52
+ st.session_state.video_model = "Fast"
53
+ if "music_model" not in st.session_state:
54
+ st.session_state.music_model = "Fast"
55
+ if "num_samples" not in st.session_state:
56
+ st.session_state.num_samples = 3
57
+ if "music_genre" not in st.session_state:
58
+ st.session_state.music_genre = None
59
+ if "music_bpm" not in st.session_state:
60
+ st.session_state.music_bpm = 100
61
+ if "user_keywords" not in st.session_state:
62
+ st.session_state.user_keywords = None
63
 
64
  # Sidebar
65
  st.sidebar.title("Settings")
66
+
67
+ # Basic Settings
68
+ st.session_state.video_model = st.sidebar.selectbox(
69
+ "Select Video Descriptor",
70
+ ["Fast", "Quality"],
71
+ index=["Fast", "Quality"].index(st.session_state.video_model),
72
+ )
73
+ st.session_state.music_model = st.sidebar.selectbox(
74
+ "Select Music Generator",
75
+ ["Fast", "Balanced", "Quality"],
76
+ index=["Fast", "Balanced", "Quality"].index(st.session_state.music_model),
77
  )
78
+ st.session_state.num_samples = st.sidebar.slider(
79
+ "Number of samples", 1, 5, st.session_state.num_samples
80
  )
81
+
82
+ # Sidebar for advanced settings
83
+ with st.sidebar:
84
+ # Create a placeholder for the advanced settings button
85
+ placeholder = st.empty()
86
+
87
+ # Button to toggle advanced settings
88
+ if placeholder.button("Advanced"):
89
+ st.session_state.show_advanced = not st.session_state.show_advanced
90
+ st.rerun() # Refresh the layout after button click
91
+
92
+ # Display advanced settings if enabled
93
+ if st.session_state.show_advanced:
94
+ # Advanced settings
95
+ st.session_state.music_bpm = st.sidebar.slider("Beats Per Minute", 35, 180, 100)
96
+ st.session_state.music_genre = st.sidebar.selectbox(
97
+ "Select Music Genre",
98
+ list(genre_map.keys()),
99
+ index=(
100
+ list(genre_map.keys()).index(st.session_state.music_genre)
101
+ if st.session_state.music_genre in genre_map.keys()
102
+ else 0
103
+ ),
104
+ )
105
+ st.session_state.user_keywords = st.sidebar.text_input(
106
+ "User Keywords",
107
+ value=st.session_state.user_keywords,
108
+ help="Enter keywords separated by commas.",
109
+ )
110
+ else:
111
+ st.session_state.music_genre = None
112
+ st.session_state.music_bpm = None
113
+ st.session_state.user_keywords = None
114
+
115
+ # Generate Button
116
  generate_button = st.sidebar.button("Generate Music")
117
 
 
 
 
118
 
119
+ # Cache the model loading
120
+ @st.cache_resource
121
+ def load_models(video_model_key, music_model_key):
122
+ video_descriptor = DescribeVideo(model=video_model_map[video_model_key])
123
+ audio_generator = GenerateAudio(model=music_model_map[music_model_key])
124
+ return video_descriptor, audio_generator
125
 
126
 
127
+ # Load models
128
+ video_descriptor, audio_generator = load_models(
129
+ st.session_state.video_model, st.session_state.music_model
130
+ )
131
+
132
  # Video Uploader
133
  uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
134
  if uploaded_video is not None:
 
145
  st.error("Please upload a video before generating music.")
146
  st.stop()
147
 
 
148
  # Submit Button and music generation if video is uploaded
149
  if generate_button and uploaded_video is not None:
150
  with st.spinner("Analyzing video..."):
151
+ video_description = video_descriptor.describe_video(
152
+ "temp.mp4",
153
+ genre=st.session_state.music_genre,
154
+ bpm=st.session_state.music_bpm,
155
+ user_keywords=st.session_state.user_keywords,
156
+ )
157
  video_duration = VideoFileClip("temp.mp4").duration
158
  music_prompt = video_description["Music Prompt"]
159
 
 
175
 
176
  # Generate Music
177
  with st.spinner("Generating music..."):
178
+ if video_duration > 30:
179
+ st.warning(
180
+ "Due to hardware limitations, the maximum music length is capped at 30 seconds."
181
+ )
182
+ music_prompt = [music_prompt] * st.session_state.num_samples
183
  audio_generator.generate_audio(music_prompt, duration=video_duration)
184
  audio_paths = audio_generator.save_audio()
185
  st.success("Music generated successfully.")
186
  for i, audio_path in enumerate(audio_paths):
187
  st.audio(audio_path, format="audio/wav")
188
+
189
+ st.balloons()