animikhaich commited on
Commit
5254bd1
1 Parent(s): 9ed1e74

UI and some headache fixes

Browse files
Files changed (2) hide show
  1. engine/audio_generator.py +7 -4
  2. main.py +152 -63
engine/audio_generator.py CHANGED
@@ -9,11 +9,13 @@ import numpy as np
9
  from audiocraft.models import musicgen
10
  from scipy.io.wavfile import write as wav_write
11
 
12
- try:
13
- from logger import logging
14
- except:
15
- import logging
16
 
 
 
 
 
 
17
 
18
  class GenerateAudio:
19
  def __init__(self, model="musicgen-stereo-small"):
@@ -75,6 +77,7 @@ class GenerateAudio:
75
  self.result = result.cpu().numpy().T
76
  self.result = self.result.transpose((2, 0, 1))
77
  self.sampling_rate = self.model.sample_rate
 
78
  logging.info(
79
  f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
80
  )
 
9
  from audiocraft.models import musicgen
10
  from scipy.io.wavfile import write as wav_write
11
 
12
+ import logging
 
 
 
13
 
14
+ FORMAT = "%(asctime)s: %(levelname)s: %(message)s"
15
+ logging.basicConfig(filename='logs.log', level=logging.INFO, format=FORMAT)
16
+ STDERRLOGGER = logging.StreamHandler()
17
+ STDERRLOGGER.setFormatter(logging.Formatter(FORMAT))
18
+ logging.getLogger().addHandler(STDERRLOGGER)
19
 
20
  class GenerateAudio:
21
  def __init__(self, model="musicgen-stereo-small"):
 
77
  self.result = result.cpu().numpy().T
78
  self.result = self.result.transpose((2, 0, 1))
79
  self.sampling_rate = self.model.sample_rate
80
+
81
  logging.info(
82
  f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
83
  )
main.py CHANGED
@@ -1,5 +1,8 @@
1
  import streamlit as st
2
  from engine import DescribeVideo, GenerateAudio
 
 
 
3
 
4
 
