animikhaich commited on
Commit
5978ae3
1 Parent(s): d8d2011

Added: Video Descriptor and UI Boilerplate

Browse files
.gitignore CHANGED
@@ -164,4 +164,7 @@ cython_debug/
164
 
165
 
166
  *.wav
167
- *.mp3
 
 
 
 
164
 
165
 
166
  *.wav
167
+ *.mp3
168
+ *.mp4
169
+
170
+ creds.json
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [browser]
2
+ gatherUsageStats = false
client.py DELETED
@@ -1,65 +0,0 @@
1
- import requests
2
- import argparse
3
-
4
- # Parse command line arguments
5
- parser = argparse.ArgumentParser(description="Music Generation Client")
6
- parser.add_argument(
7
- "--server_url", type=str, default="http://localhost:8000", help="URL of the server"
8
- )
9
- parser.add_argument(
10
- "--prompts",
11
- nargs="+",
12
- type=str,
13
- default=["Lofi Music for Coding"],
14
- help="Prompts for music generation",
15
- )
16
- parser.add_argument(
17
- "--output_file", type=str, default="output.wav", help="Output file name"
18
- )
19
- parser.add_argument(
20
- "--duration", type=int, default=10, help="Duration of generated music in seconds"
21
- )
22
- parser.add_argument(
23
- "--check_health", action='store_true', help="Check server health"
24
- )
25
-
26
- args = parser.parse_args()
27
-
28
- def generate_music(server_url, prompts, duration, output_file):
29
- url = f"{server_url}/generate_music"
30
- headers = {"Content-Type": "application/json"}
31
- data = {"prompts": prompts, "duration": duration}
32
-
33
- response = requests.post(url, json=data, headers=headers)
34
-
35
- if response.status_code == 200:
36
- with open(output_file, "wb") as f:
37
- f.write(response.content)
38
- print(f"Music saved to {output_file}")
39
- else:
40
- print(f"Failed to generate music: {response.status_code}, {response.text}")
41
-
42
- def check_server_health(server_url):
43
- url = f"{server_url}/health"
44
- response = requests.get(url)
45
-
46
- if response.status_code == 200:
47
- health_status = response.json()
48
- print("Server Health Check:")
49
- print(f"Server Running: {health_status['server_running']}")
50
- print(f"Model Loaded: {health_status['model_loaded']}")
51
- print(f"CPU Usage: {health_status['cpu_usage_percent']}%")
52
- print(f"RAM Usage: {health_status['ram_usage_percent']}%")
53
- if 'gpu_memory_allocated' in health_status:
54
- gpu_memory_allocated_gb = health_status['gpu_memory_allocated'] / (1024 ** 3)
55
- gpu_memory_reserved_gb = health_status['gpu_memory_reserved'] / (1024 ** 3)
56
- print(f"GPU Memory Allocated: {gpu_memory_allocated_gb:.2f} GB")
57
- print(f"GPU Memory Reserved: {gpu_memory_reserved_gb:.2f} GB")
58
- else:
59
- print(f"Failed to check server health: {response.status_code}, {response.text}")
60
-
61
- if __name__ == "__main__":
62
- if args.check_health:
63
- check_server_health(args.server_url)
64
- else:
65
- generate_music(args.server_url, args.prompts, args.duration, args.output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
def __init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .logger import logging
engine/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video_descriptor import DescribeVideo
engine/audio_generator.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # TODO: Add from model server
engine/video_descriptor.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from warnings import simplefilter
2
+
3
+ simplefilter("ignore")
4
+ import os
5
+
6
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
+ import json
8
+ import time
9
+ import google.generativeai as genai
10
+
11
+ try:
12
+ from logger import logging
13
+ except:
14
+ import logging
15
+
16
+ music_prompt_examples = """
17
+ 'A dynamic blend of hip-hop and orchestral elements, with sweeping strings and brass, evoking the vibrant energy of the city',
18
+ 'Smooth jazz, with a saxophone solo, piano chords, and snare full drums',
19
+ '90s rock song with electric guitar and heavy drums'.
20
+ """
21
+
22
+ json_schema = """
23
+ {"Content Description": "string", "Music Prompt": "string"}
24
+ """
25
+
26
+ gemni_instructions = f"""
27
+ You are a music supervisor who analyzes the content and tone of images and videos to describe music that fits well with the mood, evokes emotions, and enhances the narrative of the visuals. Given an image or video, describe the scene and generate a prompt suitable for music generation models. Use keywords related to genre, instruments, mood, context, and setting to craft a concise single-sentence prompt, like:
28
+
29
+ {music_prompt_examples}
30
+
31
+ You must return your response using this JSON schema: {json_schema}
32
+ """
33
+
34
+
35
+ class DescribeVideo:
36
+ def __init__(self, model="flash"):
37
+ self.model = self.get_model_name(model)
38
+ __api_key = self.load_api_key()
39
+ self.is_safety_set = False
40
+ self.safety_settings = self.get_safety_settings()
41
+
42
+ genai.configure(api_key=__api_key)
43
+ self.mllm_model = genai.GenerativeModel(self.model)
44
+
45
+ logging.info(f"Initialized DescribeVideo with model: {self.model}")
46
+
47
+ def describe_video(self, video_path):
48
+ video_file = genai.upload_file(video_path)
49
+ logging.info(f"Uploaded video: {video_path}")
50
+
51
+ while video_file.state.name == "PROCESSING":
52
+ time.sleep(0.25)
53
+ video_file = genai.get_file(video_file.name)
54
+
55
+ if video_file.state.name == "FAILED":
56
+ logging.error(f"Failed to upload video: {video_file.state.name}")
57
+ raise ValueError(f"Failed to upload video: {video_file.state.name}")
58
+
59
+ response = self.mllm_model.generate_content(
60
+ [video_file, "Explain what is happening in this video"],
61
+ request_options={"timeout": 600},
62
+ safety_settings=self.safety_settings,
63
+ )
64
+
65
+ logging.info(
66
+ f"Generated content for video: {video_path} with response: {response.text}"
67
+ )
68
+
69
+ cleaned_response = self.mllm_model.generate_content(
70
+ [
71
+ response.text,
72
+ gemni_instructions,
73
+ ],
74
+ safety_settings=self.safety_settings,
75
+ )
76
+
77
+ logging.info(f"Generated : {video_path} with response: {cleaned_response.text}")
78
+
79
+ return json.loads(cleaned_response.text.strip("```json\n"))
80
+
81
+ def reset_safety_settings(self):
82
+ logging.info("Resetting safety settings")
83
+ self.is_safety_set = False
84
+ self.safety_settings = self.get_safety_settings()
85
+
86
+ def set_safety_settings(self, safety_settings):
87
+ self.safety_settings = safety_settings
88
+ # Sanity Checks
89
+ if not isinstance(safety_settings, dict):
90
+ raise ValueError("Safety settings must be a dictionary")
91
+ for harm_category, harm_block_threshold in safety_settings.items():
92
+ if harm_category not in genai.types.HarmCategory.__members__:
93
+ raise ValueError(f"Invalid harm category: {harm_category}")
94
+ if harm_block_threshold not in genai.types.HarmBlockThreshold.__members__:
95
+ raise ValueError(
96
+ f"Invalid harm block threshold: {harm_block_threshold}"
97
+ )
98
+
99
+ logging.info(f"Set safety settings: {safety_settings}")
100
+ self.safety_settings = safety_settings
101
+ self.is_safety_set = True
102
+
103
+ def get_safety_settings(self):
104
+ default_safety_settings = {
105
+ genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_NONE,
106
+ genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
107
+ genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_NONE,
108
+ genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
109
+ }
110
+
111
+ if self.is_safety_set:
112
+ return self.safety_settings
113
+
114
+ return default_safety_settings
115
+
116
+ @staticmethod
117
+ def load_api_key(path="./creds.json"):
118
+ with open(path) as f:
119
+ creds = json.load(f)
120
+
121
+ api_key = creds.get("google_api_key", None)
122
+ if api_key is None or not isinstance(api_key, str):
123
+ logging.error(f"Google API key not found in {path}")
124
+ raise ValueError(f"Gemini API key not found in {path}")
125
+ return api_key
126
+
127
+ @staticmethod
128
+ def get_model_name(model):
129
+ models = {
130
+ "flash": "models/gemini-1.5-flash-latest",
131
+ "pro": "models/gemini-1.5-pro-latest",
132
+ }
133
+
134
+ if model not in models:
135
+ logging.error(
136
+ f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
137
+ )
138
+ raise ValueError(
139
+ f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
140
+ )
141
+
142
+ logging.info(f"Selected model: {models[model]}")
143
+ return models[model]
144
+
145
+
146
+ if __name__ == "__main__":
147
+ video_path = "videos/3A49B385FD4A8FE2E3AEEF43C140D9AF_video_dashinit.mp4"
148
+ dv = DescribeVideo(model="flash")
149
+ video_description = dv.describe_video(video_path)
150
+ print(video_description)
logger.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ FORMAT = "%(asctime)s: %(levelname)s: %(message)s"
4
+ logging.basicConfig(filename='logs.log', level=logging.INFO, format=FORMAT)
5
+ STDERRLOGGER = logging.StreamHandler()
6
+ STDERRLOGGER.setFormatter(logging.Formatter(FORMAT))
7
+ logging.getLogger().addHandler(STDERRLOGGER)
main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def main():
4
+ st.set_page_config(page_title="VidTune: Where Videos Find Their Melody", layout="centered")
5
+
6
+ # Title and Description
7
+ st.title("VidTune: Where Videos Find Their Melody")
8
+ st.write("VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video.")
9
+
10
+ # Main Page (Page 1)
11
+ if 'page' not in st.session_state:
12
+ st.session_state.page = 'main'
13
+
14
+ if st.session_state.page == 'main':
15
+ st.header("Video to Music")
16
+ uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
17
+ if uploaded_video is not None:
18
+ st.session_state.uploaded_video = uploaded_video
19
+ st.session_state.page = 'video_to_music'
20
+
21
+ if st.session_state.page == 'main':
22
+ st.header("Prompt to Music")
23
+ prompt = st.text_area("Prompt")
24
+ if st.button("Generate"):
25
+ st.session_state.prompt = prompt
26
+ st.session_state.page = 'prompt_to_music'
27
+
28
+ # Page 2a (If the user uploads a video)
29
+ if st.session_state.page == 'video_to_music':
30
+ st.sidebar.title("Settings")
31
+ device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
32
+ num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
33
+
34
+ st.video(st.session_state.uploaded_video)
35
+
36
+ st.text_area("Video Description", "This is a fixed video description", disabled=True)
37
+ st.text_area("Music Description")
38
+
39
+ if st.button("Generate Music"):
40
+ st.session_state.page = 'result'
41
+ st.session_state.device = device
42
+ st.session_state.num_samples = num_samples
43
+
44
+ # Page 2b (If user selects "Prompt to Music" in Page 1)
45
+ if st.session_state.page == 'prompt_to_music':
46
+ st.sidebar.title("Settings")
47
+ device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
48
+ num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
49
+
50
+ if st.button("Generate Music"):
51
+ st.session_state.page = 'result'
52
+ st.session_state.device = device
53
+ st.session_state.num_samples = num_samples
54
+
55
+ # Page 3 (Results Page)
56
+ if st.session_state.page == 'result':
57
+ st.header("Generated Music")
58
+ for i in range(st.session_state.num_samples):
59
+ st.write(f"Music Sample {i+1}")
60
+ st.audio(f"Generated Music {i+1}.mp3", format='audio/mp3')
61
+ st.download_button(f"Download Music {i+1}", f"Generated Music {i+1}.mp3")
62
+
63
+ if st.button("Start Over"):
64
+ st.session_state.page = 'main'
65
+
66
+ if __name__ == "__main__":
67
+ main()
run_test.sh DELETED
@@ -1,28 +0,0 @@
1
- #!/bin/bash
2
-
3
- echo "Script started."
4
-
5
- # Run server
6
- echo "Starting server..."
7
- python server.py --duration 10 &
8
- echo "Server started."
9
-
10
- # Sleep
11
- echo "Waiting for the server to startup..."
12
- sleep 10
13
-
14
- # Run client
15
- echo "Starting client..."
16
- python client.py --server_url http://localhost:8000 --prompts "Lofi Music for Coding" --output_file output.wav
17
- echo "Client finished."
18
-
19
-
20
- # Kill server
21
- echo "Killing server..."
22
- kill $(ps aux | grep 'server.py' | awk '{print $2}')
23
-
24
-
25
- # Done
26
- sleep 5
27
- echo "Script finished."
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server.py DELETED
@@ -1,90 +0,0 @@
1
- import warnings
2
- import argparse
3
- from fastapi import FastAPI, HTTPException
4
- from pydantic import BaseModel
5
- from typing import List, Optional
6
- import torch
7
- from torch.cuda import memory_allocated, memory_reserved
8
- from audiocraft.models import musicgen
9
- import numpy as np
10
- import io
11
- from fastapi.responses import StreamingResponse, JSONResponse
12
- from scipy.io.wavfile import write as wav_write
13
- import uvicorn
14
- import psutil
15
-
16
- warnings.simplefilter('ignore')
17
-
18
- # Parse command line arguments
19
- parser = argparse.ArgumentParser(description="Music Generation Server")
20
- parser.add_argument("--model", type=str, default="musicgen-stereo-small", help="Pretrained model name")
21
- parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
22
- parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds")
23
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
24
- parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
25
-
26
- args = parser.parse_args()
27
-
28
- # Initialize the FastAPI app
29
- app = FastAPI()
30
-
31
- # Build the model name based on the provided arguments
32
- if args.model.startswith('facebook/'):
33
- args.model_name = args.model
34
- else:
35
- args.model_name = f"facebook/{args.model}"
36
-
37
- # Load the model with the provided arguments
38
- try:
39
- musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
40
- model_loaded = True
41
- except Exception as e:
42
- musicgen_model = None
43
- model_loaded = False
44
-
45
- class MusicRequest(BaseModel):
46
- prompts: List[str]
47
- duration: Optional[int] = 10 # Default duration is 10 seconds if not provided
48
-
49
- @app.post("/generate_music")
50
- def generate_music(request: MusicRequest):
51
- if not model_loaded:
52
- raise HTTPException(status_code=500, detail="Model is not loaded.")
53
-
54
- try:
55
- musicgen_model.set_generation_params(duration=request.duration)
56
- result = musicgen_model.generate(request.prompts, progress=False)
57
- result = result.squeeze().cpu().numpy().T
58
-
59
- sample_rate = musicgen_model.sample_rate
60
-
61
- buffer = io.BytesIO()
62
- wav_write(buffer, sample_rate, result)
63
-
64
- buffer.seek(0)
65
-
66
- return StreamingResponse(buffer, media_type="audio/wav")
67
- except Exception as e:
68
- raise HTTPException(status_code=500, detail=str(e))
69
-
70
- @app.get("/health")
71
- def health_check():
72
- cpu_usage = psutil.cpu_percent(interval=1)
73
- ram_usage = psutil.virtual_memory().percent
74
- stats = {
75
- "server_running": True,
76
- "model_loaded": model_loaded,
77
- "cpu_usage_percent": cpu_usage,
78
- "ram_usage_percent": ram_usage
79
- }
80
- if args.device == "cuda" and torch.cuda.is_available():
81
- gpu_memory_allocated = memory_allocated()
82
- gpu_memory_reserved = memory_reserved()
83
- stats.update({
84
- "gpu_memory_allocated": gpu_memory_allocated,
85
- "gpu_memory_reserved": gpu_memory_reserved
86
- })
87
- return JSONResponse(content=stats)
88
-
89
- if __name__ == "__main__":
90
- uvicorn.run(app, host=args.host, port=args.port)