orionweller's picture
updates
9bed19f
raw
history blame
5.16 kB
import sys
import warnings
print("Warning: This application requires specific library versions. Please ensure you have the correct versions installed.")
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")
# Suppress CUDA initialization warning
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Model loading and setup
model_name = "jhu-clsp/FollowIR-7B"
try:
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
except ValueError as e:
print(f"Error loading model or tokenizer: {e}")
print("Please ensure you have the correct versions of transformers and sentencepiece installed.")
sys.exit(1)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
token_false_id = tokenizer.get_vocab()["false"]
token_true_id = tokenizer.get_vocab()["true"]
template = """<s> [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices.
Query: {query}
Document: {text}
Relevant (only output one word, either "true" or "false"): [/INST] """
@spaces.GPU
def check_relevance(query, instruction, passage):
global model
global tokenizer
global template
global token_false_id
global token_true_id
if torch.cuda.is_available():
device = "cuda"
model = model.to(device)
full_query = f"{query} {instruction}"
prompt = template.format(query=full_query, text=passage)
tokens = tokenizer(
[prompt],
padding=True,
truncation=True,
return_tensors="pt",
pad_to_multiple_of=None,
)
for key in tokens:
tokens[key] = tokens[key].to(device)
with torch.no_grad():
batch_scores = model(**tokens).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
score = batch_scores[:, 1].exp().item()
return f"{score:.4f}"
# Example inputs
examples = [
[
"What movies were directed by James Cameron?",
"A relevant document would describe any movie that was directed by James Cameron",
"Avatar: The Way of Water is a 2022 American epic science fiction film co-produced and directed by James Cameron, who co-wrote the screenplay with Rick Jaffa and Amanda Silver."
],
[
"What are the health benefits of green tea?",
"A relevant document would discuss specific health benefits associated with drinking green tea",
"Green tea is rich in polyphenols, which are natural compounds that have health benefits, such as reducing inflammation and helping to fight cancer. Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). Catechins are natural antioxidants that help prevent cell damage and provide other benefits."
],
[
"Who won the Nobel Prize in Physics in 2022?",
"A relevant document would mention the names of the physicists who won the Nobel Prize in Physics in 2022",
"The 2021 Nobel Prize in Physics was awarded jointly to Syukuro Manabe, Klaus Hasselmann and Giorgio Parisi for groundbreaking contributions to our understanding of complex physical systems."
]
]
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# FollowIR Relevance Checker")
gr.Markdown("This app uses the FollowIR-7B model to determine the relevance of a passage to a given query and instruction.")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(label="Query", placeholder="Enter your search query here")
instruction_input = gr.Textbox(label="Instruction", placeholder="Enter additional instructions or criteria")
passage_input = gr.Textbox(label="Passage", placeholder="Enter the passage to check for relevance", lines=5)
submit_button = gr.Button("Check Relevance")
with gr.Column():
output = gr.Textbox(label="Relevance Probability")
gr.Examples(
examples=examples,
inputs=[query_input, instruction_input, passage_input],
outputs=output,
fn=check_relevance,
cache_examples=True,
)
submit_button.click(
check_relevance,
inputs=[query_input, instruction_input, passage_input],
outputs=[output]
)
if __name__ == "__main__":
if np.__version__.startswith("2."):
print("Error: This application is not compatible with NumPy 2.x. Please downgrade to NumPy < 2.0.0.")
sys.exit(1)
demo.launch()