sutra-avatar-v2 / cloud_task_executor.py
mikesapi's picture
initial commit of sutra-avatar-v2
ec17e66
import base64
import json
import ntpath
import os
import time
import gradio as gr
import requests
from google.cloud import storage
from base_task_executor import BaseTaskExecutor
# ---
enc = "utf-8"
def decode(string):
return json.loads(base64.b64decode(string.encode(enc)).decode(enc))
def get_storage_client_from_env():
credentials_json = decode(os.environ["GCP_API_KEY"])
return storage.Client.from_service_account_info(credentials_json)
def get_name_ext(filepath):
filepath = os.path.abspath(filepath)
_, name_ext = os.path.split(filepath)
name, ext = os.path.splitext(name_ext)
return name, ext
def make_remote_media_path(request_id, media_path):
assert len(request_id) > 6
assert os.path.exists(media_path)
src_id = request_id[:3]
slot_id = request_id[3:6]
request_suffix = request_id[6:]
name, ext = get_name_ext(media_path)
return os.path.join(src_id, slot_id, request_suffix, name + ext)
def copy_file_to_gcloud(bucket, local_file_path, remote_file_path):
blob = bucket.blob(remote_file_path)
blob.upload_from_filename(local_file_path)
def copy_to_gcloud(storage_client, local_media_path, bucket_name, remote_media_path):
bucket = storage_client.get_bucket(bucket_name)
copy_file_to_gcloud(bucket, local_media_path, remote_media_path)
# ---
class CloudTaskExecutor(BaseTaskExecutor):
def __init__(self):
super().__init__()
self.base_url = os.getenv("SUTRA_AVATAR_BASE_URL")
self.headers = {"Authorization": f'{os.getenv("SUTRA_AVATAR_API_KEY")}', "Content-Type": "application/json"}
self.bucket_name = os.getenv("SUTRA_AVATAR_BUCKET_NAME")
self.storage_client = get_storage_client_from_env()
def submit_task(self, submit_request):
url = f"{self.base_url}/task/submit"
response = requests.post(url, json=submit_request, headers=self.headers)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
def get_task_status(self, request_id):
url = f"{self.base_url}/task/status"
response = requests.get(url, params={"rid": request_id}, headers=self.headers)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
def generate(
self,
input_base_path,
input_driving_path,
base_motion_expression,
input_driving_audio_path,
output_video_path,
request_id,
):
# Upload files
media_paths = [input_base_path, input_driving_audio_path]
for media_path in media_paths:
if media_path:
remote_media_path = make_remote_media_path(request_id, media_path)
copy_to_gcloud(self.storage_client, media_path, self.bucket_name, remote_media_path)
submit_request = {
"requestId": request_id,
"input_base_path": ntpath.basename(input_base_path),
"input_driving_path": "",
"base_motion_expression": base_motion_expression,
"input_driving_audio_path": ntpath.basename(input_driving_audio_path),
"output_video_path": ntpath.basename(output_video_path),
}
submit_reply = self.submit_task(submit_request)
estimatedWaitSeconds = "unknown"
if "estimatedWaitSeconds" in submit_reply.keys():
estimatedWaitSeconds = submit_reply["estimatedWaitSeconds"]
completion_statuses = {"Succeeded", "Cancelled", "Failed", "NotFound"}
timeout = 240 # maximum time to wait in seconds
if isinstance(estimatedWaitSeconds, int):
timeout += estimatedWaitSeconds
start_time = time.time()
result = {"messages": ''}
while True:
status_reply = self.get_task_status(request_id)
task_status = status_reply["taskStatus"]
if status_reply["taskStatus"] in completion_statuses:
break
if time.time() - start_time > timeout:
msg = "The task did not complete within the timeout period.\n The server is very busy serving other requests.\n Please try again."
result["success"] = False
result["messages"] = msg
gr.Error(msg)
break
time.sleep(3)
task_status = status_reply["taskStatus"]
if task_status == "Succeeded":
pipe_reply = status_reply["pipeReply"]
result["success"] = pipe_reply["status"] == "success"
result["messages"] = pipe_reply["messages"]
output_video_path = status_reply["videoURL"]
else:
messages = ""
if "pipeReply" in status_reply.keys():
messages = status_reply["pipeReply"]["messages"]
result["success"] = False
result["messages"] += messages
return result, output_video_path