clarity-upscaler / utils /gradio_helpers.py
jbilcke-hf's picture
jbilcke-hf HF staff
Upload folder using huggingface_hub
0a4b1cc verified
import gradio as gr
from urllib.parse import urlparse
import requests
import time
from PIL import Image
import base64
import io
import uuid
import os
def extract_property_info(prop):
combined_prop = {}
merge_keywords = ["allOf", "anyOf", "oneOf"]
for keyword in merge_keywords:
if keyword in prop:
for subprop in prop[keyword]:
combined_prop.update(subprop)
del prop[keyword]
if not combined_prop:
combined_prop = prop.copy()
for key in ["description", "default"]:
if key in prop:
combined_prop[key] = prop[key]
return combined_prop
def detect_file_type(filename):
audio_extensions = [".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"]
image_extensions = [
".jpg",
".jpeg",
".png",
".gif",
".bmp",
".tiff",
".svg",
".webp",
]
video_extensions = [
".mp4",
".mov",
".wmv",
".flv",
".avi",
".avchd",
".mkv",
".webm",
]
# Extract the file extension
if isinstance(filename, str):
extension = filename[filename.rfind(".") :].lower()
# Check the extension against each list
if extension in audio_extensions:
return "audio"
elif extension in image_extensions:
return "image"
elif extension in video_extensions:
return "video"
else:
return "string"
elif isinstance(filename, list):
return "list"
def build_gradio_inputs(ordered_input_schema, example_inputs=None):
inputs = []
input_field_strings = """inputs = []\n"""
names = []
for index, (name, prop) in enumerate(ordered_input_schema):
names.append(name)
prop = extract_property_info(prop)
if "enum" in prop:
input_field = gr.Dropdown(
choices=prop["enum"],
label=prop.get("title"),
info=prop.get("description"),
value=prop.get("default"),
)
input_field_string = f"""inputs.append(gr.Dropdown(
choices={prop["enum"]}, label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value="{prop.get("default")}"
))\n"""
elif prop["type"] == "integer":
if prop.get("minimum") and prop.get("maximum"):
input_field = gr.Slider(
label=prop.get("title"),
info=prop.get("description"),
value=prop.get("default"),
minimum=prop.get("minimum"),
maximum=prop.get("maximum"),
step=1,
)
input_field_string = f"""inputs.append(gr.Slider(
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")},
minimum={prop.get("minimum")}, maximum={prop.get("maximum")}, step=1,
))\n"""
else:
input_field = gr.Number(
label=prop.get("title"),
info=prop.get("description"),
value=prop.get("default"),
)
input_field_string = f"""inputs.append(gr.Number(
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
))\n"""
elif prop["type"] == "number":
if prop.get("minimum") and prop.get("maximum"):
input_field = gr.Slider(
label=prop.get("title"),
info=prop.get("description"),
value=prop.get("default"),
minimum=prop.get("minimum"),
maximum=prop.get("maximum"),
)
input_field_string = f"""inputs.append(gr.Slider(
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")},
minimum={prop.get("minimum")}, maximum={prop.get("maximum")}
))\n"""
else:
input_field = gr.Number(
label=prop.get("title"),
info=prop.get("description"),
value=prop.get("default"),
)
input_field_string = f"""inputs.append(gr.Number(
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
))\n"""
elif prop["type"] == "boolean":
input_field = gr.Checkbox(
label=prop.get("title"),
info=prop.get("description"),
value=prop.get("default"),
)
input_field_string = f"""inputs.append(gr.Checkbox(
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
))\n"""
elif (
prop["type"] == "string" and prop.get("format") == "uri" and example_inputs
):
input_type_example = example_inputs.get(name, None)
if input_type_example:
input_type = detect_file_type(input_type_example)
else:
input_type = None
if input_type == "image":
input_field = gr.Image(label=prop.get("title"), type="filepath")
input_field_string = f"""inputs.append(gr.Image(
label="{prop.get("title")}", type="filepath"
))\n"""
elif input_type == "audio":
input_field = gr.Audio(label=prop.get("title"), type="filepath")
input_field_string = f"""inputs.append(gr.Audio(
label="{prop.get("title")}", type="filepath"
))\n"""
elif input_type == "video":
input_field = gr.Video(label=prop.get("title"))
input_field_string = f"""inputs.append(gr.Video(
label="{prop.get("title")}"
))\n"""
else:
input_field = gr.File(label=prop.get("title"))
input_field_string = f"""inputs.append(gr.File(
label="{prop.get("title")}"
))\n"""
else:
input_field = gr.Textbox(
label=prop.get("title"),
info=prop.get("description"),
)
input_field_string = f"""inputs.append(gr.Textbox(
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}
))\n"""
inputs.append(input_field)
input_field_strings += f"{input_field_string}\n"
input_field_strings += f"names = {names}\n"
return inputs, input_field_strings, names
def build_gradio_outputs_replicate(output_types):
outputs = []
output_field_strings = """outputs = []\n"""
if output_types:
for output in output_types:
if output == "image":
output_field = gr.Image()
output_field_string = "outputs.append(gr.Image())"
elif output == "audio":
output_field = gr.Audio(type="filepath")
output_field_string = "outputs.append(gr.Audio(type='filepath'))"
elif output == "video":
output_field = gr.Video()
output_field_string = "outputs.append(gr.Video())"
elif output == "string":
output_field = gr.Textbox()
output_field_string = "outputs.append(gr.Textbox())"
elif output == "json":
output_field = gr.JSON()
output_field_string = "outputs.append(gr.JSON())"
elif output == "list":
output_field = gr.JSON()
output_field_string = "outputs.append(gr.JSON())"
outputs.append(output_field)
output_field_strings += f"{output_field_string}\n"
else:
output_field = gr.JSON()
output_field_string = "outputs.append(gr.JSON())"
outputs.append(output_field)
return outputs, output_field_strings
def build_gradio_outputs_cog():
pass
def process_outputs(outputs):
output_values = []
for output in outputs:
if not output:
continue
if isinstance(output, str):
if output.startswith("data:image"):
base64_data = output.split(",", 1)[1]
image_data = base64.b64decode(base64_data)
image_stream = io.BytesIO(image_data)
image = Image.open(image_stream)
output_values.append(image)
elif output.startswith("data:audio"):
base64_data = output.split(",", 1)[1]
audio_data = base64.b64decode(base64_data)
audio_stream = io.BytesIO(audio_data)
filename = f"{uuid.uuid4()}.wav" # Change format as needed
with open(filename, "wb") as audio_file:
audio_file.write(audio_stream.getbuffer())
output_values.append(filename)
elif output.startswith("data:video"):
base64_data = output.split(",", 1)[1]
video_data = base64.b64decode(base64_data)
video_stream = io.BytesIO(video_data)
# Here you can save the audio or return the stream for further processing
filename = f"{uuid.uuid4()}.mp4" # Change format as needed
with open(filename, "wb") as video_file:
video_file.write(video_stream.getbuffer())
output_values.append(filename)
else:
output_values.append(output)
else:
output_values.append(output)
return output_values
def parse_outputs(data):
if isinstance(data, dict):
# Handle case where data is an object
dict_values = []
for value in data.values():
extracted_values = parse_outputs(value)
# For dict, we append instead of extend to maintain list structure within objects
if isinstance(value, list):
dict_values += [extracted_values]
else:
dict_values += extracted_values
return dict_values
elif isinstance(data, list):
# Handle case where data is an array
list_values = []
for item in data:
# Here we extend to flatten the list since we're already in an array context
list_values += parse_outputs(item)
return list_values
else:
# Handle primitive data types directly
return [data]
def create_dynamic_gradio_app(
inputs,
outputs,
api_url,
api_id=None,
replicate_token=None,
title="",
model_description="",
names=[],
local_base=False,
hostname="0.0.0.0",
):
expected_outputs = len(outputs)
def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
payload = {"input": {}}
if api_id:
payload["version"] = api_id
parsed_url = urlparse(str(request.url))
if local_base:
base_url = f"http://{hostname}:7860"
else:
base_url = parsed_url.scheme + "://" + parsed_url.netloc
for i, key in enumerate(names):
value = args[i]
if value and (os.path.exists(str(value))):
value = f"{base_url}/file=" + value
if value is not None and value != "":
payload["input"][key] = value
print(payload)
headers = {"Content-Type": "application/json"}
if replicate_token:
headers["Authorization"] = f"Token {replicate_token}"
print(headers)
response = requests.post(api_url, headers=headers, json=payload)
if response.status_code == 201:
follow_up_url = response.json()["urls"]["get"]
response = requests.get(follow_up_url, headers=headers)
while response.json()["status"] != "succeeded":
if response.json()["status"] == "failed":
raise gr.Error("The submission failed!")
response = requests.get(follow_up_url, headers=headers)
time.sleep(1)
# TODO: Add a failing mechanism if the API gets stuck
if response.status_code == 200:
json_response = response.json()
# If the output component is JSON return the entire output response
if outputs[0].get_config()["name"] == "json":
return json_response["output"]
predict_outputs = parse_outputs(json_response["output"])
processed_outputs = process_outputs(predict_outputs)
difference_outputs = expected_outputs - len(processed_outputs)
# If less outputs than expected, hide the extra ones
if difference_outputs > 0:
extra_outputs = [gr.update(visible=False)] * difference_outputs
processed_outputs.extend(extra_outputs)
# If more outputs than expected, cap the outputs to the expected number if
elif difference_outputs < 0:
processed_outputs = processed_outputs[:difference_outputs]
return (
tuple(processed_outputs)
if len(processed_outputs) > 1
else processed_outputs[0]
)
else:
if response.status_code == 409:
raise gr.Error(
f"Sorry, the Cog image is still processing. Try again in a bit."
)
raise gr.Error(f"The submission failed! Error: {response.status_code}")
app = gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
title=title,
description=model_description,
allow_flagging="never",
)
return app
def create_gradio_app_script(
inputs_string,
outputs_string,
api_url,
api_id=None,
replicate_token=None,
title="",
model_description="",
local_base=False,
hostname="0.0.0.0"
):
headers = {"Content-Type": "application/json"}
if replicate_token:
headers["Authorization"] = f"Token {replicate_token}"
if local_base:
base_url = f'base_url = "http://{hostname}:7860"'
else:
base_url = """parsed_url = urlparse(str(request.url))
base_url = parsed_url.scheme + "://" + parsed_url.netloc"""
headers_string = f"""headers = {headers}\n"""
api_id_value = f'payload["version"] = "{api_id}"' if api_id is not None else ""
definition_string = """expected_outputs = len(outputs)
def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):"""
payload_string = f"""payload = {{"input": {{}}}}
{api_id_value}
{base_url}
for i, key in enumerate(names):
value = args[i]
if value and (os.path.exists(str(value))):
value = f"{{base_url}}/file=" + value
if value is not None and value != "":
payload["input"][key] = value\n"""
request_string = (
f"""response = requests.post("{api_url}", headers=headers, json=payload)\n"""
)
result_string = f"""
if response.status_code == 201:
follow_up_url = response.json()["urls"]["get"]
response = requests.get(follow_up_url, headers=headers)
while response.json()["status"] != "succeeded":
if response.json()["status"] == "failed":
raise gr.Error("The submission failed!")
response = requests.get(follow_up_url, headers=headers)
time.sleep(1)
if response.status_code == 200:
json_response = response.json()
#If the output component is JSON return the entire output response
if(outputs[0].get_config()["name"] == "json"):
return json_response["output"]
predict_outputs = parse_outputs(json_response["output"])
processed_outputs = process_outputs(predict_outputs)
difference_outputs = expected_outputs - len(processed_outputs)
# If less outputs than expected, hide the extra ones
if difference_outputs > 0:
extra_outputs = [gr.update(visible=False)] * difference_outputs
processed_outputs.extend(extra_outputs)
# If more outputs than expected, cap the outputs to the expected number
elif difference_outputs < 0:
processed_outputs = processed_outputs[:difference_outputs]
return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
else:
if(response.status_code == 409):
raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.")
raise gr.Error(f"The submission failed! Error: {{response.status_code}}")\n"""
interface_string = f"""title = "{title}"
model_description = "{model_description}"
app = gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
title=title,
description=model_description,
allow_flagging="never",
)
app.launch(share=True)
"""
app_string = f"""import gradio as gr
from urllib.parse import urlparse
import requests
import time
import os
from utils.gradio_helpers import parse_outputs, process_outputs
{inputs_string}
{outputs_string}
{definition_string}
{headers_string}
{payload_string}
{request_string}
{result_string}
{interface_string}
"""
return app_string