added more models, changed layout
Browse files- app.py +160 -17
- requirements.txt +2 -1
- vocab219SW/idx2word.json +0 -0
- vocab219SW/word2idx.json +0 -0
app.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import json
|
|
|
2 |
import re
|
3 |
import unicodedata
|
4 |
from typing import Tuple
|
5 |
-
import random
|
6 |
|
7 |
import gradio as gr
|
|
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
@@ -346,9 +347,52 @@ norm_model219.load_state_dict(torch.load('NormSeq2Seq-219M_epoch35.pt',
|
|
346 |
map_location=torch.device('cpu')))
|
347 |
norm_model219.to(device)
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
|
350 |
'AttentionSeq2Seq-219M': attn_model219,
|
351 |
-
'NormalSeq2Seq-219M': norm_model219
|
|
|
|
|
352 |
|
353 |
def generateAttn188(sentence, history, max_len=12,
|
354 |
word2idx=word2idx, idx2word=idx2word,
|
@@ -482,23 +526,122 @@ def generateNorm219(sentence, history, max_len=12,
|
|
482 |
response = lookup_words(idx2word, outputs)
|
483 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
484 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
with gr.Blocks() as demo:
|
486 |
-
gr.Markdown("""
|
487 |
-
# Seq2Seq Generative Chatbot with 188M parameters
|
488 |
-
""")
|
489 |
-
with gr.Row():
|
490 |
-
gr.ChatInterface(generateNorm188,
|
491 |
-
title="NormalSeq2Seq-188M")
|
492 |
-
gr.ChatInterface(generateAttn188,
|
493 |
-
title="AttentionSeq2Seq-188M")
|
494 |
-
gr.Markdown("""
|
495 |
-
# Seq2Seq Generative Chatbot with 219M parameters
|
496 |
-
""")
|
497 |
with gr.Row():
|
498 |
-
gr.
|
499 |
-
|
500 |
-
gr.ChatInterface(generateAttn219,
|
501 |
-
title="AttentionSeq2Seq-219M")
|
502 |
|
503 |
if __name__ == "__main__":
|
504 |
demo.launch()
|
|
|
1 |
import json
|
2 |
+
import random
|
3 |
import re
|
4 |
import unicodedata
|
5 |
from typing import Tuple
|
|
|
6 |
|
7 |
import gradio as gr
|
8 |
+
import spacy
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
|
|
347 |
map_location=torch.device('cpu')))
|
348 |
norm_model219.to(device)
|
349 |
|
350 |
+
with open('vocab219SW/word2idx.json', 'r') as f:
|
351 |
+
word2idx3 = json.load(f)
|
352 |
+
with open('vocab219SW/idx2word.json', 'r') as f:
|
353 |
+
idx2word3 = json.load(f)
|
354 |
+
|
355 |
+
params219SW = {'input_dim': len(word2idx3),
|
356 |
+
'emb_dim': 192,
|
357 |
+
'enc_hid_dim': 256,
|
358 |
+
'dec_hid_dim': 256,
|
359 |
+
'dropout': 0.5,
|
360 |
+
'attn_dim': 64,
|
361 |
+
'teacher_forcing_ratio': 0.5,
|
362 |
+
'epochs': 35}
|
363 |
+
|
364 |
+
enc = Encoder(input_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
|
365 |
+
enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
|
366 |
+
dropout=params219SW['dropout'])
|
367 |
+
attn = Attention(enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
|
368 |
+
attn_dim=params219SW['attn_dim'])
|
369 |
+
dec = AttnDecoder(output_dim=params219SW['input_dim'], emb_dim=params219['emb_dim'],
|
370 |
+
enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
|
371 |
+
attention=attn, dropout=params219SW['dropout'])
|
372 |
+
attn_model219SW = Seq2Seq(encoder=enc, decoder=dec, device=device)
|
373 |
+
attn_model219SW.load_state_dict(torch.load('AttnSeq2Seq-219M-SW_epoch35.pt',
|
374 |
+
map_location=torch.device('cpu')))
|
375 |
+
attn_model219SW.to(device)
|
376 |
+
|
377 |
+
enc = Encoder(input_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
|
378 |
+
enc_hid_dim=params219SW['enc_hid_dim'],
|
379 |
+
dec_hid_dim=params219SW['dec_hid_dim'], dropout=params219SW['dropout'])
|
380 |
+
dec = Decoder(output_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
|
381 |
+
enc_hid_dim=params219SW['enc_hid_dim'],
|
382 |
+
dec_hid_dim=params219SW['dec_hid_dim'],
|
383 |
+
dropout=params219SW['dropout'])
|
384 |
+
norm_model219SW = Seq2Seq(encoder=enc, decoder=dec, device=device)
|
385 |
+
norm_model219SW.load_state_dict(torch.load('NormSeq2Seq-219M-SW_epoch35.pt',
|
386 |
+
map_location=torch.device('cpu')))
|
387 |
+
norm_model219SW.to(device)
|
388 |
+
|
389 |
+
nlp = spacy.load('en_core_web_sm')
|
390 |
+
|
391 |
models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
|
392 |
'AttentionSeq2Seq-219M': attn_model219,
|
393 |
+
'NormalSeq2Seq-219M': norm_model219,
|
394 |
+
'AttentionSeq2Seq-219M-SW': attn_model219SW,
|
395 |
+
'NormalSeq2Seq-219M-SW': norm_model219SW}
|
396 |
|
397 |
def generateAttn188(sentence, history, max_len=12,
|
398 |
word2idx=word2idx, idx2word=idx2word,
|
|
|
526 |
response = lookup_words(idx2word, outputs)
|
527 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
528 |
|
529 |
+
def tokenize_context(text, nlp=nlp):
|
530 |
+
"""
|
531 |
+
Tokenize text and remove stop words
|
532 |
+
:param text: text to be tokenized
|
533 |
+
:return: list of tokens
|
534 |
+
"""
|
535 |
+
return [tok.text for tok in nlp.tokenizer(text) if not tok.is_stop]
|
536 |
+
|
537 |
+
def generateAttn219SW(sentence, history, max_len=12,
|
538 |
+
word2idx=word2idx3, idx2word=idx2word3,
|
539 |
+
device=device, tokenize_context=tokenize_context,
|
540 |
+
preprocess_text=preprocess_text,
|
541 |
+
lookup_words=lookup_words, models_dict=models_dict):
|
542 |
+
"""
|
543 |
+
Generate response
|
544 |
+
:param model: model
|
545 |
+
:param sentence: sentence
|
546 |
+
:param max_len: maximum length of sequence
|
547 |
+
:param word2idx: word to index mapping
|
548 |
+
:param idx2word: index to word mapping
|
549 |
+
:return: response
|
550 |
+
"""
|
551 |
+
history = history
|
552 |
+
model = models_dict['AttentionSeq2Seq-219M']
|
553 |
+
model.eval()
|
554 |
+
sentence = preprocess_text(sentence)
|
555 |
+
tokens = tokenize_context(sentence)
|
556 |
+
tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
|
557 |
+
tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
|
558 |
+
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
|
559 |
+
outputs = [word2idx['<bos>']]
|
560 |
+
with torch.no_grad():
|
561 |
+
encoder_outputs, hidden = model.encoder(tokens)
|
562 |
+
for t in range(max_len):
|
563 |
+
output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
|
564 |
+
top1 = output.max(1)[1]
|
565 |
+
outputs.append(top1.item())
|
566 |
+
if top1.item() == word2idx['<eos>']:
|
567 |
+
break
|
568 |
+
response = lookup_words(idx2word, outputs)
|
569 |
+
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
570 |
+
|
571 |
+
def generateNorm219SW(sentence, history, max_len=12,
|
572 |
+
word2idx=word2idx3, idx2word=idx2word3,
|
573 |
+
device=device, tokenize_context=tokenize_context, preprocess_text=preprocess_text,
|
574 |
+
lookup_words=lookup_words, models_dict=models_dict):
|
575 |
+
"""
|
576 |
+
Generate response
|
577 |
+
:param model: model
|
578 |
+
:param sentence: sentence
|
579 |
+
:param max_len: maximum length of sequence
|
580 |
+
:param word2idx: word to index mapping
|
581 |
+
:param idx2word: index to word mapping
|
582 |
+
:return: response
|
583 |
+
"""
|
584 |
+
history = history
|
585 |
+
model = models_dict['NormalSeq2Seq-219M']
|
586 |
+
model.eval()
|
587 |
+
sentence = preprocess_text(sentence)
|
588 |
+
tokens = tokenize_context(sentence)
|
589 |
+
tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
|
590 |
+
tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
|
591 |
+
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
|
592 |
+
outputs = [word2idx['<bos>']]
|
593 |
+
with torch.no_grad():
|
594 |
+
encoder_outputs, hidden = model.encoder(tokens)
|
595 |
+
for t in range(max_len):
|
596 |
+
output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
|
597 |
+
top1 = output.max(1)[1]
|
598 |
+
outputs.append(top1.item())
|
599 |
+
if top1.item() == word2idx['<eos>']:
|
600 |
+
break
|
601 |
+
response = lookup_words(idx2word, outputs)
|
602 |
+
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
603 |
+
|
604 |
+
norm188 = gr.ChatInterface(generateNorm188,
|
605 |
+
title="NormalSeq2Seq-188M",
|
606 |
+
description="""Seq2Seq Generative Chatbot without Attention.
|
607 |
+
|
608 |
+
188,204,500 trainable parameters""")
|
609 |
+
norm219 = gr.ChatInterface(generateNorm219,
|
610 |
+
title="NormalSeq2Seq-219M",
|
611 |
+
description="""Seq2Seq Generative Chatbot without Attention.
|
612 |
+
|
613 |
+
219,456,724 trainable parameters""")
|
614 |
+
norm219sw = gr.ChatInterface(generateNorm219SW,
|
615 |
+
title="NormalSeq2Seq-219M-SW",
|
616 |
+
description="""Seq2Seq Generative Chatbot without Attention.
|
617 |
+
|
618 |
+
219,451,344 trainable parameters
|
619 |
+
|
620 |
+
Trained with stop words removed for context (input) and more data.""")
|
621 |
+
|
622 |
+
attn188 = gr.ChatInterface(generateAttn188,
|
623 |
+
title="AttentionSeq2Seq-188M",
|
624 |
+
description="""Seq2Seq Generative Chatbot with Attention.
|
625 |
+
|
626 |
+
188,229,108 trainable parameters""")
|
627 |
+
attn219 = gr.ChatInterface(generateAttn219,
|
628 |
+
title="AttentionSeq2Seq-219M",
|
629 |
+
description="""Seq2Seq Generative Chatbot with Attention.
|
630 |
+
|
631 |
+
219,505,940 trainable parameters
|
632 |
+
""")
|
633 |
+
attn219sw = gr.ChatInterface(generateAttn219SW,
|
634 |
+
title="AttentionSeq2Seq-219M-SW",
|
635 |
+
description="""Seq2Seq Generative Chatbot with Attention.
|
636 |
+
|
637 |
+
219,500,560 trainable parameters
|
638 |
+
|
639 |
+
Trained with stop words removed for context (input) and more data""")
|
640 |
+
|
641 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
642 |
with gr.Row():
|
643 |
+
gr.TabbedInterface([norm188, norm219, norm219sw], ["188M", "219M", "219M-SW"])
|
644 |
+
gr.TabbedInterface([attn188, attn219, attn219sw], ["188M", "219M", "219M-SW"])
|
|
|
|
|
645 |
|
646 |
if __name__ == "__main__":
|
647 |
demo.launch()
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ torch
|
|
7 |
torchtext
|
8 |
nltk
|
9 |
sentence-transformers
|
10 |
-
scipy
|
|
|
|
7 |
torchtext
|
8 |
nltk
|
9 |
sentence-transformers
|
10 |
+
scipy
|
11 |
+
en-core-web-sm @ https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
|
vocab219SW/idx2word.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
vocab219SW/word2idx.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|