Update inference.py
Browse files- inference.py +3 -5
inference.py
CHANGED
@@ -2,10 +2,8 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
from transformers import BertForTokenClassification, BertTokenizerFast # Import BertTokenizerFast
|
4 |
|
5 |
-
def load_model(model_name='
|
6 |
-
|
7 |
-
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=2)
|
8 |
-
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
9 |
model.eval() # Set the model to inference mode
|
10 |
return model
|
11 |
|
@@ -45,7 +43,7 @@ def predict_and_annotate(model, tokenizer, text):
|
|
45 |
st.title("BERT Token Classification for Anchor Text Prediction")
|
46 |
|
47 |
# Load the model and tokenizer
|
48 |
-
model = load_model('
|
49 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') # Use BertTokenizerFast
|
50 |
|
51 |
# User input text area
|
|
|
2 |
import torch
|
3 |
from transformers import BertForTokenClassification, BertTokenizerFast # Import BertTokenizerFast
|
4 |
|
5 |
+
def load_model(model_name='dejanseo/LinkBERT'):
|
6 |
+
model = BertForTokenClassification.from_pretrained(model_name, num_labels=2)
|
|
|
|
|
7 |
model.eval() # Set the model to inference mode
|
8 |
return model
|
9 |
|
|
|
43 |
st.title("BERT Token Classification for Anchor Text Prediction")
|
44 |
|
45 |
# Load the model and tokenizer
|
46 |
+
model = load_model('dejanseo/LinkBERT')
|
47 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') # Use BertTokenizerFast
|
48 |
|
49 |
# User input text area
|