|
import traceback |
|
from typing import Callable |
|
import os |
|
|
|
from gradio_client.client import Job |
|
|
|
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' |
|
|
|
from gradio_client import Client |
|
|
|
|
|
class GradioClient(Client): |
|
""" |
|
Parent class of gradio client |
|
To handle automatically refreshing client if detect gradio server changed |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
self.args = args |
|
self.kwargs = kwargs |
|
super().__init__(*args, **kwargs) |
|
self.server_hash = self.get_server_hash() |
|
|
|
def get_server_hash(self): |
|
""" |
|
Get server hash using super without any refresh action triggered |
|
Returns: git hash of gradio server |
|
""" |
|
return super().submit(api_name='/system_hash').result() |
|
|
|
def refresh_client_if_should(self): |
|
|
|
|
|
server_hash = self.get_server_hash() |
|
if self.server_hash != server_hash: |
|
self.refresh_client() |
|
self.server_hash = server_hash |
|
else: |
|
self.reset_session() |
|
|
|
def refresh_client(self): |
|
""" |
|
Ensure every client call is independent |
|
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code) |
|
Returns: |
|
""" |
|
|
|
self.reset_session() |
|
|
|
client = Client(*self.args, **self.kwargs) |
|
for k, v in client.__dict__.items(): |
|
setattr(self, k, v) |
|
|
|
def submit( |
|
self, |
|
*args, |
|
api_name: str | None = None, |
|
fn_index: int | None = None, |
|
result_callbacks: Callable | list[Callable] | None = None, |
|
) -> Job: |
|
|
|
try: |
|
self.refresh_client_if_should() |
|
job = super().submit(*args, api_name=api_name, fn_index=fn_index) |
|
except Exception as e: |
|
print("Hit e=%s" % str(e), flush=True) |
|
|
|
self.refresh_client() |
|
job = super().submit(*args, api_name=api_name, fn_index=fn_index) |
|
|
|
|
|
e = job.future._exception |
|
if e is not None: |
|
print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True) |
|
|
|
self.refresh_client() |
|
job = super().submit(*args, api_name=api_name, fn_index=fn_index) |
|
e2 = job.future._exception |
|
if e2 is not None: |
|
print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True) |
|
|
|
return job |
|
|