|
import time |
|
from concurrent.futures import wait |
|
|
|
import gradio_client as grc |
|
import pytest |
|
from fastapi.testclient import TestClient |
|
|
|
import gradio as gr |
|
from gradio.route_utils import API_PREFIX |
|
|
|
|
|
class TestQueueing: |
|
def test_single_request(self, connect): |
|
with gr.Blocks() as demo: |
|
name = gr.Textbox() |
|
output = gr.Textbox() |
|
|
|
def greet(x): |
|
return f"Hello, {x}!" |
|
|
|
name.submit(greet, name, output) |
|
|
|
with connect(demo) as client: |
|
job = client.submit("x", fn_index=0) |
|
assert job.result() == "Hello, x!" |
|
|
|
def test_all_status_messages(self, connect): |
|
with gr.Blocks() as demo: |
|
name = gr.Textbox() |
|
output = gr.Textbox() |
|
|
|
def greet(x): |
|
time.sleep(2) |
|
return f"Hello, {x}!" |
|
|
|
name.submit(greet, name, output, concurrency_limit=2) |
|
|
|
app, local_url, _ = demo.launch(prevent_thread_lock=True) |
|
test_client = TestClient(app) |
|
client = grc.Client(local_url) |
|
|
|
client.submit("a", fn_index=0) |
|
job2 = client.submit("b", fn_index=0) |
|
client.submit("c", fn_index=0) |
|
job4 = client.submit("d", fn_index=0) |
|
|
|
sizes = [] |
|
while job4.status().code.value != "FINISHED": |
|
queue_status = test_client.get(f"{API_PREFIX}/queue/status").json() |
|
queue_size = queue_status["queue_size"] |
|
if len(sizes) == 0 or queue_size != sizes[-1]: |
|
sizes.append(queue_size) |
|
time.sleep(0.01) |
|
|
|
time.sleep(0.1) |
|
queue_status = test_client.get(f"{API_PREFIX}/queue/status").json() |
|
queue_size = queue_status["queue_size"] |
|
if queue_size != sizes[-1]: |
|
sizes.append(queue_size) |
|
|
|
assert ( |
|
max(sizes) |
|
in [ |
|
2, |
|
3, |
|
4, |
|
] |
|
) |
|
|
|
assert min(sizes) == 0 |
|
assert sizes[-1] == 0 |
|
|
|
assert job2.result() == "Hello, b!" |
|
assert job4.result() == "Hello, d!" |
|
|
|
@pytest.mark.parametrize( |
|
"default_concurrency_limit, statuses", |
|
[ |
|
("not_set", ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]), |
|
(None, ["PROCESSING", "PROCESSING", "PROCESSING"]), |
|
(1, ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]), |
|
(2, ["IN_QUEUE", "PROCESSING", "PROCESSING"]), |
|
], |
|
) |
|
def test_default_concurrency_limits(self, default_concurrency_limit, statuses): |
|
with gr.Blocks() as demo: |
|
a = gr.Number() |
|
b = gr.Number() |
|
output = gr.Number() |
|
|
|
add_btn = gr.Button("Add") |
|
|
|
@add_btn.click(inputs=[a, b], outputs=output) |
|
def add(x, y): |
|
time.sleep(4) |
|
return x + y |
|
|
|
demo.queue(default_concurrency_limit=default_concurrency_limit) |
|
_, local_url, _ = demo.launch( |
|
prevent_thread_lock=True, |
|
) |
|
client = grc.Client(local_url) |
|
|
|
add_job_1 = client.submit(1, 1, fn_index=0) |
|
add_job_2 = client.submit(1, 1, fn_index=0) |
|
add_job_3 = client.submit(1, 1, fn_index=0) |
|
|
|
time.sleep(2) |
|
|
|
add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()] |
|
assert sorted([s.code.value for s in add_job_statuses]) == statuses |
|
|
|
def test_concurrency_limits(self, connect): |
|
with gr.Blocks() as demo: |
|
a = gr.Number() |
|
b = gr.Number() |
|
output = gr.Number() |
|
|
|
add_btn = gr.Button("Add") |
|
|
|
@add_btn.click(inputs=[a, b], outputs=output, concurrency_limit=2) |
|
def add(x, y): |
|
time.sleep(2) |
|
return x + y |
|
|
|
sub_btn = gr.Button("Subtract") |
|
|
|
@sub_btn.click(inputs=[a, b], outputs=output, concurrency_limit=None) |
|
def sub(x, y): |
|
time.sleep(2) |
|
return x - y |
|
|
|
mul_btn = gr.Button("Multiply") |
|
|
|
@mul_btn.click( |
|
inputs=[a, b], |
|
outputs=output, |
|
concurrency_limit=2, |
|
concurrency_id="muldiv", |
|
) |
|
def mul(x, y): |
|
time.sleep(2) |
|
return x * y |
|
|
|
div_btn = gr.Button("Divide") |
|
|
|
@div_btn.click( |
|
inputs=[a, b], |
|
outputs=output, |
|
concurrency_limit=2, |
|
concurrency_id="muldiv", |
|
) |
|
def div(x, y): |
|
time.sleep(2) |
|
return x / y |
|
|
|
with connect(demo) as client: |
|
add_job_1 = client.submit(1, 1, fn_index=0) |
|
add_job_2 = client.submit(1, 1, fn_index=0) |
|
add_job_3 = client.submit(1, 1, fn_index=0) |
|
sub_job_1 = client.submit(1, 1, fn_index=1) |
|
sub_job_2 = client.submit(1, 1, fn_index=1) |
|
sub_job_3 = client.submit(1, 1, fn_index=1) |
|
mul_job_1 = client.submit(1, 1, fn_index=2) |
|
div_job_1 = client.submit(1, 1, fn_index=3) |
|
mul_job_2 = client.submit(1, 1, fn_index=2) |
|
|
|
time.sleep(2) |
|
|
|
add_job_statuses = [ |
|
add_job_1.status(), |
|
add_job_2.status(), |
|
add_job_3.status(), |
|
] |
|
assert sorted([s.code.value for s in add_job_statuses]) == [ |
|
"IN_QUEUE", |
|
"PROCESSING", |
|
"PROCESSING", |
|
] |
|
|
|
sub_job_statuses = [ |
|
sub_job_1.status(), |
|
sub_job_2.status(), |
|
sub_job_3.status(), |
|
] |
|
assert [s.code.value for s in sub_job_statuses] == [ |
|
"PROCESSING", |
|
"PROCESSING", |
|
"PROCESSING", |
|
] |
|
|
|
muldiv_job_statuses = [ |
|
mul_job_1.status(), |
|
div_job_1.status(), |
|
mul_job_2.status(), |
|
] |
|
assert sorted([s.code.value for s in muldiv_job_statuses]) == [ |
|
"IN_QUEUE", |
|
"PROCESSING", |
|
"PROCESSING", |
|
] |
|
wait( |
|
[ |
|
add_job_1, |
|
add_job_2, |
|
add_job_3, |
|
sub_job_1, |
|
sub_job_2, |
|
sub_job_3, |
|
sub_job_3, |
|
mul_job_1, |
|
div_job_1, |
|
mul_job_2, |
|
] |
|
) |
|
|