Spaces:
Runtime error
Runtime error
File size: 3,546 Bytes
94a93b4 6f25160 ed27ead 94a93b4 a89eb44 513b107 94a93b4 8cdd1fc 513b107 8cdd1fc 94a93b4 ed27ead d614278 ed27ead ff33526 ed27ead ff33526 ed27ead 55bb6df 74e146e ff33526 ed27ead d614278 480cb62 ed27ead d614278 480cb62 94a93b4 ed27ead d614278 ed27ead d614278 ed27ead d614278 ed27ead |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import requests
import json
import os
from screenshot import (
before_prompt,
prompt_to_generation,
after_generation,
js_save,
js_load_script,
)
from spaces_info import description, examples, initial_prompt_value
API_URL = os.getenv("API_URL")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
def query(payload):
print(payload)
response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
print(response)
return json.loads(response.content.decode("utf-8"))
def inference(input_sentence, max_length, sample_or_greedy, seed=42):
if sample_or_greedy == "Sample":
parameters = {
"max_new_tokens": max_length,
"top_p": 0.9,
"do_sample": True,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
else:
parameters = {
"max_new_tokens": max_length,
"do_sample": False,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }
data = query(payload)
if "error" in data:
return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")
generation = data[0]["generated_text"].split(input_sentence, 1)[1]
return (
before_prompt
+ input_sentence
+ prompt_to_generation
+ generation
+ after_generation,
data[0]["generated_text"],
"",
)
if __name__ == "__main__":
demo = gr.Blocks()
with demo:
with gr.Row():
gr.Markdown(value=description)
with gr.Row():
with gr.Column():
text = gr.Textbox(
label="Input",
value=" ", # should be set to " " when plugged into a real API
)
tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")
sampling = gr.Radio(
["Sample", "Greedy"], label="Sample or greedy", value="Sample"
)
sampling2 = gr.Radio(
["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"],
value="Sample 1",
label="Sample other generations (only work in 'Sample' mode)",
type="index",
)
with gr.Row():
submit = gr.Button("Submit")
load_image = gr.Button("Generate Image")
with gr.Column():
text_error = gr.Markdown(label="Log information")
text_out = gr.Textbox(label="Output")
display_out = gr.HTML(label="Image")
display_out.set_event_trigger(
"load",
fn=None,
inputs=None,
outputs=None,
no_target=True,
js=js_load_script,
)
with gr.Row():
gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2])
submit.click(
inference,
inputs=[text, tokens, sampling, sampling2],
outputs=[display_out, text_out, text_error],
)
load_image.click(fn=None, inputs=None, outputs=None, _js=js_save)
demo.launch()
|