dejanseo commited on
Commit
2d563ee
1 Parent(s): 59e59b8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -5
inference.py CHANGED
@@ -2,10 +2,8 @@ import streamlit as st
2
  import torch
3
  from transformers import BertForTokenClassification, BertTokenizerFast # Import BertTokenizerFast
4
 
5
- def load_model(model_name='linkbert.pth'):
6
- model_path = model_name
7
- model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=2)
8
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
9
  model.eval() # Set the model to inference mode
10
  return model
11
 
@@ -45,7 +43,7 @@ def predict_and_annotate(model, tokenizer, text):
45
  st.title("BERT Token Classification for Anchor Text Prediction")
46
 
47
  # Load the model and tokenizer
48
- model = load_model('linkbert.pth')
49
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') # Use BertTokenizerFast
50
 
51
  # User input text area
 
2
  import torch
3
  from transformers import BertForTokenClassification, BertTokenizerFast # Import BertTokenizerFast
4
 
5
+ def load_model(model_name='dejanseo/LinkBERT'):
6
+ model = BertForTokenClassification.from_pretrained(model_name, num_labels=2)
 
 
7
  model.eval() # Set the model to inference mode
8
  return model
9
 
 
43
  st.title("BERT Token Classification for Anchor Text Prediction")
44
 
45
  # Load the model and tokenizer
46
+ model = load_model('dejanseo/LinkBERT')
47
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') # Use BertTokenizerFast
48
 
49
  # User input text area