czl commited on
Commit
808fd7f
1 Parent(s): 03031ae

added more models, changed layout

Browse files
Files changed (4) hide show
  1. app.py +160 -17
  2. requirements.txt +2 -1
  3. vocab219SW/idx2word.json +0 -0
  4. vocab219SW/word2idx.json +0 -0
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import json
 
2
  import re
3
  import unicodedata
4
  from typing import Tuple
5
- import random
6
 
7
  import gradio as gr
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
@@ -346,9 +347,52 @@ norm_model219.load_state_dict(torch.load('NormSeq2Seq-219M_epoch35.pt',
346
  map_location=torch.device('cpu')))
347
  norm_model219.to(device)
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
350
  'AttentionSeq2Seq-219M': attn_model219,
351
- 'NormalSeq2Seq-219M': norm_model219}
 
 
352
 
353
  def generateAttn188(sentence, history, max_len=12,
354
  word2idx=word2idx, idx2word=idx2word,
@@ -482,23 +526,122 @@ def generateNorm219(sentence, history, max_len=12,
482
  response = lookup_words(idx2word, outputs)
483
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  with gr.Blocks() as demo:
486
- gr.Markdown("""
487
- # Seq2Seq Generative Chatbot with 188M parameters
488
- """)
489
- with gr.Row():
490
- gr.ChatInterface(generateNorm188,
491
- title="NormalSeq2Seq-188M")
492
- gr.ChatInterface(generateAttn188,
493
- title="AttentionSeq2Seq-188M")
494
- gr.Markdown("""
495
- # Seq2Seq Generative Chatbot with 219M parameters
496
- """)
497
  with gr.Row():
498
- gr.ChatInterface(generateNorm219,
499
- title="NormalSeq2Seq-219M")
500
- gr.ChatInterface(generateAttn219,
501
- title="AttentionSeq2Seq-219M")
502
 
503
  if __name__ == "__main__":
504
  demo.launch()
 
1
  import json
2
+ import random
3
  import re
4
  import unicodedata
5
  from typing import Tuple
 
6
 
7
  import gradio as gr
8
+ import spacy
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
 
347
  map_location=torch.device('cpu')))
348
  norm_model219.to(device)
349
 
350
+ with open('vocab219SW/word2idx.json', 'r') as f:
351
+ word2idx3 = json.load(f)
352
+ with open('vocab219SW/idx2word.json', 'r') as f:
353
+ idx2word3 = json.load(f)
354
+
355
+ params219SW = {'input_dim': len(word2idx3),
356
+ 'emb_dim': 192,
357
+ 'enc_hid_dim': 256,
358
+ 'dec_hid_dim': 256,
359
+ 'dropout': 0.5,
360
+ 'attn_dim': 64,
361
+ 'teacher_forcing_ratio': 0.5,
362
+ 'epochs': 35}
363
+
364
+ enc = Encoder(input_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
365
+ enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
366
+ dropout=params219SW['dropout'])
367
+ attn = Attention(enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
368
+ attn_dim=params219SW['attn_dim'])
369
+ dec = AttnDecoder(output_dim=params219SW['input_dim'], emb_dim=params219['emb_dim'],
370
+ enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
371
+ attention=attn, dropout=params219SW['dropout'])
372
+ attn_model219SW = Seq2Seq(encoder=enc, decoder=dec, device=device)
373
+ attn_model219SW.load_state_dict(torch.load('AttnSeq2Seq-219M-SW_epoch35.pt',
374
+ map_location=torch.device('cpu')))
375
+ attn_model219SW.to(device)
376
+
377
+ enc = Encoder(input_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
378
+ enc_hid_dim=params219SW['enc_hid_dim'],
379
+ dec_hid_dim=params219SW['dec_hid_dim'], dropout=params219SW['dropout'])
380
+ dec = Decoder(output_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
381
+ enc_hid_dim=params219SW['enc_hid_dim'],
382
+ dec_hid_dim=params219SW['dec_hid_dim'],
383
+ dropout=params219SW['dropout'])
384
+ norm_model219SW = Seq2Seq(encoder=enc, decoder=dec, device=device)
385
+ norm_model219SW.load_state_dict(torch.load('NormSeq2Seq-219M-SW_epoch35.pt',
386
+ map_location=torch.device('cpu')))
387
+ norm_model219SW.to(device)
388
+
389
+ nlp = spacy.load('en_core_web_sm')
390
+
391
  models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
392
  'AttentionSeq2Seq-219M': attn_model219,
393
+ 'NormalSeq2Seq-219M': norm_model219,
394
+ 'AttentionSeq2Seq-219M-SW': attn_model219SW,
395
+ 'NormalSeq2Seq-219M-SW': norm_model219SW}
396
 
