khizon's picture
use test_sub.jsonl
1e40f75
import os
import json
import numpy as np
import pandas as pd
import random
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
@st.cache(allow_output_mutation=True)
def init_model():
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('khizon/distilbert-unreliable-news-eng-4L', num_labels = 2)
return tokenizer, model
def download_dataset():
url = 'https://drive.google.com/drive/folders/11mRvsHAkggFEJvG4axH4mmWI6FHMQp7X?usp=sharing'
data = 'data/nela_gt_2018_site_split'
os.system(f'gdown --folder {url} -O {data}')
@st.cache(allow_output_mutation=True)
def jsonl_to_df(file_path):
with open(file_path) as f:
lines = f.read().splitlines()
df_inter = pd.DataFrame(lines)
df_inter.columns = ['json_element']
df_inter['json_element'].apply(json.loads)
return pd.json_normalize(df_inter['json_element'].apply(json.loads))
@st.cache
def load_test_df():
file_path = os.path.join('test_sub.jsonl')
test_df = jsonl_to_df(file_path)
test_df = pd.get_dummies(test_df, columns = ['label'])
return test_df
@st.cache(allow_output_mutation=True)
def predict(model, tokenizer, data):
labels = data[['label_0', 'label_1']]
labels = torch.tensor(labels, dtype=torch.float32)
encoding = tokenizer.encode_plus(
data['title'],
' [SEP] ' + data['content'],
add_special_tokens=True,
max_length = 512,
return_token_type_ids = False,
padding = 'max_length',
truncation = 'only_second',
return_attention_mask = True,
return_tensors = 'pt'
)
output = model(**encoding)
return correct_preds(output['logits'], labels)
@st.cache(allow_output_mutation=True)
def predict_new(model, tokenizer, title, content):
encoding = tokenizer.encode_plus(
title,
' [SEP] ' + content,
add_special_tokens=True,
max_length = 512,
return_token_type_ids = False,
padding = 'max_length',
truncation = 'only_second',
return_attention_mask = True,
return_tensors = 'pt'
)
output = model(**encoding)
preds = F.softmax(output['logits'], dim = 1)
p_idx = torch.argmax(preds, dim = 1)
return 'reliable' if p_idx > 0 else 'unreliable'
def correct_preds(preds, labels):
preds = torch.nn.functional.softmax(preds, dim = 1)
p_idx = torch.argmax(preds, dim=1)
l_idx = torch.argmax(labels, dim=0)
pred_label = 'reliable' if p_idx > 0 else 'unreliable'
correct = True if (p_idx == l_idx).sum().item() > 0 else False
return pred_label, correct
if __name__ == '__main__':
df = load_test_df()
tokenizer, model = init_model()
st.title("Unreliable News classifier")
mode = st.radio(
'', ('Test article', 'Input own article')
)
if mode == 'Test article':
if st.button('Get random article'):
idx = np.random.randint(0, len(df))
sample = df.iloc[idx]
prediction, correct = predict(model, tokenizer, sample)
label = 'reliable' if sample['label_1'] > sample['label_0'] else 'unreliable'
st.header(sample['title'])
if correct:
st.success(f'Prediction: {prediction}')
else:
st.error(f'Prediction: {prediction}')
st.caption(f'Source: {sample["source"]} ({label})')
# if len(sample['content']) > 300:
# sample['content'] = sample['content'][:300]
temp = []
for idx, word in enumerate(sample['content'].split()):
if (random.randint(0, 99)> 45) and idx > 0:
word = '▒'*len(word)
temp.append(word)
sample['content'] = ' '.join(temp)
st.markdown(sample['content'])
else:
title = st.text_input('Article title', 'Test title')
content = st.text_area('Article content', 'Lorem ipsum')
if st.button('Submit'):
pred = predict_new(model, tokenizer, title, content)
st.markdown(f'Prediction: {pred}')
# st.success('success')