heygemini / app.py
rishiraj's picture
Update app.py
69601bd verified
raw
history blame
5.19 kB
from huggingface_hub import InferenceClient
import gradio as gr
import os
import re
import requests
import http.client
import typing
import urllib.request
import vertexai
from vertexai.generative_models import GenerativeModel, Image
with open(".config/application_default_credentials.json", 'w') as file:
file.write(str(os.getenv('credentials')))
vertexai.init(project=os.getenv('project_id'))
model = GenerativeModel("gemini-1.0-pro-vision")
client = InferenceClient("google/gemma-7b-it")
def extract_image_urls(text):
url_regex = r"(https?:\/\/.*\.(?:png|jpg|jpeg|gif|webp|svg))"
image_urls = re.findall(url_regex, text, flags=re.IGNORECASE)
valid_image_url = ""
for url in image_urls:
try:
response = requests.head(url) # Use HEAD request for efficiency
if response.status_code in range(200, 300) and 'image' in response.headers.get('content-type', ''):
valid_image_url = url
except requests.exceptions.RequestException:
pass # Ignore inaccessible URLs
return valid_image_url
def load_image_from_url(image_url: str) -> Image:
with urllib.request.urlopen(image_url) as response:
response = typing.cast(http.client.HTTPResponse, response)
image_bytes = response.read()
return Image.from_bytes(image_bytes)
def search(url):
image = load_image_from_url(url)
response = model.generate_content([image,"what is shown in this image?"])
return response.text
def format_prompt(message, history):
prompt = ""
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user\n{user_prompt}<end_of_turn>\n"
prompt += f"<start_of_turn>model\n{bot_response}<end_of_turn>\n"
prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
return prompt
def generate(
prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
image = extract_image_urls(prompt)
if image:
prompt = prompt.replace(image, search(image))
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
additional_inputs=[
gr.Textbox(
label="System Prompt",
max_lines=1,
interactive=True,
),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=1048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
examples=[["I'm planning a vacation to Japan. Can you suggest a one-week itinerary including must-visit places and local cuisines to try?", None, None, None, None, None, ],
["Can you write a short story about a time-traveling detective who solves historical mysteries?", None, None, None, None, None,],
["I'm trying to learn French. Can you provide some common phrases that would be useful for a beginner, along with their pronunciations?", None, None, None, None, None,],
["I have chicken, rice, and bell peppers in my kitchen. Can you suggest an easy recipe I can make with these ingredients?", None, None, None, None, None,],
["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None,],
["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None,],
]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
additional_inputs=additional_inputs,
title="Hey Gemini",
description="Gemini Sprint submission by Rishiraj Acharya. Uses Google's Gemini 1.0 Pro Vision multimodal model from Vertex AI with Google's Gemma 7B Instruct model from Hugging Face.",
theme="Soft",
examples=examples,
concurrency_limit=20,
).launch(show_api=False)