vmuchinov commited on
Commit
205ebd7
1 Parent(s): df914b7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -12,30 +12,28 @@ DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
  ACCESS_TOKEN = os.getenv("HF_TOKEN", "")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU
16
  def generate(
17
- model: str,
18
  message: str,
19
  system_prompt: str,
20
  max_new_tokens: int = 1024,
21
  temperature: float = 0.01,
22
  top_p: float = 0.01,
23
  ) -> Iterator[str]:
24
-
25
- model_id = model
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_id,
28
- torch_dtype=torch.float16,
29
- device_map="auto",
30
- trust_remote_code=True,
31
- token=ACCESS_TOKEN)
32
- tokenizer = AutoTokenizer.from_pretrained(
33
- model_id,
34
- trust_remote_code=True,
35
- token=ACCESS_TOKEN)
36
- tokenizer.use_default_system_prompt = False
37
-
38
-
39
  conversation = []
40
  if system_prompt:
41
  conversation.append({"role": "system", "content": system_prompt})
@@ -75,7 +73,6 @@ def generate(
75
  chat_interface = gr.Interface(
76
  fn=generate,
77
  inputs=[
78
- gr.Textbox(lines=1, placeholder="Model", label="Model name"),
79
  gr.Textbox(lines=2, placeholder="Prompt", label="Prompt"),
80
  ],
81
  outputs="text",
 
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
  ACCESS_TOKEN = os.getenv("HF_TOKEN", "")
14
 
15
+ model_id = "meta-llama/Llama-2-13b-chat"
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_id,
18
+ torch_dtype=torch.float16,
19
+ device_map="auto",
20
+ trust_remote_code=True,
21
+ token=ACCESS_TOKEN)
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ model_id,
24
+ trust_remote_code=True,
25
+ token=ACCESS_TOKEN)
26
+ tokenizer.use_default_system_prompt = False
27
+
28
+
29
  @spaces.GPU
30
  def generate(
 
31
  message: str,
32
  system_prompt: str,
33
  max_new_tokens: int = 1024,
34
  temperature: float = 0.01,
35
  top_p: float = 0.01,
36
  ) -> Iterator[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  conversation = []
38
  if system_prompt:
39
  conversation.append({"role": "system", "content": system_prompt})
 
73
  chat_interface = gr.Interface(
74
  fn=generate,
75
  inputs=[
 
76
  gr.Textbox(lines=2, placeholder="Prompt", label="Prompt"),
77
  ],
78
  outputs="text",