aisuko commited on
Commit
0f60bae
1 Parent(s): 26a3280

Init commit

Browse files

Signed-off-by: Aisuko <[email protected]>

Files changed (2) hide show
  1. app.py +59 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import LongTensor, FloatTensor
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
8
+ model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.bfloat16)
9
+
10
+
11
+ class StopOnTokens(StoppingCriteria):
12
+ def __call__(self, input_ids: LongTensor, scores: FloatTensor, **kwargs) -> bool:
13
+ stop_ids=[29,0]
14
+ for stop_id in stop_ids:
15
+ if input_ids[0][-1]==stop_id:
16
+ return True
17
+ return False
18
+
19
+ def predict(message, history):
20
+ try:
21
+ history_transformer_format = history+[[message, ""]]
22
+ stop=StopOnTokens()
23
+
24
+ messages="".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) for item in history_transformer_format])
25
+
26
+
27
+ model_inputs =tokenizer([messages], return_tensors="pt")
28
+ streamer=TextIteratorStreamer(
29
+ tokenizer,
30
+ timeout=10.,
31
+ skip_prompt=True,
32
+ skip_special_tokens=True
33
+ )
34
+
35
+ generate_kwargs=dict(
36
+ model_inputs,
37
+ streamer=streamer,
38
+ max_new_tokens=1024,
39
+ do_sample=True,
40
+ top_p=0.95,
41
+ top_k=1000,
42
+ temperature=1.0,
43
+ num_beams=1,
44
+ stopping_criteria=StoppingCriteriaList([stop])
45
+ )
46
+
47
+ t=Thread(target=model.generate, kwargs=generate_kwargs)
48
+ t.start()
49
+
50
+ partical_message=""
51
+ for new_token in streamer:
52
+ if new_token !='<':
53
+ partical_message+=new_token
54
+ yield partical_message
55
+ except Exception as e:
56
+ yield "Sorry, I don't understand that."
57
+
58
+
59
+ gr.ChatInterface(predict).queue().launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==2.1.1
2
+ transformers==4.35.2