Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,23 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import
|
|
|
3 |
|
|
|
4 |
@st.cache(allow_output_mutation=True)
|
5 |
def load_model():
|
6 |
-
tokenizer =
|
7 |
-
model =
|
8 |
return model, tokenizer
|
9 |
|
10 |
model, tokenizer = load_model()
|
11 |
|
|
|
12 |
text_input = st.text_area("Enter text here:")
|
|
|
|
|
13 |
if st.button("Predict"):
|
14 |
inputs = tokenizer(text_input, return_tensors="pt")
|
15 |
-
|
|
|
16 |
prediction = outputs.logits.argmax(-1).item()
|
17 |
st.write(f"Prediction: {prediction}")
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import RobertaTokenizer, RobertaModel
|
3 |
+
from your_model_file import MDFEND # Ensure you import your model correctly
|
4 |
|
5 |
+
# Load model and tokenizer
|
6 |
@st.cache(allow_output_mutation=True)
|
7 |
def load_model():
|
8 |
+
tokenizer = RobertaTokenizer.from_pretrained("prediction_sinhala.ipynb")
|
9 |
+
model = MDFEND.from_pretrained("prediction_sinhala.ipynb")
|
10 |
return model, tokenizer
|
11 |
|
12 |
model, tokenizer = load_model()
|
13 |
|
14 |
+
# User input
|
15 |
text_input = st.text_area("Enter text here:")
|
16 |
+
|
17 |
+
# Prediction
|
18 |
if st.button("Predict"):
|
19 |
inputs = tokenizer(text_input, return_tensors="pt")
|
20 |
+
with torch.no_grad(): # Ensure no gradients are computed
|
21 |
+
outputs = model(**inputs)
|
22 |
prediction = outputs.logits.argmax(-1).item()
|
23 |
st.write(f"Prediction: {prediction}")
|