added chat interface
Browse files
app.py
CHANGED
@@ -310,7 +310,8 @@ norm_model.to(device)
|
|
310 |
|
311 |
models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model}
|
312 |
|
313 |
-
def
|
|
|
314 |
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
315 |
lookup_words=lookup_words, models_dict=models_dict):
|
316 |
"""
|
@@ -322,7 +323,8 @@ def generate(models_str, sentence, max_len=12, word2idx=word2idx, idx2word=idx2w
|
|
322 |
:param idx2word: index to word mapping
|
323 |
:return: response
|
324 |
"""
|
325 |
-
|
|
|
326 |
model.eval()
|
327 |
sentence = preprocess_text(sentence)
|
328 |
tokens = tokenize(sentence)
|
@@ -341,13 +343,46 @@ def generate(models_str, sentence, max_len=12, word2idx=word2idx, idx2word=idx2w
|
|
341 |
response = lookup_words(idx2word, outputs)
|
342 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
|
|
345 |
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
|
352 |
if __name__ == "__main__":
|
353 |
demo.launch()
|
|
|
310 |
|
311 |
models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model}
|
312 |
|
313 |
+
def generateAttn(sentence, history, max_len=12,
|
314 |
+
word2idx=word2idx, idx2word=idx2word,
|
315 |
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
316 |
lookup_words=lookup_words, models_dict=models_dict):
|
317 |
"""
|
|
|
323 |
:param idx2word: index to word mapping
|
324 |
:return: response
|
325 |
"""
|
326 |
+
history = history
|
327 |
+
model = models_dict['AttentionSeq2Seq-188M']
|
328 |
model.eval()
|
329 |
sentence = preprocess_text(sentence)
|
330 |
tokens = tokenize(sentence)
|
|
|
343 |
response = lookup_words(idx2word, outputs)
|
344 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
345 |
|
346 |
+
def generateNorm(sentence, history, max_len=12,
|
347 |
+
word2idx=word2idx, idx2word=idx2word,
|
348 |
+
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
349 |
+
lookup_words=lookup_words, models_dict=models_dict):
|
350 |
+
"""
|
351 |
+
Generate response
|
352 |
+
:param model: model
|
353 |
+
:param sentence: sentence
|
354 |
+
:param max_len: maximum length of sequence
|
355 |
+
:param word2idx: word to index mapping
|
356 |
+
:param idx2word: index to word mapping
|
357 |
+
:return: response
|
358 |
+
"""
|
359 |
+
history = history
|
360 |
+
model = models_dict['NormalSeq2Seq-188M']
|
361 |
+
model.eval()
|
362 |
+
sentence = preprocess_text(sentence)
|
363 |
+
tokens = tokenize(sentence)
|
364 |
+
tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
|
365 |
+
tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
|
366 |
+
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
|
367 |
+
outputs = [word2idx['<bos>']]
|
368 |
+
with torch.no_grad():
|
369 |
+
encoder_outputs, hidden = model.encoder(tokens)
|
370 |
+
for t in range(max_len):
|
371 |
+
output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
|
372 |
+
top1 = output.max(1)[1]
|
373 |
+
outputs.append(top1.item())
|
374 |
+
if top1.item() == word2idx['<eos>']:
|
375 |
+
break
|
376 |
+
response = lookup_words(idx2word, outputs)
|
377 |
+
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
378 |
|
379 |
+
# demo = gr.ChatInterface(generate, title="AttentionSeq2Seq-188M")
|
380 |
|
381 |
+
with gr.Blocks() as demo:
|
382 |
+
gr.ChatInterface(generateNorm,
|
383 |
+
title="NormalSeq2Seq-188M")
|
384 |
+
gr.ChatInterface(generateAttn,
|
385 |
+
title="AttentionSeq2Seq-188M")
|
386 |
|
387 |
if __name__ == "__main__":
|
388 |
demo.launch()
|