BusinessDev commited on
Commit
21c103a
1 Parent(s): fc9195f

finalmaybe

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -11,22 +11,27 @@ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
11
 
12
 
13
  def answer_question(context, question):
14
- # Encode the context and question
15
- inputs = tokenizer(context, question, return_tensors="pt")
 
 
 
 
 
16
 
17
  # Perform question answering
18
  outputs = model(**inputs)
19
 
20
  # Get the predicted start and end token positions
21
- start_scores, end_scores = outputs.start_logits, outputs.end_logits
22
 
23
  # Decode the answer based on predicted positions
24
  answer_start = torch.argmax(start_scores)
25
  answer_end = torch.argmax(end_scores) + 1
26
 
27
- # Get answer tokens and convert them to string
28
  answer = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
29
- answer = "".join(answer)
30
 
31
  return answer
32
 
 
11
 
12
 
13
  def answer_question(context, question):
14
+ """
15
+ This function takes a context and question as input,
16
+ performs question answering using the loaded model,
17
+ and returns the predicted answer.
18
+ """
19
+ # Encode the context and question with special character handling
20
+ inputs = tokenizer(context, question, return_tensors="pt", truncation=True)
21
 
22
  # Perform question answering
23
  outputs = model(**inputs)
24
 
25
  # Get the predicted start and end token positions
26
+ start_scores, end_scores = outputs.start_logits, outputs.end_scores
27
 
28
  # Decode the answer based on predicted positions
29
  answer_start = torch.argmax(start_scores)
30
  answer_end = torch.argmax(end_scores) + 1
31
 
32
+ # Get answer tokens and convert them to string, removing special tokens
33
  answer = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
34
+ answer = "".join(answer[2:-2]) # Remove special tokens ([CLS] and [SEP])
35
 
36
  return answer
37