"""This module contains the EndpointV3Compatibility class, which is used to connect to Gradio apps running 3.x.x versions of Gradio.""" from __future__ import annotations import json from pathlib import Path from typing import TYPE_CHECKING, Any, Literal import httpx import huggingface_hub import websockets from packaging import version from gradio_client import serializing, utils from gradio_client.exceptions import SerializationSetupError from gradio_client.utils import ( Communicator, ) if TYPE_CHECKING: from gradio_client import Client class EndpointV3Compatibility: """Endpoint class for connecting to v3 endpoints. Backwards compatibility.""" def __init__(self, client: Client, fn_index: int, dependency: dict, *_args): self.client: Client = client self.fn_index = fn_index self.dependency = dependency api_name = dependency.get("api_name") self.api_name: str | Literal[False] | None = ( "/" + api_name if isinstance(api_name, str) else api_name ) self.use_ws = self._use_websocket(self.dependency) self.protocol = "ws" if self.use_ws else "http" self.input_component_types = [] self.output_component_types = [] self.root_url = client.src + "/" if not client.src.endswith("/") else client.src try: # Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid, # and api_name is not False (meaning that the developer has explicitly disabled the API endpoint) self.serializers, self.deserializers = self._setup_serializers() self.is_valid = self.dependency["backend_fn"] and self.api_name is not False except SerializationSetupError: self.is_valid = False self.backend_fn = dependency.get("backend_fn") self.show_api = True def __repr__(self): return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}" def __str__(self): return self.__repr__() def make_end_to_end_fn(self, helper: Communicator | None = None): _predict = self.make_predict(helper) def _inner(*data): if not self.is_valid: raise utils.InvalidAPIEndpointError() data = self.insert_state(*data) data = self.serialize(*data) predictions = _predict(*data) predictions = self.process_predictions(*predictions) # Append final output only if not already present # for consistency between generators and not generators if helper: with helper.lock: if not helper.job.outputs: helper.job.outputs.append(predictions) return predictions return _inner def make_cancel(self, helper: Communicator | None = None): # noqa: ARG002 (needed so that both endpoints classes have the same api) return None def make_predict(self, helper: Communicator | None = None): def _predict(*data) -> tuple: data = json.dumps( { "data": data, "fn_index": self.fn_index, "session_hash": self.client.session_hash, } ) hash_data = json.dumps( { "fn_index": self.fn_index, "session_hash": self.client.session_hash, } ) if self.use_ws: result = utils.synchronize_async(self._ws_fn, data, hash_data, helper) if "error" in result: raise ValueError(result["error"]) else: response = httpx.post( self.client.api_url, headers=self.client.headers, json=data, verify=self.client.ssl_verify, **self.client.httpx_kwargs, ) result = json.loads(response.content.decode("utf-8")) try: output = result["data"] except KeyError as ke: is_public_space = ( self.client.space_id and not huggingface_hub.space_info(self.client.space_id).private ) if "error" in result and "429" in result["error"] and is_public_space: raise utils.TooManyRequestsError( f"Too many requests to the API, please try again later. To avoid being rate-limited, " f"please duplicate the Space using Client.duplicate({self.client.space_id}) " f"and pass in your Hugging Face token." ) from None elif "error" in result: raise ValueError(result["error"]) from None raise KeyError( f"Could not find 'data' key in response. Response received: {result}" ) from ke return tuple(output) return _predict def _predict_resolve(self, *data) -> Any: """Needed for gradio.load(), which has a slightly different signature for serializing/deserializing""" outputs = self.make_predict()(*data) if len(self.dependency["outputs"]) == 1: return outputs[0] return outputs def _upload( self, file_paths: list[str | list[str]] ) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]: if not file_paths: return [] # Put all the filepaths in one file # but then keep track of which index in the # original list they came from so we can recreate # the original structure files = [] indices = [] for i, fs in enumerate(file_paths): if not isinstance(fs, list): fs = [fs] for f in fs: files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115 indices.append(i) r = httpx.post( self.client.upload_url, headers=self.client.headers, files=files, verify=self.client.ssl_verify, **self.client.httpx_kwargs, ) if r.status_code != 200: uploaded = file_paths else: uploaded = [] result = r.json() for i, fs in enumerate(file_paths): if isinstance(fs, list): output = [o for ix, o in enumerate(result) if indices[ix] == i] res = [ { "is_file": True, "name": o, "orig_name": Path(f).name, "data": None, } for f, o in zip(fs, output, strict=False) ] else: o = next(o for ix, o in enumerate(result) if indices[ix] == i) res = { "is_file": True, "name": o, "orig_name": Path(fs).name, "data": None, } uploaded.append(res) return uploaded def _add_uploaded_files_to_data( self, files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]], data: list[Any], ) -> None: """Helper function to modify the input data with the uploaded files.""" file_counter = 0 for i, t in enumerate(self.input_component_types): if t in ["file", "uploadbutton"]: data[i] = files[file_counter] file_counter += 1 def insert_state(self, *data) -> tuple: data = list(data) for i, input_component_type in enumerate(self.input_component_types): if input_component_type == utils.STATE_COMPONENT: data.insert(i, None) return tuple(data) def remove_skipped_components(self, *data) -> tuple: data = [ d for d, oct in zip(data, self.output_component_types, strict=False) if oct not in utils.SKIP_COMPONENTS ] return tuple(data) def reduce_singleton_output(self, *data) -> Any: if ( len( [ oct for oct in self.output_component_types if oct not in utils.SKIP_COMPONENTS ] ) == 1 ): return data[0] else: return data def serialize(self, *data) -> tuple: if len(data) != len(self.serializers): raise ValueError( f"Expected {len(self.serializers)} arguments, got {len(data)}" ) files = [ f for f, t in zip(data, self.input_component_types, strict=False) if t in ["file", "uploadbutton"] ] uploaded_files = self._upload(files) data = list(data) self._add_uploaded_files_to_data(uploaded_files, data) o = tuple( [s.serialize(d) for s, d in zip(self.serializers, data, strict=False)] ) return o def deserialize(self, *data) -> tuple: if len(data) != len(self.deserializers): raise ValueError( f"Expected {len(self.deserializers)} outputs, got {len(data)}" ) outputs = tuple( [ s.deserialize( d, save_dir=self.client.output_dir, hf_token=self.client.hf_token, root_url=self.root_url, ) for s, d in zip(self.deserializers, data, strict=False) ] ) return outputs def process_predictions(self, *predictions): if self.client.download_files: predictions = self.deserialize(*predictions) predictions = self.remove_skipped_components(*predictions) predictions = self.reduce_singleton_output(*predictions) return predictions def _setup_serializers( self, ) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]: inputs = self.dependency["inputs"] serializers = [] for i in inputs: for component in self.client.config["components"]: if component["id"] == i: component_name = component["type"] self.input_component_types.append(component_name) if component.get("serializer"): serializer_name = component["serializer"] if serializer_name not in serializing.SERIALIZER_MAPPING: raise SerializationSetupError( f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." ) serializer = serializing.SERIALIZER_MAPPING[serializer_name] elif component_name in serializing.COMPONENT_MAPPING: serializer = serializing.COMPONENT_MAPPING[component_name] else: raise SerializationSetupError( f"Unknown component: {component_name}, you may need to update your gradio_client version." ) serializers.append(serializer()) # type: ignore outputs = self.dependency["outputs"] deserializers = [] for i in outputs: for component in self.client.config["components"]: if component["id"] == i: component_name = component["type"] self.output_component_types.append(component_name) if component.get("serializer"): serializer_name = component["serializer"] if serializer_name not in serializing.SERIALIZER_MAPPING: raise SerializationSetupError( f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." ) deserializer = serializing.SERIALIZER_MAPPING[serializer_name] elif component_name in utils.SKIP_COMPONENTS: deserializer = serializing.SimpleSerializable elif component_name in serializing.COMPONENT_MAPPING: deserializer = serializing.COMPONENT_MAPPING[component_name] else: raise SerializationSetupError( f"Unknown component: {component_name}, you may need to update your gradio_client version." ) deserializers.append(deserializer()) # type: ignore return serializers, deserializers def _use_websocket(self, dependency: dict) -> bool: queue_enabled = self.client.config.get("enable_queue", False) queue_uses_websocket = version.parse( self.client.config.get("version", "2.0") ) >= version.Version("3.2") dependency_uses_queue = dependency.get("queue", False) is not False return queue_enabled and queue_uses_websocket and dependency_uses_queue async def _ws_fn(self, data, hash_data, helper: Communicator): async with websockets.connect( # type: ignore self.client.ws_url, open_timeout=10, extra_headers=self.client.headers, max_size=1024 * 1024 * 1024, ) as websocket: return await utils.get_pred_from_ws(websocket, data, hash_data, helper)