orionweller commited on
Commit
9bed19f
1 Parent(s): 2860a50
Files changed (1) hide show
  1. app.py +34 -1
app.py CHANGED
@@ -22,7 +22,7 @@ print(f"Using device: {device}")
22
  model_name = "jhu-clsp/FollowIR-7B"
23
 
24
  try:
25
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
26
  tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
27
  except ValueError as e:
28
  print(f"Error loading model or tokenizer: {e}")
@@ -42,6 +42,11 @@ Relevant (only output one word, either "true" or "false"): [/INST] """
42
 
43
  @spaces.GPU
44
  def check_relevance(query, instruction, passage):
 
 
 
 
 
45
  if torch.cuda.is_available():
46
  device = "cuda"
47
  model = model.to(device)
@@ -62,6 +67,7 @@ def check_relevance(query, instruction, passage):
62
 
63
  with torch.no_grad():
64
  batch_scores = model(**tokens).logits[:, -1, :]
 
65
  true_vector = batch_scores[:, token_true_id]
66
  false_vector = batch_scores[:, token_false_id]
67
  batch_scores = torch.stack([false_vector, true_vector], dim=1)
@@ -70,6 +76,25 @@ def check_relevance(query, instruction, passage):
70
 
71
  return f"{score:.4f}"
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # Gradio Interface
74
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
  gr.Markdown("# FollowIR Relevance Checker")
@@ -85,6 +110,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
85
  with gr.Column():
86
  output = gr.Textbox(label="Relevance Probability")
87
 
 
 
 
 
 
 
 
 
88
  submit_button.click(
89
  check_relevance,
90
  inputs=[query_input, instruction_input, passage_input],
 
22
  model_name = "jhu-clsp/FollowIR-7B"
23
 
24
  try:
25
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
26
  tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
27
  except ValueError as e:
28
  print(f"Error loading model or tokenizer: {e}")
 
42
 
43
  @spaces.GPU
44
  def check_relevance(query, instruction, passage):
45
+ global model
46
+ global tokenizer
47
+ global template
48
+ global token_false_id
49
+ global token_true_id
50
  if torch.cuda.is_available():
51
  device = "cuda"
52
  model = model.to(device)
 
67
 
68
  with torch.no_grad():
69
  batch_scores = model(**tokens).logits[:, -1, :]
70
+
71
  true_vector = batch_scores[:, token_true_id]
72
  false_vector = batch_scores[:, token_false_id]
73
  batch_scores = torch.stack([false_vector, true_vector], dim=1)
 
76
 
77
  return f"{score:.4f}"
78
 
79
+ # Example inputs
80
+ examples = [
81
+ [
82
+ "What movies were directed by James Cameron?",
83
+ "A relevant document would describe any movie that was directed by James Cameron",
84
+ "Avatar: The Way of Water is a 2022 American epic science fiction film co-produced and directed by James Cameron, who co-wrote the screenplay with Rick Jaffa and Amanda Silver."
85
+ ],
86
+ [
87
+ "What are the health benefits of green tea?",
88
+ "A relevant document would discuss specific health benefits associated with drinking green tea",
89
+ "Green tea is rich in polyphenols, which are natural compounds that have health benefits, such as reducing inflammation and helping to fight cancer. Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). Catechins are natural antioxidants that help prevent cell damage and provide other benefits."
90
+ ],
91
+ [
92
+ "Who won the Nobel Prize in Physics in 2022?",
93
+ "A relevant document would mention the names of the physicists who won the Nobel Prize in Physics in 2022",
94
+ "The 2021 Nobel Prize in Physics was awarded jointly to Syukuro Manabe, Klaus Hasselmann and Giorgio Parisi for groundbreaking contributions to our understanding of complex physical systems."
95
+ ]
96
+ ]
97
+
98
  # Gradio Interface
99
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
100
  gr.Markdown("# FollowIR Relevance Checker")
 
110
  with gr.Column():
111
  output = gr.Textbox(label="Relevance Probability")
112
 
113
+ gr.Examples(
114
+ examples=examples,
115
+ inputs=[query_input, instruction_input, passage_input],
116
+ outputs=output,
117
+ fn=check_relevance,
118
+ cache_examples=True,
119
+ )
120
+
121
  submit_button.click(
122
  check_relevance,
123
  inputs=[query_input, instruction_input, passage_input],