|
import asyncio |
|
import os |
|
import tempfile |
|
import time |
|
from pathlib import Path |
|
from unittest.mock import patch |
|
|
|
import gradio_client as grc |
|
import pytest |
|
from gradio_client import media_data |
|
from gradio_client import utils as client_utils |
|
from pydub import AudioSegment |
|
from starlette.testclient import TestClient |
|
from tqdm import tqdm |
|
|
|
import gradio as gr |
|
from gradio import helpers, utils |
|
from gradio.route_utils import API_PREFIX |
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
|
class TestExamples: |
|
def test_handle_single_input(self, patched_cache_folder): |
|
examples = gr.Examples(["hello", "hi"], gr.Textbox()) |
|
assert examples.non_none_processed_examples.as_list() == [["hello"], ["hi"]] |
|
|
|
examples = gr.Examples([["hello"]], gr.Textbox()) |
|
assert examples.non_none_processed_examples.as_list() == [["hello"]] |
|
|
|
examples = gr.Examples(["test/test_files/bus.png"], gr.Image()) |
|
assert ( |
|
client_utils.encode_file_to_base64( |
|
examples.non_none_processed_examples.as_list()[0][0]["path"] |
|
) |
|
== media_data.BASE64_IMAGE |
|
) |
|
|
|
def test_handle_multiple_inputs(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()] |
|
) |
|
assert examples.non_none_processed_examples.as_list()[0][0] == "hello" |
|
assert ( |
|
client_utils.encode_file_to_base64( |
|
examples.non_none_processed_examples.as_list()[0][1]["path"] |
|
) |
|
== media_data.BASE64_IMAGE |
|
) |
|
|
|
def test_handle_directory(self, patched_cache_folder): |
|
examples = gr.Examples("test/test_files/images", gr.Image()) |
|
assert len(examples.non_none_processed_examples.as_list()) == 2 |
|
for row in examples.non_none_processed_examples.as_list(): |
|
for output in row: |
|
assert ( |
|
client_utils.encode_file_to_base64(output["path"]) |
|
== media_data.BASE64_IMAGE |
|
) |
|
|
|
def test_handle_directory_with_log_file(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()] |
|
) |
|
ex = client_utils.traverse( |
|
examples.non_none_processed_examples.as_list(), |
|
lambda s: client_utils.encode_file_to_base64(s["path"]), |
|
lambda x: isinstance(x, dict) and Path(x["path"]).exists(), |
|
) |
|
assert ex == [ |
|
[media_data.BASE64_IMAGE, "hello"], |
|
[media_data.BASE64_IMAGE, "hi"], |
|
] |
|
for sample in examples.dataset.samples: |
|
assert os.path.isabs(sample[0]["path"]) |
|
|
|
def test_examples_per_page(self, patched_cache_folder): |
|
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2) |
|
assert examples.dataset.get_config()["samples_per_page"] == 2 |
|
|
|
def test_no_preprocessing(self, patched_cache_folder, connect): |
|
with gr.Blocks() as demo: |
|
image = gr.Image() |
|
textbox = gr.Textbox() |
|
|
|
examples = gr.Examples( |
|
examples=["test/test_files/bus.png"], |
|
inputs=image, |
|
outputs=textbox, |
|
fn=lambda x: x["path"], |
|
cache_examples=True, |
|
preprocess=False, |
|
) |
|
|
|
with connect(demo): |
|
prediction = examples.load_from_cache(0) |
|
assert ( |
|
client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE |
|
) |
|
|
|
def test_no_postprocessing(self, patched_cache_folder, connect): |
|
def im(x): |
|
return [ |
|
{ |
|
"image": { |
|
"path": "test/test_files/bus.png", |
|
}, |
|
"caption": "hi", |
|
} |
|
] |
|
|
|
with gr.Blocks() as demo: |
|
text = gr.Textbox() |
|
gall = gr.Gallery() |
|
|
|
examples = gr.Examples( |
|
examples=["hi"], |
|
inputs=text, |
|
outputs=gall, |
|
fn=im, |
|
cache_examples=True, |
|
postprocess=False, |
|
) |
|
|
|
with connect(demo): |
|
prediction = examples.load_from_cache(0) |
|
file = prediction[0].root[0].image.path |
|
assert client_utils.encode_url_or_file_to_base64( |
|
file |
|
) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png") |
|
|
|
|
|
def test_setting_cache_dir_env_variable(monkeypatch, connect): |
|
temp_dir = tempfile.mkdtemp() |
|
monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir) |
|
with gr.Blocks() as demo: |
|
image = gr.Image() |
|
image2 = gr.Image() |
|
|
|
examples = gr.Examples( |
|
examples=["test/test_files/bus.png"], |
|
inputs=image, |
|
outputs=image2, |
|
fn=lambda x: x, |
|
cache_examples=True, |
|
) |
|
|
|
with connect(demo): |
|
prediction = examples.load_from_cache(0) |
|
path_to_cached_file = Path(prediction[0].path) |
|
assert utils.is_in_or_equal(path_to_cached_file, temp_dir) |
|
monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False) |
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
|
class TestExamplesDataset: |
|
def test_no_headers(self, patched_cache_folder): |
|
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()]) |
|
assert examples.dataset.headers == [] |
|
|
|
def test_all_headers(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
"test/test_files/images_log", |
|
[gr.Image(label="im"), gr.Text(label="your text")], |
|
) |
|
assert examples.dataset.headers == ["im", "your text"] |
|
|
|
def test_some_headers(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()] |
|
) |
|
assert examples.dataset.headers == ["im", ""] |
|
|
|
def test_example_labels(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
examples=[ |
|
[5, "add", 3], |
|
[4, "divide", 2], |
|
[-4, "multiply", 2.5], |
|
[0, "subtract", 1.2], |
|
], |
|
inputs=[ |
|
gr.Number(), |
|
gr.Radio(["add", "divide", "multiply", "subtract"]), |
|
gr.Number(), |
|
], |
|
example_labels=["add", "divide", "multiply", "subtract"], |
|
) |
|
assert examples.dataset.sample_labels == [ |
|
"add", |
|
"divide", |
|
"multiply", |
|
"subtract", |
|
] |
|
|
|
|
|
def test_example_caching_relaunch(connect): |
|
def combine(a, b): |
|
return a + " " + b |
|
|
|
with gr.Blocks() as demo: |
|
txt = gr.Textbox(label="Input") |
|
txt_2 = gr.Textbox(label="Input 2") |
|
txt_3 = gr.Textbox(value="", label="Output") |
|
btn = gr.Button(value="Submit") |
|
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) |
|
gr.Examples( |
|
[["hi", "Adam"], ["hello", "Eve"]], |
|
[txt, txt_2], |
|
txt_3, |
|
combine, |
|
cache_examples=True, |
|
api_name="examples", |
|
) |
|
|
|
with connect(demo) as client: |
|
assert client.predict(1, api_name="/examples") == ( |
|
"hello", |
|
"Eve", |
|
"hello Eve", |
|
) |
|
|
|
|
|
time.sleep(1) |
|
|
|
with connect(demo) as client: |
|
assert client.predict(1, api_name="/examples") == ( |
|
"hello", |
|
"Eve", |
|
"hello Eve", |
|
) |
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
|
class TestProcessExamples: |
|
def test_caching(self, patched_cache_folder, connect): |
|
io = gr.Interface( |
|
lambda x: f"Hello {x}", |
|
"text", |
|
"text", |
|
examples=[["World"], ["Dunya"], ["Monde"]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(1) |
|
assert prediction[0] == "Hello Dunya" |
|
|
|
def test_example_caching_relaunch(self, patched_cache_folder, connect): |
|
def combine(a, b): |
|
return a + " " + b |
|
|
|
with gr.Blocks() as demo: |
|
txt = gr.Textbox(label="Input") |
|
txt_2 = gr.Textbox(label="Input 2") |
|
txt_3 = gr.Textbox(value="", label="Output") |
|
btn = gr.Button(value="Submit") |
|
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) |
|
gr.Examples( |
|
[["hi", "Adam"], ["hello", "Eve"]], |
|
[txt, txt_2], |
|
txt_3, |
|
combine, |
|
cache_examples=True, |
|
api_name="examples", |
|
) |
|
|
|
with connect(demo) as client: |
|
assert client.predict(1, api_name="/examples") == ( |
|
"hello", |
|
"Eve", |
|
"hello Eve", |
|
) |
|
|
|
with connect(demo) as client: |
|
assert client.predict(1, api_name="/examples") == ( |
|
"hello", |
|
"Eve", |
|
"hello Eve", |
|
) |
|
|
|
def test_caching_image(self, patched_cache_folder, connect): |
|
io = gr.Interface( |
|
lambda x: x, |
|
"image", |
|
"image", |
|
examples=[["test/test_files/bus.png"]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction[0].path.endswith(".webp") |
|
|
|
def test_caching_audio(self, patched_cache_folder, connect): |
|
io = gr.Interface( |
|
lambda x: x, |
|
"audio", |
|
"audio", |
|
examples=[["test/test_files/audio_sample.wav"]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
file = prediction[0].path |
|
assert client_utils.encode_url_or_file_to_base64(file).startswith( |
|
"data:audio/wav;base64,UklGRgA/" |
|
) |
|
|
|
def test_caching_with_update(self, patched_cache_folder, connect): |
|
io = gr.Interface( |
|
lambda x: gr.update(visible=False), |
|
"text", |
|
"image", |
|
examples=[["World"], ["Dunya"], ["Monde"]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(1) |
|
assert prediction[0] == { |
|
"visible": False, |
|
"__type__": "update", |
|
} |
|
|
|
def test_caching_with_mix_update(self, patched_cache_folder, connect): |
|
io = gr.Interface( |
|
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"], |
|
"text", |
|
["text", "image"], |
|
examples=[["World"], ["Dunya"], ["Monde"]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(1) |
|
assert prediction[0] == { |
|
"lines": 4, |
|
"value": "hello", |
|
"__type__": "update", |
|
} |
|
|
|
def test_caching_with_dict(self, patched_cache_folder, connect): |
|
text = gr.Textbox() |
|
out = gr.Label() |
|
|
|
io = gr.Interface( |
|
lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"}, |
|
"textbox", |
|
[text, out], |
|
examples=["abc"], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction == [ |
|
{"lines": 4, "__type__": "update", "interactive": False}, |
|
gr.Label.data_model(**{"label": "lion", "confidences": None}), |
|
] |
|
|
|
def test_caching_with_generators(self, patched_cache_folder, connect): |
|
def test_generator(x): |
|
for y in range(len(x)): |
|
yield "Your output: " + x[: y + 1] |
|
|
|
io = gr.Interface( |
|
test_generator, |
|
"textbox", |
|
"textbox", |
|
examples=["abcdef"], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction[0] == "Your output: abcdef" |
|
|
|
def test_caching_with_generators_and_streamed_output( |
|
self, patched_cache_folder, connect |
|
): |
|
file_dir = Path(Path(__file__).parent, "test_files") |
|
audio = str(file_dir / "audio_sample.wav") |
|
|
|
def test_generator(x): |
|
for y in range(int(x)): |
|
yield audio, y * 5 |
|
|
|
io = gr.Interface( |
|
test_generator, |
|
"number", |
|
[gr.Audio(streaming=True), "number"], |
|
examples=[3], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
len_input_audio = len(AudioSegment.from_file(audio)) |
|
len_output_audio = len(AudioSegment.from_file(prediction[0].path)) |
|
length_ratio = len_output_audio / len_input_audio |
|
assert 3 <= round(length_ratio, 1) < 4 |
|
assert float(prediction[1]) == 10.0 |
|
|
|
def test_caching_with_async_generators(self, patched_cache_folder, connect): |
|
async def test_generator(x): |
|
for y in range(len(x)): |
|
yield "Your output: " + x[: y + 1] |
|
|
|
io = gr.Interface( |
|
test_generator, |
|
"textbox", |
|
"textbox", |
|
examples=["abcdef"], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction[0] == "Your output: abcdef" |
|
|
|
@pytest.mark.asyncio |
|
async def test_raise_helpful_error_message_if_providing_partial_examples( |
|
self, patched_cache_folder, tmp_path |
|
): |
|
def foo(a, b): |
|
return a + b |
|
|
|
with pytest.warns( |
|
UserWarning, |
|
match="^Examples will be cached but not all input components have", |
|
): |
|
with pytest.raises(Exception): |
|
io = gr.Interface( |
|
foo, |
|
inputs=["text", "text"], |
|
outputs=["text"], |
|
examples=[["foo"], ["bar"]], |
|
cache_examples=True, |
|
) |
|
await io.examples_handler._start_caching() |
|
|
|
with pytest.warns( |
|
UserWarning, |
|
match="^Examples will be cached but not all input components have", |
|
): |
|
with pytest.raises(Exception): |
|
io = gr.Interface( |
|
foo, |
|
inputs=["text", "text"], |
|
outputs=["text"], |
|
examples=[["foo", "bar"], ["bar", None]], |
|
cache_examples=True, |
|
) |
|
await io.examples_handler._start_caching() |
|
|
|
def foo_no_exception(a, b=2): |
|
return a * b |
|
|
|
gr.Interface( |
|
foo_no_exception, |
|
inputs=["text", "number"], |
|
outputs=["text"], |
|
examples=[["foo"], ["bar"]], |
|
cache_examples=True, |
|
) |
|
|
|
def many_missing(a, b, c): |
|
return a * b |
|
|
|
with pytest.warns( |
|
UserWarning, |
|
match="^Examples will be cached but not all input components have", |
|
): |
|
with pytest.raises(Exception): |
|
io = gr.Interface( |
|
many_missing, |
|
inputs=["text", "number", "number"], |
|
outputs=["text"], |
|
examples=[["foo", None, None], ["bar", 2, 3]], |
|
cache_examples=True, |
|
) |
|
await io.examples_handler._start_caching() |
|
|
|
def test_caching_with_batch(self, patched_cache_folder, connect): |
|
def trim_words(words, lens): |
|
trimmed_words = [ |
|
word[:length] for word, length in zip(words, lens, strict=False) |
|
] |
|
return [trimmed_words] |
|
|
|
io = gr.Interface( |
|
trim_words, |
|
["textbox", gr.Number(precision=0)], |
|
["textbox"], |
|
batch=True, |
|
max_batch_size=16, |
|
examples=[["hello", 3], ["hi", 4]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction == ["hel"] |
|
|
|
def test_caching_with_batch_multiple_outputs(self, patched_cache_folder, connect): |
|
def trim_words(words, lens): |
|
trimmed_words = [ |
|
word[:length] for word, length in zip(words, lens, strict=False) |
|
] |
|
return trimmed_words, lens |
|
|
|
io = gr.Interface( |
|
trim_words, |
|
["textbox", gr.Number(precision=0)], |
|
["textbox", gr.Number(precision=0)], |
|
batch=True, |
|
max_batch_size=16, |
|
examples=[["hello", 3], ["hi", 4]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction == ["hel", "3"] |
|
|
|
def test_caching_with_non_io_component(self, patched_cache_folder, connect): |
|
def predict(name): |
|
return name, gr.update(visible=True) |
|
|
|
with gr.Blocks() as demo: |
|
t1 = gr.Textbox() |
|
with gr.Column(visible=False) as c: |
|
t2 = gr.Textbox() |
|
|
|
examples = gr.Examples( |
|
[["John"], ["Mary"]], |
|
fn=predict, |
|
inputs=[t1], |
|
outputs=[t2, c], |
|
cache_examples=True, |
|
) |
|
|
|
with connect(demo): |
|
prediction = examples.load_from_cache(0) |
|
assert prediction == ["John", {"visible": True, "__type__": "update"}] |
|
|
|
def test_end_to_end(self, patched_cache_folder): |
|
def concatenate(str1, str2): |
|
return str1 + str2 |
|
|
|
with gr.Blocks() as demo: |
|
t1 = gr.Textbox() |
|
t2 = gr.Textbox() |
|
t1.submit(concatenate, [t1, t2], t2) |
|
|
|
gr.Examples( |
|
[["Hello,", None], ["Michael", None]], |
|
inputs=[t1, t2], |
|
api_name="load_example", |
|
) |
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True) |
|
client = TestClient(app) |
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]}) |
|
assert response.json()["data"] == [ |
|
{ |
|
"lines": 1, |
|
"max_lines": 20, |
|
"show_label": True, |
|
"container": True, |
|
"min_width": 160, |
|
"autofocus": False, |
|
"autoscroll": True, |
|
"elem_classes": [], |
|
"rtl": False, |
|
"show_copy_button": False, |
|
"__type__": "update", |
|
"visible": True, |
|
"value": "Hello,", |
|
"type": "text", |
|
"stop_btn": False, |
|
"submit_btn": False, |
|
} |
|
] |
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]}) |
|
assert response.json()["data"] == [ |
|
{ |
|
"lines": 1, |
|
"max_lines": 20, |
|
"show_label": True, |
|
"container": True, |
|
"min_width": 160, |
|
"autofocus": False, |
|
"autoscroll": True, |
|
"elem_classes": [], |
|
"rtl": False, |
|
"show_copy_button": False, |
|
"__type__": "update", |
|
"visible": True, |
|
"value": "Michael", |
|
"type": "text", |
|
"stop_btn": False, |
|
"submit_btn": False, |
|
} |
|
] |
|
|
|
def test_end_to_end_cache_examples(self, patched_cache_folder): |
|
def concatenate(str1, str2): |
|
return f"{str1} {str2}" |
|
|
|
with gr.Blocks() as demo: |
|
t1 = gr.Textbox() |
|
t2 = gr.Textbox() |
|
t1.submit(concatenate, [t1, t2], t2) |
|
|
|
gr.Examples( |
|
examples=[["Hello,", "World"], ["Michael", "Jordan"]], |
|
inputs=[t1, t2], |
|
outputs=[t2], |
|
fn=concatenate, |
|
cache_examples=True, |
|
api_name="load_example", |
|
) |
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True) |
|
client = TestClient(app) |
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]}) |
|
assert response.json()["data"] == ["Hello,", "World", "Hello, World"] |
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]}) |
|
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"] |
|
|
|
def test_end_to_end_lazy_cache_examples(self, patched_cache_folder): |
|
def image_identity(image, string): |
|
return image |
|
|
|
with gr.Blocks() as demo: |
|
i1 = gr.Image() |
|
t = gr.Textbox() |
|
i2 = gr.Image() |
|
|
|
gr.Examples( |
|
examples=[ |
|
["test/test_files/cheetah1.jpg", "cheetah"], |
|
["test/test_files/bus.png", "bus"], |
|
], |
|
inputs=[i1, t], |
|
outputs=[i2], |
|
fn=image_identity, |
|
cache_examples=True, |
|
cache_mode="lazy", |
|
api_name="load_example", |
|
) |
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True) |
|
client = TestClient(app) |
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]}) |
|
data = response.json()["data"] |
|
assert data[0]["value"]["path"].endswith("cheetah1.jpg") |
|
assert data[1]["value"] == "cheetah" |
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]}) |
|
data = response.json()["data"] |
|
assert data[0]["value"]["path"].endswith("bus.png") |
|
assert data[1]["value"] == "bus" |
|
|
|
|
|
def test_multiple_file_flagging(tmp_path, connect): |
|
with patch("gradio.utils.get_cache_folder", return_value=tmp_path): |
|
io = gr.Interface( |
|
fn=lambda *x: list(x), |
|
inputs=[ |
|
gr.Image(type="filepath", label="frame 1"), |
|
gr.Image(type="filepath", label="frame 2"), |
|
], |
|
outputs=[gr.Files()], |
|
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
|
|
assert len(prediction[0].root) == 2 |
|
assert all(isinstance(d, gr.FileData) for d in prediction[0].root) |
|
|
|
|
|
def test_examples_keep_all_suffixes(tmp_path, connect): |
|
with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())): |
|
file_1 = tmp_path / "foo.bar.txt" |
|
file_1.write_text("file 1") |
|
file_2 = tmp_path / "file_2" |
|
file_2.mkdir(parents=True) |
|
file_2 = file_2 / "foo.bar.txt" |
|
file_2.write_text("file 2") |
|
io = gr.Interface( |
|
fn=lambda x: x.name, |
|
inputs=gr.File(), |
|
outputs=[gr.File()], |
|
examples=[[str(file_1)], [str(file_2)]], |
|
cache_examples=True, |
|
) |
|
with connect(io): |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert Path(prediction[0].path).read_text() == "file 1" |
|
assert prediction[0].orig_name == "foo.bar.txt" |
|
assert prediction[0].path.endswith("foo.bar.txt") |
|
prediction = io.examples_handler.load_from_cache(1) |
|
assert Path(prediction[0].path).read_text() == "file 2" |
|
assert prediction[0].orig_name == "foo.bar.txt" |
|
assert prediction[0].path.endswith("foo.bar.txt") |
|
|
|
|
|
class TestProgressBar: |
|
@pytest.mark.asyncio |
|
async def test_progress_bar(self): |
|
with gr.Blocks() as demo: |
|
name = gr.Textbox() |
|
greeting = gr.Textbox() |
|
button = gr.Button(value="Greet") |
|
|
|
def greet(s, prog=gr.Progress()): |
|
prog(0, desc="start") |
|
time.sleep(0.15) |
|
for _ in prog.tqdm(range(4), unit="iter"): |
|
time.sleep(0.15) |
|
time.sleep(0.15) |
|
for _ in tqdm(["a", "b", "c"], desc="alphabet"): |
|
time.sleep(0.15) |
|
return f"Hello, {s}!" |
|
|
|
button.click(greet, name, greeting) |
|
demo.queue(max_size=1).launch(prevent_thread_lock=True) |
|
assert demo.local_url |
|
|
|
client = grc.Client(demo.local_url) |
|
job = client.submit("Gradio") |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = ( |
|
status.progress_data[0].index if status.progress_data else None, |
|
status.progress_data[0].desc if status.progress_data else None, |
|
) |
|
if update != (None, None) and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
|
|
assert status_updates == [ |
|
(None, "start"), |
|
(0, None), |
|
(1, None), |
|
(2, None), |
|
(3, None), |
|
(4, None), |
|
] |
|
|
|
@pytest.mark.asyncio |
|
async def test_progress_bar_track_tqdm(self): |
|
with gr.Blocks() as demo: |
|
name = gr.Textbox() |
|
greeting = gr.Textbox() |
|
button = gr.Button(value="Greet") |
|
|
|
def greet(s, prog=gr.Progress(track_tqdm=True)): |
|
prog(0, desc="start") |
|
time.sleep(0.15) |
|
for _ in prog.tqdm(range(4), unit="iter"): |
|
time.sleep(0.15) |
|
time.sleep(0.15) |
|
for _ in tqdm(["a", "b", "c"], desc="alphabet"): |
|
time.sleep(0.15) |
|
return f"Hello, {s}!" |
|
|
|
button.click(greet, name, greeting) |
|
demo.queue(max_size=1).launch(prevent_thread_lock=True) |
|
assert demo.local_url |
|
|
|
client = grc.Client(demo.local_url) |
|
job = client.submit("Gradio") |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = ( |
|
status.progress_data[0].index if status.progress_data else None, |
|
status.progress_data[0].desc if status.progress_data else None, |
|
) |
|
if update != (None, None) and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
|
|
assert status_updates == [ |
|
(None, "start"), |
|
(0, None), |
|
(1, None), |
|
(2, None), |
|
(3, None), |
|
(4, None), |
|
(0, "alphabet"), |
|
(1, "alphabet"), |
|
(2, "alphabet"), |
|
] |
|
|
|
@pytest.mark.asyncio |
|
@pytest.mark.flaky(reruns=5) |
|
async def test_progress_bar_track_tqdm_without_iterable(self): |
|
def greet(s, _=gr.Progress(track_tqdm=True)): |
|
with tqdm(total=len(s)) as progress_bar: |
|
for _c in s: |
|
progress_bar.update() |
|
time.sleep(0.1) |
|
return f"Hello, {s}!" |
|
|
|
demo = gr.Interface(greet, "text", "text") |
|
demo.queue().launch(prevent_thread_lock=True) |
|
assert demo.local_url |
|
|
|
client = grc.Client(demo.local_url) |
|
job = client.submit("Gradio") |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = ( |
|
status.progress_data[0].index if status.progress_data else None, |
|
status.progress_data[0].unit if status.progress_data else None, |
|
) |
|
if update != (None, None) and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
|
|
assert status_updates[-1] == (6, "steps") |
|
|
|
@pytest.mark.asyncio |
|
async def test_info_and_warning_alerts(self): |
|
def greet(s): |
|
for _c in s: |
|
gr.Info(f"Letter {_c}") |
|
time.sleep(0.15) |
|
if len(s) < 5: |
|
gr.Warning("Too short!") |
|
time.sleep(0.15) |
|
return f"Hello, {s}!" |
|
|
|
demo = gr.Interface(greet, "text", "text") |
|
demo.queue().launch(prevent_thread_lock=True) |
|
assert demo.local_url |
|
|
|
client = grc.Client(demo.local_url) |
|
job = client.submit("Jon") |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = status.log |
|
if update is not None and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
|
|
assert status_updates == [ |
|
("Letter J", "info"), |
|
("Letter o", "info"), |
|
("Letter n", "info"), |
|
("Too short!", "warning"), |
|
] |
|
|
|
|
|
@pytest.mark.asyncio |
|
@pytest.mark.parametrize("async_handler", [True, False]) |
|
async def test_info_isolation(async_handler: bool): |
|
async def greet_async(name): |
|
await asyncio.sleep(2) |
|
gr.Info(f"Hello {name}") |
|
await asyncio.sleep(1) |
|
return name |
|
|
|
def greet_sync(name): |
|
time.sleep(2) |
|
gr.Info(f"Hello {name}") |
|
time.sleep(1) |
|
return name |
|
|
|
demo = gr.Interface( |
|
greet_async if async_handler else greet_sync, |
|
"text", |
|
"text", |
|
concurrency_limit=2, |
|
) |
|
demo.launch(prevent_thread_lock=True) |
|
|
|
async def session_interaction(name, delay=0): |
|
assert demo.local_url |
|
client = grc.Client(demo.local_url) |
|
job = client.submit(name) |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = status.log |
|
if update is not None and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
return status_updates[-1][0] if status_updates else None |
|
|
|
alice_logs, bob_logs = await asyncio.gather( |
|
session_interaction("Alice"), |
|
session_interaction("Bob", delay=1), |
|
) |
|
|
|
assert alice_logs == "Hello Alice" |
|
assert bob_logs == "Hello Bob" |
|
|
|
|
|
def test_check_event_data_in_cache(): |
|
def get_select_index(evt: gr.SelectData): |
|
return evt.index |
|
|
|
with pytest.raises(gr.Error): |
|
helpers.special_args( |
|
get_select_index, |
|
inputs=[], |
|
event_data=helpers.EventData( |
|
None, |
|
{ |
|
"index": {"path": "foo", "meta": {"_type": "gradio.FileData"}}, |
|
"value": "whatever", |
|
}, |
|
), |
|
) |
|
|