Eliott commited on
Commit
1f57142
1 Parent(s): e030ae6

update app

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -14,28 +14,32 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_toke
14
  def create_prompt(stars, useful, funny, cool):
15
  return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"
16
 
 
 
 
 
17
  def generate_reviews(stars, useful, funny, cool):
18
  text = create_prompt(stars, useful, funny, cool)
19
  inputs = tokenizer(text, return_tensors='pt')
20
  out = model.generate(
21
  input_ids=inputs.input_ids,
22
  attention_mask=inputs.attention_mask,
23
- num_beams=5,
24
- num_return_sequences=3
 
25
  )
26
  reviews = []
27
  for review in out:
28
- reviews.append(tokenizer.decode(review, skip_special_tokens=True))
29
 
30
  return reviews[0], reviews[1], reviews[2]
31
 
32
  css = """
33
  #ctr {text-align: center;}
34
- #btn {color: white; background: linear-gradient(90deg, #00d2ff 0%, #3a47d5 100%);}
35
  """
36
 
37
 
38
-
39
  md_text = """## Generating Yelp reviews with BART-base ⭐⭐⭐"""
40
  demo = gr.Blocks(css=css)
41
  with demo:
 
14
  def create_prompt(stars, useful, funny, cool):
15
  return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"
16
 
17
+ def postprocess(review):
18
+ dot = review.rfind('.')
19
+ return review[:dot]
20
+
21
  def generate_reviews(stars, useful, funny, cool):
22
  text = create_prompt(stars, useful, funny, cool)
23
  inputs = tokenizer(text, return_tensors='pt')
24
  out = model.generate(
25
  input_ids=inputs.input_ids,
26
  attention_mask=inputs.attention_mask,
27
+ num_beams=3,
28
+ num_return_sequences=3,
29
+ temperature=1.2
30
  )
31
  reviews = []
32
  for review in out:
33
+ reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
34
 
35
  return reviews[0], reviews[1], reviews[2]
36
 
37
  css = """
38
  #ctr {text-align: center;}
39
+ #btn {color: white; background: linear-gradient( 90deg, rgba(255,166,0,1) 14.7%, rgba(255,99,97,1) 73% );}
40
  """
41
 
42
 
 
43
  md_text = """## Generating Yelp reviews with BART-base ⭐⭐⭐"""
44
  demo = gr.Blocks(css=css)
45
  with demo: