File size: 3,546 Bytes
94a93b4
 
 
6f25160
ed27ead
 
 
 
 
 
 
 
94a93b4
a89eb44
513b107
94a93b4
 
 
8cdd1fc
513b107
8cdd1fc
94a93b4
ed27ead
d614278
ed27ead
ff33526
ed27ead
 
 
 
 
 
 
 
 
ff33526
ed27ead
 
 
 
 
 
 
 
55bb6df
74e146e
ff33526
ed27ead
d614278
 
 
480cb62
ed27ead
 
 
 
 
 
 
 
d614278
480cb62
94a93b4
 
ed27ead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d614278
 
 
ed27ead
 
 
 
 
 
 
 
 
d614278
ed27ead
 
 
 
d614278
ed27ead
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import requests
import json
import os
from screenshot import (
    before_prompt,
    prompt_to_generation,
    after_generation,
    js_save,
    js_load_script,
)
from spaces_info import description, examples, initial_prompt_value

API_URL = os.getenv("API_URL")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")


def query(payload):
    print(payload)
    response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
    print(response)
    return json.loads(response.content.decode("utf-8"))


def inference(input_sentence, max_length, sample_or_greedy, seed=42):
    if sample_or_greedy == "Sample":
        parameters = {
            "max_new_tokens": max_length,
            "top_p": 0.9,
            "do_sample": True,
            "seed": seed,
            "early_stopping": False,
            "length_penalty": 0.0,
            "eos_token_id": None,
        }
    else:
        parameters = {
            "max_new_tokens": max_length,
            "do_sample": False,
            "seed": seed,
            "early_stopping": False,
            "length_penalty": 0.0,
            "eos_token_id": None,
        }

    payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }

    data = query(payload)

    if "error" in data:
        return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")

    generation = data[0]["generated_text"].split(input_sentence, 1)[1]
    return (
        before_prompt
        + input_sentence
        + prompt_to_generation
        + generation
        + after_generation,
        data[0]["generated_text"],
        "",
    )


if __name__ == "__main__":
    demo = gr.Blocks()
    with demo:
        with gr.Row():
            gr.Markdown(value=description)
        with gr.Row():
            with gr.Column():
                text = gr.Textbox(
                    label="Input",
                    value=" ",  # should be set to " " when plugged into a real API
                )
                tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")
                sampling = gr.Radio(
                    ["Sample", "Greedy"], label="Sample or greedy", value="Sample"
                )
                sampling2 = gr.Radio(
                    ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"],
                    value="Sample 1",
                    label="Sample other generations (only work in 'Sample' mode)",
                    type="index",
                )

                with gr.Row():
                    submit = gr.Button("Submit")
                    load_image = gr.Button("Generate Image")
            with gr.Column():
                text_error = gr.Markdown(label="Log information")
                text_out = gr.Textbox(label="Output")
                display_out = gr.HTML(label="Image")
                display_out.set_event_trigger(
                    "load",
                    fn=None,
                    inputs=None,
                    outputs=None,
                    no_target=True,
                    js=js_load_script,
                )
        with gr.Row():
            gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2])

        submit.click(
            inference,
            inputs=[text, tokens, sampling, sampling2],
            outputs=[display_out, text_out, text_error],
        )

        load_image.click(fn=None, inputs=None, outputs=None, _js=js_save)

    demo.launch()