import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Model loading and setup model_name = "jhu-clsp/FollowIR-7B" model = AutoModelForCausalLM.from_pretrained(model_name).cuda() tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" token_false_id = tokenizer.get_vocab()["false"] token_true_id = tokenizer.get_vocab()["true"] template = """ [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices. Query: {query} Document: {text} Relevant (only output one word, either "true" or "false"): [/INST] """ def check_relevance(query, instruction, passage): full_query = f"{query} {instruction}" prompt = template.format(query=full_query, text=passage) tokens = tokenizer( [prompt], padding=True, truncation=True, return_tensors="pt", pad_to_multiple_of=None, ) for key in tokens: tokens[key] = tokens[key].cuda() batch_scores = model(**tokens).logits[:, -1, :] true_vector = batch_scores[:, token_true_id] false_vector = batch_scores[:, token_false_id] batch_scores = torch.stack([false_vector, true_vector], dim=1) batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) score = batch_scores[:, 1].exp().item() return f"{score:.4f}" # Gradio Interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# FollowIR Relevance Checker") gr.Markdown("This app uses the FollowIR-7B model to determine the relevance of a passage to a given query and instruction.") with gr.Row(): with gr.Column(): query_input = gr.Textbox(label="Query", placeholder="Enter your search query here") instruction_input = gr.Textbox(label="Instruction", placeholder="Enter additional instructions or criteria") passage_input = gr.Textbox(label="Passage", placeholder="Enter the passage to check for relevance", lines=5) submit_button = gr.Button("Check Relevance") with gr.Column(): output = gr.Textbox(label="Relevance Probability") submit_button.click( check_relevance, inputs=[query_input, instruction_input, passage_input], outputs=[output] ) demo.launch()