animikhaich commited on
Commit
032d7c2
1 Parent(s): ac90b4d

Added Google API key input + text field persistence

Browse files
Files changed (2) hide show
  1. engine/video_descriptor.py +2 -2
  2. main.py +55 -22
engine/video_descriptor.py CHANGED
@@ -37,9 +37,9 @@ You must return your response using this JSON schema: {json_schema}
37
 
38
 
39
  class DescribeVideo:
40
- def __init__(self, model="flash"):
41
  self.model = self.get_model_name(model)
42
- __api_key = self.load_api_key()
43
  self.is_safety_set = False
44
  self.safety_settings = self.get_safety_settings()
45
 
 
37
 
38
 
39
  class DescribeVideo:
40
+ def __init__(self, model="flash", google_api_key=None):
41
  self.model = self.get_model_name(model)
42
+ __api_key = google_api_key # self.load_api_key()
43
  self.is_safety_set = False
44
  self.safety_settings = self.get_safety_settings()
45
 
main.py CHANGED
@@ -83,9 +83,30 @@ if "orig_audio_vol" not in st.session_state:
83
  st.session_state.orig_audio_vol = 100
84
  if "generated_audio_vol" not in st.session_state:
85
  st.session_state.generated_audio_vol = 100
 
 
 
 
 
 
 
 
 
 
86
 
87
  # Sidebar
88
- st.sidebar.title("Settings")
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Basic Settings
91
  st.session_state.video_model = st.sidebar.selectbox(
@@ -141,15 +162,19 @@ generate_button = st.sidebar.button("Generate Music")
141
 
142
  # Cache the model loading
143
  @st.cache_resource
144
- def load_models(video_model_key, music_model_key):
145
- video_descriptor = DescribeVideo(model=video_model_map[video_model_key])
 
 
146
  audio_generator = GenerateAudio(model=music_model_map[music_model_key])
147
  return video_descriptor, audio_generator
148
 
149
 
150
  # Load models
151
  video_descriptor, audio_generator = load_models(
152
- st.session_state.video_model, st.session_state.music_model
 
 
153
  )
154
 
155
  # Video Uploader
@@ -177,31 +202,37 @@ if generate_button:
177
  user_keywords=st.session_state.user_keywords,
178
  )
179
  video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration
180
- music_prompt = video_description["Music Prompt"]
 
 
 
181
 
