tmzh commited on
Commit
371d8a8
1 Parent(s): 5a96289

enable gpu

Browse files
Files changed (2) hide show
  1. app.py +32 -21
  2. requirements.txt +1 -1
app.py CHANGED
@@ -14,15 +14,19 @@ from chromadb.utils import embedding_functions
14
  import torch
15
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
16
 
 
 
 
 
17
  models = {
18
- "wizardLM-7B-HF" : "TheBloke/wizardLM-7B-HF",
19
- "wizard-vicuna-13B-GPTQ" : "TheBloke/wizard-vicuna-13B-GPTQ",
20
- "Wizard-Vicuna-13B-Uncensored" : "ehartford/Wizard-Vicuna-13B-Uncensored",
21
- "WizardLM-13B" : "TheBloke/WizardLM-13B-V1.0-Uncensored-GPTQ",
22
- "Llama-2-7B" : "TheBloke/Llama-2-7b-Chat-GPTQ",
23
- "Vicuna-13B" : "TheBloke/vicuna-13B-v1.5-GPTQ",
24
- "WizardLM-13B-V1.2" : "TheBloke/WizardLM-13B-V1.2-GPTQ", # Trained from Llama-2 13b
25
- "Mistral-7B" : "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
26
  }
27
 
28
 
@@ -32,27 +36,29 @@ tokenizer = AutoTokenizer.from_pretrained(models[model_name])
32
  # tokenizer.use_default_system_prompt = True
33
  tokenizer.chat_template = tokenizer.default_chat_template
34
 
35
- model = AutoModelForCausalLM.from_pretrained(models[model_name],
36
- torch_dtype=torch.float16,
37
  device_map="auto")
38
 
39
 
40
- file_path='./data/faq_dataset.json'
41
  data = json.loads(Path(file_path).read_text())
42
 
43
 
44
  client = chromadb.Client()
45
 
46
- emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="BAAI/bge-small-en-v1.5")
 
47
 
48
  collection = client.create_collection(
49
  name="retrieval_qa",
50
  embedding_function=emb_fn,
51
- metadata={"hnsw:space": "cosine"} # l2 is the default
52
  )
53
 
54
- documents = [json.dumps(q) for q in data['questions']] # encode QnA as json strings for generating embeddings
55
- metadatas = data['questions'] # retain QnA as dict in metadatas
 
56
  ids = [str(uuid.uuid1()) for _ in documents]
57
 
58
 
@@ -99,8 +105,10 @@ def respond(query):
99
 
100
  model.to(model.device)
101
 
102
- generated_ids = model.generate(model_inputs, streamer=streamer, temperature=0.01, max_new_tokens=100, do_sample=True)
103
- answer = tokenizer.batch_decode(generated_ids[:, model_inputs.shape[1]:])[0]
 
 
104
  answer = answer.replace('</s>', '')
105
  samples = related_questions
106
 
@@ -119,13 +127,16 @@ with gr.Blocks() as chatbot:
119
  with gr.Column():
120
  answer_block = gr.Textbox(label="Answers", lines=2)
121
  question = gr.Textbox(label="Question")
122
- examples = gr.Dataset(samples=samples, components=[question], label="Similar questions", type="index")
 
123
  generate = gr.Button(value="Ask")
124
  with gr.Column():
125
- references_block = gr.Markdown("## References\n", label="global variable")
 
126
 
127
  examples.click(load_example, inputs=[examples], outputs=[question])
128
- generate.click(respond, inputs=question, outputs=[answer_block, references_block, examples])
 
129
 
130
  chatbot.queue()
131
- chatbot.launch()
 
14
  import torch
15
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
16
 
17
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
18
+ print(
19
+ f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
20
+
21
  models = {
22
+ "wizardLM-7B-HF": "TheBloke/wizardLM-7B-HF",
23
+ "wizard-vicuna-13B-GPTQ": "TheBloke/wizard-vicuna-13B-GPTQ",
24
+ "Wizard-Vicuna-13B-Uncensored": "ehartford/Wizard-Vicuna-13B-Uncensored",
25
+ "WizardLM-13B": "TheBloke/WizardLM-13B-V1.0-Uncensored-GPTQ",
26
+ "Llama-2-7B": "TheBloke/Llama-2-7b-Chat-GPTQ",
27
+ "Vicuna-13B": "TheBloke/vicuna-13B-v1.5-GPTQ",
28
+ "WizardLM-13B-V1.2": "TheBloke/WizardLM-13B-V1.2-GPTQ", # Trained from Llama-2 13b
29
+ "Mistral-7B": "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
30
  }
31
 
32
 
 
36
  # tokenizer.use_default_system_prompt = True
37
  tokenizer.chat_template = tokenizer.default_chat_template
38
 
39
+ model = AutoModelForCausalLM.from_pretrained(models[model_name],
40
+ torch_dtype=torch.float16,
41
  device_map="auto")
42
 
43
 
44
+ file_path = './data/faq_dataset.json'
45
  data = json.loads(Path(file_path).read_text())
46
 
47
 
48
  client = chromadb.Client()
49
 
50
+ emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
51
+ model_name="BAAI/bge-small-en-v1.5")
52
 
53
  collection = client.create_collection(
54
  name="retrieval_qa",
55
  embedding_function=emb_fn,
56
+ metadata={"hnsw:space": "cosine"} # l2 is the default
57
  )
58
 
59
+ # encode QnA as json strings for generating embeddings
60
+ documents = [json.dumps(q) for q in data['questions']]
61
+ metadatas = data['questions'] # retain QnA as dict in metadatas
62
  ids = [str(uuid.uuid1()) for _ in documents]
63
 
64
 
 
105
 
106
  model.to(model.device)
107
 
108
+ generated_ids = model.generate(
109
+ model_inputs, streamer=streamer, temperature=0.01, max_new_tokens=100, do_sample=True)
110
+ answer = tokenizer.batch_decode(
111
+ generated_ids[:, model_inputs.shape[1]:])[0]
112
  answer = answer.replace('</s>', '')
113
  samples = related_questions
114
 
 
127
  with gr.Column():
128
  answer_block = gr.Textbox(label="Answers", lines=2)
129
  question = gr.Textbox(label="Question")
130
+ examples = gr.Dataset(samples=samples, components=[
131
+ question], label="Similar questions", type="index")
132
  generate = gr.Button(value="Ask")
133
  with gr.Column():
134
+ references_block = gr.Markdown(
135
+ "## References\n", label="global variable")
136
 
137
  examples.click(load_example, inputs=[examples], outputs=[question])
138
+ generate.click(respond, inputs=question, outputs=[
139
+ answer_block, references_block, examples])
140
 
141
  chatbot.queue()
142
+ chatbot.launch()
requirements.txt CHANGED
@@ -6,5 +6,5 @@ huggingface_hub
6
  optimum
7
  sentence_transformers
8
  spaces
9
- torch==2.3.0
10
  transformers==4.43.0.dev0
 
6
  optimum
7
  sentence_transformers
8
  spaces
9
+ torch==2.3.0 --index-url https://download.pytorch.org/whl/cu121
10
  transformers==4.43.0.dev0