|
"""This module contains the EndpointV3Compatibility class, which is used to connect to Gradio apps running 3.x.x versions of Gradio.""" |
|
|
|
from __future__ import annotations |
|
|
|
import json |
|
from pathlib import Path |
|
from typing import TYPE_CHECKING, Any, Literal |
|
|
|
import httpx |
|
import huggingface_hub |
|
import websockets |
|
from packaging import version |
|
|
|
from gradio_client import serializing, utils |
|
from gradio_client.exceptions import SerializationSetupError |
|
from gradio_client.utils import ( |
|
Communicator, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from gradio_client import Client |
|
|
|
|
|
class EndpointV3Compatibility: |
|
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility.""" |
|
|
|
def __init__(self, client: Client, fn_index: int, dependency: dict, *_args): |
|
self.client: Client = client |
|
self.fn_index = fn_index |
|
self.dependency = dependency |
|
api_name = dependency.get("api_name") |
|
self.api_name: str | Literal[False] | None = ( |
|
"/" + api_name if isinstance(api_name, str) else api_name |
|
) |
|
self.use_ws = self._use_websocket(self.dependency) |
|
self.protocol = "ws" if self.use_ws else "http" |
|
self.input_component_types = [] |
|
self.output_component_types = [] |
|
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src |
|
try: |
|
|
|
|
|
self.serializers, self.deserializers = self._setup_serializers() |
|
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False |
|
except SerializationSetupError: |
|
self.is_valid = False |
|
self.backend_fn = dependency.get("backend_fn") |
|
self.show_api = True |
|
|
|
def __repr__(self): |
|
return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}" |
|
|
|
def __str__(self): |
|
return self.__repr__() |
|
|
|
def make_end_to_end_fn(self, helper: Communicator | None = None): |
|
_predict = self.make_predict(helper) |
|
|
|
def _inner(*data): |
|
if not self.is_valid: |
|
raise utils.InvalidAPIEndpointError() |
|
data = self.insert_state(*data) |
|
data = self.serialize(*data) |
|
predictions = _predict(*data) |
|
predictions = self.process_predictions(*predictions) |
|
|
|
|
|
if helper: |
|
with helper.lock: |
|
if not helper.job.outputs: |
|
helper.job.outputs.append(predictions) |
|
return predictions |
|
|
|
return _inner |
|
|
|
def make_cancel(self, helper: Communicator | None = None): |
|
return None |
|
|
|
def make_predict(self, helper: Communicator | None = None): |
|
def _predict(*data) -> tuple: |
|
data = json.dumps( |
|
{ |
|
"data": data, |
|
"fn_index": self.fn_index, |
|
"session_hash": self.client.session_hash, |
|
} |
|
) |
|
hash_data = json.dumps( |
|
{ |
|
"fn_index": self.fn_index, |
|
"session_hash": self.client.session_hash, |
|
} |
|
) |
|
if self.use_ws: |
|
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper) |
|
if "error" in result: |
|
raise ValueError(result["error"]) |
|
else: |
|
response = httpx.post( |
|
self.client.api_url, |
|
headers=self.client.headers, |
|
json=data, |
|
verify=self.client.ssl_verify, |
|
**self.client.httpx_kwargs, |
|
) |
|
result = json.loads(response.content.decode("utf-8")) |
|
try: |
|
output = result["data"] |
|
except KeyError as ke: |
|
is_public_space = ( |
|
self.client.space_id |
|
and not huggingface_hub.space_info(self.client.space_id).private |
|
) |
|
if "error" in result and "429" in result["error"] and is_public_space: |
|
raise utils.TooManyRequestsError( |
|
f"Too many requests to the API, please try again later. To avoid being rate-limited, " |
|
f"please duplicate the Space using Client.duplicate({self.client.space_id}) " |
|
f"and pass in your Hugging Face token." |
|
) from None |
|
elif "error" in result: |
|
raise ValueError(result["error"]) from None |
|
raise KeyError( |
|
f"Could not find 'data' key in response. Response received: {result}" |
|
) from ke |
|
return tuple(output) |
|
|
|
return _predict |
|
|
|
def _predict_resolve(self, *data) -> Any: |
|
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing""" |
|
outputs = self.make_predict()(*data) |
|
if len(self.dependency["outputs"]) == 1: |
|
return outputs[0] |
|
return outputs |
|
|
|
def _upload( |
|
self, file_paths: list[str | list[str]] |
|
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]: |
|
if not file_paths: |
|
return [] |
|
|
|
|
|
|
|
|
|
files = [] |
|
indices = [] |
|
for i, fs in enumerate(file_paths): |
|
if not isinstance(fs, list): |
|
fs = [fs] |
|
for f in fs: |
|
files.append(("files", (Path(f).name, open(f, "rb")))) |
|
indices.append(i) |
|
r = httpx.post( |
|
self.client.upload_url, |
|
headers=self.client.headers, |
|
files=files, |
|
verify=self.client.ssl_verify, |
|
**self.client.httpx_kwargs, |
|
) |
|
if r.status_code != 200: |
|
uploaded = file_paths |
|
else: |
|
uploaded = [] |
|
result = r.json() |
|
for i, fs in enumerate(file_paths): |
|
if isinstance(fs, list): |
|
output = [o for ix, o in enumerate(result) if indices[ix] == i] |
|
res = [ |
|
{ |
|
"is_file": True, |
|
"name": o, |
|
"orig_name": Path(f).name, |
|
"data": None, |
|
} |
|
for f, o in zip(fs, output, strict=False) |
|
] |
|
else: |
|
o = next(o for ix, o in enumerate(result) if indices[ix] == i) |
|
res = { |
|
"is_file": True, |
|
"name": o, |
|
"orig_name": Path(fs).name, |
|
"data": None, |
|
} |
|
uploaded.append(res) |
|
return uploaded |
|
|
|
def _add_uploaded_files_to_data( |
|
self, |
|
files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]], |
|
data: list[Any], |
|
) -> None: |
|
"""Helper function to modify the input data with the uploaded files.""" |
|
file_counter = 0 |
|
for i, t in enumerate(self.input_component_types): |
|
if t in ["file", "uploadbutton"]: |
|
data[i] = files[file_counter] |
|
file_counter += 1 |
|
|
|
def insert_state(self, *data) -> tuple: |
|
data = list(data) |
|
for i, input_component_type in enumerate(self.input_component_types): |
|
if input_component_type == utils.STATE_COMPONENT: |
|
data.insert(i, None) |
|
return tuple(data) |
|
|
|
def remove_skipped_components(self, *data) -> tuple: |
|
data = [ |
|
d |
|
for d, oct in zip(data, self.output_component_types, strict=False) |
|
if oct not in utils.SKIP_COMPONENTS |
|
] |
|
return tuple(data) |
|
|
|
def reduce_singleton_output(self, *data) -> Any: |
|
if ( |
|
len( |
|
[ |
|
oct |
|
for oct in self.output_component_types |
|
if oct not in utils.SKIP_COMPONENTS |
|
] |
|
) |
|
== 1 |
|
): |
|
return data[0] |
|
else: |
|
return data |
|
|
|
def serialize(self, *data) -> tuple: |
|
if len(data) != len(self.serializers): |
|
raise ValueError( |
|
f"Expected {len(self.serializers)} arguments, got {len(data)}" |
|
) |
|
|
|
files = [ |
|
f |
|
for f, t in zip(data, self.input_component_types, strict=False) |
|
if t in ["file", "uploadbutton"] |
|
] |
|
uploaded_files = self._upload(files) |
|
data = list(data) |
|
self._add_uploaded_files_to_data(uploaded_files, data) |
|
o = tuple( |
|
[s.serialize(d) for s, d in zip(self.serializers, data, strict=False)] |
|
) |
|
return o |
|
|
|
def deserialize(self, *data) -> tuple: |
|
if len(data) != len(self.deserializers): |
|
raise ValueError( |
|
f"Expected {len(self.deserializers)} outputs, got {len(data)}" |
|
) |
|
outputs = tuple( |
|
[ |
|
s.deserialize( |
|
d, |
|
save_dir=self.client.output_dir, |
|
hf_token=self.client.hf_token, |
|
root_url=self.root_url, |
|
) |
|
for s, d in zip(self.deserializers, data, strict=False) |
|
] |
|
) |
|
return outputs |
|
|
|
def process_predictions(self, *predictions): |
|
if self.client.download_files: |
|
predictions = self.deserialize(*predictions) |
|
predictions = self.remove_skipped_components(*predictions) |
|
predictions = self.reduce_singleton_output(*predictions) |
|
return predictions |
|
|
|
def _setup_serializers( |
|
self, |
|
) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]: |
|
inputs = self.dependency["inputs"] |
|
serializers = [] |
|
|
|
for i in inputs: |
|
for component in self.client.config["components"]: |
|
if component["id"] == i: |
|
component_name = component["type"] |
|
self.input_component_types.append(component_name) |
|
if component.get("serializer"): |
|
serializer_name = component["serializer"] |
|
if serializer_name not in serializing.SERIALIZER_MAPPING: |
|
raise SerializationSetupError( |
|
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." |
|
) |
|
serializer = serializing.SERIALIZER_MAPPING[serializer_name] |
|
elif component_name in serializing.COMPONENT_MAPPING: |
|
serializer = serializing.COMPONENT_MAPPING[component_name] |
|
else: |
|
raise SerializationSetupError( |
|
f"Unknown component: {component_name}, you may need to update your gradio_client version." |
|
) |
|
serializers.append(serializer()) |
|
|
|
outputs = self.dependency["outputs"] |
|
deserializers = [] |
|
for i in outputs: |
|
for component in self.client.config["components"]: |
|
if component["id"] == i: |
|
component_name = component["type"] |
|
self.output_component_types.append(component_name) |
|
if component.get("serializer"): |
|
serializer_name = component["serializer"] |
|
if serializer_name not in serializing.SERIALIZER_MAPPING: |
|
raise SerializationSetupError( |
|
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." |
|
) |
|
deserializer = serializing.SERIALIZER_MAPPING[serializer_name] |
|
elif component_name in utils.SKIP_COMPONENTS: |
|
deserializer = serializing.SimpleSerializable |
|
elif component_name in serializing.COMPONENT_MAPPING: |
|
deserializer = serializing.COMPONENT_MAPPING[component_name] |
|
else: |
|
raise SerializationSetupError( |
|
f"Unknown component: {component_name}, you may need to update your gradio_client version." |
|
) |
|
deserializers.append(deserializer()) |
|
|
|
return serializers, deserializers |
|
|
|
def _use_websocket(self, dependency: dict) -> bool: |
|
queue_enabled = self.client.config.get("enable_queue", False) |
|
queue_uses_websocket = version.parse( |
|
self.client.config.get("version", "2.0") |
|
) >= version.Version("3.2") |
|
dependency_uses_queue = dependency.get("queue", False) is not False |
|
return queue_enabled and queue_uses_websocket and dependency_uses_queue |
|
|
|
async def _ws_fn(self, data, hash_data, helper: Communicator): |
|
async with websockets.connect( |
|
self.client.ws_url, |
|
open_timeout=10, |
|
extra_headers=self.client.headers, |
|
max_size=1024 * 1024 * 1024, |
|
) as websocket: |
|
return await utils.get_pred_from_ws(websocket, data, hash_data, helper) |
|
|