Ventsislav Muchinov commited on
Commit
e8081aa
1 Parent(s): f74f77b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -12,11 +12,11 @@ DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
  ACCESS_TOKEN = os.getenv("HF_TOKEN", "")
14
 
15
- model_id = "hugging-quants/Meta-Llama-3.1-70B-Instruct-GPTQ-INT4"
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_id,
18
  torch_dtype=torch.float16,
19
- device_map="cuda",
20
  trust_remote_code=True,
21
  token=ACCESS_TOKEN)
22
  tokenizer = AutoTokenizer.from_pretrained(
@@ -39,7 +39,7 @@ def generate(
39
  conversation.append({"role": "system", "content": system_prompt})
40
  conversation.append({"role": "user", "content": message})
41
 
42
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to("cuda")
43
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
44
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
45
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
@@ -101,7 +101,7 @@ chat_interface = gr.Interface(
101
  value=0.01,
102
  ),
103
  ],
104
- title="Model testing",
105
  description="Provide system settings and a prompt to interact with the model.",
106
  )
107
 
 
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
  ACCESS_TOKEN = os.getenv("HF_TOKEN", "")
14
 
15
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_id,
18
  torch_dtype=torch.float16,
19
+ device_map="auto",
20
  trust_remote_code=True,
21
  token=ACCESS_TOKEN)
22
  tokenizer = AutoTokenizer.from_pretrained(
 
39
  conversation.append({"role": "system", "content": system_prompt})
40
  conversation.append({"role": "user", "content": message})
41
 
42
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
43
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
44
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
45
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
101
  value=0.01,
102
  ),
103
  ],
104
+ title="Model testing - Meta-Llama-3-8B-Instruct",
105
  description="Provide system settings and a prompt to interact with the model.",
106
  )
107