397
  def generateAttn188(sentence, history, max_len=12,
398
  word2idx=word2idx, idx2word=idx2word,
 
526
  response = lookup_words(idx2word, outputs)
527
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
528
 
529
+ def tokenize_context(text, nlp=nlp):
530
+ """
531
+ Tokenize text and remove stop words
532
+ :param text: text to be tokenized
533
+ :return: list of tokens
534
+ """
535
+ return [tok.text for tok in nlp.tokenizer(text) if not tok.is_stop]
536
+
537
+ def generateAttn219SW(sentence, history, max_len=12,
538
+ word2idx=word2idx3, idx2word=idx2word3,
539
+ device=device, tokenize_context=tokenize_context,
540
+ preprocess_text=preprocess_text,
541
+ lookup_words=lookup_words, models_dict=models_dict):
542
+ """
543
+ Generate response
544
+ :param model: model
545
+ :param sentence: sentence
546
+ :param max_len: maximum length of sequence
547
+ :param word2idx: word to index mapping
548
+ :param idx2word: index to word mapping
549
+ :return: response
550
+ """
551
+ history = history
552
+ model = models_dict['AttentionSeq2Seq-219M']
553
+ model.eval()
554
+ sentence = preprocess_text(sentence)
555
+ tokens = tokenize_context(sentence)
556
+ tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
557
+ tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
558
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
559
+ outputs = [word2idx['<bos>']]
560
+ with torch.no_grad():
561
+ encoder_outputs, hidden = model.encoder(tokens)
562
+ for t in range(max_len):
563
+ output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
564
+ top1 = output.max(1)[1]
565
+ outputs.append(top1.item())
566
+ if top1.item() == word2idx['<eos>']:
567
+ break
568
+ response = lookup_words(idx2word, outputs)
569
+ return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
570
+
571
+ def generateNorm219SW(sentence, history, max_len=12,
572
+ word2idx=word2idx3, idx2word=idx2word3,
573
+ device=device, tokenize_context=tokenize_context, preprocess_text=preprocess_text,
574
+ lookup_words=lookup_words, models_dict=models_dict):
575
+ """
576
+ Generate response
577
+ :param model: model
578
+ :param sentence: sentence
579
+ :param max_len: maximum length of sequence
580
+ :param word2idx: word to index mapping
581
+ :param idx2word: index to word mapping
582
+ :return: response
583
+ """
584
+ history = history
585
+ model = models_dict['NormalSeq2Seq-219M']
586
+ model.eval()
587
+ sentence = preprocess_text(sentence)
588
+ tokens = tokenize_context(sentence)
589
+ tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
590
+ tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
591
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
592
+ outputs = [word2idx['<bos>']]
593
+ with torch.no_grad():
594
+ encoder_outputs, hidden = model.encoder(tokens)
595
+ for t in range(max_len):
596
+ output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
597
+ top1 = output.max(1)[1]
598
+ outputs.append(top1.item())
599
+ if top1.item() == word2idx['<eos>']:
600
+ break
601
+ response = lookup_words(idx2word, outputs)
602
+ return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
603
+
604
+ norm188 = gr.ChatInterface(generateNorm188,
605
+ title="NormalSeq2Seq-188M",
606
+ description="""Seq2Seq Generative Chatbot without Attention.
607
+
608
+ 188,204,500 trainable parameters""")
609
+ norm219 = gr.ChatInterface(generateNorm219,
610
+ title="NormalSeq2Seq-219M",
611
+ description="""Seq2Seq Generative Chatbot without Attention.
612
+
613
+ 219,456,724 trainable parameters""")
614
+ norm219sw = gr.ChatInterface(generateNorm219SW,
615
+ title="NormalSeq2Seq-219M-SW",
616
+ description="""Seq2Seq Generative Chatbot without Attention.
617
+
618
+ 219,451,344 trainable parameters
619
+
620
+ Trained with stop words removed for context (input) and more data.""")
621
+
622
+ attn188 = gr.ChatInterface(generateAttn188,
623
+ title="AttentionSeq2Seq-188M",
624
+ description="""Seq2Seq Generative Chatbot with Attention.
625
+
626
+ 188,229,108 trainable parameters""")
627
+ attn219 = gr.ChatInterface(generateAttn219,
628
+ title="AttentionSeq2Seq-219M",
629
+ description="""Seq2Seq Generative Chatbot with Attention.
630
+
631
+ 219,505,940 trainable parameters
632
+ """)
633
+ attn219sw = gr.ChatInterface(generateAttn219SW,
634
+ title="AttentionSeq2Seq-219M-SW",
635
+ description="""Seq2Seq Generative Chatbot with Attention.
636
+
637
+ 219,500,560 trainable parameters
638
+
639
+ Trained with stop words removed for context (input) and more data""")
640
+
641
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
642
  with gr.Row():
643
+ gr.TabbedInterface([norm188, norm219, norm219sw], ["188M", "219M", "219M-SW"])
644
+ gr.TabbedInterface([attn188, attn219, attn219sw], ["188M", "219M", "219M-SW"])
 
 
645
 
646
  if __name__ == "__main__":
647
  demo.launch()
requirements.txt CHANGED
@@ -7,4 +7,5 @@ torch
7
  torchtext
8
  nltk
9
  sentence-transformers
10
- scipy
 
 
7
  torchtext
8
  nltk
9
  sentence-transformers
10
+ scipy
11
+ en-core-web-sm @ https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
vocab219SW/idx2word.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab219SW/word2idx.json ADDED
The diff for this file is too large to render. See raw diff