Happzy-WHU commited on
Commit
f1a0050
1 Parent(s): 5a54d5d

load model.

Browse files
Files changed (2) hide show
  1. V3.py +24 -21
  2. requirements.txt +27 -79
V3.py CHANGED
@@ -1,31 +1,34 @@
1
- import os
2
 
3
- from transformers import AutoTokenizer
4
- from vllm import LLM, SamplingParams
5
- from huggingface_hub import snapshot_download
6
-
7
- model_path = "happzy2633/qwen2.5-7b-ins-v3"
8
-
9
- tokenizer = AutoTokenizer.from_pretrained(model_path)
10
- sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=8192)
11
- llm = LLM(model=model_path)
12
-
13
- def api_call_batch(batch_messages):
14
- text_list = [
15
- tokenizer.apply_chat_template(conversation=messages, tokenize=False, add_generation_prompt=True, return_tensors='pt')
16
- for messages in batch_messages
17
- ]
18
- outputs = llm.generate(text_list, sampling_params)
19
- result = [output.outputs[0].text for output in outputs]
20
- return result
21
 
22
  def api_call(messages):
23
- return api_call_batch([messages])[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def call_gpt(history, prompt):
26
  return api_call(history+[{"role":"user", "content":prompt}])
27
 
28
  if __name__ == "__main__":
29
  messages = [{"role":"user", "content":"你是谁?"}]
 
30
  breakpoint()
31
- print(api_call_batch([messages]*4))
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
 
3
+ model_name = "happzy2633/qwen2.5-7b-ins-v3"
4
+ model = AutoModelForCausalLM.from_pretrained(
5
+ model_name,
6
+ torch_dtype="auto",
7
+ device_map="auto"
8
+ )
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def api_call(messages):
12
+ text = tokenizer.apply_chat_template(
13
+ messages,
14
+ tokenize=False,
15
+ add_generation_prompt=True
16
+ )
17
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
18
+ generated_ids = model.generate(
19
+ **model_inputs,
20
+ max_new_tokens=512
21
+ )
22
+ generated_ids = [
23
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
24
+ ]
25
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
26
+ return response
27
 
28
  def call_gpt(history, prompt):
29
  return api_call(history+[{"role":"user", "content":prompt}])
30
 
31
  if __name__ == "__main__":
32
  messages = [{"role":"user", "content":"你是谁?"}]
33
+ print(api_call(messages))
34
  breakpoint()
 
requirements.txt CHANGED
@@ -1,65 +1,38 @@
 
1
  aiofiles==23.2.1
2
- aiohappyeyeballs==2.3.6
3
- aiohttp==3.10.3
4
- aiosignal==1.3.1
5
  annotated-types==0.7.0
6
- anyio==4.4.0
7
- async-timeout==4.0.3
8
- attrs==24.2.0
9
- blinker==1.8.2
10
- certifi==2024.7.4
11
  charset-normalizer==3.3.2
12
  click==8.1.7
13
- cloudpickle==3.0.0
14
- cmake==3.30.2
15
  contourpy==1.3.0
16
  cycler==0.12.1
17
- datasets==2.21.0
18
- dill==0.3.8
19
- diskcache==5.6.3
20
  distro==1.9.0
21
  exceptiongroup==1.2.2
22
- fastapi==0.112.1
23
  ffmpy==0.4.0
24
- filelock==3.15.4
25
- Flask==3.0.3
26
- Flask-Cors==4.0.1
27
  fonttools==4.54.1
28
- frozenlist==1.4.1
29
- fsspec==2024.6.1
30
  gradio==4.44.1
31
  gradio_client==1.3.0
32
  h11==0.14.0
33
- httpcore==1.0.5
34
- httptools==0.6.1
35
- httpx==0.27.0
36
- huggingface-hub==0.24.5
37
- idna==3.7
38
  importlib_resources==6.4.5
39
- interegular==0.3.3
40
- itsdangerous==2.2.0
41
  Jinja2==3.1.4
42
  jiter==0.5.0
43
- jsonschema==4.23.0
44
- jsonschema-specifications==2023.12.1
45
  kiwisolver==1.4.7
46
- lark==1.2.2
47
- llvmlite==0.43.0
48
- lm-format-enforcer==0.10.3
49
  loguru==0.7.2
50
  markdown-it-py==3.0.0
51
  MarkupSafe==2.1.5
52
  matplotlib==3.9.2
53
  mdurl==0.1.2
54
  mpmath==1.3.0
55
- msgpack==1.0.8
56
- multidict==6.0.5
57
- multiprocess==0.70.16
58
- nest-asyncio==1.6.0
59
  networkx==3.3
60
- ninja==1.11.1.1
61
- numba==0.60.0
62
- numpy==1.26.4
63
  nvidia-cublas-cu12==12.1.3.1
64
  nvidia-cuda-cupti-cu12==12.1.105
65
  nvidia-cuda-nvrtc-cu12==12.1.105
@@ -69,70 +42,45 @@ nvidia-cufft-cu12==11.0.2.54
69
  nvidia-curand-cu12==10.3.2.106
70
  nvidia-cusolver-cu12==11.4.5.107
71
  nvidia-cusparse-cu12==12.1.0.106
72
- nvidia-ml-py==12.560.30
73
  nvidia-nccl-cu12==2.20.5
74
- nvidia-nvjitlink-cu12==12.6.20
75
  nvidia-nvtx-cu12==12.1.105
76
- openai==1.40.8
77
  orjson==3.10.7
78
- outlines==0.0.46
79
  packaging==24.1
80
- pandas==2.2.2
81
  pillow==10.4.0
82
- prometheus-fastapi-instrumentator==7.0.0
83
- prometheus_client==0.20.0
84
- protobuf==5.27.3
85
  psutil==6.0.0
86
- py-cpuinfo==9.0.0
87
- pyairports==2.1.1
88
- pyarrow==17.0.0
89
- pycountry==24.6.1
90
- pydantic==2.8.2
91
- pydantic_core==2.20.1
92
  pydub==0.25.1
93
- pyext==0.7
94
  Pygments==2.18.0
95
  pyparsing==3.1.4
96
  python-dateutil==2.9.0.post0
97
  python-dotenv==1.0.1
98
  python-multipart==0.0.12
99
- pytz==2024.1
100
  PyYAML==6.0.2
101
- pyzmq==26.1.0
102
- ray==2.34.0
103
- referencing==0.35.1
104
- regex==2024.7.24
105
  requests==2.32.3
106
  rich==13.9.1
107
- rpds-py==0.20.0
108
  ruff==0.6.8
109
- safetensors==0.4.4
110
  semantic-version==2.10.0
111
- sentencepiece==0.2.0
112
  shellingham==1.5.4
113
  six==1.16.0
114
  sniffio==1.3.1
115
- starlette==0.38.2
116
- sympy==1.13.2
117
- tiktoken==0.7.0
118
- tokenizers==0.19.1
119
  tomlkit==0.12.0
120
- torch==2.4.0
121
- torchvision==0.19.0
122
  tqdm==4.66.5
123
- transformers==4.44.0
124
  triton==3.0.0
125
  typer==0.12.5
126
  typing_extensions==4.12.2
127
- tzdata==2024.1
128
- urllib3==2.2.2
129
- uvicorn==0.30.6
130
- uvloop==0.20.0
131
- vllm==0.5.4
132
- vllm-flash-attn==2.6.1
133
- watchfiles==0.23.0
134
  websockets==12.0
135
- Werkzeug==3.0.3
136
- xformers==0.0.27.post2
137
- xxhash==3.4.1
138
- yarl==1.9.4
 
1
+ accelerate==0.34.2
2
  aiofiles==23.2.1
 
 
 
3
  annotated-types==0.7.0
4
+ anyio==4.6.0
5
+ certifi==2024.8.30
 
 
 
6
  charset-normalizer==3.3.2
7
  click==8.1.7
 
 
8
  contourpy==1.3.0
9
  cycler==0.12.1
 
 
 
10
  distro==1.9.0
11
  exceptiongroup==1.2.2
12
+ fastapi==0.115.0
13
  ffmpy==0.4.0
14
+ filelock==3.16.1
 
 
15
  fonttools==4.54.1
16
+ fsspec==2024.9.0
 
17
  gradio==4.44.1
18
  gradio_client==1.3.0
19
  h11==0.14.0
20
+ httpcore==1.0.6
21
+ httpx==0.27.2
22
+ huggingface-hub==0.25.1
23
+ idna==3.10
 
24
  importlib_resources==6.4.5
 
 
25
  Jinja2==3.1.4
26
  jiter==0.5.0
 
 
27
  kiwisolver==1.4.7
 
 
 
28
  loguru==0.7.2
29
  markdown-it-py==3.0.0
30
  MarkupSafe==2.1.5
31
  matplotlib==3.9.2
32
  mdurl==0.1.2
33
  mpmath==1.3.0
 
 
 
 
34
  networkx==3.3
35
+ numpy==2.1.1
 
 
36
  nvidia-cublas-cu12==12.1.3.1
37
  nvidia-cuda-cupti-cu12==12.1.105
38
  nvidia-cuda-nvrtc-cu12==12.1.105
 
42
  nvidia-curand-cu12==10.3.2.106
43
  nvidia-cusolver-cu12==11.4.5.107
44
  nvidia-cusparse-cu12==12.1.0.106
 
45
  nvidia-nccl-cu12==2.20.5
46
+ nvidia-nvjitlink-cu12==12.6.77
47
  nvidia-nvtx-cu12==12.1.105
48
+ openai==1.51.0
49
  orjson==3.10.7
 
50
  packaging==24.1
51
+ pandas==2.2.3
52
  pillow==10.4.0
 
 
 
53
  psutil==6.0.0
54
+ pydantic==2.9.2
55
+ pydantic_core==2.23.4
 
 
 
 
56
  pydub==0.25.1
 
57
  Pygments==2.18.0
58
  pyparsing==3.1.4
59
  python-dateutil==2.9.0.post0
60
  python-dotenv==1.0.1
61
  python-multipart==0.0.12
62
+ pytz==2024.2
63
  PyYAML==6.0.2
64
+ regex==2024.9.11
 
 
 
65
  requests==2.32.3
66
  rich==13.9.1
 
67
  ruff==0.6.8
68
+ safetensors==0.4.5
69
  semantic-version==2.10.0
 
70
  shellingham==1.5.4
71
  six==1.16.0
72
  sniffio==1.3.1
73
+ starlette==0.38.6
74
+ sympy==1.13.3
75
+ tokenizers==0.20.0
 
76
  tomlkit==0.12.0
77
+ torch==2.4.1
 
78
  tqdm==4.66.5
79
+ transformers==4.45.1
80
  triton==3.0.0
81
  typer==0.12.5
82
  typing_extensions==4.12.2
83
+ tzdata==2024.2
84
+ urllib3==2.2.3
85
+ uvicorn==0.31.0
 
 
 
 
86
  websockets==12.0