|
import sys |
|
import warnings |
|
print("Warning: This application requires specific library versions. Please ensure you have the correct versions installed.") |
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import numpy as np |
|
|
|
print(f"NumPy version: {np.__version__}") |
|
print(f"PyTorch version: {torch.__version__}") |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
model_name = "jhu-clsp/FollowIR-7B" |
|
|
|
try: |
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") |
|
except ValueError as e: |
|
print(f"Error loading model or tokenizer: {e}") |
|
print("Please ensure you have the correct versions of transformers and sentencepiece installed.") |
|
sys.exit(1) |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "left" |
|
token_false_id = tokenizer.get_vocab()["false"] |
|
token_true_id = tokenizer.get_vocab()["true"] |
|
|
|
template = """<s> [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices. |
|
|
|
Query: {query} |
|
Document: {text} |
|
Relevant (only output one word, either "true" or "false"): [/INST] """ |
|
|
|
def check_relevance(query, instruction, passage): |
|
full_query = f"{query} {instruction}" |
|
prompt = template.format(query=full_query, text=passage) |
|
|
|
tokens = tokenizer( |
|
[prompt], |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
pad_to_multiple_of=None, |
|
) |
|
|
|
for key in tokens: |
|
tokens[key] = tokens[key].to(device) |
|
|
|
with torch.no_grad(): |
|
batch_scores = model(**tokens).logits[:, -1, :] |
|
true_vector = batch_scores[:, token_true_id] |
|
false_vector = batch_scores[:, token_false_id] |
|
batch_scores = torch.stack([false_vector, true_vector], dim=1) |
|
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) |
|
score = batch_scores[:, 1].exp().item() |
|
|
|
return f"{score:.4f}" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# FollowIR Relevance Checker") |
|
gr.Markdown("This app uses the FollowIR-7B model to determine the relevance of a passage to a given query and instruction.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
query_input = gr.Textbox(label="Query", placeholder="Enter your search query here") |
|
instruction_input = gr.Textbox(label="Instruction", placeholder="Enter additional instructions or criteria") |
|
passage_input = gr.Textbox(label="Passage", placeholder="Enter the passage to check for relevance", lines=5) |
|
submit_button = gr.Button("Check Relevance") |
|
|
|
with gr.Column(): |
|
output = gr.Textbox(label="Relevance Probability") |
|
|
|
submit_button.click( |
|
check_relevance, |
|
inputs=[query_input, instruction_input, passage_input], |
|
outputs=[output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
if np.__version__.startswith("2."): |
|
print("Error: This application is not compatible with NumPy 2.x. Please downgrade to NumPy < 2.0.0.") |
|
sys.exit(1) |
|
demo.launch() |