import tempfile from concurrent.futures import wait from pathlib import Path from unittest.mock import patch import pytest import gradio as gr def invalid_fn(message): return message def double(message, history): return message + " " + message async def async_greet(message, history): return "hi, " + message def stream(message, history): for i in range(len(message)): yield message[: i + 1] async def async_stream(message, history): for i in range(len(message)): yield message[: i + 1] def count(message, history): return str(len(history)) def echo_system_prompt_plus_message(message, history, system_prompt, tokens): response = f"{system_prompt} {message}" for i in range(min(len(response), int(tokens))): yield response[: i + 1] class TestInit: def test_no_fn(self): with pytest.raises(TypeError): gr.ChatInterface() # type: ignore def test_concurrency_limit(self): chat = gr.ChatInterface(double, concurrency_limit=10) assert chat.concurrency_limit == 10 fns = [ fn for fn in chat.fns.values() if fn.name in {"_submit_fn", "_api_submit_fn"} ] assert all(fn.concurrency_limit == 10 for fn in fns) def test_custom_textbox(self): def chat(): return "Hello" gr.ChatInterface( chat, chatbot=gr.Chatbot(height=400), textbox=gr.Textbox(placeholder="Type Message", container=False, scale=7), title="Test", ) gr.ChatInterface( chat, chatbot=gr.Chatbot(height=400), textbox=gr.MultimodalTextbox(container=False, scale=7), title="Test", ) def test_events_attached(self): chatbot = gr.ChatInterface(double) dependencies = chatbot.fns.values() textbox = chatbot.textbox._id assert next( (d for d in dependencies if d.targets == [(textbox, "submit")]), None, ) def test_example_caching(self, connect): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( double, examples=["hello", "hi"], cache_examples=True ) with connect(chatbot): prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "hello hello") assert prediction_hi[0].root[0] == ("hi", "hi hi") @pytest.mark.asyncio async def test_example_caching_lazy(self, connect): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( double, examples=["hello", "hi"], cache_examples=True, cache_mode="lazy", ) async for _ in chatbot.examples_handler.async_lazy_cache( (0, ["hello"]), "hello" ): pass with connect(chatbot): prediction_hello = chatbot.examples_handler.load_from_cache(0) assert prediction_hello[0].root[0] == ("hello", "hello hello") with pytest.raises(IndexError): prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hi[0].root[0] == ("hi", "hi hi") def test_example_caching_async(self, connect): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( async_greet, examples=["abubakar", "tom"], cache_examples=True ) with connect(chatbot): prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar") assert prediction_hi[0].root[0] == ("tom", "hi, tom") def test_example_caching_with_streaming(self, connect): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( stream, examples=["hello", "hi"], cache_examples=True ) with connect(chatbot): prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "hello") assert prediction_hi[0].root[0] == ("hi", "hi") def test_example_caching_with_streaming_async(self, connect): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( async_stream, examples=["hello", "hi"], cache_examples=True ) with connect(chatbot): prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "hello") assert prediction_hi[0].root[0] == ("hi", "hi") def test_default_accordion_params(self): chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=["textbox", "slider"], ) accordion = [ comp for comp in chatbot.blocks.values() if comp.get_config().get("name") == "accordion" ][0] assert accordion.get_config().get("open") is False assert accordion.get_config().get("label") == "Additional Inputs" def test_setting_accordion_params(self, monkeypatch): chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=["textbox", "slider"], additional_inputs_accordion=gr.Accordion(open=True, label="MOAR"), ) accordion = [ comp for comp in chatbot.blocks.values() if comp.get_config().get("name") == "accordion" ][0] assert accordion.get_config().get("open") is True assert accordion.get_config().get("label") == "MOAR" def test_example_caching_with_additional_inputs(self, monkeypatch, connect): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=["textbox", "slider"], examples=[["hello", "robot", 100], ["hi", "robot", 2]], cache_examples=True, ) with connect(chatbot): prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "robot hello") assert prediction_hi[0].root[0] == ("hi", "ro") def test_example_caching_with_additional_inputs_already_rendered( self, monkeypatch, connect ): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): with gr.Blocks() as demo: with gr.Accordion("Inputs"): text = gr.Textbox() slider = gr.Slider() chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=[text, slider], examples=[["hello", "robot", 100], ["hi", "robot", 2]], cache_examples=True, ) with connect(demo): prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "robot hello") assert prediction_hi[0].root[0] == ("hi", "ro") def test_custom_chatbot_with_events(self): with gr.Blocks() as demo: chatbot = gr.Chatbot() chatbot.like(lambda: None, None, None) gr.ChatInterface(fn=lambda x, y: x, chatbot=chatbot) dependencies = demo.fns.values() assert next( (d for d in dependencies if d.targets == [(chatbot._id, "like")]), None, ) class TestAPI: def test_get_api_info(self): chatbot = gr.ChatInterface(double) api_info = chatbot.get_api_info() assert api_info assert len(api_info["named_endpoints"]) == 1 assert len(api_info["unnamed_endpoints"]) == 0 assert "/chat" in api_info["named_endpoints"] @pytest.mark.parametrize("type", ["tuples", "messages"]) def test_streaming_api(self, type, connect): chatbot = gr.ChatInterface(stream, type=type).queue() with connect(chatbot) as client: job = client.submit("hello") wait([job]) assert job.outputs() == ["h", "he", "hel", "hell", "hello"] @pytest.mark.parametrize("type", ["tuples", "messages"]) def test_streaming_api_async(self, type, connect): chatbot = gr.ChatInterface(async_stream, type=type).queue() with connect(chatbot) as client: job = client.submit("hello") wait([job]) assert job.outputs() == ["h", "he", "hel", "hell", "hello"] @pytest.mark.parametrize("type", ["tuples", "messages"]) def test_non_streaming_api(self, type, connect): chatbot = gr.ChatInterface(double, type=type) with connect(chatbot) as client: result = client.predict("hello") assert result == "hello hello" @pytest.mark.parametrize("type", ["tuples", "messages"]) def test_non_streaming_api_async(self, type, connect): chatbot = gr.ChatInterface(async_greet, type=type) with connect(chatbot) as client: result = client.predict("gradio") assert result == "hi, gradio" @pytest.mark.parametrize("type", ["tuples", "messages"]) def test_streaming_api_with_additional_inputs(self, type, connect): chatbot = gr.ChatInterface( echo_system_prompt_plus_message, type=type, additional_inputs=["textbox", "slider"], ).queue() with connect(chatbot) as client: job = client.submit("hello", "robot", 7) wait([job]) assert job.outputs() == [ "r", "ro", "rob", "robo", "robot", "robot ", "robot h", ] @pytest.mark.parametrize("type", ["tuples", "messages"]) def test_multimodal_api(self, type, connect): def double_multimodal(msg, history): return msg["text"] + " " + msg["text"] chatbot = gr.ChatInterface( double_multimodal, type=type, multimodal=True, ) with connect(chatbot) as client: result = client.predict({"text": "hello", "files": []}, api_name="/chat") assert result == "hello hello"