Shanat commited on
Commit
5b12d75
1 Parent(s): 2be5043

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -11,8 +11,17 @@ import os
11
 
12
 
13
  #chatbot = pipeline(model="NCSOFT/Llama-3-OffsetBias-8B")
14
- token = os.getenv("HF_TOKEN")
15
- chatbot = pipeline(model="meta-llama/Llama-3.2-1B")
 
 
 
 
 
 
 
 
 
16
  #chatbot = pipeline(model="facebook/blenderbot-400M-distill")
17
 
18
  message_list = []
@@ -20,9 +29,13 @@ response_list = []
20
 
21
 
22
  def vanilla_chatbot(message, history):
23
- conversation = chatbot(message)
 
 
 
 
24
 
25
- return conversation[0]['generated_text']
26
 
27
  demo_chatbot = gr.ChatInterface(vanilla_chatbot, title="Vanilla Chatbot", description="Enter text to start chatting.")
28
 
 
11
 
12
 
13
  #chatbot = pipeline(model="NCSOFT/Llama-3-OffsetBias-8B")
14
+ #token = os.getenv("HF_TOKEN")
15
+ login(token = os.getenv('HF_TOKEN'))
16
+ #chatbot = pipeline(model="meta-llama/Llama-3.2-1B")
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ "meta-llama/Llama-3.2-1B",
21
+ device_map="auto",
22
+ torch_dtype="auto",
23
+ )
24
+
25
  #chatbot = pipeline(model="facebook/blenderbot-400M-distill")
26
 
27
  message_list = []
 
29
 
30
 
31
  def vanilla_chatbot(message, history):
32
+ inputs = tokenizer(message['text'], return_tensors="pt").to("cpu")
33
+ with torch.no_grad():
34
+ outputs = model.generate(inputs.input_ids, max_length=100)
35
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ #conversation = chatbot(message)
37
 
38
+ #return conversation[0]['generated_text']
39
 
40
  demo_chatbot = gr.ChatInterface(vanilla_chatbot, title="Vanilla Chatbot", description="Enter text to start chatting.")
41