|
from huggingface_hub import InferenceClient |
|
import gradio as gr |
|
import os |
|
import re |
|
import requests |
|
import http.client |
|
import typing |
|
import urllib.request |
|
import vertexai |
|
from vertexai.generative_models import GenerativeModel, Image |
|
|
|
with open(".config/application_default_credentials.json", 'w') as file: |
|
file.write(str(os.getenv('credentials'))) |
|
|
|
vertexai.init(project=os.getenv('project_id')) |
|
model = GenerativeModel("gemini-1.0-pro-vision") |
|
client = InferenceClient("google/gemma-7b-it") |
|
|
|
def extract_image_urls(text): |
|
url_regex = r"(https?:\/\/.*\.(?:png|jpg|jpeg|gif|webp|svg))" |
|
image_urls = re.findall(url_regex, text, flags=re.IGNORECASE) |
|
valid_image_url = "" |
|
for url in image_urls: |
|
try: |
|
response = requests.head(url) |
|
if response.status_code in range(200, 300) and 'image' in response.headers.get('content-type', ''): |
|
valid_image_url = url |
|
except requests.exceptions.RequestException: |
|
pass |
|
return valid_image_url |
|
|
|
def load_image_from_url(image_url: str) -> Image: |
|
with urllib.request.urlopen(image_url) as response: |
|
response = typing.cast(http.client.HTTPResponse, response) |
|
image_bytes = response.read() |
|
return Image.from_bytes(image_bytes) |
|
|
|
def search(url): |
|
image = load_image_from_url(url) |
|
response = model.generate_content([image,"what is shown in this image?"]) |
|
return response.text |
|
|
|
def format_prompt(message, history): |
|
prompt = "" |
|
for user_prompt, bot_response in history: |
|
prompt += f"<start_of_turn>user\n{user_prompt}<end_of_turn>\n" |
|
prompt += f"<start_of_turn>model\n{bot_response}<end_of_turn>\n" |
|
prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" |
|
return prompt |
|
|
|
def generate( |
|
prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, |
|
): |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
image = extract_image_urls(prompt) |
|
if image: |
|
prompt = prompt.replace(image, search(image)) |
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
yield output |
|
return output |
|
|
|
|
|
additional_inputs=[ |
|
gr.Textbox( |
|
label="System Prompt", |
|
max_lines=1, |
|
interactive=True, |
|
), |
|
gr.Slider( |
|
label="Temperature", |
|
value=0.9, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values produce more diverse outputs", |
|
), |
|
gr.Slider( |
|
label="Max new tokens", |
|
value=256, |
|
minimum=0, |
|
maximum=1048, |
|
step=64, |
|
interactive=True, |
|
info="The maximum numbers of new tokens", |
|
), |
|
gr.Slider( |
|
label="Top-p (nucleus sampling)", |
|
value=0.90, |
|
minimum=0.0, |
|
maximum=1, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values sample more low-probability tokens", |
|
), |
|
gr.Slider( |
|
label="Repetition penalty", |
|
value=1.2, |
|
minimum=1.0, |
|
maximum=2.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Penalize repeated tokens", |
|
) |
|
] |
|
|
|
examples=[["I'm planning a vacation to Japan. Can you suggest a one-week itinerary including must-visit places and local cuisines to try?", None, None, None, None, None, ], |
|
["Can you write a short story about a time-traveling detective who solves historical mysteries?", None, None, None, None, None,], |
|
["I'm trying to learn French. Can you provide some common phrases that would be useful for a beginner, along with their pronunciations?", None, None, None, None, None,], |
|
["I have chicken, rice, and bell peppers in my kitchen. Can you suggest an easy recipe I can make with these ingredients?", None, None, None, None, None,], |
|
["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None,], |
|
["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None,], |
|
] |
|
|
|
gr.ChatInterface( |
|
fn=generate, |
|
chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False), |
|
additional_inputs=additional_inputs, |
|
title="Hey Gemini", |
|
description="Gemini Sprint submission by Rishiraj Acharya. Uses Google's Gemini 1.0 Pro Vision multimodal model from Vertex AI with Google's Gemma 7B Instruct model from Hugging Face.", |
|
theme="Soft", |
|
examples=examples, |
|
concurrency_limit=20, |
|
).launch(show_api=False) |