surkovvv commited on
Commit
2418c6c
1 Parent(s): ac3889c

test demo for chatbot from docs for cpu

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+ from threading import Thread
5
+
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
8
+ model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1") #, torch_dtype=torch.float16)
9
+ model = model #.to('cuda')
10
+
11
+
12
+ class StopOnTokens(StoppingCriteria):
13
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
14
+ stop_ids = [29, 0]
15
+ for stop_id in stop_ids:
16
+ if input_ids[0][-1] == stop_id:
17
+ return True
18
+ return False
19
+
20
+
21
+ def predict(message, history):
22
+ history_transformer_format = history + [[message, ""]]
23
+ stop = StopOnTokens()
24
+
25
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
26
+ for item in history_transformer_format])
27
+
28
+ model_inputs = tokenizer([messages], return_tensors="pt") # .to("cuda")
29
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
30
+ generate_kwargs = dict(
31
+ model_inputs,
32
+ streamer=streamer,
33
+ max_new_tokens=1024,
34
+ do_sample=True,
35
+ top_p=0.95,
36
+ top_k=1000,
37
+ temperature=1.0,
38
+ num_beams=1,
39
+ stopping_criteria=StoppingCriteriaList([stop])
40
+ )
41
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
42
+ t.start()
43
+
44
+ partial_message = ""
45
+ for new_token in streamer:
46
+ if new_token != '<':
47
+ partial_message += new_token
48
+ yield partial_message
49
+
50
+ gr.ChatInterface(predict).launch()