LinkangZhan commited on
Commit
f595202
1 Parent(s): ce089fa

fix gpu version

Browse files
Files changed (2) hide show
  1. app.py +10 -8
  2. requirements.txt +5 -5
app.py CHANGED
@@ -13,19 +13,21 @@ import torch
13
  # else lora_folder),
14
  # trust_remote_code=True)
15
  # model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
16
- # else model_folder), torch_dtype=torch.float32, trust_remote_code=True)
 
 
 
17
  # model = PeftModel.from_pretrained(model,
18
  # ("Junity/Genshin-World-Model" if lora_folder == ''
19
- # else lora_folder)
20
- # , torch_dtype=torch.float32, trust_remote_code=True)
 
 
21
  # tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
22
  # else model_folder),
23
  # trust_remote_code=True)
24
- # history = []
25
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- # if device == "cuda":
27
- # model.cuda()
28
- # model = model.half()
29
 
30
 
31
  def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k):
 
13
  # else lora_folder),
14
  # trust_remote_code=True)
15
  # model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
16
+ # else model_folder),
17
+ # torch_dtype=torch.float16,
18
+ # device_map="auto",
19
+ # trust_remote_code=True)
20
  # model = PeftModel.from_pretrained(model,
21
  # ("Junity/Genshin-World-Model" if lora_folder == ''
22
+ # else lora_folder),
23
+ # device_map="auto",
24
+ # torch_dtype=torch.float32,
25
+ # trust_remote_code=True)
26
  # tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
27
  # else model_folder),
28
  # trust_remote_code=True)
29
+ history = []
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
31
 
32
 
33
  def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k):
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  gradio==3.40.1
2
  peft==0.4.0
3
- torch==2.0.0+cu118
4
- transformers==4.31.0
5
- transformers_stream_generator==0.0.4
6
- tiktoken
7
- sentencepiece
 
1
  gradio==3.40.1
2
  peft==0.4.0
3
+ transformers_stream_generator
4
+ sentencepiece
5
+ accelerate
6
+ colorama
7
+ cpm_kernels