orionweller
commited on
Commit
•
9bed19f
1
Parent(s):
2860a50
updates
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ print(f"Using device: {device}")
|
|
22 |
model_name = "jhu-clsp/FollowIR-7B"
|
23 |
|
24 |
try:
|
25 |
-
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
27 |
except ValueError as e:
|
28 |
print(f"Error loading model or tokenizer: {e}")
|
@@ -42,6 +42,11 @@ Relevant (only output one word, either "true" or "false"): [/INST] """
|
|
42 |
|
43 |
@spaces.GPU
|
44 |
def check_relevance(query, instruction, passage):
|
|
|
|
|
|
|
|
|
|
|
45 |
if torch.cuda.is_available():
|
46 |
device = "cuda"
|
47 |
model = model.to(device)
|
@@ -62,6 +67,7 @@ def check_relevance(query, instruction, passage):
|
|
62 |
|
63 |
with torch.no_grad():
|
64 |
batch_scores = model(**tokens).logits[:, -1, :]
|
|
|
65 |
true_vector = batch_scores[:, token_true_id]
|
66 |
false_vector = batch_scores[:, token_false_id]
|
67 |
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
@@ -70,6 +76,25 @@ def check_relevance(query, instruction, passage):
|
|
70 |
|
71 |
return f"{score:.4f}"
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
# Gradio Interface
|
74 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
75 |
gr.Markdown("# FollowIR Relevance Checker")
|
@@ -85,6 +110,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
85 |
with gr.Column():
|
86 |
output = gr.Textbox(label="Relevance Probability")
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
submit_button.click(
|
89 |
check_relevance,
|
90 |
inputs=[query_input, instruction_input, passage_input],
|
|
|
22 |
model_name = "jhu-clsp/FollowIR-7B"
|
23 |
|
24 |
try:
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
27 |
except ValueError as e:
|
28 |
print(f"Error loading model or tokenizer: {e}")
|
|
|
42 |
|
43 |
@spaces.GPU
|
44 |
def check_relevance(query, instruction, passage):
|
45 |
+
global model
|
46 |
+
global tokenizer
|
47 |
+
global template
|
48 |
+
global token_false_id
|
49 |
+
global token_true_id
|
50 |
if torch.cuda.is_available():
|
51 |
device = "cuda"
|
52 |
model = model.to(device)
|
|
|
67 |
|
68 |
with torch.no_grad():
|
69 |
batch_scores = model(**tokens).logits[:, -1, :]
|
70 |
+
|
71 |
true_vector = batch_scores[:, token_true_id]
|
72 |
false_vector = batch_scores[:, token_false_id]
|
73 |
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
|
|
76 |
|
77 |
return f"{score:.4f}"
|
78 |
|
79 |
+
# Example inputs
|
80 |
+
examples = [
|
81 |
+
[
|
82 |
+
"What movies were directed by James Cameron?",
|
83 |
+
"A relevant document would describe any movie that was directed by James Cameron",
|
84 |
+
"Avatar: The Way of Water is a 2022 American epic science fiction film co-produced and directed by James Cameron, who co-wrote the screenplay with Rick Jaffa and Amanda Silver."
|
85 |
+
],
|
86 |
+
[
|
87 |
+
"What are the health benefits of green tea?",
|
88 |
+
"A relevant document would discuss specific health benefits associated with drinking green tea",
|
89 |
+
"Green tea is rich in polyphenols, which are natural compounds that have health benefits, such as reducing inflammation and helping to fight cancer. Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). Catechins are natural antioxidants that help prevent cell damage and provide other benefits."
|
90 |
+
],
|
91 |
+
[
|
92 |
+
"Who won the Nobel Prize in Physics in 2022?",
|
93 |
+
"A relevant document would mention the names of the physicists who won the Nobel Prize in Physics in 2022",
|
94 |
+
"The 2021 Nobel Prize in Physics was awarded jointly to Syukuro Manabe, Klaus Hasselmann and Giorgio Parisi for groundbreaking contributions to our understanding of complex physical systems."
|
95 |
+
]
|
96 |
+
]
|
97 |
+
|
98 |
# Gradio Interface
|
99 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
100 |
gr.Markdown("# FollowIR Relevance Checker")
|
|
|
110 |
with gr.Column():
|
111 |
output = gr.Textbox(label="Relevance Probability")
|
112 |
|
113 |
+
gr.Examples(
|
114 |
+
examples=examples,
|
115 |
+
inputs=[query_input, instruction_input, passage_input],
|
116 |
+
outputs=output,
|
117 |
+
fn=check_relevance,
|
118 |
+
cache_examples=True,
|
119 |
+
)
|
120 |
+
|
121 |
submit_button.click(
|
122 |
check_relevance,
|
123 |
inputs=[query_input, instruction_input, passage_input],
|