182
  st.success("Video description generated successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- # Display Video Description and Music Prompt
185
- st.text_area(
186
- "Video Description",
187
- video_description["Content Description"],
188
- disabled=True,
189
- height=120,
190
- )
191
- music_prompt = st.text_area(
192
- "Music Prompt",
193
- music_prompt,
194
- disabled=True,
195
- height=120,
196
- )
197
-
198
  # Generate Music
199
  with st.spinner("Generating music..."):
200
  if video_duration > 30:
201
  st.warning(
202
  "Due to hardware limitations, the maximum music length is capped at 30 seconds."
203
  )
204
- music_prompt = [music_prompt] * st.session_state.num_samples
205
  audio_generator.generate_audio(music_prompt, duration=video_duration)
206
  st.session_state.audio_paths = audio_generator.save_audio()
207
  st.success("Music generated successfully.")
@@ -210,6 +241,7 @@ if generate_button:
210
 
211
  # Callback function for radio button selection change
212
  def on_audio_selection_change():
 
213
  selected_audio_index = st.session_state.selected_audio
214
  if selected_audio_index > 0:
215
  st.session_state.selected_audio_path = st.session_state.audio_paths[
@@ -235,14 +267,15 @@ if st.session_state.audio_paths:
235
  format_func=lambda x: audio_options[x],
236
  index=0,
237
  key="selected_audio",
 
238
  )
239
 
240
  # Button to confirm the selection
241
  if st.button("Add Generated Music to Video"):
242
- on_audio_selection_change()
243
 
244
  # Handle Audio Mixing and Export
245
- if st.session_state.selected_audio_path is not None:
246
  with st.spinner("Mixing Audio..."):
247
  orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4")
248
  orig_clip_audio = orig_clip.audio
 
83
  st.session_state.orig_audio_vol = 100
84
  if "generated_audio_vol" not in st.session_state:
85
  st.session_state.generated_audio_vol = 100
86
+ if "generate_button_flag" not in st.session_state:
87
+ st.session_state.generate_button_flag = False
88
+ if "video_description_content" not in st.session_state:
89
+ st.session_state.video_description_content = ""
90
+ if "music_prompt" not in st.session_state:
91
+ st.session_state.music_prompt = ""
92
+ if "audio_mix_flag" not in st.session_state:
93
+ st.session_state.audio_mix_flag = False
94
+ if "google_api_key" not in st.session_state:
95
+ st.session_state.google_api_key = ""
96
 
97
  # Sidebar
98
+ st.sidebar.title("Configuration")
99
+
100
+ # Google API Key
101
+ st.session_state.google_api_key = st.sidebar.text_input(
102
+ "Enter your Google API Key to get started:",
103
+ st.session_state.google_api_key,
104
+ type="password",
105
+ )
106
+
107
+ if not st.session_state.google_api_key:
108
+ st.warning("Please enter your Google API Key to proceed.")
109
+ st.stop()
110
 
111
  # Basic Settings
112
  st.session_state.video_model = st.sidebar.selectbox(
 
162
 
163
  # Cache the model loading
164
  @st.cache_resource
165
+ def load_models(video_model_key, music_model_key, google_api_key):
166
+ video_descriptor = DescribeVideo(
167
+ model=video_model_map[video_model_key], google_api_key=google_api_key
168
+ )
169
  audio_generator = GenerateAudio(model=music_model_map[music_model_key])
170
  return video_descriptor, audio_generator
171
 
172
 
173
  # Load models
174
  video_descriptor, audio_generator = load_models(
175
+ st.session_state.video_model,
176
+ st.session_state.music_model,
177
+ st.session_state.google_api_key,
178
  )
179
 
180
  # Video Uploader
 
202
  user_keywords=st.session_state.user_keywords,
203
  )
204
  video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration
205
+ st.session_state.video_description_content = video_description[
206
+ "Content Description"
207
+ ]
208
+ st.session_state.music_prompt = video_description["Music Prompt"]
209
 
210
  st.success("Video description generated successfully.")
211
+ st.session_state.generate_button_flag = True
212
+
213
+ # Display Video Description and Music Prompt
214
+ if st.session_state.generate_button_flag:
215
+ st.text_area(
216
+ "Video Description",
217
+ st.session_state.video_description_content,
218
+ disabled=True,
219
+ height=120,
220
+ )
221
+ music_prompt = st.text_area(
222
+ "Music Prompt",
223
+ st.session_state.music_prompt,
224
+ disabled=True,
225
+ height=120,
226
+ )
227
 
228
+ if generate_button:
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  # Generate Music
230
  with st.spinner("Generating music..."):
231
  if video_duration > 30:
232
  st.warning(
233
  "Due to hardware limitations, the maximum music length is capped at 30 seconds."
234
  )
235
+ music_prompt = [st.session_state.music_prompt] * st.session_state.num_samples
236
  audio_generator.generate_audio(music_prompt, duration=video_duration)
237
  st.session_state.audio_paths = audio_generator.save_audio()
238
  st.success("Music generated successfully.")
 
241
 
242
  # Callback function for radio button selection change
243
  def on_audio_selection_change():
244
+ st.session_state.audio_mix_flag = False
245
  selected_audio_index = st.session_state.selected_audio
246
  if selected_audio_index > 0:
247
  st.session_state.selected_audio_path = st.session_state.audio_paths[
 
267
  format_func=lambda x: audio_options[x],
268
  index=0,
269
  key="selected_audio",
270
+ on_change=on_audio_selection_change,
271
  )
272
 
273
  # Button to confirm the selection
274
  if st.button("Add Generated Music to Video"):
275
+ st.session_state.audio_mix_flag = True
276
 
277
  # Handle Audio Mixing and Export
278
+ if st.session_state.selected_audio_path is not None and st.session_state.audio_mix_flag:
279
  with st.spinner("Mixing Audio..."):
280
  orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4")
281
  orig_clip_audio = orig_clip.audio