Spaces:
Running
on
Zero
Running
on
Zero
animikhaich
commited on
Commit
•
5978ae3
1
Parent(s):
d8d2011
Added: Video Descriptor and UI Boilerplate
Browse files- .gitignore +4 -1
- .streamlit/config.toml +2 -0
- client.py +0 -65
- def __init__.py +1 -0
- engine/__init__.py +1 -0
- engine/audio_generator.py +1 -0
- engine/video_descriptor.py +150 -0
- logger.py +7 -0
- main.py +67 -0
- run_test.sh +0 -28
- server.py +0 -90
.gitignore
CHANGED
@@ -164,4 +164,7 @@ cython_debug/
|
|
164 |
|
165 |
|
166 |
*.wav
|
167 |
-
*.mp3
|
|
|
|
|
|
|
|
164 |
|
165 |
|
166 |
*.wav
|
167 |
+
*.mp3
|
168 |
+
*.mp4
|
169 |
+
|
170 |
+
creds.json
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[browser]
|
2 |
+
gatherUsageStats = false
|
client.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
import requests
|
2 |
-
import argparse
|
3 |
-
|
4 |
-
# Parse command line arguments
|
5 |
-
parser = argparse.ArgumentParser(description="Music Generation Client")
|
6 |
-
parser.add_argument(
|
7 |
-
"--server_url", type=str, default="http://localhost:8000", help="URL of the server"
|
8 |
-
)
|
9 |
-
parser.add_argument(
|
10 |
-
"--prompts",
|
11 |
-
nargs="+",
|
12 |
-
type=str,
|
13 |
-
default=["Lofi Music for Coding"],
|
14 |
-
help="Prompts for music generation",
|
15 |
-
)
|
16 |
-
parser.add_argument(
|
17 |
-
"--output_file", type=str, default="output.wav", help="Output file name"
|
18 |
-
)
|
19 |
-
parser.add_argument(
|
20 |
-
"--duration", type=int, default=10, help="Duration of generated music in seconds"
|
21 |
-
)
|
22 |
-
parser.add_argument(
|
23 |
-
"--check_health", action='store_true', help="Check server health"
|
24 |
-
)
|
25 |
-
|
26 |
-
args = parser.parse_args()
|
27 |
-
|
28 |
-
def generate_music(server_url, prompts, duration, output_file):
|
29 |
-
url = f"{server_url}/generate_music"
|
30 |
-
headers = {"Content-Type": "application/json"}
|
31 |
-
data = {"prompts": prompts, "duration": duration}
|
32 |
-
|
33 |
-
response = requests.post(url, json=data, headers=headers)
|
34 |
-
|
35 |
-
if response.status_code == 200:
|
36 |
-
with open(output_file, "wb") as f:
|
37 |
-
f.write(response.content)
|
38 |
-
print(f"Music saved to {output_file}")
|
39 |
-
else:
|
40 |
-
print(f"Failed to generate music: {response.status_code}, {response.text}")
|
41 |
-
|
42 |
-
def check_server_health(server_url):
|
43 |
-
url = f"{server_url}/health"
|
44 |
-
response = requests.get(url)
|
45 |
-
|
46 |
-
if response.status_code == 200:
|
47 |
-
health_status = response.json()
|
48 |
-
print("Server Health Check:")
|
49 |
-
print(f"Server Running: {health_status['server_running']}")
|
50 |
-
print(f"Model Loaded: {health_status['model_loaded']}")
|
51 |
-
print(f"CPU Usage: {health_status['cpu_usage_percent']}%")
|
52 |
-
print(f"RAM Usage: {health_status['ram_usage_percent']}%")
|
53 |
-
if 'gpu_memory_allocated' in health_status:
|
54 |
-
gpu_memory_allocated_gb = health_status['gpu_memory_allocated'] / (1024 ** 3)
|
55 |
-
gpu_memory_reserved_gb = health_status['gpu_memory_reserved'] / (1024 ** 3)
|
56 |
-
print(f"GPU Memory Allocated: {gpu_memory_allocated_gb:.2f} GB")
|
57 |
-
print(f"GPU Memory Reserved: {gpu_memory_reserved_gb:.2f} GB")
|
58 |
-
else:
|
59 |
-
print(f"Failed to check server health: {response.status_code}, {response.text}")
|
60 |
-
|
61 |
-
if __name__ == "__main__":
|
62 |
-
if args.check_health:
|
63 |
-
check_server_health(args.server_url)
|
64 |
-
else:
|
65 |
-
generate_music(args.server_url, args.prompts, args.duration, args.output_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .logger import logging
|
engine/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .video_descriptor import DescribeVideo
|
engine/audio_generator.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# TODO: Add from model server
|
engine/video_descriptor.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from warnings import simplefilter
|
2 |
+
|
3 |
+
simplefilter("ignore")
|
4 |
+
import os
|
5 |
+
|
6 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
7 |
+
import json
|
8 |
+
import time
|
9 |
+
import google.generativeai as genai
|
10 |
+
|
11 |
+
try:
|
12 |
+
from logger import logging
|
13 |
+
except:
|
14 |
+
import logging
|
15 |
+
|
16 |
+
music_prompt_examples = """
|
17 |
+
'A dynamic blend of hip-hop and orchestral elements, with sweeping strings and brass, evoking the vibrant energy of the city',
|
18 |
+
'Smooth jazz, with a saxophone solo, piano chords, and snare full drums',
|
19 |
+
'90s rock song with electric guitar and heavy drums'.
|
20 |
+
"""
|
21 |
+
|
22 |
+
json_schema = """
|
23 |
+
{"Content Description": "string", "Music Prompt": "string"}
|
24 |
+
"""
|
25 |
+
|
26 |
+
gemni_instructions = f"""
|
27 |
+
You are a music supervisor who analyzes the content and tone of images and videos to describe music that fits well with the mood, evokes emotions, and enhances the narrative of the visuals. Given an image or video, describe the scene and generate a prompt suitable for music generation models. Use keywords related to genre, instruments, mood, context, and setting to craft a concise single-sentence prompt, like:
|
28 |
+
|
29 |
+
{music_prompt_examples}
|
30 |
+
|
31 |
+
You must return your response using this JSON schema: {json_schema}
|
32 |
+
"""
|
33 |
+
|
34 |
+
|
35 |
+
class DescribeVideo:
|
36 |
+
def __init__(self, model="flash"):
|
37 |
+
self.model = self.get_model_name(model)
|
38 |
+
__api_key = self.load_api_key()
|
39 |
+
self.is_safety_set = False
|
40 |
+
self.safety_settings = self.get_safety_settings()
|
41 |
+
|
42 |
+
genai.configure(api_key=__api_key)
|
43 |
+
self.mllm_model = genai.GenerativeModel(self.model)
|
44 |
+
|
45 |
+
logging.info(f"Initialized DescribeVideo with model: {self.model}")
|
46 |
+
|
47 |
+
def describe_video(self, video_path):
|
48 |
+
video_file = genai.upload_file(video_path)
|
49 |
+
logging.info(f"Uploaded video: {video_path}")
|
50 |
+
|
51 |
+
while video_file.state.name == "PROCESSING":
|
52 |
+
time.sleep(0.25)
|
53 |
+
video_file = genai.get_file(video_file.name)
|
54 |
+
|
55 |
+
if video_file.state.name == "FAILED":
|
56 |
+
logging.error(f"Failed to upload video: {video_file.state.name}")
|
57 |
+
raise ValueError(f"Failed to upload video: {video_file.state.name}")
|
58 |
+
|
59 |
+
response = self.mllm_model.generate_content(
|
60 |
+
[video_file, "Explain what is happening in this video"],
|
61 |
+
request_options={"timeout": 600},
|
62 |
+
safety_settings=self.safety_settings,
|
63 |
+
)
|
64 |
+
|
65 |
+
logging.info(
|
66 |
+
f"Generated content for video: {video_path} with response: {response.text}"
|
67 |
+
)
|
68 |
+
|
69 |
+
cleaned_response = self.mllm_model.generate_content(
|
70 |
+
[
|
71 |
+
response.text,
|
72 |
+
gemni_instructions,
|
73 |
+
],
|
74 |
+
safety_settings=self.safety_settings,
|
75 |
+
)
|
76 |
+
|
77 |
+
logging.info(f"Generated : {video_path} with response: {cleaned_response.text}")
|
78 |
+
|
79 |
+
return json.loads(cleaned_response.text.strip("```json\n"))
|
80 |
+
|
81 |
+
def reset_safety_settings(self):
|
82 |
+
logging.info("Resetting safety settings")
|
83 |
+
self.is_safety_set = False
|
84 |
+
self.safety_settings = self.get_safety_settings()
|
85 |
+
|
86 |
+
def set_safety_settings(self, safety_settings):
|
87 |
+
self.safety_settings = safety_settings
|
88 |
+
# Sanity Checks
|
89 |
+
if not isinstance(safety_settings, dict):
|
90 |
+
raise ValueError("Safety settings must be a dictionary")
|
91 |
+
for harm_category, harm_block_threshold in safety_settings.items():
|
92 |
+
if harm_category not in genai.types.HarmCategory.__members__:
|
93 |
+
raise ValueError(f"Invalid harm category: {harm_category}")
|
94 |
+
if harm_block_threshold not in genai.types.HarmBlockThreshold.__members__:
|
95 |
+
raise ValueError(
|
96 |
+
f"Invalid harm block threshold: {harm_block_threshold}"
|
97 |
+
)
|
98 |
+
|
99 |
+
logging.info(f"Set safety settings: {safety_settings}")
|
100 |
+
self.safety_settings = safety_settings
|
101 |
+
self.is_safety_set = True
|
102 |
+
|
103 |
+
def get_safety_settings(self):
|
104 |
+
default_safety_settings = {
|
105 |
+
genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_NONE,
|
106 |
+
genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
|
107 |
+
genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_NONE,
|
108 |
+
genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
|
109 |
+
}
|
110 |
+
|
111 |
+
if self.is_safety_set:
|
112 |
+
return self.safety_settings
|
113 |
+
|
114 |
+
return default_safety_settings
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def load_api_key(path="./creds.json"):
|
118 |
+
with open(path) as f:
|
119 |
+
creds = json.load(f)
|
120 |
+
|
121 |
+
api_key = creds.get("google_api_key", None)
|
122 |
+
if api_key is None or not isinstance(api_key, str):
|
123 |
+
logging.error(f"Google API key not found in {path}")
|
124 |
+
raise ValueError(f"Gemini API key not found in {path}")
|
125 |
+
return api_key
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def get_model_name(model):
|
129 |
+
models = {
|
130 |
+
"flash": "models/gemini-1.5-flash-latest",
|
131 |
+
"pro": "models/gemini-1.5-pro-latest",
|
132 |
+
}
|
133 |
+
|
134 |
+
if model not in models:
|
135 |
+
logging.error(
|
136 |
+
f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
|
137 |
+
)
|
138 |
+
raise ValueError(
|
139 |
+
f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
|
140 |
+
)
|
141 |
+
|
142 |
+
logging.info(f"Selected model: {models[model]}")
|
143 |
+
return models[model]
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
video_path = "videos/3A49B385FD4A8FE2E3AEEF43C140D9AF_video_dashinit.mp4"
|
148 |
+
dv = DescribeVideo(model="flash")
|
149 |
+
video_description = dv.describe_video(video_path)
|
150 |
+
print(video_description)
|
logger.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
FORMAT = "%(asctime)s: %(levelname)s: %(message)s"
|
4 |
+
logging.basicConfig(filename='logs.log', level=logging.INFO, format=FORMAT)
|
5 |
+
STDERRLOGGER = logging.StreamHandler()
|
6 |
+
STDERRLOGGER.setFormatter(logging.Formatter(FORMAT))
|
7 |
+
logging.getLogger().addHandler(STDERRLOGGER)
|
main.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def main():
|
4 |
+
st.set_page_config(page_title="VidTune: Where Videos Find Their Melody", layout="centered")
|
5 |
+
|
6 |
+
# Title and Description
|
7 |
+
st.title("VidTune: Where Videos Find Their Melody")
|
8 |
+
st.write("VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video.")
|
9 |
+
|
10 |
+
# Main Page (Page 1)
|
11 |
+
if 'page' not in st.session_state:
|
12 |
+
st.session_state.page = 'main'
|
13 |
+
|
14 |
+
if st.session_state.page == 'main':
|
15 |
+
st.header("Video to Music")
|
16 |
+
uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
|
17 |
+
if uploaded_video is not None:
|
18 |
+
st.session_state.uploaded_video = uploaded_video
|
19 |
+
st.session_state.page = 'video_to_music'
|
20 |
+
|
21 |
+
if st.session_state.page == 'main':
|
22 |
+
st.header("Prompt to Music")
|
23 |
+
prompt = st.text_area("Prompt")
|
24 |
+
if st.button("Generate"):
|
25 |
+
st.session_state.prompt = prompt
|
26 |
+
st.session_state.page = 'prompt_to_music'
|
27 |
+
|
28 |
+
# Page 2a (If the user uploads a video)
|
29 |
+
if st.session_state.page == 'video_to_music':
|
30 |
+
st.sidebar.title("Settings")
|
31 |
+
device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
|
32 |
+
num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
|
33 |
+
|
34 |
+
st.video(st.session_state.uploaded_video)
|
35 |
+
|
36 |
+
st.text_area("Video Description", "This is a fixed video description", disabled=True)
|
37 |
+
st.text_area("Music Description")
|
38 |
+
|
39 |
+
if st.button("Generate Music"):
|
40 |
+
st.session_state.page = 'result'
|
41 |
+
st.session_state.device = device
|
42 |
+
st.session_state.num_samples = num_samples
|
43 |
+
|
44 |
+
# Page 2b (If user selects "Prompt to Music" in Page 1)
|
45 |
+
if st.session_state.page == 'prompt_to_music':
|
46 |
+
st.sidebar.title("Settings")
|
47 |
+
device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
|
48 |
+
num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
|
49 |
+
|
50 |
+
if st.button("Generate Music"):
|
51 |
+
st.session_state.page = 'result'
|
52 |
+
st.session_state.device = device
|
53 |
+
st.session_state.num_samples = num_samples
|
54 |
+
|
55 |
+
# Page 3 (Results Page)
|
56 |
+
if st.session_state.page == 'result':
|
57 |
+
st.header("Generated Music")
|
58 |
+
for i in range(st.session_state.num_samples):
|
59 |
+
st.write(f"Music Sample {i+1}")
|
60 |
+
st.audio(f"Generated Music {i+1}.mp3", format='audio/mp3')
|
61 |
+
st.download_button(f"Download Music {i+1}", f"Generated Music {i+1}.mp3")
|
62 |
+
|
63 |
+
if st.button("Start Over"):
|
64 |
+
st.session_state.page = 'main'
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
main()
|
run_test.sh
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
echo "Script started."
|
4 |
-
|
5 |
-
# Run server
|
6 |
-
echo "Starting server..."
|
7 |
-
python server.py --duration 10 &
|
8 |
-
echo "Server started."
|
9 |
-
|
10 |
-
# Sleep
|
11 |
-
echo "Waiting for the server to startup..."
|
12 |
-
sleep 10
|
13 |
-
|
14 |
-
# Run client
|
15 |
-
echo "Starting client..."
|
16 |
-
python client.py --server_url http://localhost:8000 --prompts "Lofi Music for Coding" --output_file output.wav
|
17 |
-
echo "Client finished."
|
18 |
-
|
19 |
-
|
20 |
-
# Kill server
|
21 |
-
echo "Killing server..."
|
22 |
-
kill $(ps aux | grep 'server.py' | awk '{print $2}')
|
23 |
-
|
24 |
-
|
25 |
-
# Done
|
26 |
-
sleep 5
|
27 |
-
echo "Script finished."
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import warnings
|
2 |
-
import argparse
|
3 |
-
from fastapi import FastAPI, HTTPException
|
4 |
-
from pydantic import BaseModel
|
5 |
-
from typing import List, Optional
|
6 |
-
import torch
|
7 |
-
from torch.cuda import memory_allocated, memory_reserved
|
8 |
-
from audiocraft.models import musicgen
|
9 |
-
import numpy as np
|
10 |
-
import io
|
11 |
-
from fastapi.responses import StreamingResponse, JSONResponse
|
12 |
-
from scipy.io.wavfile import write as wav_write
|
13 |
-
import uvicorn
|
14 |
-
import psutil
|
15 |
-
|
16 |
-
warnings.simplefilter('ignore')
|
17 |
-
|
18 |
-
# Parse command line arguments
|
19 |
-
parser = argparse.ArgumentParser(description="Music Generation Server")
|
20 |
-
parser.add_argument("--model", type=str, default="musicgen-stereo-small", help="Pretrained model name")
|
21 |
-
parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
|
22 |
-
parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds")
|
23 |
-
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
24 |
-
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
25 |
-
|
26 |
-
args = parser.parse_args()
|
27 |
-
|
28 |
-
# Initialize the FastAPI app
|
29 |
-
app = FastAPI()
|
30 |
-
|
31 |
-
# Build the model name based on the provided arguments
|
32 |
-
if args.model.startswith('facebook/'):
|
33 |
-
args.model_name = args.model
|
34 |
-
else:
|
35 |
-
args.model_name = f"facebook/{args.model}"
|
36 |
-
|
37 |
-
# Load the model with the provided arguments
|
38 |
-
try:
|
39 |
-
musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
|
40 |
-
model_loaded = True
|
41 |
-
except Exception as e:
|
42 |
-
musicgen_model = None
|
43 |
-
model_loaded = False
|
44 |
-
|
45 |
-
class MusicRequest(BaseModel):
|
46 |
-
prompts: List[str]
|
47 |
-
duration: Optional[int] = 10 # Default duration is 10 seconds if not provided
|
48 |
-
|
49 |
-
@app.post("/generate_music")
|
50 |
-
def generate_music(request: MusicRequest):
|
51 |
-
if not model_loaded:
|
52 |
-
raise HTTPException(status_code=500, detail="Model is not loaded.")
|
53 |
-
|
54 |
-
try:
|
55 |
-
musicgen_model.set_generation_params(duration=request.duration)
|
56 |
-
result = musicgen_model.generate(request.prompts, progress=False)
|
57 |
-
result = result.squeeze().cpu().numpy().T
|
58 |
-
|
59 |
-
sample_rate = musicgen_model.sample_rate
|
60 |
-
|
61 |
-
buffer = io.BytesIO()
|
62 |
-
wav_write(buffer, sample_rate, result)
|
63 |
-
|
64 |
-
buffer.seek(0)
|
65 |
-
|
66 |
-
return StreamingResponse(buffer, media_type="audio/wav")
|
67 |
-
except Exception as e:
|
68 |
-
raise HTTPException(status_code=500, detail=str(e))
|
69 |
-
|
70 |
-
@app.get("/health")
|
71 |
-
def health_check():
|
72 |
-
cpu_usage = psutil.cpu_percent(interval=1)
|
73 |
-
ram_usage = psutil.virtual_memory().percent
|
74 |
-
stats = {
|
75 |
-
"server_running": True,
|
76 |
-
"model_loaded": model_loaded,
|
77 |
-
"cpu_usage_percent": cpu_usage,
|
78 |
-
"ram_usage_percent": ram_usage
|
79 |
-
}
|
80 |
-
if args.device == "cuda" and torch.cuda.is_available():
|
81 |
-
gpu_memory_allocated = memory_allocated()
|
82 |
-
gpu_memory_reserved = memory_reserved()
|
83 |
-
stats.update({
|
84 |
-
"gpu_memory_allocated": gpu_memory_allocated,
|
85 |
-
"gpu_memory_reserved": gpu_memory_reserved
|
86 |
-
})
|
87 |
-
return JSONResponse(content=stats)
|
88 |
-
|
89 |
-
if __name__ == "__main__":
|
90 |
-
uvicorn.run(app, host=args.host, port=args.port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|