5
  video_model_map = {
@@ -13,79 +16,165 @@ music_model_map = {
13
  "Quality": "musicgen-stereo-large",
14
  }
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- st.set_page_config(page_title="VidTune: Where Videos Find Their Melody", layout="centered")
 
 
18
 
19
  # Title and Description
20
  st.title("VidTune: Where Videos Find Their Melody")
21
- st.write("VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video.")
 
 
22
 
23
 
24
  # Sidebar
25
  st.sidebar.title("Settings")
26
- video_model = st.sidebar.selectbox("Select Video Descriptor", ["Fast", "Balanced", "Quality"], index=0)
27
- music_model = st.sidebar.selectbox("Select Music Generator", ["Fast", "Balanced", "Quality"], index=0)
28
- num_samples = st.sidebar.slider("Number of samples", 1, 8, 3)
 
 
 
 
 
29
  generate_button = st.sidebar.button("Generate Music")
30
 
31
- video_descriptor = DescribeVideo(model=video_model_map[video_model])
32
- audio_generator = GenerateAudio(model=music_model_map[music_model])
33
-
34
  video_description = None
35
 
36
- # Main Page (Page 1)
37
- if 'page' not in st.session_state:
38
- st.session_state.page = 'main'
39
-
40
- if st.session_state.page == 'main':
41
- st.header("Video to Music")
42
- uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
43
-
44
- if uploaded_video is not None:
45
- st.session_state.uploaded_video = uploaded_video
46
- with open("temp.mp4", mode='wb') as w:
47
- w.write(uploaded_video.getvalue())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  video_description = video_descriptor.describe_video("temp.mp4")
49
-
50
- st.session_state.page = 'video_to_music'
51
-
52
- if st.session_state.page == 'main':
53
- st.header("Prompt to Music")
54
- prompt = st.text_area("Prompt")
55
- if generate_button:
56
- st.session_state.prompt = prompt
57
- st.session_state.page = 'prompt_to_music'
58
-
59
- # Page 2a (If the user uploads a video)
60
- if st.session_state.page == 'video_to_music':
61
- st.video(st.session_state.uploaded_video)
62
-
63
- st.text_area("Video Description", "This is a fixed video description", disabled=True)
64
- st.text_area("Music Description")
65
-
66
- if generate_button:
67
- st.session_state.page = 'result'
68
- st.session_state.device = device
69
- st.session_state.num_samples = num_samples
70
-
71
- # Page 2b (If user selects "Prompt to Music" in Page 1)
72
- if st.session_state.page == 'prompt_to_music':
73
- st.sidebar.title("Settings")
74
- device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
75
- num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
76
-
77
- if generate_button:
78
- st.session_state.page = 'result'
79
- st.session_state.device = device
80
- st.session_state.num_samples = num_samples
81
-
82
- # Page 3 (Results Page)
83
- if st.session_state.page == 'result':
84
- st.header("Generated Music")
85
- for i in range(st.session_state.num_samples):
86
- st.write(f"Music Sample {i+1}")
87
- st.audio(f"Generated Music {i+1}.mp3", format='audio/mp3')
88
- st.download_button(f"Download Music {i+1}", f"Generated Music {i+1}.mp3")
89
-
90
- if st.button("Start Over"):
91
- st.session_state.page = 'main'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from engine import DescribeVideo, GenerateAudio
3
+ import os
4
+ import time
5
+ from moviepy.editor import VideoFileClip
6
 
7
 
8
  video_model_map = {
 
16
  "Quality": "musicgen-stereo-large",
17
  }
18
 
19
+ genre_map = {
20
+ "Pop": "Pop",
21
+ "Rock": "Rock",
22
+ "Hip Hop": "Hip-Hop/Rap",
23
+ "Jazz": "Jazz",
24
+ "Classical": "Classical",
25
+ "Blues": "Blues",
26
+ "Country": "Country",
27
+ "EDM": "Electronic/Dance",
28
+ "Metal": "Metal",
29
+ "Disco": "Disco",
30
+ "Lo-Fi": "Lo-Fi",
31
+ }
32
+
33
 
34
+ st.set_page_config(
35
+ page_title="VidTune: Where Videos Find Their Melody", layout="centered"
36
+ )
37
 
38
  # Title and Description
39
  st.title("VidTune: Where Videos Find Their Melody")
40
+ st.write(
41
+ "VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video."
42
+ )
43
 
44
 
45
  # Sidebar
46
  st.sidebar.title("Settings")
47
+ video_model = st.sidebar.selectbox(
48
+ "Select Video Descriptor", ["Fast", "Quality"], index=0
49
+ )
50
+ music_model = st.sidebar.selectbox(
51
+ "Select Music Generator", ["Fast", "Balanced", "Quality"], index=0
52
+ )
53
+ music_genre = st.sidebar.selectbox("Select Music Genre", list(genre_map.keys()))
54
+ num_samples = st.sidebar.slider("Number of samples", 1, 5, 3)
55
  generate_button = st.sidebar.button("Generate Music")
56
 
57
+ video_descriptor = None
58
+ audio_descriptor = None
 
59
  video_description = None
60
 
61
+ # Initialize Video Descriptor and Audio Generator
62
+ if video_descriptor is None or audio_descriptor is None:
63
+ video_descriptor = DescribeVideo(model=video_model_map[video_model])
64
+ audio_generator = GenerateAudio(model=music_model_map[music_model])
65
+
66
+
67
+ # Video Uploader
68
+ uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
69
+ if uploaded_video is not None:
70
+ st.session_state.uploaded_video = uploaded_video
71
+ with open("temp.mp4", mode="wb") as w:
72
+ w.write(uploaded_video.getvalue())
73
+
74
+ # Video Player
75
+ if os.path.exists("temp.mp4") and uploaded_video is not None:
76
+ st.video(uploaded_video)
77
+
78
+ # Submit button if video is not uploaded
79
+ if generate_button and uploaded_video is None:
80
+ st.error("Please upload a video before generating music.")
81
+ st.stop()
82
+
83
+
84
+ # Submit Button and music generation if video is uploaded
85
+ if generate_button and uploaded_video is not None:
86
+ with st.spinner("Analyzing video..."):
87
  video_description = video_descriptor.describe_video("temp.mp4")
88
+ video_duration = VideoFileClip("temp.mp4").duration
89
+ music_prompt = video_description["Music Prompt"]
90
+
91
+ st.success("Video description generated successfully.")
92
+
93
+ # Display Video Description and Music Prompt
94
+ st.text_area(
95
+ "Video Description",
96
+ video_description["Content Description"],
97
+ disabled=True,
98
+ height=120,
99
+ )
100
+ music_prompt = st.text_area(
101
+ "Music Prompt",
102
+ music_prompt,
103
+ disabled=False,
104
+ height=120,
105
+ )
106
+
107
+ # Generate Music
108
+ with st.spinner("Generating music..."):
109
+ music_prompt = [music_prompt] * num_samples
110
+ audio_generator.generate_audio(music_prompt, duration=video_duration)
111
+ audio_paths = audio_generator.save_audio()
112
+ st.success("Music generated successfully.")
113
+ for i, audio_path in enumerate(audio_paths):
114
+ st.audio(audio_path, format="audio/wav")
115
+ # st.download_button(
116
+ # label=f"Download Music {i+1}",
117
+ # data=open(audio_path, "rb"),
118
+ # file_name=f"Generated Music {i+1}.wav",
119
+ # mime="audio/wav",
120
+ # )
121
+
122
+
123
+ # # Main Page (Page 1)
124
+ # if "page" not in st.session_state:
125
+ # st.session_state.page = "main"
126
+
127
+ # if st.session_state.page == "main":
128
+ # st.header("Video to Music")
129
+ # uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
130
+
131
+ # if uploaded_video is not None:
132
+ # st.session_state.uploaded_video = uploaded_video
133
+ # with open("temp.mp4", mode="wb") as w:
134
+ # w.write(uploaded_video.getvalue())
135
+ # video_description = video_descriptor.describe_video("temp.mp4")
136
+
137
+ # st.session_state.page = "video_to_music"
138
+
139
+ # if st.session_state.page == "main":
140
+ # st.header("Prompt to Music")
141
+ # prompt = st.text_area("Prompt")
142
+ # if generate_button:
143
+ # st.session_state.prompt = prompt
144
+ # st.session_state.page = "prompt_to_music"
145
+
146
+ # # Page 2a (If the user uploads a video)
147
+ # if st.session_state.page == "video_to_music":
148
+ # st.video(st.session_state.uploaded_video)
149
+
150
+ # st.text_area(
151
+ # "Video Description", "This is a fixed video description", disabled=True
152
+ # )
153
+ # st.text_area("Music Description")
154
+
155
+ # if generate_button:
156
+ # st.session_state.page = "result"
157
+ # st.session_state.device = device
158
+ # st.session_state.num_samples = num_samples
159
+
160
+ # # Page 2b (If user selects "Prompt to Music" in Page 1)
161
+ # if st.session_state.page == "prompt_to_music":
162
+ # st.sidebar.title("Settings")
163
+ # device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
164
+ # num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
165
+
166
+ # if generate_button:
167
+ # st.session_state.page = "result"
168
+ # st.session_state.device = device
169
+ # st.session_state.num_samples = num_samples
170
+
171
+ # # Page 3 (Results Page)
172
+ # if st.session_state.page == "result":
173
+ # st.header("Generated Music")
174
+ # for i in range(st.session_state.num_samples):
175
+ # st.write(f"Music Sample {i+1}")
176
+ # st.audio(f"Generated Music {i+1}.mp3", format="audio/mp3")
177
+ # st.download_button(f"Download Music {i+1}", f"Generated Music {i+1}.mp3")
178
+
179
+ # if st.button("Start Over"):
180
+ # st.session_state.page = "main"