Spaces:
Sleeping
Sleeping
first finance bot
Browse files- app.py +54 -0
- gpt2talk.pt +3 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
import gradio as gr
|
3 |
+
import warnings
|
4 |
+
import torch
|
5 |
+
warnings.simplefilter('ignore')
|
6 |
+
|
7 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
+
|
9 |
+
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
|
10 |
+
#add padding token, beginstring and endstring tokens
|
11 |
+
tokenizer.add_special_tokens(
|
12 |
+
{
|
13 |
+
"pad_token":"<pad>",
|
14 |
+
"bos_token":"<startstring>",
|
15 |
+
"eos_token":"<endstring>"
|
16 |
+
})
|
17 |
+
#add bot token since it is not a special token
|
18 |
+
tokenizer.add_tokens(["<bot>:"])
|
19 |
+
|
20 |
+
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
|
21 |
+
model.resize_token_embeddings(len(tokenizer))
|
22 |
+
model.load_state_dict(torch.load('gpt2talk.pt', map_location=torch.device('cpu')))
|
23 |
+
|
24 |
+
model.eval()
|
25 |
+
def inference(quiz):
|
26 |
+
quiz1 = quiz
|
27 |
+
quiz = "<startstring>"+quiz+" <bot>:"
|
28 |
+
|
29 |
+
quiztoken = tokenizer(quiz,
|
30 |
+
return_tensors='pt'
|
31 |
+
)
|
32 |
+
|
33 |
+
answer = model.generate(**quiztoken, max_length=200, top_k=0.7,top_p=0.1)[0]
|
34 |
+
answer = tokenizer.decode(answer, skip_special_tokens=True)
|
35 |
+
answer = answer.replace(" <bot>:","").replace(quiz1,"") + '.'
|
36 |
+
return answer
|
37 |
+
|
38 |
+
def chatbot(input_text):
|
39 |
+
response = inference(input_text)
|
40 |
+
return response
|
41 |
+
|
42 |
+
# Create the Gradio interface
|
43 |
+
iface = gr.Interface(
|
44 |
+
fn=chatbot,
|
45 |
+
inputs=gr.Textbox(),
|
46 |
+
outputs=gr.Textbox(),
|
47 |
+
live=False, #set false to avoid caching
|
48 |
+
interpretation="chat",
|
49 |
+
title="ChatFinance",
|
50 |
+
description="Ask the a question and see its response!",
|
51 |
+
)
|
52 |
+
|
53 |
+
# Launch the Gradio interface
|
54 |
+
iface.launch()
|
gpt2talk.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbb4f3318512a3c112ad2fc12db6a7fb41b1beb98d4ead0825728a228705fb66
|
3 |
+
size 497826671
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
transformers==4.31.0
|
3 |
+
gradio==3.44.4
|
4 |
+
gradio_client==0.5.1
|