czl commited on
Commit
3159b65
1 Parent(s): 9e53223

added chat interface

Browse files
Files changed (1) hide show
  1. app.py +42 -7
app.py CHANGED
@@ -310,7 +310,8 @@ norm_model.to(device)
310
 
311
  models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model}
312
 
313
- def generate(models_str, sentence, max_len=12, word2idx=word2idx, idx2word=idx2word,
 
314
  device=device, tokenize=tokenize, preprocess_text=preprocess_text,
315
  lookup_words=lookup_words, models_dict=models_dict):
316
  """
@@ -322,7 +323,8 @@ def generate(models_str, sentence, max_len=12, word2idx=word2idx, idx2word=idx2w
322
  :param idx2word: index to word mapping
323
  :return: response
324
  """
325
- model = models_dict[models_str]
 
326
  model.eval()
327
  sentence = preprocess_text(sentence)
328
  tokens = tokenize(sentence)
@@ -341,13 +343,46 @@ def generate(models_str, sentence, max_len=12, word2idx=word2idx, idx2word=idx2w
341
  response = lookup_words(idx2word, outputs)
342
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
 
345
 
346
- demo = gr.Interface(fn=generate,
347
- inputs=[gr.Radio(list(models_dict.keys()), label="Model"),
348
- gr.Textbox(lines=2, label="Input Text")],
349
- outputs=gr.Textbox(label="Output Text"))
350
-
351
 
352
  if __name__ == "__main__":
353
  demo.launch()
 
310
 
311
  models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model}
312
 
313
+ def generateAttn(sentence, history, max_len=12,
314
+ word2idx=word2idx, idx2word=idx2word,
315
  device=device, tokenize=tokenize, preprocess_text=preprocess_text,
316
  lookup_words=lookup_words, models_dict=models_dict):
317
  """
 
323
  :param idx2word: index to word mapping
324
  :return: response
325
  """
326
+ history = history
327
+ model = models_dict['AttentionSeq2Seq-188M']
328
  model.eval()
329
  sentence = preprocess_text(sentence)
330
  tokens = tokenize(sentence)
 
343
  response = lookup_words(idx2word, outputs)
344
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
345
 
346
+ def generateNorm(sentence, history, max_len=12,
347
+ word2idx=word2idx, idx2word=idx2word,
348
+ device=device, tokenize=tokenize, preprocess_text=preprocess_text,
349
+ lookup_words=lookup_words, models_dict=models_dict):
350
+ """
351
+ Generate response
352
+ :param model: model
353
+ :param sentence: sentence
354
+ :param max_len: maximum length of sequence
355
+ :param word2idx: word to index mapping
356
+ :param idx2word: index to word mapping
357
+ :return: response
358
+ """
359
+ history = history
360
+ model = models_dict['NormalSeq2Seq-188M']
361
+ model.eval()
362
+ sentence = preprocess_text(sentence)
363
+ tokens = tokenize(sentence)
364
+ tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
365
+ tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
366
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
367
+ outputs = [word2idx['<bos>']]
368
+ with torch.no_grad():
369
+ encoder_outputs, hidden = model.encoder(tokens)
370
+ for t in range(max_len):
371
+ output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
372
+ top1 = output.max(1)[1]
373
+ outputs.append(top1.item())
374
+ if top1.item() == word2idx['<eos>']:
375
+ break
376
+ response = lookup_words(idx2word, outputs)
377
+ return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
378
 
379
+ # demo = gr.ChatInterface(generate, title="AttentionSeq2Seq-188M")
380
 
381
+ with gr.Blocks() as demo:
382
+ gr.ChatInterface(generateNorm,
383
+ title="NormalSeq2Seq-188M")
384
+ gr.ChatInterface(generateAttn,
385
+ title="AttentionSeq2Seq-188M")
386
 
387
  if __name__ == "__main__":
388
  demo.launch()