File size: 761 Bytes
21ac434 fdd0b8f 78e2935 fdd0b8f 21ac434 78e2935 21ac434 fdd0b8f 21ac434 78e2935 21ac434 78e2935 21ac434 78e2935 21ac434 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import streamlit as st
import torch
from transformers import RobertaTokenizer, RobertaModel
from prediction_sinhala import MDFEND
# Load model and tokenizer
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = RobertaTokenizer.from_pretrained("./prediction_sinhala/")
model = MDFEND.from_pretrained("./prediction_sinhala/")
return model, tokenizer
model, tokenizer = load_model()
# User input
text_input = st.text_area("Enter text here:")
# Prediction
if st.button("Predict"):
inputs = tokenizer(text_input, return_tensors="pt")
with torch.no_grad(): # Ensure no gradients are computed
outputs = model(**inputs)
prediction = outputs.logits.argmax(-1).item()
st.write(f"Prediction: {prediction}")
|