Spaces:
Sleeping
Sleeping
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
from transformers import AdamW | |
import pandas as pd | |
import torch | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from torch.nn.utils.rnn import pad_sequence | |
# from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler | |
pl.seed_everything(100) | |
MODEL_NAME='t5-base' | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
INPUT_MAX_LEN = 128 | |
OUTPUT_MAX_LEN = 128 | |
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512) | |
class T5Model(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True) | |
def forward(self, input_ids, attention_mask, labels=None): | |
output = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
labels=labels | |
) | |
return output.loss, output.logits | |
def training_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels= batch["target"] | |
loss, logits = self(input_ids , attention_mask, labels) | |
self.log("train_loss", loss, prog_bar=True, logger=True) | |
return {'loss': loss} | |
def validation_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels= batch["target"] | |
loss, logits = self(input_ids, attention_mask, labels) | |
self.log("val_loss", loss, prog_bar=True, logger=True) | |
return {'val_loss': loss} | |
def configure_optimizers(self): | |
return AdamW(self.parameters(), lr=0.0001) | |
train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE) | |
train_model.freeze() | |
def generate_question(question): | |
inputs_encoding = tokenizer( | |
question, | |
add_special_tokens=True, | |
max_length= INPUT_MAX_LEN, | |
padding = 'max_length', | |
truncation='only_first', | |
return_attention_mask=True, | |
return_tensors="pt" | |
) | |
generate_ids = train_model.model.generate( | |
input_ids = inputs_encoding["input_ids"], | |
attention_mask = inputs_encoding["attention_mask"], | |
max_length = INPUT_MAX_LEN, | |
num_beams = 4, | |
num_return_sequences = 1, | |
no_repeat_ngram_size=2, | |
early_stopping=True, | |
) | |
preds = [ | |
tokenizer.decode(gen_id, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True) | |
for gen_id in generate_ids | |
] | |
return "".join(preds) | |
import gradio as gr | |
import random | |
import time | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot() | |
gr.Chatbot.style(chatbot,height=300) | |
msg = gr.Textbox(info="Press \'Enter\' to send") | |
clear = gr.Button("Clear") | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history): | |
bot_message = generate_question(history[-1][0]) | |
history[-1][1] = "" | |
for character in bot_message: | |
history[-1][1] += character | |
time.sleep(0.05) | |
yield history | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then( | |
bot, chatbot, chatbot | |
) | |
clear.click(lambda: None, None, chatbot, queue=True) | |
demo.queue(concurrency_count=2) | |
demo.launch() |