Spaces:
Paused
Paused
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
import torch | |
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXPipeline | |
from diffusers.utils import export_to_video, load_image | |
import boto3 | |
from botocore.exceptions import NoCredentialsError | |
import tempfile | |
import os | |
import logging | |
from openai import OpenAI | |
import uuid | |
import openai | |
from pydantic import BaseModel | |
from dotenv import load_dotenv | |
import platform | |
app = FastAPI(docs_url="/docs") | |
# Initialize openai Client | |
client = OpenAI() | |
# Initialize S3 client | |
s3_client = boto3.client( | |
's3', | |
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), | |
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), | |
region_name=os.environ.get("AWS_REGION") | |
) | |
# Initialize OpenAI API key | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
S3_BUCKET = os.environ.get("S3_BUCKET") | |
S3_BASE_URL = f"https://{S3_BUCKET}.s3.{os.environ.get('AWS_REGION')}.amazonaws.com/" | |
# OpenAI TTS Models (voice names) | |
OPENAI_SPEAKERS = { | |
"Chris (Male)": "alloy", | |
"Ryan (Male)": "echo", | |
"Louis (Male)": "fable", | |
"Alex (Male)": "onyx", | |
"Sophia (Female)": "nova", | |
"Hannah (Female)": "shimmer", | |
} | |
# Request model for input | |
class AudioRequest(BaseModel): | |
text: str | |
speaker: str | |
# Function to generate audio using OpenAI and save to file | |
def generate_audio_using_openai(text, model, output_file: str): | |
response = client.audio.speech.create( | |
model="tts-1-hd", | |
voice=model, | |
input=text | |
) | |
response.stream_to_file(output_file) | |
# Function to upload file to S3 and get the file URL | |
def upload_audio_to_s3(file_path: str, bucket: str, object_name: str) -> str: | |
try: | |
s3_client.upload_file(file_path, bucket, object_name) | |
s3_url = f"https://{bucket}.s3.amazonaws.com/{object_name}" | |
return s3_url | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to upload to S3: {str(e)}") | |
# API endpoint to generate audio using OpenAI, save it to S3, and return the URL | |
async def generate_audio(request: AudioRequest): | |
if request.speaker not in OPENAI_SPEAKERS: | |
raise HTTPException(status_code=400, detail="Invalid speaker selection") | |
# Generate a unique file name | |
unique_filename = str(uuid.uuid4()) + ".wav" | |
local_file_path = f"/tmp/{unique_filename}" # Temporary location to save the file locally | |
# Generate the audio | |
try: | |
generate_audio_using_openai(request.text, OPENAI_SPEAKERS[request.speaker], output_file=local_file_path) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}") | |
# Upload the audio file to S3 | |
try: | |
s3_audio_url = upload_audio_to_s3(local_file_path, S3_BUCKET, unique_filename) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error uploading to S3: {str(e)}") | |
return {"audio_url": s3_audio_url} | |
# # Function to enhance prompt using OpenAI API (chat-based completion) | |
# def enrich_prompt_with_openai(prompt): | |
# try: | |
# response = openai.ChatCompletion.create( | |
# model="gpt-3.5-turbo", | |
# messages=[ | |
# {"role": "system", "content": "You are a creative assistant for generating video prompts. Please improve the following video generation prompt."}, | |
# {"role": "user", "content": f"Improve the following video generation prompt: {prompt}"} | |
# ] | |
# ) | |
# # Accessing the message content from the first choice | |
# enriched_prompt = response.choices[0].message['content'].strip() | |
# return enriched_prompt | |
# except Exception as e: | |
# logging.error(f"Error enriching prompt with OpenAI: {str(e)}") | |
# return prompt | |
# Function to enrich prompt using OpenAI chat-based completion | |
def enrich_prompt_with_openai(prompt): | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": "You are a creative and professional assistant for generating high quality video prompts. The generated video must be ultra realistic, high natural look, proper color balanced, without any single distortion, not any single blurry frames. Please improve the following video generation prompt according to data and add negative prompts as well to increase the video quality much more."}, | |
{ | |
"role": "user", | |
"content": f"Improve the following video generation prompt: {prompt}" | |
} | |
] | |
) | |
# Corrected way to access message content | |
enriched_prompt = response.choices[0].message.content.strip() | |
print(response.choices[0].message.content) # Prints the enriched prompt content | |
logging.info(response.choices[0].message.content) # Logs the enriched prompt content | |
return enriched_prompt | |
except Exception as e: | |
logging.error(f"Error enriching prompt with OpenAI: {str(e)}") | |
return prompt | |
# Define the video generation pipeline | |
def generate_video(prompt=None, image_path=None, num_videos=1, num_steps=50, frames=49, guidance=6): | |
if image_path: | |
# Use image-to-video pipeline | |
pipe = CogVideoXImageToVideoPipeline.from_pretrained( | |
"THUDM/CogVideoX-5b-I2V", | |
torch_dtype=torch.bfloat16 | |
) | |
else: | |
# Use text-to-video pipeline | |
pipe = CogVideoXPipeline.from_pretrained( | |
"THUDM/CogVideoX-5b", | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.enable_sequential_cpu_offload() | |
pipe.vae.enable_tiling() | |
pipe.vae.enable_slicing() | |
if image_path: | |
image = load_image(image=image_path) | |
video = pipe( | |
prompt=prompt, | |
image=image, | |
num_videos_per_prompt=num_videos, | |
num_inference_steps=num_steps, | |
num_frames=frames, | |
guidance_scale=guidance, | |
generator=torch.Generator(device="cuda").manual_seed(42) | |
).frames[0] | |
else: | |
video = pipe( | |
prompt=prompt, | |
num_videos_per_prompt=num_videos, | |
num_inference_steps=num_steps, | |
num_frames=frames, | |
guidance_scale=guidance, | |
generator=torch.Generator(device="cuda").manual_seed(42) | |
).frames[0] | |
return video | |
# Upload video to S3 | |
def upload_to_s3(file_path, file_name): | |
try: | |
s3_client.upload_file(file_path, S3_BUCKET, file_name) | |
return f"{S3_BASE_URL}{file_name}" | |
except FileNotFoundError: | |
raise Exception("The file was not found.") | |
except NoCredentialsError: | |
raise Exception("Credentials not available.") | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
async def generate_video_api(prompt: str = Form(...), num_steps: int = Form(50), num_frames: int = Form(49), guidance: float = Form(6.0), | |
image: UploadFile = File(None)): | |
tmp_img_path = None | |
logger.info("Starting video generation...") | |
if image: | |
logger.info("Received image file.") | |
# Save uploaded image temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_img: | |
tmp_img.write(image.file.read()) | |
tmp_img_path = tmp_img.name | |
logger.info(f"Image saved to {tmp_img_path}") | |
# Enhance the prompt using OpenAI | |
try: | |
enriched_prompt = enrich_prompt_with_openai(prompt) | |
logger.info(f"Prompt enriched with OpenAI: {enriched_prompt}") | |
except Exception as e: | |
logger.error(f"Error with OpenAI prompt enhancement: {e}") | |
enriched_prompt = prompt # Fallback to original prompt | |
# Generate the video | |
try: | |
video_frames = generate_video(prompt=enriched_prompt, image_path=tmp_img_path, num_steps=num_steps, frames=num_frames, guidance=guidance) | |
except Exception as e: | |
logger.error(f"Error generating video: {str(e)}") | |
raise HTTPException(status_code=500, detail="Video generation failed") | |
# Save the generated video to a temporary file | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video: | |
output_path = tmp_video.name | |
export_to_video(video_frames, output_path, fps=8) | |
logger.info(f"Video saved to {output_path}") | |
except Exception as e: | |
logger.error(f"Error saving video: {str(e)}") | |
raise HTTPException(status_code=500, detail="Failed to save video") | |
# Upload to AWS S3 | |
try: | |
video_name = f"generated_video_{enriched_prompt.replace(' ', '_')}.mp4" | |
s3_video_url = upload_to_s3(output_path, video_name) | |
logger.info(f"Video uploaded to S3 at {s3_video_url}") | |
except Exception as e: | |
logger.error(f"Error uploading to S3: {str(e)}") | |
raise HTTPException(status_code=500, detail="Failed to upload video to S3") | |
# Clean up temporary files | |
if tmp_img_path: | |
os.remove(tmp_img_path) | |
os.remove(output_path) | |
return {"video_url": s3_video_url} | |
def greet_json(): | |
python_version = platform.python_version() | |
return {"Hello": "World!", "Python Version": python_version} | |