SHSH0819's picture
Update app.py
332a741
raw
history blame
2.87 kB
import os
import sys
sys.path.insert(0, os.path.abspath('./'))
import torch
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForMaskedLM
from event_detection_dataset import *
from event_detection_model import *
import gradio as gr
#print(f"Gradio version: {gr.__version__}")
def predict(data):
data=[data]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#print(f"Device {device}")
"""Load Tokenizer"""
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased', use_fast=True)
"""Tokenized Inputs"""
tokenized_inputs = tokenizer(
data,
add_special_tokens=True,
max_length=512,
padding='max_length',
return_token_type_ids=True,
truncation=True,
is_split_into_words=True
)
"""Load Model"""
model_path = "./"
#print("model_path:", model_path)
#print("================ load model ===========================")
model = DistillBERTClass('distilbert-base-cased')
#print("================ model init ===========================")
pretrained_model=torch.load(model_path + "event_domain_final.pt",map_location=torch.device('cpu'))
model.load_state_dict(pretrained_model['model_state_dict'])
model.to(device)
"""Make Prediction"""
model.eval()
ids = torch.tensor(tokenized_inputs['input_ids']).to(device)
mask = torch.tensor(tokenized_inputs['attention_mask']).to(device)
with torch.no_grad():
outputs = model(ids, mask)
max_val, max_idx = torch.max(outputs.data, dim=1)
#print("=============== inference result =================")
#print(f"predicted class {max_idx}")
id2tags={0: "Acquisition",1: "I-Positive Clinical Trial & FDA Approval",2: "Dividend Cut",3: "Dividend Increase",4: "Guidance Increase",5: "New Contract",6: "Dividend",7: "Reverse Stock Split",8: "Special Dividend ",9: "Stock Repurchase",10: "Stock Split",11: "Others"}
return id2tags[max_idx.item()]
title="Financial Event Detection"
description="Predict Finacial Events."
article="modified the model in the following paper: Zhou, Z., Ma, L., & Liu, H. (2021)."
example_list=[["Investors who receive dividends can choose to take them as cash or as additional shares."]]
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs="text", # what are the inputs?
outputs="text", # our fn has two outputs, therefore we have two outputs
examples=example_list,
title=title,
description=description,
article=article)
# Launch the demo!
demo.launch(debug=False, share=False)