Ventsislav Muchinov commited on
Commit
df914b7
1 Parent(s): 1f2e6c9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -34
app.py CHANGED
@@ -12,30 +12,30 @@ DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
  ACCESS_TOKEN = os.getenv("HF_TOKEN", "")
14
 
15
- model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_id,
18
- torch_dtype=torch.float16,
19
- device_map="auto",
20
- trust_remote_code=True,
21
- token=ACCESS_TOKEN)
22
- tokenizer = AutoTokenizer.from_pretrained(
23
- model_id,
24
- trust_remote_code=True,
25
- token=ACCESS_TOKEN)
26
- tokenizer.use_default_system_prompt = False
27
-
28
-
29
  @spaces.GPU
30
  def generate(
 
31
  message: str,
32
  system_prompt: str,
33
  max_new_tokens: int = 1024,
34
  temperature: float = 0.01,
35
  top_p: float = 0.01,
36
- top_k: int = 50,
37
- repetition_penalty: float = 1.0,
38
  ) -> Iterator[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  conversation = []
40
  if system_prompt:
41
  conversation.append({"role": "system", "content": system_prompt})
@@ -60,10 +60,8 @@ def generate(
60
  eos_token_id=terminators,
61
  do_sample=True,
62
  top_p=top_p,
63
- top_k=top_k,
64
  temperature=temperature,
65
  num_beams=1,
66
- repetition_penalty=repetition_penalty,
67
  )
68
  t = Thread(target=model.generate, kwargs=generate_kwargs)
69
  t.start()
@@ -77,6 +75,7 @@ def generate(
77
  chat_interface = gr.Interface(
78
  fn=generate,
79
  inputs=[
 
80
  gr.Textbox(lines=2, placeholder="Prompt", label="Prompt"),
81
  ],
82
  outputs="text",
@@ -102,24 +101,10 @@ chat_interface = gr.Interface(
102
  maximum=1.0,
103
  step=0.01,
104
  value=0.01,
105
- ),
106
- gr.Slider(
107
- label="Top-k",
108
- minimum=1,
109
- maximum=1000,
110
- step=1,
111
- value=50,
112
- ),
113
- gr.Slider(
114
- label="Repetition penalty",
115
- minimum=1.0,
116
- maximum=2.0,
117
- step=0.05,
118
- value=1.0,
119
- ),
120
  ],
121
  title="Model testing",
122
  description="Provide system settings and a prompt to interact with the model.",
123
  )
124
 
125
- chat_interface.queue(max_size=20).launch(share = True)
 
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
  ACCESS_TOKEN = os.getenv("HF_TOKEN", "")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU
16
  def generate(
17
+ model: str,
18
  message: str,
19
  system_prompt: str,
20
  max_new_tokens: int = 1024,
21
  temperature: float = 0.01,
22
  top_p: float = 0.01,
 
 
23
  ) -> Iterator[str]:
24
+
25
+ model_id = model
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ torch_dtype=torch.float16,
29
+ device_map="auto",
30
+ trust_remote_code=True,
31
+ token=ACCESS_TOKEN)
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ model_id,
34
+ trust_remote_code=True,
35
+ token=ACCESS_TOKEN)
36
+ tokenizer.use_default_system_prompt = False
37
+
38
+
39
  conversation = []
40
  if system_prompt:
41
  conversation.append({"role": "system", "content": system_prompt})
 
60
  eos_token_id=terminators,
61
  do_sample=True,
62
  top_p=top_p,
 
63
  temperature=temperature,
64
  num_beams=1,
 
65
  )
66
  t = Thread(target=model.generate, kwargs=generate_kwargs)
67
  t.start()
 
75
  chat_interface = gr.Interface(
76
  fn=generate,
77
  inputs=[
78
+ gr.Textbox(lines=1, placeholder="Model", label="Model name"),
79
  gr.Textbox(lines=2, placeholder="Prompt", label="Prompt"),
80
  ],
81
  outputs="text",
 
101
  maximum=1.0,
102
  step=0.01,
103
  value=0.01,
104
+ ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  ],
106
  title="Model testing",
107
  description="Provide system settings and a prompt to interact with the model.",
108
  )
109
 
110
+ chat_interface.queue(max_size=20).launch()