hysts HF staff commited on
Commit
b5eb658
1 Parent(s): 8b8ee9d
Files changed (10) hide show
  1. .gitignore +10 -0
  2. .pre-commit-config.yaml +60 -0
  3. .python-version +1 -0
  4. .vscode/settings.json +30 -0
  5. README.md +4 -3
  6. app.py +136 -0
  7. pyproject.toml +13 -0
  8. requirements.txt +258 -0
  9. style.css +11 -0
  10. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
.pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Gemma 2 2b Jpn It
3
- emoji: 🔥
4
  colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Gemma 2 2B JPN IT
3
+ emoji: 😻
4
  colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Chatbot
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Gemma 2 2B JPN IT
12
+
13
+ Gemma-2-JPN は日本語の文章で fine-tune された Gemma 2 2B モデルです。英語のみのクエリと同レベルの性能で日本語をサポートします。
14
+
15
+ (Gemma-2-JPN is a Gemma 2 2B model fine-tuned on Japanese text. It supports the Japanese language at the same level of performance as English-only queries on Gemma 2.)
16
+ """
17
+
18
+ MAX_MAX_NEW_TOKENS = 2048
19
+ DEFAULT_MAX_NEW_TOKENS = 1024
20
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
+
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ model_id = "gg-hf/gemma-2-2b-jpn-it"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ device_map="auto",
29
+ torch_dtype=torch.bfloat16,
30
+ )
31
+ model.config.sliding_window = 4096
32
+ model.eval()
33
+
34
+
35
+ @spaces.GPU
36
+ def generate(
37
+ message: str,
38
+ chat_history: list[dict],
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.2,
44
+ ) -> Iterator[str]:
45
+ conversation = chat_history + [{"role": "user", "content": message}]
46
+
47
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
48
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
49
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
50
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
51
+ input_ids = input_ids.to(model.device)
52
+
53
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
54
+ generate_kwargs = dict(
55
+ {"input_ids": input_ids},
56
+ streamer=streamer,
57
+ max_new_tokens=max_new_tokens,
58
+ do_sample=True,
59
+ top_p=top_p,
60
+ top_k=top_k,
61
+ temperature=temperature,
62
+ num_beams=1,
63
+ repetition_penalty=repetition_penalty,
64
+ )
65
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
66
+ t.start()
67
+
68
+ outputs = []
69
+ for text in streamer:
70
+ outputs.append(text)
71
+ yield "".join(outputs)
72
+
73
+
74
+ demo = gr.ChatInterface(
75
+ fn=generate,
76
+ type="messages",
77
+ description=DESCRIPTION,
78
+ css="style.css",
79
+ fill_height=True,
80
+ textbox=gr.Textbox(placeholder="ここにメッセージを入力してください。", scale=7, autofocus=True),
81
+ additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
82
+ additional_inputs=[
83
+ gr.Slider(
84
+ label="Max new tokens",
85
+ minimum=1,
86
+ maximum=MAX_MAX_NEW_TOKENS,
87
+ step=1,
88
+ value=DEFAULT_MAX_NEW_TOKENS,
89
+ ),
90
+ gr.Slider(
91
+ label="Temperature",
92
+ minimum=0.1,
93
+ maximum=4.0,
94
+ step=0.1,
95
+ value=0.6,
96
+ ),
97
+ gr.Slider(
98
+ label="Top-p (nucleus sampling)",
99
+ minimum=0.05,
100
+ maximum=1.0,
101
+ step=0.05,
102
+ value=0.9,
103
+ ),
104
+ gr.Slider(
105
+ label="Top-k",
106
+ minimum=1,
107
+ maximum=1000,
108
+ step=1,
109
+ value=50,
110
+ ),
111
+ gr.Slider(
112
+ label="Repetition penalty",
113
+ minimum=1.0,
114
+ maximum=2.0,
115
+ step=0.05,
116
+ value=1.2,
117
+ ),
118
+ ],
119
+ submit_btn="送信",
120
+ retry_btn="🔄 再実行",
121
+ undo_btn="↩️ 元に戻す",
122
+ clear_btn="🗑️ クリア",
123
+ stop_btn=None,
124
+ examples=[
125
+ ["こんにちは、自己紹介をしてください。"],
126
+ ["マシンラーニングについての詩を書いてください。"],
127
+ [
128
+ "次の文章を英語にして: Gemma-2-JPN は日本語の文章で fine-tune された Gemma 2 2B モデルです。英語のみのクエリと同レベルの性能で日本語をサポートします。"
129
+ ],
130
+ ],
131
+ cache_examples=False,
132
+ )
133
+
134
+
135
+ if __name__ == "__main__":
136
+ demo.launch()
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "gemma-2-2b-jpn-it"
3
+ version = "0.1.0"
4
+ description = ""
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "accelerate>=0.34.2",
9
+ "gradio>=4.44.1",
10
+ "spaces>=0.30.2",
11
+ "torch==2.4.0",
12
+ "transformers>=4.45.1",
13
+ ]
requirements.txt ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==0.34.2
4
+ # via gemma-2-2b-jpn-it (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.6.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ certifi==2024.8.30
15
+ # via
16
+ # httpcore
17
+ # httpx
18
+ # requests
19
+ charset-normalizer==3.3.2
20
+ # via requests
21
+ click==8.1.7
22
+ # via
23
+ # typer
24
+ # uvicorn
25
+ contourpy==1.3.0
26
+ # via matplotlib
27
+ cycler==0.12.1
28
+ # via matplotlib
29
+ exceptiongroup==1.2.2
30
+ # via anyio
31
+ fastapi==0.115.0
32
+ # via gradio
33
+ ffmpy==0.4.0
34
+ # via gradio
35
+ filelock==3.16.1
36
+ # via
37
+ # huggingface-hub
38
+ # torch
39
+ # transformers
40
+ # triton
41
+ fonttools==4.54.1
42
+ # via matplotlib
43
+ fsspec==2024.9.0
44
+ # via
45
+ # gradio-client
46
+ # huggingface-hub
47
+ # torch
48
+ gradio==4.44.1
49
+ # via
50
+ # gemma-2-2b-jpn-it (pyproject.toml)
51
+ # spaces
52
+ gradio-client==1.3.0
53
+ # via gradio
54
+ h11==0.14.0
55
+ # via
56
+ # httpcore
57
+ # uvicorn
58
+ httpcore==1.0.6
59
+ # via httpx
60
+ httpx==0.27.2
61
+ # via
62
+ # gradio
63
+ # gradio-client
64
+ # spaces
65
+ huggingface-hub==0.25.1
66
+ # via
67
+ # accelerate
68
+ # gradio
69
+ # gradio-client
70
+ # tokenizers
71
+ # transformers
72
+ idna==3.10
73
+ # via
74
+ # anyio
75
+ # httpx
76
+ # requests
77
+ importlib-resources==6.4.5
78
+ # via gradio
79
+ jinja2==3.1.4
80
+ # via
81
+ # gradio
82
+ # torch
83
+ kiwisolver==1.4.7
84
+ # via matplotlib
85
+ markdown-it-py==3.0.0
86
+ # via rich
87
+ markupsafe==2.1.5
88
+ # via
89
+ # gradio
90
+ # jinja2
91
+ matplotlib==3.9.2
92
+ # via gradio
93
+ mdurl==0.1.2
94
+ # via markdown-it-py
95
+ mpmath==1.3.0
96
+ # via sympy
97
+ networkx==3.3
98
+ # via torch
99
+ numpy==2.1.1
100
+ # via
101
+ # accelerate
102
+ # contourpy
103
+ # gradio
104
+ # matplotlib
105
+ # pandas
106
+ # transformers
107
+ nvidia-cublas-cu12==12.1.3.1
108
+ # via
109
+ # nvidia-cudnn-cu12
110
+ # nvidia-cusolver-cu12
111
+ # torch
112
+ nvidia-cuda-cupti-cu12==12.1.105
113
+ # via torch
114
+ nvidia-cuda-nvrtc-cu12==12.1.105
115
+ # via torch
116
+ nvidia-cuda-runtime-cu12==12.1.105
117
+ # via torch
118
+ nvidia-cudnn-cu12==9.1.0.70
119
+ # via torch
120
+ nvidia-cufft-cu12==11.0.2.54
121
+ # via torch
122
+ nvidia-curand-cu12==10.3.2.106
123
+ # via torch
124
+ nvidia-cusolver-cu12==11.4.5.107
125
+ # via torch
126
+ nvidia-cusparse-cu12==12.1.0.106
127
+ # via
128
+ # nvidia-cusolver-cu12
129
+ # torch
130
+ nvidia-nccl-cu12==2.20.5
131
+ # via torch
132
+ nvidia-nvjitlink-cu12==12.6.77
133
+ # via
134
+ # nvidia-cusolver-cu12
135
+ # nvidia-cusparse-cu12
136
+ nvidia-nvtx-cu12==12.1.105
137
+ # via torch
138
+ orjson==3.10.7
139
+ # via gradio
140
+ packaging==24.1
141
+ # via
142
+ # accelerate
143
+ # gradio
144
+ # gradio-client
145
+ # huggingface-hub
146
+ # matplotlib
147
+ # spaces
148
+ # transformers
149
+ pandas==2.2.3
150
+ # via gradio
151
+ pillow==10.4.0
152
+ # via
153
+ # gradio
154
+ # matplotlib
155
+ psutil==5.9.8
156
+ # via
157
+ # accelerate
158
+ # spaces
159
+ pydantic==2.9.2
160
+ # via
161
+ # fastapi
162
+ # gradio
163
+ # spaces
164
+ pydantic-core==2.23.4
165
+ # via pydantic
166
+ pydub==0.25.1
167
+ # via gradio
168
+ pygments==2.18.0
169
+ # via rich
170
+ pyparsing==3.1.4
171
+ # via matplotlib
172
+ python-dateutil==2.9.0.post0
173
+ # via
174
+ # matplotlib
175
+ # pandas
176
+ python-multipart==0.0.12
177
+ # via gradio
178
+ pytz==2024.2
179
+ # via pandas
180
+ pyyaml==6.0.2
181
+ # via
182
+ # accelerate
183
+ # gradio
184
+ # huggingface-hub
185
+ # transformers
186
+ regex==2024.9.11
187
+ # via transformers
188
+ requests==2.32.3
189
+ # via
190
+ # huggingface-hub
191
+ # spaces
192
+ # transformers
193
+ rich==13.9.1
194
+ # via typer
195
+ ruff==0.6.8
196
+ # via gradio
197
+ safetensors==0.4.5
198
+ # via
199
+ # accelerate
200
+ # transformers
201
+ semantic-version==2.10.0
202
+ # via gradio
203
+ shellingham==1.5.4
204
+ # via typer
205
+ six==1.16.0
206
+ # via python-dateutil
207
+ sniffio==1.3.1
208
+ # via
209
+ # anyio
210
+ # httpx
211
+ spaces==0.30.2
212
+ # via gemma-2-2b-jpn-it (pyproject.toml)
213
+ starlette==0.38.6
214
+ # via fastapi
215
+ sympy==1.13.3
216
+ # via torch
217
+ tokenizers==0.20.0
218
+ # via transformers
219
+ tomlkit==0.12.0
220
+ # via gradio
221
+ torch==2.4.0
222
+ # via
223
+ # gemma-2-2b-jpn-it (pyproject.toml)
224
+ # accelerate
225
+ tqdm==4.66.5
226
+ # via
227
+ # huggingface-hub
228
+ # transformers
229
+ transformers==4.45.1
230
+ # via gemma-2-2b-jpn-it (pyproject.toml)
231
+ triton==3.0.0
232
+ # via torch
233
+ typer==0.12.5
234
+ # via gradio
235
+ typing-extensions==4.12.2
236
+ # via
237
+ # anyio
238
+ # fastapi
239
+ # gradio
240
+ # gradio-client
241
+ # huggingface-hub
242
+ # pydantic
243
+ # pydantic-core
244
+ # rich
245
+ # spaces
246
+ # torch
247
+ # typer
248
+ # uvicorn
249
+ tzdata==2024.2
250
+ # via pandas
251
+ urllib3==2.2.3
252
+ # via
253
+ # gradio
254
+ # requests
255
+ uvicorn==0.31.0
256
+ # via gradio
257
+ websockets==12.0
258
+ # via gradio-client
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff