Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import base64 | |
import json | |
import ntpath | |
import os | |
import time | |
import gradio as gr | |
import requests | |
from google.cloud import storage | |
from base_task_executor import BaseTaskExecutor | |
# --- | |
enc = "utf-8" | |
def decode(string): | |
return json.loads(base64.b64decode(string.encode(enc)).decode(enc)) | |
def get_storage_client_from_env(): | |
credentials_json = decode(os.environ["GCP_API_KEY"]) | |
return storage.Client.from_service_account_info(credentials_json) | |
def get_name_ext(filepath): | |
filepath = os.path.abspath(filepath) | |
_, name_ext = os.path.split(filepath) | |
name, ext = os.path.splitext(name_ext) | |
return name, ext | |
def make_remote_media_path(request_id, media_path): | |
assert len(request_id) > 6 | |
assert os.path.exists(media_path) | |
src_id = request_id[:3] | |
slot_id = request_id[3:6] | |
request_suffix = request_id[6:] | |
name, ext = get_name_ext(media_path) | |
return os.path.join(src_id, slot_id, request_suffix, name + ext) | |
def copy_file_to_gcloud(bucket, local_file_path, remote_file_path): | |
blob = bucket.blob(remote_file_path) | |
blob.upload_from_filename(local_file_path) | |
def copy_to_gcloud(storage_client, local_media_path, bucket_name, remote_media_path): | |
bucket = storage_client.get_bucket(bucket_name) | |
copy_file_to_gcloud(bucket, local_media_path, remote_media_path) | |
# --- | |
class CloudTaskExecutor(BaseTaskExecutor): | |
def __init__(self): | |
super().__init__() | |
self.base_url = os.getenv("SUTRA_AVATAR_BASE_URL") | |
self.headers = {"Authorization": f'{os.getenv("SUTRA_AVATAR_API_KEY")}', "Content-Type": "application/json"} | |
self.bucket_name = os.getenv("SUTRA_AVATAR_BUCKET_NAME") | |
self.storage_client = get_storage_client_from_env() | |
def submit_task(self, submit_request): | |
url = f"{self.base_url}/task/submit" | |
response = requests.post(url, json=submit_request, headers=self.headers) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
response.raise_for_status() | |
def get_task_status(self, request_id): | |
url = f"{self.base_url}/task/status" | |
response = requests.get(url, params={"rid": request_id}, headers=self.headers) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
response.raise_for_status() | |
def generate( | |
self, | |
input_base_path, | |
input_driving_path, | |
base_motion_expression, | |
input_driving_audio_path, | |
output_video_path, | |
request_id, | |
): | |
# Upload files | |
media_paths = [input_base_path, input_driving_audio_path] | |
for media_path in media_paths: | |
if media_path: | |
remote_media_path = make_remote_media_path(request_id, media_path) | |
copy_to_gcloud(self.storage_client, media_path, self.bucket_name, remote_media_path) | |
submit_request = { | |
"requestId": request_id, | |
"input_base_path": ntpath.basename(input_base_path), | |
"input_driving_path": "", | |
"base_motion_expression": base_motion_expression, | |
"input_driving_audio_path": ntpath.basename(input_driving_audio_path), | |
"output_video_path": ntpath.basename(output_video_path), | |
} | |
submit_reply = self.submit_task(submit_request) | |
estimatedWaitSeconds = "unknown" | |
if "estimatedWaitSeconds" in submit_reply.keys(): | |
estimatedWaitSeconds = submit_reply["estimatedWaitSeconds"] | |
completion_statuses = {"Succeeded", "Cancelled", "Failed", "NotFound"} | |
timeout = 240 # maximum time to wait in seconds | |
if isinstance(estimatedWaitSeconds, int): | |
timeout += estimatedWaitSeconds | |
start_time = time.time() | |
result = {"messages": ''} | |
while True: | |
status_reply = self.get_task_status(request_id) | |
task_status = status_reply["taskStatus"] | |
if status_reply["taskStatus"] in completion_statuses: | |
break | |
if time.time() - start_time > timeout: | |
msg = "The task did not complete within the timeout period.\n The server is very busy serving other requests.\n Please try again." | |
result["success"] = False | |
result["messages"] = msg | |
gr.Error(msg) | |
break | |
time.sleep(3) | |
task_status = status_reply["taskStatus"] | |
if task_status == "Succeeded": | |
pipe_reply = status_reply["pipeReply"] | |
result["success"] = pipe_reply["status"] == "success" | |
result["messages"] = pipe_reply["messages"] | |
output_video_path = status_reply["videoURL"] | |
else: | |
messages = "" | |
if "pipeReply" in status_reply.keys(): | |
messages = status_reply["pipeReply"]["messages"] | |
result["success"] = False | |
result["messages"] += messages | |
return result, output_video_path | |