Spaces:
Sleeping
Sleeping
from typing import Optional | |
import gradio as gr | |
from build_rag import get_context | |
from models.gemma import gemma_predict | |
from models.gemini import get_gemini_response | |
def clear(): | |
return None, None, None | |
def generate_text(query_text, model_name: Optional[str] = "google/gemma-2b-it", tokens: Optional[int] = 1024, | |
temp: Optional[float] = 0.49): | |
combined_information = get_context(query_text) | |
gen_config = { | |
"temperature": temp, | |
"max_output_tokens": tokens, | |
} | |
if model_name is None or model_name == "google/gemma-2b-it": | |
return gemma_predict(combined_information, model_name, gen_config) | |
if model_name == "gemini-1.0-pro": | |
return get_gemini_response(combined_information, model_name, gen_config) | |
return "Sorry, something went wrong! Please try again." | |
examples = [["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and " | |
"local cuisines to try?", "google/gemma-2b-it"], | |
["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "gemini-1.0-pro"], | |
["Suggest some cities that can be visited from London and are very rich in history and culture.", | |
"google/gemma-2b-it"], | |
] | |
with gr.Blocks() as demo: | |
gr.HTML("""<center><h1 style='font-size:xx-large;'>πͺπΊ Euro City Recommender using Gemini & Gemma πͺπΊ</h1><br><h3>Gemini | |
& Gemma Sprints 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the compatibility of | |
Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 Pro</b> | |
models through HuggingFace and VertexAI, respectively, to generate travel recommendations. This early version (read | |
quick and dirty implementation) aims to see if functionalities work smoothly. It relies on Wikipedia abstracts | |
from 160 European cities to provide answers to your questions. Please be kind with it, as it's a work in progress! | |
</p> <br>Google Cloud credits are provided for this project. </p> | |
""") | |
with gr.Group(): | |
query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!") | |
model = gr.Dropdown( | |
["google/gemma-2b-it", "gemini-1.0-pro"], label="Model", info="Select your model. Will add more models " | |
"later!", | |
) | |
output = gr.Textbox(label="Generated Results", lines=4) | |
with gr.Accordion("Settings", open=False): | |
max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64, | |
interactive=True, | |
visible=True, info="The maximum number of output tokens") | |
temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49, | |
interactive=True, | |
visible=True, info="The value used to module the logits distribution") | |
with gr.Group(): | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear_btn = gr.Button("Clear", variant="secondary") | |
cancel_btn = gr.Button("Cancel", variant="stop") | |
submit_btn.click(generate_text, inputs=[query, model], outputs=[output]) | |
clear_btn.click(clear, inputs=[], outputs=[query, model, output]) | |
cancel_btn.click(clear, inputs=[], outputs=[query, model, output]) | |
gr.Markdown("## Examples") | |
gr.Examples( | |
examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output], | |
cache_examples=True, | |
) | |
if __name__ == "__main__": | |
demo.launch(show_api=False) | |