Tonic commited on
Commit
1d2b883
β€’
1 Parent(s): ffba04f

add do sample

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -29,7 +29,7 @@ model = AutoModelForCausalLM.from_pretrained(
29
  )
30
 
31
  @spaces.GPU()
32
- def generate_text(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty):
33
  date_string = datetime.today().strftime('%Y-%m-%d')
34
  messages = [
35
  {"role": "system", "content": system_prompt},
@@ -51,15 +51,15 @@ def generate_text(system_prompt, prompt, temperature, max_new_tokens, top_p, rep
51
  temperature=temperature,
52
  top_p=top_p,
53
  repetition_penalty=repetition_penalty,
54
- do_sample=True,
55
  pad_token_id=tokenizer.eos_token_id
56
  )
57
 
58
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
  return generated_text.split("assistant\n")[-1].strip()
60
 
61
- def update_output(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty):
62
- return generate_text(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty)
63
 
64
  with gr.Blocks() as demo:
65
  gr.Markdown("# 🦎 Welcome to Tonic's Salamandra-7b-instruct Demo")
@@ -82,9 +82,10 @@ with gr.Blocks() as demo:
82
  prompt = gr.Textbox(lines=5, label="πŸ™‹β€β™‚οΈ User Prompt")
83
  generate_button = gr.Button("Generate with 🦎 Salamandra-7b-instruct")
84
 
85
- with gr.Accordion("πŸ§ͺ Parameters", open=False):
 
86
  temperature = gr.Slider(0.0, 1.0, value=0.7, label="🌑️ Temperature")
87
- max_new_tokens = gr.Slider(1, 1000, value=200, step=1, label="πŸ”’ Max New Tokens")
88
  top_p = gr.Slider(0.0, 1.0, value=0.95, label="βš›οΈ Top P")
89
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="πŸ” Repetition Penalty")
90
 
@@ -93,10 +94,16 @@ with gr.Blocks() as demo:
93
 
94
  generate_button.click(
95
  update_output,
96
- inputs=[system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty],
97
  outputs=output
98
  )
99
 
 
 
 
 
 
 
100
  gr.Examples(
101
  examples=[
102
  ["At what temperature does water boil?"],
 
29
  )
30
 
31
  @spaces.GPU()
32
+ def generate_text(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty, do_sample):
33
  date_string = datetime.today().strftime('%Y-%m-%d')
34
  messages = [
35
  {"role": "system", "content": system_prompt},
 
51
  temperature=temperature,
52
  top_p=top_p,
53
  repetition_penalty=repetition_penalty,
54
+ do_sample=do_sample,
55
  pad_token_id=tokenizer.eos_token_id
56
  )
57
 
58
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
  return generated_text.split("assistant\n")[-1].strip()
60
 
61
+ def update_output(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty, do_sample):
62
+ return generate_text(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty, do_sample)
63
 
64
  with gr.Blocks() as demo:
65
  gr.Markdown("# 🦎 Welcome to Tonic's Salamandra-7b-instruct Demo")
 
82
  prompt = gr.Textbox(lines=5, label="πŸ™‹β€β™‚οΈ User Prompt")
83
  generate_button = gr.Button("Generate with 🦎 Salamandra-7b-instruct")
84
 
85
+ advanced_checkbox = gr.Checkbox(label="πŸ§ͺ Advanced Settings", value=False)
86
+ with gr.Column(visible=False) as advanced_settings:
87
  temperature = gr.Slider(0.0, 1.0, value=0.7, label="🌑️ Temperature")
88
+ max_new_tokens = gr.Slider(1, 2250, value=750, step=1, label="πŸ”’ Max New Tokens")
89
  top_p = gr.Slider(0.0, 1.0, value=0.95, label="βš›οΈ Top P")
90
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="πŸ” Repetition Penalty")
91
 
 
94
 
95
  generate_button.click(
96
  update_output,
97
+ inputs=[system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty, advanced_checkbox],
98
  outputs=output
99
  )
100
 
101
+ advanced_checkbox.change(
102
+ fn=lambda x: gr.update(visible=x),
103
+ inputs=[advanced_checkbox],
104
+ outputs=[advanced_settings]
105
+ )
106
+
107
  gr.Examples(
108
  examples=[
109
  ["At what temperature does water boil?"],