neuralleap's picture
Update app.py
529778d verified
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
import torch
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXPipeline
from diffusers.utils import export_to_video, load_image
import boto3
from botocore.exceptions import NoCredentialsError
import tempfile
import os
import logging
from openai import OpenAI
import uuid
import openai
from pydantic import BaseModel
from dotenv import load_dotenv
import platform
app = FastAPI(docs_url="/docs")
# Initialize openai Client
client = OpenAI()
# Initialize S3 client
s3_client = boto3.client(
's3',
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
region_name=os.environ.get("AWS_REGION")
)
# Initialize OpenAI API key
openai.api_key = os.environ.get("OPENAI_API_KEY")
S3_BUCKET = os.environ.get("S3_BUCKET")
S3_BASE_URL = f"https://{S3_BUCKET}.s3.{os.environ.get('AWS_REGION')}.amazonaws.com/"
# OpenAI TTS Models (voice names)
OPENAI_SPEAKERS = {
"Chris (Male)": "alloy",
"Ryan (Male)": "echo",
"Louis (Male)": "fable",
"Alex (Male)": "onyx",
"Sophia (Female)": "nova",
"Hannah (Female)": "shimmer",
}
# Request model for input
class AudioRequest(BaseModel):
text: str
speaker: str
# Function to generate audio using OpenAI and save to file
def generate_audio_using_openai(text, model, output_file: str):
response = client.audio.speech.create(
model="tts-1-hd",
voice=model,
input=text
)
response.stream_to_file(output_file)
# Function to upload file to S3 and get the file URL
def upload_audio_to_s3(file_path: str, bucket: str, object_name: str) -> str:
try:
s3_client.upload_file(file_path, bucket, object_name)
s3_url = f"https://{bucket}.s3.amazonaws.com/{object_name}"
return s3_url
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to upload to S3: {str(e)}")
# API endpoint to generate audio using OpenAI, save it to S3, and return the URL
@app.post("/generate-audio-openai")
async def generate_audio(request: AudioRequest):
if request.speaker not in OPENAI_SPEAKERS:
raise HTTPException(status_code=400, detail="Invalid speaker selection")
# Generate a unique file name
unique_filename = str(uuid.uuid4()) + ".wav"
local_file_path = f"/tmp/{unique_filename}" # Temporary location to save the file locally
# Generate the audio
try:
generate_audio_using_openai(request.text, OPENAI_SPEAKERS[request.speaker], output_file=local_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
# Upload the audio file to S3
try:
s3_audio_url = upload_audio_to_s3(local_file_path, S3_BUCKET, unique_filename)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error uploading to S3: {str(e)}")
return {"audio_url": s3_audio_url}
# # Function to enhance prompt using OpenAI API (chat-based completion)
# def enrich_prompt_with_openai(prompt):
# try:
# response = openai.ChatCompletion.create(
# model="gpt-3.5-turbo",
# messages=[
# {"role": "system", "content": "You are a creative assistant for generating video prompts. Please improve the following video generation prompt."},
# {"role": "user", "content": f"Improve the following video generation prompt: {prompt}"}
# ]
# )
# # Accessing the message content from the first choice
# enriched_prompt = response.choices[0].message['content'].strip()
# return enriched_prompt
# except Exception as e:
# logging.error(f"Error enriching prompt with OpenAI: {str(e)}")
# return prompt
# Function to enrich prompt using OpenAI chat-based completion
def enrich_prompt_with_openai(prompt):
try:
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a creative and professional assistant for generating high quality video prompts. The generated video must be ultra realistic, high natural look, proper color balanced, without any single distortion, not any single blurry frames. Please improve the following video generation prompt according to data and add negative prompts as well to increase the video quality much more."},
{
"role": "user",
"content": f"Improve the following video generation prompt: {prompt}"
}
]
)
# Corrected way to access message content
enriched_prompt = response.choices[0].message.content.strip()
print(response.choices[0].message.content) # Prints the enriched prompt content
logging.info(response.choices[0].message.content) # Logs the enriched prompt content
return enriched_prompt
except Exception as e:
logging.error(f"Error enriching prompt with OpenAI: {str(e)}")
return prompt
# Define the video generation pipeline
def generate_video(prompt=None, image_path=None, num_videos=1, num_steps=50, frames=49, guidance=6):
if image_path:
# Use image-to-video pipeline
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
"THUDM/CogVideoX-5b-I2V",
torch_dtype=torch.bfloat16
)
else:
# Use text-to-video pipeline
pipe = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX-5b",
torch_dtype=torch.bfloat16
)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
if image_path:
image = load_image(image=image_path)
video = pipe(
prompt=prompt,
image=image,
num_videos_per_prompt=num_videos,
num_inference_steps=num_steps,
num_frames=frames,
guidance_scale=guidance,
generator=torch.Generator(device="cuda").manual_seed(42)
).frames[0]
else:
video = pipe(
prompt=prompt,
num_videos_per_prompt=num_videos,
num_inference_steps=num_steps,
num_frames=frames,
guidance_scale=guidance,
generator=torch.Generator(device="cuda").manual_seed(42)
).frames[0]
return video
# Upload video to S3
def upload_to_s3(file_path, file_name):
try:
s3_client.upload_file(file_path, S3_BUCKET, file_name)
return f"{S3_BASE_URL}{file_name}"
except FileNotFoundError:
raise Exception("The file was not found.")
except NoCredentialsError:
raise Exception("Credentials not available.")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.post("/")
async def generate_video_api(prompt: str = Form(...), num_steps: int = Form(50), num_frames: int = Form(49), guidance: float = Form(6.0),
image: UploadFile = File(None)):
tmp_img_path = None
logger.info("Starting video generation...")
if image:
logger.info("Received image file.")
# Save uploaded image temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_img:
tmp_img.write(image.file.read())
tmp_img_path = tmp_img.name
logger.info(f"Image saved to {tmp_img_path}")
# Enhance the prompt using OpenAI
try:
enriched_prompt = enrich_prompt_with_openai(prompt)
logger.info(f"Prompt enriched with OpenAI: {enriched_prompt}")
except Exception as e:
logger.error(f"Error with OpenAI prompt enhancement: {e}")
enriched_prompt = prompt # Fallback to original prompt
# Generate the video
try:
video_frames = generate_video(prompt=enriched_prompt, image_path=tmp_img_path, num_steps=num_steps, frames=num_frames, guidance=guidance)
except Exception as e:
logger.error(f"Error generating video: {str(e)}")
raise HTTPException(status_code=500, detail="Video generation failed")
# Save the generated video to a temporary file
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video:
output_path = tmp_video.name
export_to_video(video_frames, output_path, fps=8)
logger.info(f"Video saved to {output_path}")
except Exception as e:
logger.error(f"Error saving video: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to save video")
# Upload to AWS S3
try:
video_name = f"generated_video_{enriched_prompt.replace(' ', '_')}.mp4"
s3_video_url = upload_to_s3(output_path, video_name)
logger.info(f"Video uploaded to S3 at {s3_video_url}")
except Exception as e:
logger.error(f"Error uploading to S3: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to upload video to S3")
# Clean up temporary files
if tmp_img_path:
os.remove(tmp_img_path)
os.remove(output_path)
return {"video_url": s3_video_url}
@app.get("/")
def greet_json():
python_version = platform.python_version()
return {"Hello": "World!", "Python Version": python_version}