MrVicente commited on
Commit
af44808
1 Parent(s): 19b6aec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -0
app.py CHANGED
@@ -4,6 +4,11 @@ from transformers import (
4
  BartTokenizer
5
  )
6
  import torch
 
 
 
 
 
7
 
8
  def get_device():
9
  # If there's a GPU available...
@@ -27,6 +32,12 @@ model.eval()
27
 
28
  def run_bart(question, censor):
29
  print(question, censor)
 
 
 
 
 
 
30
 
31
  model_input = tokenizer(question, truncation=True, padding=True, return_tensors="pt")
32
  generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
@@ -36,6 +47,7 @@ def run_bart(question, censor):
36
  min_length=1,
37
  max_length=100,
38
  do_sample=True,
 
39
  early_stopping=True,
40
  num_beams=4,
41
  temperature=1.0,
 
4
  BartTokenizer
5
  )
6
  import torch
7
+ import json
8
+
9
+ def read_json_file_2_dict(filename, store_dir='.'):
10
+ with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
11
+ return json.load(file)
12
 
13
  def get_device():
14
  # If there's a GPU available...
 
32
 
33
  def run_bart(question, censor):
34
  print(question, censor)
35
+ if censor:
36
+ bad_words = read_json_file_2_dict('bad_words_file.json')
37
+ bad_words_ids = tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).get('input_ids')
38
+ else:
39
+ bad_words_ids = None
40
+
41
 
42
  model_input = tokenizer(question, truncation=True, padding=True, return_tensors="pt")
43
  generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
 
47
  min_length=1,
48
  max_length=100,
49
  do_sample=True,
50
+ bad_words_ids=bad_words_ids
51
  early_stopping=True,
52
  num_beams=4,
53
  temperature=1.0,