Spaces:
Sleeping
Sleeping
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +332 -0
- MANIFEST.in +1 -0
- README.md +1 -13
- SETUP.cfg +8 -0
- app.py +245 -0
- dockerfiles/Dockerfile.cpu +17 -0
- dockerfiles/Dockerfile.cuda +38 -0
- examples/train_retriever.py +45 -0
- pyproject.toml +15 -0
- relik/__init__.py +1 -0
- relik/common/__init__.py +0 -0
- relik/common/log.py +97 -0
- relik/common/upload.py +128 -0
- relik/common/utils.py +609 -0
- relik/inference/__init__.py +0 -0
- relik/inference/annotator.py +422 -0
- relik/inference/data/__init__.py +0 -0
- relik/inference/data/objects.py +64 -0
- relik/inference/data/tokenizers/__init__.py +89 -0
- relik/inference/data/tokenizers/base_tokenizer.py +84 -0
- relik/inference/data/tokenizers/regex_tokenizer.py +73 -0
- relik/inference/data/tokenizers/spacy_tokenizer.py +228 -0
- relik/inference/data/tokenizers/whitespace_tokenizer.py +70 -0
- relik/inference/data/window/__init__.py +0 -0
- relik/inference/data/window/manager.py +262 -0
- relik/inference/gerbil.py +254 -0
- relik/inference/preprocessing.py +4 -0
- relik/inference/serve/__init__.py +0 -0
- relik/inference/serve/backend/__init__.py +0 -0
- relik/inference/serve/backend/relik.py +210 -0
- relik/inference/serve/backend/retriever.py +206 -0
- relik/inference/serve/backend/utils.py +29 -0
- relik/inference/serve/frontend/__init__.py +0 -0
- relik/inference/serve/frontend/relik.py +231 -0
- relik/inference/serve/frontend/style.css +33 -0
- relik/reader/__init__.py +0 -0
- relik/reader/conf/config.yaml +14 -0
- relik/reader/conf/data/base.yaml +21 -0
- relik/reader/conf/data/re.yaml +54 -0
- relik/reader/conf/training/base.yaml +12 -0
- relik/reader/conf/training/re.yaml +12 -0
- relik/reader/data/__init__.py +0 -0
- relik/reader/data/patches.py +51 -0
- relik/reader/data/relik_reader_data.py +965 -0
- relik/reader/data/relik_reader_data_utils.py +51 -0
- relik/reader/data/relik_reader_sample.py +49 -0
- relik/reader/lightning_modules/__init__.py +0 -0
- relik/reader/lightning_modules/relik_reader_pl_module.py +50 -0
- relik/reader/lightning_modules/relik_reader_re_pl_module.py +54 -0
- relik/reader/pytorch_modules/__init__.py +0 -0
.gitignore
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# custom
|
2 |
+
|
3 |
+
data/*
|
4 |
+
experiments/*
|
5 |
+
retrievers
|
6 |
+
outputs
|
7 |
+
model
|
8 |
+
wandb
|
9 |
+
|
10 |
+
# Created by https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
|
11 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
|
12 |
+
|
13 |
+
### JetBrains+all ###
|
14 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
15 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
16 |
+
|
17 |
+
# User-specific stuff
|
18 |
+
.idea/**/workspace.xml
|
19 |
+
.idea/**/tasks.xml
|
20 |
+
.idea/**/usage.statistics.xml
|
21 |
+
.idea/**/dictionaries
|
22 |
+
.idea/**/shelf
|
23 |
+
|
24 |
+
# Generated files
|
25 |
+
.idea/**/contentModel.xml
|
26 |
+
|
27 |
+
# Sensitive or high-churn files
|
28 |
+
.idea/**/dataSources/
|
29 |
+
.idea/**/dataSources.ids
|
30 |
+
.idea/**/dataSources.local.xml
|
31 |
+
.idea/**/sqlDataSources.xml
|
32 |
+
.idea/**/dynamic.xml
|
33 |
+
.idea/**/uiDesigner.xml
|
34 |
+
.idea/**/dbnavigator.xml
|
35 |
+
|
36 |
+
# Gradle
|
37 |
+
.idea/**/gradle.xml
|
38 |
+
.idea/**/libraries
|
39 |
+
|
40 |
+
# Gradle and Maven with auto-import
|
41 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
42 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
43 |
+
# auto-import.
|
44 |
+
# .idea/artifacts
|
45 |
+
# .idea/compiler.xml
|
46 |
+
# .idea/jarRepositories.xml
|
47 |
+
# .idea/modules.xml
|
48 |
+
# .idea/*.iml
|
49 |
+
# .idea/modules
|
50 |
+
# *.iml
|
51 |
+
# *.ipr
|
52 |
+
|
53 |
+
# CMake
|
54 |
+
cmake-build-*/
|
55 |
+
|
56 |
+
# Mongo Explorer plugin
|
57 |
+
.idea/**/mongoSettings.xml
|
58 |
+
|
59 |
+
# File-based project format
|
60 |
+
*.iws
|
61 |
+
|
62 |
+
# IntelliJ
|
63 |
+
out/
|
64 |
+
|
65 |
+
# mpeltonen/sbt-idea plugin
|
66 |
+
.idea_modules/
|
67 |
+
|
68 |
+
# JIRA plugin
|
69 |
+
atlassian-ide-plugin.xml
|
70 |
+
|
71 |
+
# Cursive Clojure plugin
|
72 |
+
.idea/replstate.xml
|
73 |
+
|
74 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
75 |
+
com_crashlytics_export_strings.xml
|
76 |
+
crashlytics.properties
|
77 |
+
crashlytics-build.properties
|
78 |
+
fabric.properties
|
79 |
+
|
80 |
+
# Editor-based Rest Client
|
81 |
+
.idea/httpRequests
|
82 |
+
|
83 |
+
# Android studio 3.1+ serialized cache file
|
84 |
+
.idea/caches/build_file_checksums.ser
|
85 |
+
|
86 |
+
### JetBrains+all Patch ###
|
87 |
+
# Ignores the whole .idea folder and all .iml files
|
88 |
+
# See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
|
89 |
+
|
90 |
+
.idea/
|
91 |
+
|
92 |
+
# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
|
93 |
+
|
94 |
+
*.iml
|
95 |
+
modules.xml
|
96 |
+
.idea/misc.xml
|
97 |
+
*.ipr
|
98 |
+
|
99 |
+
# Sonarlint plugin
|
100 |
+
.idea/sonarlint
|
101 |
+
|
102 |
+
### JupyterNotebooks ###
|
103 |
+
# gitignore template for Jupyter Notebooks
|
104 |
+
# website: http://jupyter.org/
|
105 |
+
|
106 |
+
.ipynb_checkpoints
|
107 |
+
*/.ipynb_checkpoints/*
|
108 |
+
|
109 |
+
# IPython
|
110 |
+
profile_default/
|
111 |
+
ipython_config.py
|
112 |
+
|
113 |
+
# Remove previous ipynb_checkpoints
|
114 |
+
# git rm -r .ipynb_checkpoints/
|
115 |
+
|
116 |
+
### Linux ###
|
117 |
+
*~
|
118 |
+
|
119 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
120 |
+
.fuse_hidden*
|
121 |
+
|
122 |
+
# KDE directory preferences
|
123 |
+
.directory
|
124 |
+
|
125 |
+
# Linux trash folder which might appear on any partition or disk
|
126 |
+
.Trash-*
|
127 |
+
|
128 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
129 |
+
.nfs*
|
130 |
+
|
131 |
+
### macOS ###
|
132 |
+
# General
|
133 |
+
.DS_Store
|
134 |
+
.AppleDouble
|
135 |
+
.LSOverride
|
136 |
+
|
137 |
+
# Icon must end with two \r
|
138 |
+
Icon
|
139 |
+
|
140 |
+
|
141 |
+
# Thumbnails
|
142 |
+
._*
|
143 |
+
|
144 |
+
# Files that might appear in the root of a volume
|
145 |
+
.DocumentRevisions-V100
|
146 |
+
.fseventsd
|
147 |
+
.Spotlight-V100
|
148 |
+
.TemporaryItems
|
149 |
+
.Trashes
|
150 |
+
.VolumeIcon.icns
|
151 |
+
.com.apple.timemachine.donotpresent
|
152 |
+
|
153 |
+
# Directories potentially created on remote AFP share
|
154 |
+
.AppleDB
|
155 |
+
.AppleDesktop
|
156 |
+
Network Trash Folder
|
157 |
+
Temporary Items
|
158 |
+
.apdisk
|
159 |
+
|
160 |
+
### Python ###
|
161 |
+
# Byte-compiled / optimized / DLL files
|
162 |
+
__pycache__/
|
163 |
+
*.py[cod]
|
164 |
+
*$py.class
|
165 |
+
|
166 |
+
# C extensions
|
167 |
+
*.so
|
168 |
+
|
169 |
+
# Distribution / packaging
|
170 |
+
.Python
|
171 |
+
build/
|
172 |
+
develop-eggs/
|
173 |
+
dist/
|
174 |
+
downloads/
|
175 |
+
eggs/
|
176 |
+
.eggs/
|
177 |
+
lib/
|
178 |
+
lib64/
|
179 |
+
parts/
|
180 |
+
sdist/
|
181 |
+
var/
|
182 |
+
wheels/
|
183 |
+
pip-wheel-metadata/
|
184 |
+
share/python-wheels/
|
185 |
+
*.egg-info/
|
186 |
+
.installed.cfg
|
187 |
+
*.egg
|
188 |
+
MANIFEST
|
189 |
+
|
190 |
+
# PyInstaller
|
191 |
+
# Usually these files are written by a python script from a template
|
192 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
193 |
+
*.manifest
|
194 |
+
*.spec
|
195 |
+
|
196 |
+
# Installer logs
|
197 |
+
pip-log.txt
|
198 |
+
pip-delete-this-directory.txt
|
199 |
+
|
200 |
+
# Unit test / coverage reports
|
201 |
+
htmlcov/
|
202 |
+
.tox/
|
203 |
+
.nox/
|
204 |
+
.coverage
|
205 |
+
.coverage.*
|
206 |
+
.cache
|
207 |
+
nosetests.xml
|
208 |
+
coverage.xml
|
209 |
+
*.cover
|
210 |
+
*.py,cover
|
211 |
+
.hypothesis/
|
212 |
+
.pytest_cache/
|
213 |
+
pytestdebug.log
|
214 |
+
|
215 |
+
# Translations
|
216 |
+
*.mo
|
217 |
+
*.pot
|
218 |
+
|
219 |
+
# Django stuff:
|
220 |
+
*.log
|
221 |
+
local_settings.py
|
222 |
+
db.sqlite3
|
223 |
+
db.sqlite3-journal
|
224 |
+
|
225 |
+
# Flask stuff:
|
226 |
+
instance/
|
227 |
+
.webassets-cache
|
228 |
+
|
229 |
+
# Scrapy stuff:
|
230 |
+
.scrapy
|
231 |
+
|
232 |
+
# Sphinx documentation
|
233 |
+
docs/_build/
|
234 |
+
doc/_build/
|
235 |
+
|
236 |
+
# PyBuilder
|
237 |
+
target/
|
238 |
+
|
239 |
+
# Jupyter Notebook
|
240 |
+
|
241 |
+
# IPython
|
242 |
+
|
243 |
+
# pyenv
|
244 |
+
.python-version
|
245 |
+
|
246 |
+
# pipenv
|
247 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
248 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
249 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
250 |
+
# install all needed dependencies.
|
251 |
+
#Pipfile.lock
|
252 |
+
|
253 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
254 |
+
__pypackages__/
|
255 |
+
|
256 |
+
# Celery stuff
|
257 |
+
celerybeat-schedule
|
258 |
+
celerybeat.pid
|
259 |
+
|
260 |
+
# SageMath parsed files
|
261 |
+
*.sage.py
|
262 |
+
|
263 |
+
# Environments
|
264 |
+
.env
|
265 |
+
.venv
|
266 |
+
env/
|
267 |
+
venv/
|
268 |
+
ENV/
|
269 |
+
env.bak/
|
270 |
+
venv.bak/
|
271 |
+
pythonenv*
|
272 |
+
|
273 |
+
# Spyder project settings
|
274 |
+
.spyderproject
|
275 |
+
.spyproject
|
276 |
+
|
277 |
+
# Rope project settings
|
278 |
+
.ropeproject
|
279 |
+
|
280 |
+
# mkdocs documentation
|
281 |
+
/site
|
282 |
+
|
283 |
+
# mypy
|
284 |
+
.mypy_cache/
|
285 |
+
.dmypy.json
|
286 |
+
dmypy.json
|
287 |
+
|
288 |
+
# Pyre type checker
|
289 |
+
.pyre/
|
290 |
+
|
291 |
+
# pytype static type analyzer
|
292 |
+
.pytype/
|
293 |
+
|
294 |
+
# profiling data
|
295 |
+
.prof
|
296 |
+
|
297 |
+
### vscode ###
|
298 |
+
.vscode
|
299 |
+
.vscode/*
|
300 |
+
!.vscode/settings.json
|
301 |
+
!.vscode/tasks.json
|
302 |
+
!.vscode/launch.json
|
303 |
+
!.vscode/extensions.json
|
304 |
+
*.code-workspace
|
305 |
+
|
306 |
+
### Windows ###
|
307 |
+
# Windows thumbnail cache files
|
308 |
+
Thumbs.db
|
309 |
+
Thumbs.db:encryptable
|
310 |
+
ehthumbs.db
|
311 |
+
ehthumbs_vista.db
|
312 |
+
|
313 |
+
# Dump file
|
314 |
+
*.stackdump
|
315 |
+
|
316 |
+
# Folder config file
|
317 |
+
[Dd]esktop.ini
|
318 |
+
|
319 |
+
# Recycle Bin used on file shares
|
320 |
+
$RECYCLE.BIN/
|
321 |
+
|
322 |
+
# Windows Installer files
|
323 |
+
*.cab
|
324 |
+
*.msi
|
325 |
+
*.msix
|
326 |
+
*.msm
|
327 |
+
*.msp
|
328 |
+
|
329 |
+
# Windows shortcuts
|
330 |
+
*.lnk
|
331 |
+
|
332 |
+
# End of https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
|
MANIFEST.in
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
include requirements.txt
|
README.md
CHANGED
@@ -1,13 +1 @@
|
|
1 |
-
|
2 |
-
title: Relik
|
3 |
-
emoji: 🐨
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: pink
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.27.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# relik
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SETUP.cfg
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[metadata]
|
2 |
+
description-file = README.md
|
3 |
+
|
4 |
+
[build]
|
5 |
+
build-base = /tmp/build
|
6 |
+
|
7 |
+
[egg_info]
|
8 |
+
egg-base = /tmp
|
app.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import streamlit as st
|
8 |
+
from spacy import displacy
|
9 |
+
from streamlit_extras.badges import badge
|
10 |
+
from streamlit_extras.stylable_container import stylable_container
|
11 |
+
|
12 |
+
# RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
|
13 |
+
|
14 |
+
import random
|
15 |
+
|
16 |
+
from relik.inference.annotator import Relik
|
17 |
+
|
18 |
+
|
19 |
+
def get_random_color(ents):
|
20 |
+
colors = {}
|
21 |
+
random_colors = generate_pastel_colors(len(ents))
|
22 |
+
for ent in ents:
|
23 |
+
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
|
24 |
+
return colors
|
25 |
+
|
26 |
+
|
27 |
+
def floatrange(start, stop, steps):
|
28 |
+
if int(steps) == 1:
|
29 |
+
return [stop]
|
30 |
+
return [
|
31 |
+
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def hsl_to_rgb(h, s, l):
|
36 |
+
def hue_2_rgb(v1, v2, v_h):
|
37 |
+
while v_h < 0.0:
|
38 |
+
v_h += 1.0
|
39 |
+
while v_h > 1.0:
|
40 |
+
v_h -= 1.0
|
41 |
+
if 6 * v_h < 1.0:
|
42 |
+
return v1 + (v2 - v1) * 6.0 * v_h
|
43 |
+
if 2 * v_h < 1.0:
|
44 |
+
return v2
|
45 |
+
if 3 * v_h < 2.0:
|
46 |
+
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
|
47 |
+
return v1
|
48 |
+
|
49 |
+
# if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
|
50 |
+
# if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
|
51 |
+
|
52 |
+
r, b, g = (l * 255,) * 3
|
53 |
+
if s != 0.0:
|
54 |
+
if l < 0.5:
|
55 |
+
var_2 = l * (1.0 + s)
|
56 |
+
else:
|
57 |
+
var_2 = (l + s) - (s * l)
|
58 |
+
var_1 = 2.0 * l - var_2
|
59 |
+
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
|
60 |
+
g = 255 * hue_2_rgb(var_1, var_2, h)
|
61 |
+
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
|
62 |
+
|
63 |
+
return int(round(r)), int(round(g)), int(round(b))
|
64 |
+
|
65 |
+
|
66 |
+
def generate_pastel_colors(n):
|
67 |
+
"""Return different pastel colours.
|
68 |
+
|
69 |
+
Input:
|
70 |
+
n (integer) : The number of colors to return
|
71 |
+
|
72 |
+
Output:
|
73 |
+
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
|
74 |
+
|
75 |
+
Example:
|
76 |
+
>>> print generate_pastel_colors(5)
|
77 |
+
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
|
78 |
+
"""
|
79 |
+
if n == 0:
|
80 |
+
return []
|
81 |
+
|
82 |
+
# To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
|
83 |
+
start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
|
84 |
+
saturation = 1.0
|
85 |
+
lightness = 0.8
|
86 |
+
# We take points around the chromatic circle (hue):
|
87 |
+
# (Note: we generate n+1 colors, then drop the last one ([:-1]) because
|
88 |
+
# it equals the first one (hue 0 = hue 1))
|
89 |
+
return [
|
90 |
+
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
|
91 |
+
for hue in floatrange(start_hue, start_hue + 1, n + 1)
|
92 |
+
][:-1]
|
93 |
+
|
94 |
+
|
95 |
+
def set_sidebar(css):
|
96 |
+
white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
|
97 |
+
with st.sidebar:
|
98 |
+
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
99 |
+
st.image(
|
100 |
+
"http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
|
101 |
+
use_column_width=True,
|
102 |
+
)
|
103 |
+
st.markdown("## ReLiK")
|
104 |
+
st.write(
|
105 |
+
f"""
|
106 |
+
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")}
|
107 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")}
|
108 |
+
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")}
|
109 |
+
""",
|
110 |
+
unsafe_allow_html=True,
|
111 |
+
)
|
112 |
+
st.markdown("## Sapienza NLP")
|
113 |
+
st.write(
|
114 |
+
f"""
|
115 |
+
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")}
|
116 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")}
|
117 |
+
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")}
|
118 |
+
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")}
|
119 |
+
""",
|
120 |
+
unsafe_allow_html=True,
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
def get_el_annotations(response):
|
125 |
+
# swap labels key with ents
|
126 |
+
dict_of_ents = {"text": response.text, "ents": []}
|
127 |
+
dict_of_ents["ents"] = response.labels
|
128 |
+
label_in_text = set(l["label"] for l in dict_of_ents["ents"])
|
129 |
+
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
|
130 |
+
return dict_of_ents, options
|
131 |
+
|
132 |
+
|
133 |
+
def set_intro(css):
|
134 |
+
# intro
|
135 |
+
st.markdown("# ReLik")
|
136 |
+
st.markdown(
|
137 |
+
"### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
|
138 |
+
)
|
139 |
+
# st.markdown(
|
140 |
+
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
141 |
+
# "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
|
142 |
+
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
143 |
+
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
144 |
+
# )
|
145 |
+
badge(type="github", name="sapienzanlp/relik")
|
146 |
+
badge(type="pypi", name="relik")
|
147 |
+
|
148 |
+
|
149 |
+
def run_client():
|
150 |
+
with open(Path(__file__).parent / "style.css") as f:
|
151 |
+
css = f.read()
|
152 |
+
|
153 |
+
st.set_page_config(
|
154 |
+
page_title="ReLik",
|
155 |
+
page_icon="🦮",
|
156 |
+
layout="wide",
|
157 |
+
)
|
158 |
+
set_sidebar(css)
|
159 |
+
set_intro(css)
|
160 |
+
|
161 |
+
# text input
|
162 |
+
text = st.text_area(
|
163 |
+
"Enter Text Below:",
|
164 |
+
value="Obama went to Rome for a quick vacation.",
|
165 |
+
height=200,
|
166 |
+
max_chars=500,
|
167 |
+
)
|
168 |
+
|
169 |
+
with stylable_container(
|
170 |
+
key="annotate_button",
|
171 |
+
css_styles="""
|
172 |
+
button {
|
173 |
+
background-color: #802433;
|
174 |
+
color: white;
|
175 |
+
border-radius: 25px;
|
176 |
+
}
|
177 |
+
""",
|
178 |
+
):
|
179 |
+
submit = st.button("Annotate")
|
180 |
+
# submit = st.button("Run")
|
181 |
+
|
182 |
+
relik = Relik(
|
183 |
+
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
|
184 |
+
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
|
185 |
+
reader="riccorl/relik-reader-aida-deberta-small",
|
186 |
+
top_k=100,
|
187 |
+
window_size=32,
|
188 |
+
window_stride=16,
|
189 |
+
candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
|
190 |
+
)
|
191 |
+
|
192 |
+
# ReLik API call
|
193 |
+
if submit:
|
194 |
+
text = text.strip()
|
195 |
+
if text:
|
196 |
+
st.markdown("####")
|
197 |
+
st.markdown("#### Entity Linking")
|
198 |
+
with st.spinner(text="In progress"):
|
199 |
+
response = relik(text)
|
200 |
+
# response = requests.post(RELIK, json=text)
|
201 |
+
# if response.status_code != 200:
|
202 |
+
# st.error("Error: {}".format(response.status_code))
|
203 |
+
# else:
|
204 |
+
# response = response.json()
|
205 |
+
|
206 |
+
# Entity Linking
|
207 |
+
# with stylable_container(
|
208 |
+
# key="container_with_border",
|
209 |
+
# css_styles="""
|
210 |
+
# {
|
211 |
+
# border: 1px solid rgba(49, 51, 63, 0.2);
|
212 |
+
# border-radius: 0.5rem;
|
213 |
+
# padding: 0.5rem;
|
214 |
+
# padding-bottom: 2rem;
|
215 |
+
# }
|
216 |
+
# """,
|
217 |
+
# ):
|
218 |
+
# st.markdown("##")
|
219 |
+
dict_of_ents, options = get_el_annotations(response=response)
|
220 |
+
display = displacy.render(
|
221 |
+
dict_of_ents, manual=True, style="ent", options=options
|
222 |
+
)
|
223 |
+
display = display.replace("\n", " ")
|
224 |
+
# wsd_display = re.sub(
|
225 |
+
# r"(wiki::\d+\w)",
|
226 |
+
# r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
|
227 |
+
# language.upper()
|
228 |
+
# ),
|
229 |
+
# wsd_display,
|
230 |
+
# )
|
231 |
+
with st.container():
|
232 |
+
st.write(display, unsafe_allow_html=True)
|
233 |
+
|
234 |
+
st.markdown("####")
|
235 |
+
st.markdown("#### Relation Extraction")
|
236 |
+
|
237 |
+
with st.container():
|
238 |
+
st.write("Coming :)", unsafe_allow_html=True)
|
239 |
+
|
240 |
+
else:
|
241 |
+
st.error("Please enter some text.")
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
run_client()
|
dockerfiles/Dockerfile.cpu
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM tiangolo/uvicorn-gunicorn:python3.10-slim
|
2 |
+
|
3 |
+
# Copy and install requirements.txt
|
4 |
+
COPY ./requirements.txt ./requirements.txt
|
5 |
+
COPY ./src /app
|
6 |
+
COPY ./scripts/start.sh /start.sh
|
7 |
+
COPY ./scripts/prestart.sh /app
|
8 |
+
COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py
|
9 |
+
COPY ./scripts/start-reload.sh /start-reload.sh
|
10 |
+
COPY ./VERSION /
|
11 |
+
RUN mkdir -p /app/resources/model \
|
12 |
+
&& pip install --no-cache-dir -r requirements.txt \
|
13 |
+
&& chmod +x /start.sh && chmod +x /start-reload.sh
|
14 |
+
ARG MODEL_PATH
|
15 |
+
COPY ${MODEL_PATH}/* /app/resources/model/
|
16 |
+
|
17 |
+
ENV APP_MODULE=main:app
|
dockerfiles/Dockerfile.cuda
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.2.0-base-ubuntu20.04
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
RUN apt-get update \
|
6 |
+
&& apt-get install \
|
7 |
+
curl wget python3.10 \
|
8 |
+
python3.10-distutils \
|
9 |
+
python3-pip \
|
10 |
+
curl wget -y \
|
11 |
+
&& rm -rf /var/lib/apt/lists/*
|
12 |
+
|
13 |
+
# FastAPI section
|
14 |
+
# device env
|
15 |
+
ENV DEVICE="cuda"
|
16 |
+
# Copy and install requirements.txt
|
17 |
+
COPY ./gpu-requirements.txt ./requirements.txt
|
18 |
+
COPY ./src /app
|
19 |
+
COPY ./scripts/start.sh /start.sh
|
20 |
+
COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py
|
21 |
+
COPY ./scripts/start-reload.sh /start-reload.sh
|
22 |
+
COPY ./scripts/prestart.sh /app
|
23 |
+
COPY ./VERSION /
|
24 |
+
RUN mkdir -p /app/resources/model \
|
25 |
+
&& pip install --upgrade --no-cache-dir -r requirements.txt \
|
26 |
+
&& chmod +x /start.sh \
|
27 |
+
&& chmod +x /start-reload.sh
|
28 |
+
ARG MODEL_NAME_OR_PATH
|
29 |
+
|
30 |
+
WORKDIR /app
|
31 |
+
|
32 |
+
ENV PYTHONPATH=/app
|
33 |
+
|
34 |
+
EXPOSE 80
|
35 |
+
|
36 |
+
# Run the start script, it will check for an /app/prestart.sh script (e.g. for migrations)
|
37 |
+
# And then will start Gunicorn with Uvicorn
|
38 |
+
CMD ["/start.sh"]
|
examples/train_retriever.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from relik.retriever.trainer import RetrieverTrainer
|
2 |
+
from relik import GoldenRetriever
|
3 |
+
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
|
4 |
+
from relik.retriever.data.datasets import AidaInBatchNegativesDataset
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
# instantiate retriever
|
8 |
+
document_index = InMemoryDocumentIndex(
|
9 |
+
documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt",
|
10 |
+
device="cuda",
|
11 |
+
precision="16",
|
12 |
+
)
|
13 |
+
retriever = GoldenRetriever(
|
14 |
+
question_encoder="intfloat/e5-small-v2", document_index=document_index
|
15 |
+
)
|
16 |
+
|
17 |
+
train_dataset = AidaInBatchNegativesDataset(
|
18 |
+
name="aida_train",
|
19 |
+
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl",
|
20 |
+
tokenizer=retriever.question_tokenizer,
|
21 |
+
question_batch_size=64,
|
22 |
+
passage_batch_size=400,
|
23 |
+
max_passage_length=64,
|
24 |
+
use_topics=True,
|
25 |
+
shuffle=True,
|
26 |
+
)
|
27 |
+
val_dataset = AidaInBatchNegativesDataset(
|
28 |
+
name="aida_val",
|
29 |
+
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl",
|
30 |
+
tokenizer=retriever.question_tokenizer,
|
31 |
+
question_batch_size=64,
|
32 |
+
passage_batch_size=400,
|
33 |
+
max_passage_length=64,
|
34 |
+
use_topics=True,
|
35 |
+
)
|
36 |
+
|
37 |
+
trainer = RetrieverTrainer(
|
38 |
+
retriever=retriever,
|
39 |
+
train_dataset=train_dataset,
|
40 |
+
val_dataset=val_dataset,
|
41 |
+
max_steps=25_000,
|
42 |
+
wandb_offline_mode=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
trainer.train()
|
pyproject.toml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
include = '\.pyi?$'
|
3 |
+
exclude = '''
|
4 |
+
/(
|
5 |
+
\.git
|
6 |
+
| \.hg
|
7 |
+
| \.mypy_cache
|
8 |
+
| \.tox
|
9 |
+
| \.venv
|
10 |
+
| _build
|
11 |
+
| buck-out
|
12 |
+
| build
|
13 |
+
| dist
|
14 |
+
)/
|
15 |
+
'''
|
relik/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from relik.retriever.pytorch_modules.model import GoldenRetriever
|
relik/common/__init__.py
ADDED
File without changes
|
relik/common/log.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
import threading
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from rich import get_console
|
7 |
+
|
8 |
+
_lock = threading.Lock()
|
9 |
+
_default_handler: Optional[logging.Handler] = None
|
10 |
+
|
11 |
+
_default_log_level = logging.WARNING
|
12 |
+
|
13 |
+
# fancy logger
|
14 |
+
_console = get_console()
|
15 |
+
|
16 |
+
|
17 |
+
def _get_library_name() -> str:
|
18 |
+
return __name__.split(".")[0]
|
19 |
+
|
20 |
+
|
21 |
+
def _get_library_root_logger() -> logging.Logger:
|
22 |
+
return logging.getLogger(_get_library_name())
|
23 |
+
|
24 |
+
|
25 |
+
def _configure_library_root_logger() -> None:
|
26 |
+
global _default_handler
|
27 |
+
|
28 |
+
with _lock:
|
29 |
+
if _default_handler:
|
30 |
+
# This library has already configured the library root logger.
|
31 |
+
return
|
32 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
33 |
+
_default_handler.flush = sys.stderr.flush
|
34 |
+
|
35 |
+
# Apply our default configuration to the library root logger.
|
36 |
+
library_root_logger = _get_library_root_logger()
|
37 |
+
library_root_logger.addHandler(_default_handler)
|
38 |
+
library_root_logger.setLevel(_default_log_level)
|
39 |
+
library_root_logger.propagate = False
|
40 |
+
|
41 |
+
|
42 |
+
def _reset_library_root_logger() -> None:
|
43 |
+
global _default_handler
|
44 |
+
|
45 |
+
with _lock:
|
46 |
+
if not _default_handler:
|
47 |
+
return
|
48 |
+
|
49 |
+
library_root_logger = _get_library_root_logger()
|
50 |
+
library_root_logger.removeHandler(_default_handler)
|
51 |
+
library_root_logger.setLevel(logging.NOTSET)
|
52 |
+
_default_handler = None
|
53 |
+
|
54 |
+
|
55 |
+
def set_log_level(level: int, logger: logging.Logger = None) -> None:
|
56 |
+
"""
|
57 |
+
Set the log level.
|
58 |
+
Args:
|
59 |
+
level (:obj:`int`):
|
60 |
+
Logging level.
|
61 |
+
logger (:obj:`logging.Logger`):
|
62 |
+
Logger to set the log level.
|
63 |
+
"""
|
64 |
+
if not logger:
|
65 |
+
_configure_library_root_logger()
|
66 |
+
logger = _get_library_root_logger()
|
67 |
+
logger.setLevel(level)
|
68 |
+
|
69 |
+
|
70 |
+
def get_logger(
|
71 |
+
name: Optional[str] = None,
|
72 |
+
level: Optional[int] = None,
|
73 |
+
formatter: Optional[str] = None,
|
74 |
+
) -> logging.Logger:
|
75 |
+
"""
|
76 |
+
Return a logger with the specified name.
|
77 |
+
"""
|
78 |
+
|
79 |
+
if name is None:
|
80 |
+
name = _get_library_name()
|
81 |
+
|
82 |
+
_configure_library_root_logger()
|
83 |
+
|
84 |
+
if level is not None:
|
85 |
+
set_log_level(level)
|
86 |
+
|
87 |
+
if formatter is None:
|
88 |
+
formatter = logging.Formatter(
|
89 |
+
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
90 |
+
)
|
91 |
+
_default_handler.setFormatter(formatter)
|
92 |
+
|
93 |
+
return logging.getLogger(name)
|
94 |
+
|
95 |
+
|
96 |
+
def get_console_logger():
|
97 |
+
return _console
|
relik/common/upload.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import tempfile
|
6 |
+
import zipfile
|
7 |
+
from datetime import datetime
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
import huggingface_hub
|
12 |
+
|
13 |
+
from relik.common.log import get_logger
|
14 |
+
from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
|
15 |
+
|
16 |
+
logger = get_logger(level=logging.DEBUG)
|
17 |
+
|
18 |
+
|
19 |
+
def create_info_file(tmpdir: Path):
|
20 |
+
logger.debug("Computing md5 of model.zip")
|
21 |
+
md5 = get_md5(tmpdir / "model.zip")
|
22 |
+
date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
|
23 |
+
|
24 |
+
logger.debug("Dumping info.json file")
|
25 |
+
with (tmpdir / "info.json").open("w") as f:
|
26 |
+
json.dump(dict(md5=md5, upload_date=date), f, indent=2)
|
27 |
+
|
28 |
+
|
29 |
+
def zip_run(
|
30 |
+
dir_path: Union[str, os.PathLike],
|
31 |
+
tmpdir: Union[str, os.PathLike],
|
32 |
+
zip_name: str = "model.zip",
|
33 |
+
) -> Path:
|
34 |
+
logger.debug(f"zipping {dir_path} to {tmpdir}")
|
35 |
+
# creates a zip version of the provided dir_path
|
36 |
+
run_dir = Path(dir_path)
|
37 |
+
zip_path = tmpdir / zip_name
|
38 |
+
|
39 |
+
with zipfile.ZipFile(zip_path, "w") as zip_file:
|
40 |
+
# fully zip the run directory maintaining its structure
|
41 |
+
for file in run_dir.rglob("*.*"):
|
42 |
+
if file.is_dir():
|
43 |
+
continue
|
44 |
+
|
45 |
+
zip_file.write(file, arcname=file.relative_to(run_dir))
|
46 |
+
|
47 |
+
return zip_path
|
48 |
+
|
49 |
+
|
50 |
+
def upload(
|
51 |
+
model_dir: Union[str, os.PathLike],
|
52 |
+
model_name: str,
|
53 |
+
organization: Optional[str] = None,
|
54 |
+
repo_name: Optional[str] = None,
|
55 |
+
commit: Optional[str] = None,
|
56 |
+
archive: bool = False,
|
57 |
+
):
|
58 |
+
token = huggingface_hub.HfFolder.get_token()
|
59 |
+
if token is None:
|
60 |
+
print(
|
61 |
+
"No HuggingFace token found. You need to execute `huggingface-cli login` first!"
|
62 |
+
)
|
63 |
+
return
|
64 |
+
|
65 |
+
repo_id = repo_name or model_name
|
66 |
+
if organization is not None:
|
67 |
+
repo_id = f"{organization}/{repo_id}"
|
68 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
69 |
+
api = huggingface_hub.HfApi()
|
70 |
+
repo_url = api.create_repo(
|
71 |
+
token=token,
|
72 |
+
repo_id=repo_id,
|
73 |
+
exist_ok=True,
|
74 |
+
)
|
75 |
+
repo = huggingface_hub.Repository(
|
76 |
+
str(tmpdir), clone_from=repo_url, use_auth_token=token
|
77 |
+
)
|
78 |
+
|
79 |
+
tmp_path = Path(tmpdir)
|
80 |
+
if archive:
|
81 |
+
# otherwise we zip the model_dir
|
82 |
+
logger.debug(f"Zipping {model_dir} to {tmp_path}")
|
83 |
+
zip_run(model_dir, tmp_path)
|
84 |
+
create_info_file(tmp_path)
|
85 |
+
else:
|
86 |
+
# if the user wants to upload a transformers model, we don't need to zip it
|
87 |
+
# we just need to copy the files to the tmpdir
|
88 |
+
logger.debug(f"Copying {model_dir} to {tmpdir}")
|
89 |
+
os.system(f"cp -r {model_dir}/* {tmpdir}")
|
90 |
+
|
91 |
+
# this method automatically puts large files (>10MB) into git lfs
|
92 |
+
repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
|
93 |
+
|
94 |
+
|
95 |
+
def parse_args() -> argparse.Namespace:
|
96 |
+
parser = argparse.ArgumentParser()
|
97 |
+
parser.add_argument(
|
98 |
+
"model_dir", help="The directory of the model you want to upload"
|
99 |
+
)
|
100 |
+
parser.add_argument("model_name", help="The model you want to upload")
|
101 |
+
parser.add_argument(
|
102 |
+
"--organization",
|
103 |
+
help="the name of the organization where you want to upload the model",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--repo_name",
|
107 |
+
help="Optional name to use when uploading to the HuggingFace repository",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--commit", help="Commit message to use when pushing to the HuggingFace Hub"
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--archive",
|
114 |
+
action="store_true",
|
115 |
+
help="""
|
116 |
+
Whether to compress the model directory before uploading it.
|
117 |
+
If True, the model directory will be zipped and the zip file will be uploaded.
|
118 |
+
If False, the model directory will be uploaded as is.""",
|
119 |
+
)
|
120 |
+
return parser.parse_args()
|
121 |
+
|
122 |
+
|
123 |
+
def main():
|
124 |
+
upload(**vars(parse_args()))
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
main()
|
relik/common/utils.py
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import tarfile
|
7 |
+
import tempfile
|
8 |
+
from functools import partial
|
9 |
+
from hashlib import sha256
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Any, BinaryIO, Dict, List, Optional, Union
|
12 |
+
from urllib.parse import urlparse
|
13 |
+
from zipfile import ZipFile, is_zipfile
|
14 |
+
|
15 |
+
import huggingface_hub
|
16 |
+
import requests
|
17 |
+
import tqdm
|
18 |
+
from filelock import FileLock
|
19 |
+
from transformers.utils.hub import cached_file as hf_cached_file
|
20 |
+
|
21 |
+
from relik.common.log import get_logger
|
22 |
+
|
23 |
+
# name constants
|
24 |
+
WEIGHTS_NAME = "weights.pt"
|
25 |
+
ONNX_WEIGHTS_NAME = "weights.onnx"
|
26 |
+
CONFIG_NAME = "config.yaml"
|
27 |
+
LABELS_NAME = "labels.json"
|
28 |
+
|
29 |
+
# SAPIENZANLP_USER_NAME = "sapienzanlp"
|
30 |
+
SAPIENZANLP_USER_NAME = "riccorl"
|
31 |
+
SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}"
|
32 |
+
SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = (
|
33 |
+
f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip"
|
34 |
+
)
|
35 |
+
# path constants
|
36 |
+
SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp")
|
37 |
+
SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S"
|
38 |
+
|
39 |
+
|
40 |
+
logger = get_logger(__name__)
|
41 |
+
|
42 |
+
|
43 |
+
def sapienzanlp_model_urls(model_id: str) -> str:
|
44 |
+
"""
|
45 |
+
Returns the URL for a possible SapienzaNLP valid model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
model_id (:obj:`str`):
|
49 |
+
A SapienzaNLP model id.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
:obj:`str`: The url for the model id.
|
53 |
+
"""
|
54 |
+
# check if there is already the namespace of the user
|
55 |
+
if "/" in model_id:
|
56 |
+
return model_id
|
57 |
+
return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id)
|
58 |
+
|
59 |
+
|
60 |
+
def is_package_available(package_name: str) -> bool:
|
61 |
+
"""
|
62 |
+
Check if a package is available.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
package_name (`str`): The name of the package to check.
|
66 |
+
"""
|
67 |
+
return importlib.util.find_spec(package_name) is not None
|
68 |
+
|
69 |
+
|
70 |
+
def load_json(path: Union[str, Path]) -> Any:
|
71 |
+
"""
|
72 |
+
Load a json file provided in input.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
path (`Union[str, Path]`): The path to the json file to load.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
`Any`: The loaded json file.
|
79 |
+
"""
|
80 |
+
with open(path, encoding="utf8") as f:
|
81 |
+
return json.load(f)
|
82 |
+
|
83 |
+
|
84 |
+
def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None):
|
85 |
+
"""
|
86 |
+
Dump input to json file.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
document (`Any`): The document to dump.
|
90 |
+
path (`Union[str, Path]`): The path to dump the document to.
|
91 |
+
indent (`Optional[int]`): The indent to use for the json file.
|
92 |
+
|
93 |
+
"""
|
94 |
+
with open(path, "w", encoding="utf8") as outfile:
|
95 |
+
json.dump(document, outfile, indent=indent)
|
96 |
+
|
97 |
+
|
98 |
+
def get_md5(path: Path):
|
99 |
+
"""
|
100 |
+
Get the MD5 value of a path.
|
101 |
+
"""
|
102 |
+
import hashlib
|
103 |
+
|
104 |
+
with path.open("rb") as fin:
|
105 |
+
data = fin.read()
|
106 |
+
return hashlib.md5(data).hexdigest()
|
107 |
+
|
108 |
+
|
109 |
+
def file_exists(path: Union[str, os.PathLike]) -> bool:
|
110 |
+
"""
|
111 |
+
Check if the file at :obj:`path` exists.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
path (:obj:`str`, :obj:`os.PathLike`):
|
115 |
+
Path to check.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
:obj:`bool`: :obj:`True` if the file exists.
|
119 |
+
"""
|
120 |
+
return Path(path).exists()
|
121 |
+
|
122 |
+
|
123 |
+
def dir_exists(path: Union[str, os.PathLike]) -> bool:
|
124 |
+
"""
|
125 |
+
Check if the directory at :obj:`path` exists.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
path (:obj:`str`, :obj:`os.PathLike`):
|
129 |
+
Path to check.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
:obj:`bool`: :obj:`True` if the directory exists.
|
133 |
+
"""
|
134 |
+
return Path(path).is_dir()
|
135 |
+
|
136 |
+
|
137 |
+
def is_remote_url(url_or_filename: Union[str, Path]):
|
138 |
+
"""
|
139 |
+
Returns :obj:`True` if the input path is an url.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
url_or_filename (:obj:`str`, :obj:`Path`):
|
143 |
+
path to check.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
:obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise.
|
147 |
+
|
148 |
+
"""
|
149 |
+
if isinstance(url_or_filename, Path):
|
150 |
+
url_or_filename = str(url_or_filename)
|
151 |
+
parsed = urlparse(url_or_filename)
|
152 |
+
return parsed.scheme in ("http", "https")
|
153 |
+
|
154 |
+
|
155 |
+
def url_to_filename(resource: str, etag: str = None) -> str:
|
156 |
+
"""
|
157 |
+
Convert a `resource` into a hashed filename in a repeatable way.
|
158 |
+
If `etag` is specified, append its hash to the resources's, delimited
|
159 |
+
by a period.
|
160 |
+
"""
|
161 |
+
resource_bytes = resource.encode("utf-8")
|
162 |
+
resource_hash = sha256(resource_bytes)
|
163 |
+
filename = resource_hash.hexdigest()
|
164 |
+
|
165 |
+
if etag:
|
166 |
+
etag_bytes = etag.encode("utf-8")
|
167 |
+
etag_hash = sha256(etag_bytes)
|
168 |
+
filename += "." + etag_hash.hexdigest()
|
169 |
+
|
170 |
+
return filename
|
171 |
+
|
172 |
+
|
173 |
+
def download_resource(
|
174 |
+
url: str,
|
175 |
+
temp_file: BinaryIO,
|
176 |
+
headers=None,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Download remote file.
|
180 |
+
"""
|
181 |
+
|
182 |
+
if headers is None:
|
183 |
+
headers = {}
|
184 |
+
|
185 |
+
r = requests.get(url, stream=True, headers=headers)
|
186 |
+
r.raise_for_status()
|
187 |
+
content_length = r.headers.get("Content-Length")
|
188 |
+
total = int(content_length) if content_length is not None else None
|
189 |
+
progress = tqdm(
|
190 |
+
unit="B",
|
191 |
+
unit_scale=True,
|
192 |
+
total=total,
|
193 |
+
desc="Downloading",
|
194 |
+
disable=logger.level in [logging.NOTSET],
|
195 |
+
)
|
196 |
+
for chunk in r.iter_content(chunk_size=1024):
|
197 |
+
if chunk: # filter out keep-alive new chunks
|
198 |
+
progress.update(len(chunk))
|
199 |
+
temp_file.write(chunk)
|
200 |
+
progress.close()
|
201 |
+
|
202 |
+
|
203 |
+
def download_and_cache(
|
204 |
+
url: Union[str, Path],
|
205 |
+
cache_dir: Union[str, Path] = None,
|
206 |
+
force_download: bool = False,
|
207 |
+
):
|
208 |
+
if cache_dir is None:
|
209 |
+
cache_dir = SAPIENZANLP_CACHE_DIR
|
210 |
+
if isinstance(url, Path):
|
211 |
+
url = str(url)
|
212 |
+
|
213 |
+
# check if cache dir exists
|
214 |
+
Path(cache_dir).mkdir(parents=True, exist_ok=True)
|
215 |
+
|
216 |
+
# check if file is private
|
217 |
+
headers = {}
|
218 |
+
try:
|
219 |
+
r = requests.head(url, allow_redirects=False, timeout=10)
|
220 |
+
r.raise_for_status()
|
221 |
+
except requests.exceptions.HTTPError:
|
222 |
+
if r.status_code == 401:
|
223 |
+
hf_token = huggingface_hub.HfFolder.get_token()
|
224 |
+
if hf_token is None:
|
225 |
+
raise ValueError(
|
226 |
+
"You need to login to HuggingFace to download this model "
|
227 |
+
"(use the `huggingface-cli login` command)"
|
228 |
+
)
|
229 |
+
headers["Authorization"] = f"Bearer {hf_token}"
|
230 |
+
|
231 |
+
etag = None
|
232 |
+
try:
|
233 |
+
r = requests.head(url, allow_redirects=True, timeout=10, headers=headers)
|
234 |
+
r.raise_for_status()
|
235 |
+
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
236 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
237 |
+
# we fallback to the regular etag header.
|
238 |
+
# If we don't have any of those, raise an error.
|
239 |
+
if etag is None:
|
240 |
+
raise OSError(
|
241 |
+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
242 |
+
)
|
243 |
+
# In case of a redirect,
|
244 |
+
# save an extra redirect on the request.get call,
|
245 |
+
# and ensure we download the exact atomic version even if it changed
|
246 |
+
# between the HEAD and the GET (unlikely, but hey).
|
247 |
+
if 300 <= r.status_code <= 399:
|
248 |
+
url = r.headers["Location"]
|
249 |
+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
250 |
+
# Actually raise for those subclasses of ConnectionError
|
251 |
+
raise
|
252 |
+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
253 |
+
# Otherwise, our Internet connection is down.
|
254 |
+
# etag is None
|
255 |
+
pass
|
256 |
+
|
257 |
+
# get filename from the url
|
258 |
+
filename = url_to_filename(url, etag)
|
259 |
+
# get cache path to put the file
|
260 |
+
cache_path = cache_dir / filename
|
261 |
+
|
262 |
+
# the file is already here, return it
|
263 |
+
if file_exists(cache_path) and not force_download:
|
264 |
+
logger.info(
|
265 |
+
f"{url} found in cache, set `force_download=True` to force the download"
|
266 |
+
)
|
267 |
+
return cache_path
|
268 |
+
|
269 |
+
cache_path = str(cache_path)
|
270 |
+
# Prevent parallel downloads of the same file with a lock.
|
271 |
+
lock_path = cache_path + ".lock"
|
272 |
+
with FileLock(lock_path):
|
273 |
+
# If the download just completed while the lock was activated.
|
274 |
+
if file_exists(cache_path) and not force_download:
|
275 |
+
# Even if returning early like here, the lock will be released.
|
276 |
+
return cache_path
|
277 |
+
|
278 |
+
temp_file_manager = partial(
|
279 |
+
tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
|
280 |
+
)
|
281 |
+
|
282 |
+
# Download to temporary file, then copy to cache dir once finished.
|
283 |
+
# Otherwise, you get corrupt cache entries if the download gets interrupted.
|
284 |
+
with temp_file_manager() as temp_file:
|
285 |
+
logger.info(
|
286 |
+
f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}"
|
287 |
+
)
|
288 |
+
download_resource(url, temp_file, headers)
|
289 |
+
|
290 |
+
logger.info(f"storing {url} in cache at {cache_path}")
|
291 |
+
os.replace(temp_file.name, cache_path)
|
292 |
+
|
293 |
+
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
|
294 |
+
umask = os.umask(0o666)
|
295 |
+
os.umask(umask)
|
296 |
+
os.chmod(cache_path, 0o666 & ~umask)
|
297 |
+
|
298 |
+
logger.info(f"creating metadata file for {cache_path}")
|
299 |
+
meta = {"url": url} # , "etag": etag}
|
300 |
+
meta_path = cache_path + ".json"
|
301 |
+
with open(meta_path, "w") as meta_file:
|
302 |
+
json.dump(meta, meta_file)
|
303 |
+
|
304 |
+
return cache_path
|
305 |
+
|
306 |
+
|
307 |
+
def download_from_hf(
|
308 |
+
path_or_repo_id: Union[str, Path],
|
309 |
+
filenames: Optional[List[str]],
|
310 |
+
cache_dir: Union[str, Path] = None,
|
311 |
+
force_download: bool = False,
|
312 |
+
resume_download: bool = False,
|
313 |
+
proxies: Optional[Dict[str, str]] = None,
|
314 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
315 |
+
revision: Optional[str] = None,
|
316 |
+
local_files_only: bool = False,
|
317 |
+
subfolder: str = "",
|
318 |
+
):
|
319 |
+
if isinstance(path_or_repo_id, Path):
|
320 |
+
path_or_repo_id = str(path_or_repo_id)
|
321 |
+
|
322 |
+
downloaded_paths = []
|
323 |
+
for filename in filenames:
|
324 |
+
downloaded_path = hf_cached_file(
|
325 |
+
path_or_repo_id,
|
326 |
+
filename,
|
327 |
+
cache_dir=cache_dir,
|
328 |
+
force_download=force_download,
|
329 |
+
proxies=proxies,
|
330 |
+
resume_download=resume_download,
|
331 |
+
use_auth_token=use_auth_token,
|
332 |
+
revision=revision,
|
333 |
+
local_files_only=local_files_only,
|
334 |
+
subfolder=subfolder,
|
335 |
+
)
|
336 |
+
downloaded_paths.append(downloaded_path)
|
337 |
+
|
338 |
+
# we want the folder where the files are downloaded
|
339 |
+
# the best guess is the parent folder of the first file
|
340 |
+
probably_the_folder = Path(downloaded_paths[0]).parent
|
341 |
+
return probably_the_folder
|
342 |
+
|
343 |
+
|
344 |
+
def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str:
|
345 |
+
"""
|
346 |
+
Resolve a model name or directory to a model archive name or directory.
|
347 |
+
|
348 |
+
Args:
|
349 |
+
model_name_or_dir (:obj:`str` or :obj:`os.PathLike`):
|
350 |
+
A model name or directory.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
:obj:`str`: The model archive name or directory.
|
354 |
+
"""
|
355 |
+
if is_remote_url(model_name_or_dir):
|
356 |
+
# if model_name_or_dir is a URL
|
357 |
+
# download it and try to load
|
358 |
+
model_archive = model_name_or_dir
|
359 |
+
elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file():
|
360 |
+
# if model_name_or_dir is a local directory or
|
361 |
+
# an archive file try to load it
|
362 |
+
model_archive = model_name_or_dir
|
363 |
+
else:
|
364 |
+
# probably model_name_or_dir is a sapienzanlp model id
|
365 |
+
# guess the url and try to download
|
366 |
+
model_name_or_dir_ = model_name_or_dir
|
367 |
+
# raise ValueError(f"Providing a model id is not supported yet.")
|
368 |
+
model_archive = sapienzanlp_model_urls(model_name_or_dir_)
|
369 |
+
|
370 |
+
return model_archive
|
371 |
+
|
372 |
+
|
373 |
+
def from_cache(
|
374 |
+
url_or_filename: Union[str, Path],
|
375 |
+
cache_dir: Union[str, Path] = None,
|
376 |
+
force_download: bool = False,
|
377 |
+
resume_download: bool = False,
|
378 |
+
proxies: Optional[Dict[str, str]] = None,
|
379 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
380 |
+
revision: Optional[str] = None,
|
381 |
+
local_files_only: bool = False,
|
382 |
+
subfolder: str = "",
|
383 |
+
filenames: Optional[List[str]] = None,
|
384 |
+
) -> Path:
|
385 |
+
"""
|
386 |
+
Given something that could be either a local path or a URL (or a SapienzaNLP model id),
|
387 |
+
determine which one and return a path to the corresponding file.
|
388 |
+
|
389 |
+
Args:
|
390 |
+
url_or_filename (:obj:`str` or :obj:`Path`):
|
391 |
+
A path to a local file or a URL (or a SapienzaNLP model id).
|
392 |
+
cache_dir (:obj:`str` or :obj:`Path`, `optional`):
|
393 |
+
Path to a directory in which a downloaded file will be cached.
|
394 |
+
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
395 |
+
Whether or not to re-download the file even if it already exists.
|
396 |
+
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
397 |
+
Whether or not to delete incompletely received files. Attempts to resume the download if such a file
|
398 |
+
exists.
|
399 |
+
proxies (:obj:`Dict[str, str]`, `optional`):
|
400 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
401 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
402 |
+
use_auth_token (:obj:`Union[bool, str]`, `optional`):
|
403 |
+
Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from
|
404 |
+
:obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token.
|
405 |
+
revision (:obj:`str`, `optional`):
|
406 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
407 |
+
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
408 |
+
identifier allowed by git.
|
409 |
+
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
410 |
+
Whether or not to raise an error if the file to be downloaded is local.
|
411 |
+
subfolder (:obj:`str`, `optional`):
|
412 |
+
In case the relevant file is in a subfolder of the URL, specify it here.
|
413 |
+
filenames (:obj:`List[str]`, `optional`):
|
414 |
+
List of filenames to look for in the directory structure.
|
415 |
+
|
416 |
+
Returns:
|
417 |
+
:obj:`Path`: Path to the cached file.
|
418 |
+
"""
|
419 |
+
|
420 |
+
url_or_filename = model_name_or_path_resolver(url_or_filename)
|
421 |
+
|
422 |
+
if cache_dir is None:
|
423 |
+
cache_dir = SAPIENZANLP_CACHE_DIR
|
424 |
+
|
425 |
+
if file_exists(url_or_filename):
|
426 |
+
logger.info(f"{url_or_filename} is a local path or file")
|
427 |
+
output_path = url_or_filename
|
428 |
+
elif is_remote_url(url_or_filename):
|
429 |
+
# URL, so get it from the cache (downloading if necessary)
|
430 |
+
output_path = download_and_cache(
|
431 |
+
url_or_filename,
|
432 |
+
cache_dir=cache_dir,
|
433 |
+
force_download=force_download,
|
434 |
+
)
|
435 |
+
else:
|
436 |
+
if filenames is None:
|
437 |
+
filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME]
|
438 |
+
output_path = download_from_hf(
|
439 |
+
url_or_filename,
|
440 |
+
filenames,
|
441 |
+
cache_dir,
|
442 |
+
force_download,
|
443 |
+
resume_download,
|
444 |
+
proxies,
|
445 |
+
use_auth_token,
|
446 |
+
revision,
|
447 |
+
local_files_only,
|
448 |
+
subfolder,
|
449 |
+
)
|
450 |
+
|
451 |
+
# if is_hf_hub_url(url_or_filename):
|
452 |
+
# HuggingFace Hub
|
453 |
+
# output_path = hf_hub_download_url(url_or_filename)
|
454 |
+
# elif is_remote_url(url_or_filename):
|
455 |
+
# # URL, so get it from the cache (downloading if necessary)
|
456 |
+
# output_path = download_and_cache(
|
457 |
+
# url_or_filename,
|
458 |
+
# cache_dir=cache_dir,
|
459 |
+
# force_download=force_download,
|
460 |
+
# )
|
461 |
+
# elif file_exists(url_or_filename):
|
462 |
+
# logger.info(f"{url_or_filename} is a local path or file")
|
463 |
+
# # File, and it exists.
|
464 |
+
# output_path = url_or_filename
|
465 |
+
# elif urlparse(url_or_filename).scheme == "":
|
466 |
+
# # File, but it doesn't exist.
|
467 |
+
# raise EnvironmentError(f"file {url_or_filename} not found")
|
468 |
+
# else:
|
469 |
+
# # Something unknown
|
470 |
+
# raise ValueError(
|
471 |
+
# f"unable to parse {url_or_filename} as a URL or as a local path"
|
472 |
+
# )
|
473 |
+
|
474 |
+
if dir_exists(output_path) or (
|
475 |
+
not is_zipfile(output_path) and not tarfile.is_tarfile(output_path)
|
476 |
+
):
|
477 |
+
return Path(output_path)
|
478 |
+
|
479 |
+
# Path where we extract compressed archives
|
480 |
+
# for now it will extract it in the same folder
|
481 |
+
# maybe implement extraction in the sapienzanlp folder
|
482 |
+
# when using local archive path?
|
483 |
+
logger.info("Extracting compressed archive")
|
484 |
+
output_dir, output_file = os.path.split(output_path)
|
485 |
+
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
486 |
+
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
487 |
+
|
488 |
+
# already extracted, do not extract
|
489 |
+
if (
|
490 |
+
os.path.isdir(output_path_extracted)
|
491 |
+
and os.listdir(output_path_extracted)
|
492 |
+
and not force_download
|
493 |
+
):
|
494 |
+
return Path(output_path_extracted)
|
495 |
+
|
496 |
+
# Prevent parallel extractions
|
497 |
+
lock_path = output_path + ".lock"
|
498 |
+
with FileLock(lock_path):
|
499 |
+
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
500 |
+
os.makedirs(output_path_extracted)
|
501 |
+
if is_zipfile(output_path):
|
502 |
+
with ZipFile(output_path, "r") as zip_file:
|
503 |
+
zip_file.extractall(output_path_extracted)
|
504 |
+
zip_file.close()
|
505 |
+
elif tarfile.is_tarfile(output_path):
|
506 |
+
tar_file = tarfile.open(output_path)
|
507 |
+
tar_file.extractall(output_path_extracted)
|
508 |
+
tar_file.close()
|
509 |
+
else:
|
510 |
+
raise EnvironmentError(
|
511 |
+
f"Archive format of {output_path} could not be identified"
|
512 |
+
)
|
513 |
+
|
514 |
+
# remove lock file, is it safe?
|
515 |
+
os.remove(lock_path)
|
516 |
+
|
517 |
+
return Path(output_path_extracted)
|
518 |
+
|
519 |
+
|
520 |
+
def is_str_a_path(maybe_path: str) -> bool:
|
521 |
+
"""
|
522 |
+
Check if a string is a path.
|
523 |
+
|
524 |
+
Args:
|
525 |
+
maybe_path (`str`): The string to check.
|
526 |
+
|
527 |
+
Returns:
|
528 |
+
`bool`: `True` if the string is a path, `False` otherwise.
|
529 |
+
"""
|
530 |
+
# first check if it is a path
|
531 |
+
if Path(maybe_path).exists():
|
532 |
+
return True
|
533 |
+
# check if it is a relative path
|
534 |
+
if Path(os.path.join(os.getcwd(), maybe_path)).exists():
|
535 |
+
return True
|
536 |
+
# otherwise it is not a path
|
537 |
+
return False
|
538 |
+
|
539 |
+
|
540 |
+
def relative_to_absolute_path(path: str) -> os.PathLike:
|
541 |
+
"""
|
542 |
+
Convert a relative path to an absolute path.
|
543 |
+
|
544 |
+
Args:
|
545 |
+
path (`str`): The relative path to convert.
|
546 |
+
|
547 |
+
Returns:
|
548 |
+
`os.PathLike`: The absolute path.
|
549 |
+
"""
|
550 |
+
if not is_str_a_path(path):
|
551 |
+
raise ValueError(f"{path} is not a path")
|
552 |
+
if Path(path).exists():
|
553 |
+
return Path(path).absolute()
|
554 |
+
if Path(os.path.join(os.getcwd(), path)).exists():
|
555 |
+
return Path(os.path.join(os.getcwd(), path)).absolute()
|
556 |
+
raise ValueError(f"{path} is not a path")
|
557 |
+
|
558 |
+
|
559 |
+
def to_config(object_to_save: Any) -> Dict[str, Any]:
|
560 |
+
"""
|
561 |
+
Convert an object to a dictionary.
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
`Dict[str, Any]`: The dictionary representation of the object.
|
565 |
+
"""
|
566 |
+
|
567 |
+
def obj_to_dict(obj):
|
568 |
+
match obj:
|
569 |
+
case dict():
|
570 |
+
data = {}
|
571 |
+
for k, v in obj.items():
|
572 |
+
data[k] = obj_to_dict(v)
|
573 |
+
return data
|
574 |
+
|
575 |
+
case list() | tuple():
|
576 |
+
return [obj_to_dict(x) for x in obj]
|
577 |
+
|
578 |
+
case object(__dict__=_):
|
579 |
+
data = {
|
580 |
+
"_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
|
581 |
+
}
|
582 |
+
for k, v in obj.__dict__.items():
|
583 |
+
if not k.startswith("_"):
|
584 |
+
data[k] = obj_to_dict(v)
|
585 |
+
return data
|
586 |
+
|
587 |
+
case _:
|
588 |
+
return obj
|
589 |
+
|
590 |
+
return obj_to_dict(object_to_save)
|
591 |
+
|
592 |
+
|
593 |
+
def get_callable_from_string(callable_fn: str) -> Any:
|
594 |
+
"""
|
595 |
+
Get a callable from a string.
|
596 |
+
|
597 |
+
Args:
|
598 |
+
callable_fn (`str`):
|
599 |
+
The string representation of the callable.
|
600 |
+
|
601 |
+
Returns:
|
602 |
+
`Any`: The callable.
|
603 |
+
"""
|
604 |
+
# separate the function name from the module name
|
605 |
+
module_name, function_name = callable_fn.rsplit(".", 1)
|
606 |
+
# import the module
|
607 |
+
module = importlib.import_module(module_name)
|
608 |
+
# get the function
|
609 |
+
return getattr(module, function_name)
|
relik/inference/__init__.py
ADDED
File without changes
|
relik/inference/annotator.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Any, Callable, Dict, Optional, Union
|
4 |
+
|
5 |
+
import hydra
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
|
8 |
+
from rich.pretty import pprint
|
9 |
+
|
10 |
+
from relik.common.log import get_console_logger, get_logger
|
11 |
+
from relik.common.upload import upload
|
12 |
+
from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string
|
13 |
+
from relik.inference.data.objects import EntitySpan, RelikOutput
|
14 |
+
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
15 |
+
from relik.inference.data.window.manager import WindowManager
|
16 |
+
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
|
17 |
+
from relik.reader.relik_reader import RelikReader
|
18 |
+
from relik.retriever.data.utils import batch_generator
|
19 |
+
from relik.retriever.indexers.base import BaseDocumentIndex
|
20 |
+
from relik.retriever.pytorch_modules.model import GoldenRetriever
|
21 |
+
|
22 |
+
logger = get_logger(__name__)
|
23 |
+
console_logger = get_console_logger()
|
24 |
+
|
25 |
+
|
26 |
+
class Relik:
|
27 |
+
"""
|
28 |
+
Relik main class. It is a wrapper around a retriever and a reader.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
retriever (`Optional[GoldenRetriever]`, `optional`):
|
32 |
+
The retriever to use. If `None`, a retriever will be instantiated from the
|
33 |
+
provided `question_encoder`, `passage_encoder` and `document_index`.
|
34 |
+
Defaults to `None`.
|
35 |
+
question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
|
36 |
+
The question encoder to use. If `retriever` is `None`, a retriever will be
|
37 |
+
instantiated from this parameter. Defaults to `None`.
|
38 |
+
passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
|
39 |
+
The passage encoder to use. If `retriever` is `None`, a retriever will be
|
40 |
+
instantiated from this parameter. Defaults to `None`.
|
41 |
+
document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`):
|
42 |
+
The document index to use. If `retriever` is `None`, a retriever will be
|
43 |
+
instantiated from this parameter. Defaults to `None`.
|
44 |
+
reader (`Optional[Union[str, RelikReader]]`, `optional`):
|
45 |
+
The reader to use. If `None`, a reader will be instantiated from the
|
46 |
+
provided `reader`. Defaults to `None`.
|
47 |
+
retriever_device (`str`, `optional`, defaults to `cpu`):
|
48 |
+
The device to use for the retriever.
|
49 |
+
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
retriever: GoldenRetriever | None = None,
|
55 |
+
question_encoder: str | GoldenRetrieverModel | None = None,
|
56 |
+
passage_encoder: str | GoldenRetrieverModel | None = None,
|
57 |
+
document_index: str | BaseDocumentIndex | None = None,
|
58 |
+
reader: str | RelikReader | None = None,
|
59 |
+
device: str = "cpu",
|
60 |
+
retriever_device: str | None = None,
|
61 |
+
document_index_device: str | None = None,
|
62 |
+
reader_device: str | None = None,
|
63 |
+
precision: int = 32,
|
64 |
+
retriever_precision: int | None = None,
|
65 |
+
document_index_precision: int | None = None,
|
66 |
+
reader_precision: int | None = None,
|
67 |
+
reader_kwargs: dict | None = None,
|
68 |
+
retriever_kwargs: dict | None = None,
|
69 |
+
candidates_preprocessing_fn: str | Callable | None = None,
|
70 |
+
top_k: int | None = None,
|
71 |
+
window_size: int | None = None,
|
72 |
+
window_stride: int | None = None,
|
73 |
+
**kwargs,
|
74 |
+
) -> None:
|
75 |
+
# retriever
|
76 |
+
retriever_device = retriever_device or device
|
77 |
+
document_index_device = document_index_device or device
|
78 |
+
retriever_precision = retriever_precision or precision
|
79 |
+
document_index_precision = document_index_precision or precision
|
80 |
+
if retriever is None and question_encoder is None:
|
81 |
+
raise ValueError(
|
82 |
+
"Either `retriever` or `question_encoder` must be provided"
|
83 |
+
)
|
84 |
+
if retriever is None:
|
85 |
+
self.retriever_kwargs = dict(
|
86 |
+
question_encoder=question_encoder,
|
87 |
+
passage_encoder=passage_encoder,
|
88 |
+
document_index=document_index,
|
89 |
+
device=retriever_device,
|
90 |
+
precision=retriever_precision,
|
91 |
+
index_device=document_index_device,
|
92 |
+
index_precision=document_index_precision,
|
93 |
+
)
|
94 |
+
# overwrite default_retriever_kwargs with retriever_kwargs
|
95 |
+
self.retriever_kwargs.update(retriever_kwargs or {})
|
96 |
+
retriever = GoldenRetriever(**self.retriever_kwargs)
|
97 |
+
retriever.training = False
|
98 |
+
retriever.eval()
|
99 |
+
self.retriever = retriever
|
100 |
+
|
101 |
+
# reader
|
102 |
+
self.reader_device = reader_device or device
|
103 |
+
self.reader_precision = reader_precision or precision
|
104 |
+
self.reader_kwargs = reader_kwargs
|
105 |
+
if isinstance(reader, str):
|
106 |
+
reader_kwargs = reader_kwargs or {}
|
107 |
+
reader = RelikReaderForSpanExtraction(reader, **reader_kwargs)
|
108 |
+
self.reader = reader
|
109 |
+
|
110 |
+
# windowization stuff
|
111 |
+
self.tokenizer = SpacyTokenizer(language="en")
|
112 |
+
self.window_manager: WindowManager | None = None
|
113 |
+
|
114 |
+
# candidates preprocessing
|
115 |
+
# TODO: maybe move this logic somewhere else
|
116 |
+
candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x)
|
117 |
+
if isinstance(candidates_preprocessing_fn, str):
|
118 |
+
candidates_preprocessing_fn = get_callable_from_string(
|
119 |
+
candidates_preprocessing_fn
|
120 |
+
)
|
121 |
+
self.candidates_preprocessing_fn = candidates_preprocessing_fn
|
122 |
+
|
123 |
+
# inference params
|
124 |
+
self.top_k = top_k
|
125 |
+
self.window_size = window_size
|
126 |
+
self.window_stride = window_stride
|
127 |
+
|
128 |
+
def __call__(
|
129 |
+
self,
|
130 |
+
text: Union[str, list],
|
131 |
+
top_k: Optional[int] = None,
|
132 |
+
window_size: Optional[int] = None,
|
133 |
+
window_stride: Optional[int] = None,
|
134 |
+
retriever_batch_size: Optional[int] = 32,
|
135 |
+
reader_batch_size: Optional[int] = 32,
|
136 |
+
return_also_windows: bool = False,
|
137 |
+
**kwargs,
|
138 |
+
) -> Union[RelikOutput, list[RelikOutput]]:
|
139 |
+
"""
|
140 |
+
Annotate a text with entities.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
text (`str` or `list`):
|
144 |
+
The text to annotate. If a list is provided, each element of the list
|
145 |
+
will be annotated separately.
|
146 |
+
top_k (`int`, `optional`, defaults to `None`):
|
147 |
+
The number of candidates to retrieve for each window.
|
148 |
+
window_size (`int`, `optional`, defaults to `None`):
|
149 |
+
The size of the window. If `None`, the whole text will be annotated.
|
150 |
+
window_stride (`int`, `optional`, defaults to `None`):
|
151 |
+
The stride of the window. If `None`, there will be no overlap between windows.
|
152 |
+
retriever_batch_size (`int`, `optional`, defaults to `None`):
|
153 |
+
The batch size to use for the retriever. The whole input is the batch for the retriever.
|
154 |
+
reader_batch_size (`int`, `optional`, defaults to `None`):
|
155 |
+
The batch size to use for the reader. The whole input is the batch for the reader.
|
156 |
+
return_also_windows (`bool`, `optional`, defaults to `False`):
|
157 |
+
Whether to return the windows in the output.
|
158 |
+
**kwargs:
|
159 |
+
Additional keyword arguments to pass to the retriever and the reader.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
`RelikOutput` or `list[RelikOutput]`:
|
163 |
+
The annotated text. If a list was provided as input, a list of
|
164 |
+
`RelikOutput` objects will be returned.
|
165 |
+
"""
|
166 |
+
if top_k is None:
|
167 |
+
top_k = self.top_k or 100
|
168 |
+
if window_size is None:
|
169 |
+
window_size = self.window_size
|
170 |
+
if window_stride is None:
|
171 |
+
window_stride = self.window_stride
|
172 |
+
|
173 |
+
if isinstance(text, str):
|
174 |
+
text = [text]
|
175 |
+
|
176 |
+
if window_size is not None:
|
177 |
+
if self.window_manager is None:
|
178 |
+
self.window_manager = WindowManager(self.tokenizer)
|
179 |
+
|
180 |
+
if window_size == "sentence":
|
181 |
+
# todo: implement sentence windowizer
|
182 |
+
raise NotImplementedError("Sentence windowizer not implemented yet")
|
183 |
+
|
184 |
+
# if window_size < window_stride:
|
185 |
+
# raise ValueError(
|
186 |
+
# f"Window size ({window_size}) must be greater than window stride ({window_stride})"
|
187 |
+
# )
|
188 |
+
|
189 |
+
# window generator
|
190 |
+
windows = [
|
191 |
+
window
|
192 |
+
for doc_id, t in enumerate(text)
|
193 |
+
for window in self.window_manager.create_windows(
|
194 |
+
t,
|
195 |
+
window_size=window_size,
|
196 |
+
stride=window_stride,
|
197 |
+
doc_id=doc_id,
|
198 |
+
)
|
199 |
+
]
|
200 |
+
|
201 |
+
# retrieve candidates first
|
202 |
+
windows_candidates = []
|
203 |
+
# TODO: Move batching inside retriever
|
204 |
+
for batch in batch_generator(windows, batch_size=retriever_batch_size):
|
205 |
+
retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k)
|
206 |
+
windows_candidates.extend(
|
207 |
+
[[p.label for p in predictions] for predictions in retriever_out]
|
208 |
+
)
|
209 |
+
|
210 |
+
# add passage to the windows
|
211 |
+
for window, candidates in zip(windows, windows_candidates):
|
212 |
+
window.window_candidates = [
|
213 |
+
self.candidates_preprocessing_fn(c) for c in candidates
|
214 |
+
]
|
215 |
+
|
216 |
+
windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size)
|
217 |
+
windows = self.window_manager.merge_windows(windows)
|
218 |
+
|
219 |
+
# transform predictions into RelikOutput objects
|
220 |
+
output = []
|
221 |
+
for w in windows:
|
222 |
+
sample_output = RelikOutput(
|
223 |
+
text=text[w.doc_id],
|
224 |
+
labels=sorted(
|
225 |
+
[
|
226 |
+
EntitySpan(
|
227 |
+
start=ss, end=se, label=sl, text=text[w.doc_id][ss:se]
|
228 |
+
)
|
229 |
+
for ss, se, sl in w.predicted_window_labels_chars
|
230 |
+
],
|
231 |
+
key=lambda x: x.start,
|
232 |
+
),
|
233 |
+
)
|
234 |
+
output.append(sample_output)
|
235 |
+
|
236 |
+
if return_also_windows:
|
237 |
+
for i, sample_output in enumerate(output):
|
238 |
+
sample_output.windows = [w for w in windows if w.doc_id == i]
|
239 |
+
|
240 |
+
# if only one text was provided, return a single RelikOutput object
|
241 |
+
if len(output) == 1:
|
242 |
+
return output[0]
|
243 |
+
|
244 |
+
return output
|
245 |
+
|
246 |
+
@classmethod
|
247 |
+
def from_pretrained(
|
248 |
+
cls,
|
249 |
+
model_name_or_dir: Union[str, os.PathLike],
|
250 |
+
config_kwargs: Optional[Dict] = None,
|
251 |
+
config_file_name: str = CONFIG_NAME,
|
252 |
+
*args,
|
253 |
+
**kwargs,
|
254 |
+
) -> "Relik":
|
255 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
256 |
+
force_download = kwargs.pop("force_download", False)
|
257 |
+
|
258 |
+
model_dir = from_cache(
|
259 |
+
model_name_or_dir,
|
260 |
+
filenames=[config_file_name],
|
261 |
+
cache_dir=cache_dir,
|
262 |
+
force_download=force_download,
|
263 |
+
)
|
264 |
+
|
265 |
+
config_path = model_dir / config_file_name
|
266 |
+
if not config_path.exists():
|
267 |
+
raise FileNotFoundError(
|
268 |
+
f"Model configuration file not found at {config_path}."
|
269 |
+
)
|
270 |
+
|
271 |
+
# overwrite config with config_kwargs
|
272 |
+
config = OmegaConf.load(config_path)
|
273 |
+
if config_kwargs is not None:
|
274 |
+
# TODO: check merging behavior
|
275 |
+
config = OmegaConf.merge(config, OmegaConf.create(config_kwargs))
|
276 |
+
# do we want to print the config? I like it
|
277 |
+
pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
|
278 |
+
|
279 |
+
# load relik from config
|
280 |
+
relik = hydra.utils.instantiate(config, *args, **kwargs)
|
281 |
+
|
282 |
+
return relik
|
283 |
+
|
284 |
+
def save_pretrained(
|
285 |
+
self,
|
286 |
+
output_dir: Union[str, os.PathLike],
|
287 |
+
config: Optional[Dict[str, Any]] = None,
|
288 |
+
config_file_name: Optional[str] = None,
|
289 |
+
save_weights: bool = False,
|
290 |
+
push_to_hub: bool = False,
|
291 |
+
model_id: Optional[str] = None,
|
292 |
+
organization: Optional[str] = None,
|
293 |
+
repo_name: Optional[str] = None,
|
294 |
+
**kwargs,
|
295 |
+
):
|
296 |
+
"""
|
297 |
+
Save the configuration of Relik to the specified directory as a YAML file.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
output_dir (`str`):
|
301 |
+
The directory to save the configuration file to.
|
302 |
+
config (`Optional[Dict[str, Any]]`, `optional`):
|
303 |
+
The configuration to save. If `None`, the current configuration will be
|
304 |
+
saved. Defaults to `None`.
|
305 |
+
config_file_name (`Optional[str]`, `optional`):
|
306 |
+
The name of the configuration file. Defaults to `config.yaml`.
|
307 |
+
save_weights (`bool`, `optional`):
|
308 |
+
Whether to save the weights of the model. Defaults to `False`.
|
309 |
+
push_to_hub (`bool`, `optional`):
|
310 |
+
Whether to push the saved model to the hub. Defaults to `False`.
|
311 |
+
model_id (`Optional[str]`, `optional`):
|
312 |
+
The id of the model to push to the hub. If `None`, the name of the
|
313 |
+
directory will be used. Defaults to `None`.
|
314 |
+
organization (`Optional[str]`, `optional`):
|
315 |
+
The organization to push the model to. Defaults to `None`.
|
316 |
+
repo_name (`Optional[str]`, `optional`):
|
317 |
+
The name of the repository to push the model to. Defaults to `None`.
|
318 |
+
**kwargs:
|
319 |
+
Additional keyword arguments to pass to `OmegaConf.save`.
|
320 |
+
"""
|
321 |
+
if config is None:
|
322 |
+
# create a default config
|
323 |
+
config = {
|
324 |
+
"_target_": f"{self.__class__.__module__}.{self.__class__.__name__}"
|
325 |
+
}
|
326 |
+
if self.retriever is not None:
|
327 |
+
if self.retriever.question_encoder is not None:
|
328 |
+
config[
|
329 |
+
"question_encoder"
|
330 |
+
] = self.retriever.question_encoder.name_or_path
|
331 |
+
if self.retriever.passage_encoder is not None:
|
332 |
+
config[
|
333 |
+
"passage_encoder"
|
334 |
+
] = self.retriever.passage_encoder.name_or_path
|
335 |
+
if self.retriever.document_index is not None:
|
336 |
+
config["document_index"] = self.retriever.document_index.name_or_dir
|
337 |
+
if self.reader is not None:
|
338 |
+
config["reader"] = self.reader.model_path
|
339 |
+
|
340 |
+
config["retriever_kwargs"] = self.retriever_kwargs
|
341 |
+
config["reader_kwargs"] = self.reader_kwargs
|
342 |
+
# expand the fn as to be able to save it and load it later
|
343 |
+
config[
|
344 |
+
"candidates_preprocessing_fn"
|
345 |
+
] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}"
|
346 |
+
|
347 |
+
# these are model-specific and should be saved
|
348 |
+
config["top_k"] = self.top_k
|
349 |
+
config["window_size"] = self.window_size
|
350 |
+
config["window_stride"] = self.window_stride
|
351 |
+
|
352 |
+
config_file_name = config_file_name or CONFIG_NAME
|
353 |
+
|
354 |
+
# create the output directory
|
355 |
+
output_dir = Path(output_dir)
|
356 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
357 |
+
|
358 |
+
logger.info(f"Saving relik config to {output_dir / config_file_name}")
|
359 |
+
# pretty print the config
|
360 |
+
pprint(config, console=console_logger, expand_all=True)
|
361 |
+
OmegaConf.save(config, output_dir / config_file_name)
|
362 |
+
|
363 |
+
if save_weights:
|
364 |
+
model_id = model_id or output_dir.name
|
365 |
+
retriever_model_id = model_id + "-retriever"
|
366 |
+
# save weights
|
367 |
+
logger.info(f"Saving retriever to {output_dir / retriever_model_id}")
|
368 |
+
self.retriever.save_pretrained(
|
369 |
+
output_dir / retriever_model_id,
|
370 |
+
question_encoder_name=retriever_model_id + "-question-encoder",
|
371 |
+
passage_encoder_name=retriever_model_id + "-passage-encoder",
|
372 |
+
document_index_name=retriever_model_id + "-index",
|
373 |
+
push_to_hub=push_to_hub,
|
374 |
+
organization=organization,
|
375 |
+
repo_name=repo_name,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
reader_model_id = model_id + "-reader"
|
379 |
+
logger.info(f"Saving reader to {output_dir / reader_model_id}")
|
380 |
+
self.reader.save_pretrained(
|
381 |
+
output_dir / reader_model_id,
|
382 |
+
push_to_hub=push_to_hub,
|
383 |
+
organization=organization,
|
384 |
+
repo_name=repo_name,
|
385 |
+
**kwargs,
|
386 |
+
)
|
387 |
+
|
388 |
+
if push_to_hub:
|
389 |
+
# push to hub
|
390 |
+
logger.info(f"Pushing to hub")
|
391 |
+
model_id = model_id or output_dir.name
|
392 |
+
upload(output_dir, model_id, organization=organization, repo_name=repo_name)
|
393 |
+
|
394 |
+
|
395 |
+
def main():
|
396 |
+
from pprint import pprint
|
397 |
+
|
398 |
+
relik = Relik(
|
399 |
+
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
|
400 |
+
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
|
401 |
+
reader="riccorl/relik-reader-aida-deberta-small",
|
402 |
+
device="cuda",
|
403 |
+
precision=16,
|
404 |
+
top_k=100,
|
405 |
+
window_size=32,
|
406 |
+
window_stride=16,
|
407 |
+
candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
|
408 |
+
)
|
409 |
+
|
410 |
+
input_text = """
|
411 |
+
Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore.
|
412 |
+
The 92-year-old billionaire did not disclose the trust to the government in July 2015.
|
413 |
+
Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty.
|
414 |
+
Ecclestone had been due to go on trial next month.
|
415 |
+
"""
|
416 |
+
|
417 |
+
preds = relik(input_text)
|
418 |
+
pprint(preds)
|
419 |
+
|
420 |
+
|
421 |
+
if __name__ == "__main__":
|
422 |
+
main()
|
relik/inference/data/__init__.py
ADDED
File without changes
|
relik/inference/data/objects.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, NamedTuple, Optional
|
5 |
+
|
6 |
+
from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class Word:
|
11 |
+
"""
|
12 |
+
A word representation that includes text, index in the sentence, POS tag, lemma,
|
13 |
+
dependency relation, and similar information.
|
14 |
+
|
15 |
+
# Parameters
|
16 |
+
text : `str`, optional
|
17 |
+
The text representation.
|
18 |
+
index : `int`, optional
|
19 |
+
The word offset in the sentence.
|
20 |
+
lemma : `str`, optional
|
21 |
+
The lemma of this word.
|
22 |
+
pos : `str`, optional
|
23 |
+
The coarse-grained part of speech of this word.
|
24 |
+
dep : `str`, optional
|
25 |
+
The dependency relation for this word.
|
26 |
+
|
27 |
+
input_id : `int`, optional
|
28 |
+
Integer representation of the word, used to pass it to a model.
|
29 |
+
token_type_id : `int`, optional
|
30 |
+
Token type id used by some transformers.
|
31 |
+
attention_mask: `int`, optional
|
32 |
+
Attention mask used by transformers, indicates to the model which tokens should
|
33 |
+
be attended to, and which should not.
|
34 |
+
"""
|
35 |
+
|
36 |
+
text: str
|
37 |
+
index: int
|
38 |
+
start_char: Optional[int] = None
|
39 |
+
end_char: Optional[int] = None
|
40 |
+
# preprocessing fields
|
41 |
+
lemma: Optional[str] = None
|
42 |
+
pos: Optional[str] = None
|
43 |
+
dep: Optional[str] = None
|
44 |
+
head: Optional[int] = None
|
45 |
+
|
46 |
+
def __str__(self):
|
47 |
+
return self.text
|
48 |
+
|
49 |
+
def __repr__(self):
|
50 |
+
return self.__str__()
|
51 |
+
|
52 |
+
|
53 |
+
class EntitySpan(NamedTuple):
|
54 |
+
start: int
|
55 |
+
end: int
|
56 |
+
label: str
|
57 |
+
text: str
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class RelikOutput:
|
62 |
+
text: str
|
63 |
+
labels: List[EntitySpan]
|
64 |
+
windows: Optional[List[RelikReaderSample]] = None
|
relik/inference/data/tokenizers/__init__.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SPACY_LANGUAGE_MAPPER = {
|
2 |
+
"ca": "ca_core_news_sm",
|
3 |
+
"da": "da_core_news_sm",
|
4 |
+
"de": "de_core_news_sm",
|
5 |
+
"el": "el_core_news_sm",
|
6 |
+
"en": "en_core_web_sm",
|
7 |
+
"es": "es_core_news_sm",
|
8 |
+
"fr": "fr_core_news_sm",
|
9 |
+
"it": "it_core_news_sm",
|
10 |
+
"ja": "ja_core_news_sm",
|
11 |
+
"lt": "lt_core_news_sm",
|
12 |
+
"mk": "mk_core_news_sm",
|
13 |
+
"nb": "nb_core_news_sm",
|
14 |
+
"nl": "nl_core_news_sm",
|
15 |
+
"pl": "pl_core_news_sm",
|
16 |
+
"pt": "pt_core_news_sm",
|
17 |
+
"ro": "ro_core_news_sm",
|
18 |
+
"ru": "ru_core_news_sm",
|
19 |
+
"xx": "xx_sent_ud_sm",
|
20 |
+
"zh": "zh_core_web_sm",
|
21 |
+
"ca_core_news_sm": "ca_core_news_sm",
|
22 |
+
"ca_core_news_md": "ca_core_news_md",
|
23 |
+
"ca_core_news_lg": "ca_core_news_lg",
|
24 |
+
"ca_core_news_trf": "ca_core_news_trf",
|
25 |
+
"da_core_news_sm": "da_core_news_sm",
|
26 |
+
"da_core_news_md": "da_core_news_md",
|
27 |
+
"da_core_news_lg": "da_core_news_lg",
|
28 |
+
"da_core_news_trf": "da_core_news_trf",
|
29 |
+
"de_core_news_sm": "de_core_news_sm",
|
30 |
+
"de_core_news_md": "de_core_news_md",
|
31 |
+
"de_core_news_lg": "de_core_news_lg",
|
32 |
+
"de_dep_news_trf": "de_dep_news_trf",
|
33 |
+
"el_core_news_sm": "el_core_news_sm",
|
34 |
+
"el_core_news_md": "el_core_news_md",
|
35 |
+
"el_core_news_lg": "el_core_news_lg",
|
36 |
+
"en_core_web_sm": "en_core_web_sm",
|
37 |
+
"en_core_web_md": "en_core_web_md",
|
38 |
+
"en_core_web_lg": "en_core_web_lg",
|
39 |
+
"en_core_web_trf": "en_core_web_trf",
|
40 |
+
"es_core_news_sm": "es_core_news_sm",
|
41 |
+
"es_core_news_md": "es_core_news_md",
|
42 |
+
"es_core_news_lg": "es_core_news_lg",
|
43 |
+
"es_dep_news_trf": "es_dep_news_trf",
|
44 |
+
"fr_core_news_sm": "fr_core_news_sm",
|
45 |
+
"fr_core_news_md": "fr_core_news_md",
|
46 |
+
"fr_core_news_lg": "fr_core_news_lg",
|
47 |
+
"fr_dep_news_trf": "fr_dep_news_trf",
|
48 |
+
"it_core_news_sm": "it_core_news_sm",
|
49 |
+
"it_core_news_md": "it_core_news_md",
|
50 |
+
"it_core_news_lg": "it_core_news_lg",
|
51 |
+
"ja_core_news_sm": "ja_core_news_sm",
|
52 |
+
"ja_core_news_md": "ja_core_news_md",
|
53 |
+
"ja_core_news_lg": "ja_core_news_lg",
|
54 |
+
"ja_dep_news_trf": "ja_dep_news_trf",
|
55 |
+
"lt_core_news_sm": "lt_core_news_sm",
|
56 |
+
"lt_core_news_md": "lt_core_news_md",
|
57 |
+
"lt_core_news_lg": "lt_core_news_lg",
|
58 |
+
"mk_core_news_sm": "mk_core_news_sm",
|
59 |
+
"mk_core_news_md": "mk_core_news_md",
|
60 |
+
"mk_core_news_lg": "mk_core_news_lg",
|
61 |
+
"nb_core_news_sm": "nb_core_news_sm",
|
62 |
+
"nb_core_news_md": "nb_core_news_md",
|
63 |
+
"nb_core_news_lg": "nb_core_news_lg",
|
64 |
+
"nl_core_news_sm": "nl_core_news_sm",
|
65 |
+
"nl_core_news_md": "nl_core_news_md",
|
66 |
+
"nl_core_news_lg": "nl_core_news_lg",
|
67 |
+
"pl_core_news_sm": "pl_core_news_sm",
|
68 |
+
"pl_core_news_md": "pl_core_news_md",
|
69 |
+
"pl_core_news_lg": "pl_core_news_lg",
|
70 |
+
"pt_core_news_sm": "pt_core_news_sm",
|
71 |
+
"pt_core_news_md": "pt_core_news_md",
|
72 |
+
"pt_core_news_lg": "pt_core_news_lg",
|
73 |
+
"ro_core_news_sm": "ro_core_news_sm",
|
74 |
+
"ro_core_news_md": "ro_core_news_md",
|
75 |
+
"ro_core_news_lg": "ro_core_news_lg",
|
76 |
+
"ru_core_news_sm": "ru_core_news_sm",
|
77 |
+
"ru_core_news_md": "ru_core_news_md",
|
78 |
+
"ru_core_news_lg": "ru_core_news_lg",
|
79 |
+
"xx_ent_wiki_sm": "xx_ent_wiki_sm",
|
80 |
+
"xx_sent_ud_sm": "xx_sent_ud_sm",
|
81 |
+
"zh_core_web_sm": "zh_core_web_sm",
|
82 |
+
"zh_core_web_md": "zh_core_web_md",
|
83 |
+
"zh_core_web_lg": "zh_core_web_lg",
|
84 |
+
"zh_core_web_trf": "zh_core_web_trf",
|
85 |
+
}
|
86 |
+
|
87 |
+
from relik.inference.data.tokenizers.regex_tokenizer import RegexTokenizer
|
88 |
+
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
89 |
+
from relik.inference.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer
|
relik/inference/data/tokenizers/base_tokenizer.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
from relik.inference.data.objects import Word
|
4 |
+
|
5 |
+
|
6 |
+
class BaseTokenizer:
|
7 |
+
"""
|
8 |
+
A :obj:`Tokenizer` splits strings of text into single words, optionally adds
|
9 |
+
pos tags and perform lemmatization.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __call__(
|
13 |
+
self,
|
14 |
+
texts: Union[str, List[str], List[List[str]]],
|
15 |
+
is_split_into_words: bool = False,
|
16 |
+
**kwargs
|
17 |
+
) -> List[List[Word]]:
|
18 |
+
"""
|
19 |
+
Tokenize the input into single words.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
23 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
24 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
25 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
29 |
+
"""
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
def tokenize(self, text: str) -> List[Word]:
|
33 |
+
"""
|
34 |
+
Implements splitting words into tokens.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
text (:obj:`str`):
|
38 |
+
Text to tokenize.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
:obj:`List[Word]`: The input text tokenized in single words.
|
42 |
+
|
43 |
+
"""
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
|
47 |
+
"""
|
48 |
+
Implements batch splitting words into tokens.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
texts (:obj:`List[str]`):
|
52 |
+
Batch of text to tokenize.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
:obj:`List[List[Word]]`: The input batch tokenized in single words.
|
56 |
+
|
57 |
+
"""
|
58 |
+
return [self.tokenize(text) for text in texts]
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def check_is_batched(
|
62 |
+
texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
|
63 |
+
):
|
64 |
+
"""
|
65 |
+
Check if input is batched or a single sample.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
69 |
+
Text to check.
|
70 |
+
is_split_into_words (:obj:`bool`):
|
71 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
:obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
|
75 |
+
"""
|
76 |
+
return bool(
|
77 |
+
(not is_split_into_words and isinstance(texts, (list, tuple)))
|
78 |
+
or (
|
79 |
+
is_split_into_words
|
80 |
+
and isinstance(texts, (list, tuple))
|
81 |
+
and texts
|
82 |
+
and isinstance(texts[0], (list, tuple))
|
83 |
+
)
|
84 |
+
)
|
relik/inference/data/tokenizers/regex_tokenizer.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
from overrides import overrides
|
5 |
+
|
6 |
+
from relik.inference.data.objects import Word
|
7 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
class RegexTokenizer(BaseTokenizer):
|
11 |
+
"""
|
12 |
+
A :obj:`Tokenizer` that splits the text based on a simple regex.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super(RegexTokenizer, self).__init__()
|
17 |
+
# regex for splitting on spaces and punctuation and new lines
|
18 |
+
# self._regex = re.compile(r"\S+|[\[\](),.!?;:\"]|\\n")
|
19 |
+
self._regex = re.compile(
|
20 |
+
r"\w+|\$[\d\.]+|\S+", re.UNICODE | re.MULTILINE | re.DOTALL
|
21 |
+
)
|
22 |
+
|
23 |
+
def __call__(
|
24 |
+
self,
|
25 |
+
texts: Union[str, List[str], List[List[str]]],
|
26 |
+
is_split_into_words: bool = False,
|
27 |
+
**kwargs,
|
28 |
+
) -> List[List[Word]]:
|
29 |
+
"""
|
30 |
+
Tokenize the input into single words by splitting using a simple regex.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
34 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
35 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
36 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
40 |
+
|
41 |
+
Example::
|
42 |
+
|
43 |
+
>>> from relik.retriever.serve.tokenizers.regex_tokenizer import RegexTokenizer
|
44 |
+
|
45 |
+
>>> regex_tokenizer = RegexTokenizer()
|
46 |
+
>>> regex_tokenizer("Mary sold the car to John.")
|
47 |
+
|
48 |
+
"""
|
49 |
+
# check if input is batched or a single sample
|
50 |
+
is_batched = self.check_is_batched(texts, is_split_into_words)
|
51 |
+
|
52 |
+
if is_batched:
|
53 |
+
tokenized = self.tokenize_batch(texts)
|
54 |
+
else:
|
55 |
+
tokenized = self.tokenize(texts)
|
56 |
+
|
57 |
+
return tokenized
|
58 |
+
|
59 |
+
@overrides
|
60 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
61 |
+
if not isinstance(text, (str, list)):
|
62 |
+
raise ValueError(
|
63 |
+
f"text must be either `str` or `list`, found: `{type(text)}`"
|
64 |
+
)
|
65 |
+
|
66 |
+
if isinstance(text, list):
|
67 |
+
text = " ".join(text)
|
68 |
+
return [
|
69 |
+
Word(t[0], i, start_char=t[1], end_char=t[2])
|
70 |
+
for i, t in enumerate(
|
71 |
+
(m.group(0), m.start(), m.end()) for m in self._regex.finditer(text)
|
72 |
+
)
|
73 |
+
]
|
relik/inference/data/tokenizers/spacy_tokenizer.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Dict, List, Tuple, Union
|
3 |
+
|
4 |
+
import spacy
|
5 |
+
|
6 |
+
# from ipa.common.utils import load_spacy
|
7 |
+
from overrides import overrides
|
8 |
+
from spacy.cli.download import download as spacy_download
|
9 |
+
from spacy.tokens import Doc
|
10 |
+
|
11 |
+
from relik.common.log import get_logger
|
12 |
+
from relik.inference.data.objects import Word
|
13 |
+
from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
|
14 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
15 |
+
|
16 |
+
logger = get_logger(level=logging.DEBUG)
|
17 |
+
|
18 |
+
# Spacy and Stanza stuff
|
19 |
+
|
20 |
+
LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
|
21 |
+
|
22 |
+
|
23 |
+
def load_spacy(
|
24 |
+
language: str,
|
25 |
+
pos_tags: bool = False,
|
26 |
+
lemma: bool = False,
|
27 |
+
parse: bool = False,
|
28 |
+
split_on_spaces: bool = False,
|
29 |
+
) -> spacy.Language:
|
30 |
+
"""
|
31 |
+
Download and load spacy model.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
language (:obj:`str`, defaults to :obj:`en`):
|
35 |
+
Language of the text to tokenize.
|
36 |
+
pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
|
37 |
+
If :obj:`True`, performs POS tagging with spacy model.
|
38 |
+
lemma (:obj:`bool`, optional, defaults to :obj:`False`):
|
39 |
+
If :obj:`True`, performs lemmatization with spacy model.
|
40 |
+
parse (:obj:`bool`, optional, defaults to :obj:`False`):
|
41 |
+
If :obj:`True`, performs dependency parsing with spacy model.
|
42 |
+
split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
|
43 |
+
If :obj:`True`, will split by spaces without performing tokenization.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
:obj:`spacy.Language`: The spacy model loaded.
|
47 |
+
"""
|
48 |
+
exclude = ["vectors", "textcat", "ner"]
|
49 |
+
if not pos_tags:
|
50 |
+
exclude.append("tagger")
|
51 |
+
if not lemma:
|
52 |
+
exclude.append("lemmatizer")
|
53 |
+
if not parse:
|
54 |
+
exclude.append("parser")
|
55 |
+
|
56 |
+
# check if the model is already loaded
|
57 |
+
# if so, there is no need to reload it
|
58 |
+
spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
|
59 |
+
if spacy_params not in LOADED_SPACY_MODELS:
|
60 |
+
try:
|
61 |
+
spacy_tagger = spacy.load(language, exclude=exclude)
|
62 |
+
except OSError:
|
63 |
+
logger.warning(
|
64 |
+
"Spacy model '%s' not found. Downloading and installing.", language
|
65 |
+
)
|
66 |
+
spacy_download(language)
|
67 |
+
spacy_tagger = spacy.load(language, exclude=exclude)
|
68 |
+
|
69 |
+
# if everything is disabled, return only the tokenizer
|
70 |
+
# for faster tokenization
|
71 |
+
# TODO: is it really faster?
|
72 |
+
# if len(exclude) >= 6:
|
73 |
+
# spacy_tagger = spacy_tagger.tokenizer
|
74 |
+
LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
|
75 |
+
|
76 |
+
return LOADED_SPACY_MODELS[spacy_params]
|
77 |
+
|
78 |
+
|
79 |
+
class SpacyTokenizer(BaseTokenizer):
|
80 |
+
"""
|
81 |
+
A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
language (:obj:`str`, optional, defaults to :obj:`en`):
|
85 |
+
Language of the text to tokenize.
|
86 |
+
return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
|
87 |
+
If :obj:`True`, performs POS tagging with spacy model.
|
88 |
+
return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
|
89 |
+
If :obj:`True`, performs lemmatization with spacy model.
|
90 |
+
return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
|
91 |
+
If :obj:`True`, performs dependency parsing with spacy model.
|
92 |
+
split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
|
93 |
+
If :obj:`True`, will split by spaces without performing tokenization.
|
94 |
+
use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
|
95 |
+
If :obj:`True`, will load the Stanza model on GPU.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
language: str = "en",
|
101 |
+
return_pos_tags: bool = False,
|
102 |
+
return_lemmas: bool = False,
|
103 |
+
return_deps: bool = False,
|
104 |
+
split_on_spaces: bool = False,
|
105 |
+
use_gpu: bool = False,
|
106 |
+
):
|
107 |
+
super(SpacyTokenizer, self).__init__()
|
108 |
+
if language not in SPACY_LANGUAGE_MAPPER:
|
109 |
+
raise ValueError(
|
110 |
+
f"`{language}` language not supported. The supported "
|
111 |
+
f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
|
112 |
+
)
|
113 |
+
if use_gpu:
|
114 |
+
# load the model on GPU
|
115 |
+
# if the GPU is not available or not correctly configured,
|
116 |
+
# it will rise an error
|
117 |
+
spacy.require_gpu()
|
118 |
+
self.spacy = load_spacy(
|
119 |
+
SPACY_LANGUAGE_MAPPER[language],
|
120 |
+
return_pos_tags,
|
121 |
+
return_lemmas,
|
122 |
+
return_deps,
|
123 |
+
split_on_spaces,
|
124 |
+
)
|
125 |
+
self.split_on_spaces = split_on_spaces
|
126 |
+
|
127 |
+
def __call__(
|
128 |
+
self,
|
129 |
+
texts: Union[str, List[str], List[List[str]]],
|
130 |
+
is_split_into_words: bool = False,
|
131 |
+
**kwargs,
|
132 |
+
) -> Union[List[Word], List[List[Word]]]:
|
133 |
+
"""
|
134 |
+
Tokenize the input into single words using SpaCy models.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
138 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
139 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
140 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
144 |
+
|
145 |
+
Example::
|
146 |
+
|
147 |
+
>>> from ipa import SpacyTokenizer
|
148 |
+
|
149 |
+
>>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
|
150 |
+
>>> spacy_tokenizer("Mary sold the car to John.")
|
151 |
+
|
152 |
+
"""
|
153 |
+
# check if input is batched or a single sample
|
154 |
+
is_batched = self.check_is_batched(texts, is_split_into_words)
|
155 |
+
if is_batched:
|
156 |
+
tokenized = self.tokenize_batch(texts)
|
157 |
+
else:
|
158 |
+
tokenized = self.tokenize(texts)
|
159 |
+
return tokenized
|
160 |
+
|
161 |
+
@overrides
|
162 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
163 |
+
if self.split_on_spaces:
|
164 |
+
if isinstance(text, str):
|
165 |
+
text = text.split(" ")
|
166 |
+
spaces = [True] * len(text)
|
167 |
+
text = Doc(self.spacy.vocab, words=text, spaces=spaces)
|
168 |
+
return self._clean_tokens(self.spacy(text))
|
169 |
+
|
170 |
+
@overrides
|
171 |
+
def tokenize_batch(
|
172 |
+
self, texts: Union[List[str], List[List[str]]]
|
173 |
+
) -> List[List[Word]]:
|
174 |
+
if self.split_on_spaces:
|
175 |
+
if isinstance(texts[0], str):
|
176 |
+
texts = [text.split(" ") for text in texts]
|
177 |
+
spaces = [[True] * len(text) for text in texts]
|
178 |
+
texts = [
|
179 |
+
Doc(self.spacy.vocab, words=text, spaces=space)
|
180 |
+
for text, space in zip(texts, spaces)
|
181 |
+
]
|
182 |
+
return [self._clean_tokens(tokens) for tokens in self.spacy.pipe(texts)]
|
183 |
+
|
184 |
+
@staticmethod
|
185 |
+
def _clean_tokens(tokens: Doc) -> List[Word]:
|
186 |
+
"""
|
187 |
+
Converts spaCy tokens to :obj:`Word`.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
tokens (:obj:`spacy.tokens.Doc`):
|
191 |
+
Tokens from SpaCy model.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
:obj:`List[Word]`: The SpaCy model output converted into :obj:`Word` objects.
|
195 |
+
"""
|
196 |
+
words = [
|
197 |
+
Word(
|
198 |
+
token.text,
|
199 |
+
token.i,
|
200 |
+
token.idx,
|
201 |
+
token.idx + len(token),
|
202 |
+
token.lemma_,
|
203 |
+
token.pos_,
|
204 |
+
token.dep_,
|
205 |
+
token.head.i,
|
206 |
+
)
|
207 |
+
for token in tokens
|
208 |
+
]
|
209 |
+
return words
|
210 |
+
|
211 |
+
|
212 |
+
class WhitespaceSpacyTokenizer:
|
213 |
+
"""Simple white space tokenizer for SpaCy."""
|
214 |
+
|
215 |
+
def __init__(self, vocab):
|
216 |
+
self.vocab = vocab
|
217 |
+
|
218 |
+
def __call__(self, text):
|
219 |
+
if isinstance(text, str):
|
220 |
+
words = text.split(" ")
|
221 |
+
elif isinstance(text, list):
|
222 |
+
words = text
|
223 |
+
else:
|
224 |
+
raise ValueError(
|
225 |
+
f"text must be either `str` or `list`, found: `{type(text)}`"
|
226 |
+
)
|
227 |
+
spaces = [True] * len(words)
|
228 |
+
return Doc(self.vocab, words=words, spaces=spaces)
|
relik/inference/data/tokenizers/whitespace_tokenizer.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
from overrides import overrides
|
5 |
+
|
6 |
+
from relik.inference.data.objects import Word
|
7 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
class WhitespaceTokenizer(BaseTokenizer):
|
11 |
+
"""
|
12 |
+
A :obj:`Tokenizer` that splits the text on spaces.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super(WhitespaceTokenizer, self).__init__()
|
17 |
+
self.whitespace_regex = re.compile(r"\S+")
|
18 |
+
|
19 |
+
def __call__(
|
20 |
+
self,
|
21 |
+
texts: Union[str, List[str], List[List[str]]],
|
22 |
+
is_split_into_words: bool = False,
|
23 |
+
**kwargs,
|
24 |
+
) -> List[List[Word]]:
|
25 |
+
"""
|
26 |
+
Tokenize the input into single words by splitting on spaces.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
30 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
31 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
32 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
36 |
+
|
37 |
+
Example::
|
38 |
+
|
39 |
+
>>> from nlp_preprocessing_wrappers import WhitespaceTokenizer
|
40 |
+
|
41 |
+
>>> whitespace_tokenizer = WhitespaceTokenizer()
|
42 |
+
>>> whitespace_tokenizer("Mary sold the car to John .")
|
43 |
+
|
44 |
+
"""
|
45 |
+
# check if input is batched or a single sample
|
46 |
+
is_batched = self.check_is_batched(texts, is_split_into_words)
|
47 |
+
|
48 |
+
if is_batched:
|
49 |
+
tokenized = self.tokenize_batch(texts)
|
50 |
+
else:
|
51 |
+
tokenized = self.tokenize(texts)
|
52 |
+
|
53 |
+
return tokenized
|
54 |
+
|
55 |
+
@overrides
|
56 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
57 |
+
if not isinstance(text, (str, list)):
|
58 |
+
raise ValueError(
|
59 |
+
f"text must be either `str` or `list`, found: `{type(text)}`"
|
60 |
+
)
|
61 |
+
|
62 |
+
if isinstance(text, list):
|
63 |
+
text = " ".join(text)
|
64 |
+
return [
|
65 |
+
Word(t[0], i, start_char=t[1], end_char=t[2])
|
66 |
+
for i, t in enumerate(
|
67 |
+
(m.group(0), m.start(), m.end())
|
68 |
+
for m in self.whitespace_regex.finditer(text)
|
69 |
+
)
|
70 |
+
]
|
relik/inference/data/window/__init__.py
ADDED
File without changes
|
relik/inference/data/window/manager.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import itertools
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Set, Tuple
|
5 |
+
|
6 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
7 |
+
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class Window:
|
12 |
+
doc_id: int
|
13 |
+
window_id: int
|
14 |
+
text: str
|
15 |
+
tokens: List[str]
|
16 |
+
doc_topic: Optional[str]
|
17 |
+
offset: int
|
18 |
+
token2char_start: dict
|
19 |
+
token2char_end: dict
|
20 |
+
window_candidates: Optional[List[str]] = None
|
21 |
+
|
22 |
+
|
23 |
+
class WindowManager:
|
24 |
+
def __init__(self, tokenizer: BaseTokenizer) -> None:
|
25 |
+
self.tokenizer = tokenizer
|
26 |
+
|
27 |
+
def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]:
|
28 |
+
tokenized_document = self.tokenizer(document)
|
29 |
+
tokens = []
|
30 |
+
tokens_char_mapping = []
|
31 |
+
for token in tokenized_document:
|
32 |
+
tokens.append(token.text)
|
33 |
+
tokens_char_mapping.append((token.start_char, token.end_char))
|
34 |
+
return tokens, tokens_char_mapping
|
35 |
+
|
36 |
+
def create_windows(
|
37 |
+
self,
|
38 |
+
document: str,
|
39 |
+
window_size: int,
|
40 |
+
stride: int,
|
41 |
+
doc_id: int = 0,
|
42 |
+
doc_topic: str = None,
|
43 |
+
) -> List[RelikReaderSample]:
|
44 |
+
document_tokens, tokens_char_mapping = self.tokenize(document)
|
45 |
+
if doc_topic is None:
|
46 |
+
doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
|
47 |
+
document_windows = []
|
48 |
+
if len(document_tokens) <= window_size:
|
49 |
+
text = document
|
50 |
+
# relik_reader_sample = RelikReaderSample()
|
51 |
+
document_windows.append(
|
52 |
+
# Window(
|
53 |
+
RelikReaderSample(
|
54 |
+
doc_id=doc_id,
|
55 |
+
window_id=0,
|
56 |
+
text=text,
|
57 |
+
tokens=document_tokens,
|
58 |
+
doc_topic=doc_topic,
|
59 |
+
offset=0,
|
60 |
+
token2char_start={
|
61 |
+
str(i): tokens_char_mapping[i][0]
|
62 |
+
for i in range(len(document_tokens))
|
63 |
+
},
|
64 |
+
token2char_end={
|
65 |
+
str(i): tokens_char_mapping[i][1]
|
66 |
+
for i in range(len(document_tokens))
|
67 |
+
},
|
68 |
+
)
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
for window_id, i in enumerate(range(0, len(document_tokens), stride)):
|
72 |
+
# if the last stride is smaller than the window size, then we can
|
73 |
+
# include more tokens form the previous window.
|
74 |
+
if i != 0 and i + window_size > len(document_tokens):
|
75 |
+
overflowing_tokens = i + window_size - len(document_tokens)
|
76 |
+
if overflowing_tokens >= stride:
|
77 |
+
break
|
78 |
+
i -= overflowing_tokens
|
79 |
+
|
80 |
+
involved_token_indices = list(
|
81 |
+
range(i, min(i + window_size, len(document_tokens) - 1))
|
82 |
+
)
|
83 |
+
window_tokens = [document_tokens[j] for j in involved_token_indices]
|
84 |
+
window_text_start = tokens_char_mapping[involved_token_indices[0]][0]
|
85 |
+
window_text_end = tokens_char_mapping[involved_token_indices[-1]][1]
|
86 |
+
text = document[window_text_start:window_text_end]
|
87 |
+
document_windows.append(
|
88 |
+
# Window(
|
89 |
+
RelikReaderSample(
|
90 |
+
# dict(
|
91 |
+
doc_id=doc_id,
|
92 |
+
window_id=window_id,
|
93 |
+
text=text,
|
94 |
+
tokens=window_tokens,
|
95 |
+
doc_topic=doc_topic,
|
96 |
+
offset=window_text_start,
|
97 |
+
token2char_start={
|
98 |
+
str(i): tokens_char_mapping[ti][0]
|
99 |
+
for i, ti in enumerate(involved_token_indices)
|
100 |
+
},
|
101 |
+
token2char_end={
|
102 |
+
str(i): tokens_char_mapping[ti][1]
|
103 |
+
for i, ti in enumerate(involved_token_indices)
|
104 |
+
},
|
105 |
+
# )
|
106 |
+
)
|
107 |
+
)
|
108 |
+
return document_windows
|
109 |
+
|
110 |
+
def merge_windows(
|
111 |
+
self, windows: List[RelikReaderSample]
|
112 |
+
) -> List[RelikReaderSample]:
|
113 |
+
windows_by_doc_id = collections.defaultdict(list)
|
114 |
+
for window in windows:
|
115 |
+
windows_by_doc_id[window.doc_id].append(window)
|
116 |
+
|
117 |
+
merged_window_by_doc = {
|
118 |
+
doc_id: self.merge_doc_windows(doc_windows)
|
119 |
+
for doc_id, doc_windows in windows_by_doc_id.items()
|
120 |
+
}
|
121 |
+
|
122 |
+
return list(merged_window_by_doc.values())
|
123 |
+
|
124 |
+
def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
|
125 |
+
if len(windows) == 1:
|
126 |
+
return windows[0]
|
127 |
+
|
128 |
+
if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
|
129 |
+
windows = sorted(windows, key=(lambda x: x.offset))
|
130 |
+
|
131 |
+
window_accumulator = windows[0]
|
132 |
+
|
133 |
+
for next_window in windows[1:]:
|
134 |
+
window_accumulator = self._merge_window_pair(
|
135 |
+
window_accumulator, next_window
|
136 |
+
)
|
137 |
+
|
138 |
+
return window_accumulator
|
139 |
+
|
140 |
+
def _merge_tokens(
|
141 |
+
self, window1: RelikReaderSample, window2: RelikReaderSample
|
142 |
+
) -> Tuple[list, dict, dict]:
|
143 |
+
w1_tokens = window1.tokens[1:-1]
|
144 |
+
w2_tokens = window2.tokens[1:-1]
|
145 |
+
|
146 |
+
# find intersection
|
147 |
+
tokens_intersection = None
|
148 |
+
for k in reversed(range(1, len(w1_tokens))):
|
149 |
+
if w1_tokens[-k:] == w2_tokens[:k]:
|
150 |
+
tokens_intersection = k
|
151 |
+
break
|
152 |
+
assert tokens_intersection is not None, (
|
153 |
+
f"{window1.doc_id} - {window1.sent_id} - {window1.offset}"
|
154 |
+
+ f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n"
|
155 |
+
+ f"w1 tokens: {w1_tokens}\n"
|
156 |
+
+ f"w2 tokens: {w2_tokens}\n"
|
157 |
+
)
|
158 |
+
|
159 |
+
final_tokens = (
|
160 |
+
[window1.tokens[0]] # CLS
|
161 |
+
+ w1_tokens
|
162 |
+
+ w2_tokens[tokens_intersection:]
|
163 |
+
+ [window1.tokens[-1]] # SEP
|
164 |
+
)
|
165 |
+
|
166 |
+
w2_starting_offset = len(w1_tokens) - tokens_intersection
|
167 |
+
|
168 |
+
def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
|
169 |
+
final_t2c = dict()
|
170 |
+
final_t2c.update(t2c1)
|
171 |
+
for t, c in t2c2.items():
|
172 |
+
t = int(t)
|
173 |
+
if t < tokens_intersection:
|
174 |
+
continue
|
175 |
+
final_t2c[str(t + w2_starting_offset)] = c
|
176 |
+
return final_t2c
|
177 |
+
|
178 |
+
return (
|
179 |
+
final_tokens,
|
180 |
+
merge_char_mapping(window1.token2char_start, window2.token2char_start),
|
181 |
+
merge_char_mapping(window1.token2char_end, window2.token2char_end),
|
182 |
+
)
|
183 |
+
|
184 |
+
def _merge_span_annotation(
|
185 |
+
self, span_annotation1: List[list], span_annotation2: List[list]
|
186 |
+
) -> List[list]:
|
187 |
+
uniq_store = set()
|
188 |
+
final_span_annotation_store = []
|
189 |
+
for span_annotation in itertools.chain(span_annotation1, span_annotation2):
|
190 |
+
span_annotation_id = tuple(span_annotation)
|
191 |
+
if span_annotation_id not in uniq_store:
|
192 |
+
uniq_store.add(span_annotation_id)
|
193 |
+
final_span_annotation_store.append(span_annotation)
|
194 |
+
return sorted(final_span_annotation_store, key=lambda x: x[0])
|
195 |
+
|
196 |
+
def _merge_predictions(
|
197 |
+
self,
|
198 |
+
window1: RelikReaderSample,
|
199 |
+
window2: RelikReaderSample,
|
200 |
+
) -> Tuple[Set[Tuple[int, int, str]], dict]:
|
201 |
+
merged_predictions = window1.predicted_window_labels_chars.union(
|
202 |
+
window2.predicted_window_labels_chars
|
203 |
+
)
|
204 |
+
|
205 |
+
span_title_probabilities = dict()
|
206 |
+
# probabilities
|
207 |
+
for span_prediction, predicted_probs in itertools.chain(
|
208 |
+
window1.probs_window_labels_chars.items(),
|
209 |
+
window2.probs_window_labels_chars.items(),
|
210 |
+
):
|
211 |
+
if span_prediction not in span_title_probabilities:
|
212 |
+
span_title_probabilities[span_prediction] = predicted_probs
|
213 |
+
|
214 |
+
return merged_predictions, span_title_probabilities
|
215 |
+
|
216 |
+
def _merge_window_pair(
|
217 |
+
self,
|
218 |
+
window1: RelikReaderSample,
|
219 |
+
window2: RelikReaderSample,
|
220 |
+
) -> RelikReaderSample:
|
221 |
+
merging_output = dict()
|
222 |
+
|
223 |
+
if getattr(window1, "doc_id", None) is not None:
|
224 |
+
assert window1.doc_id == window2.doc_id
|
225 |
+
|
226 |
+
if getattr(window1, "offset", None) is not None:
|
227 |
+
assert (
|
228 |
+
window1.offset < window2.offset
|
229 |
+
), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
|
230 |
+
|
231 |
+
merging_output["doc_id"] = window1.doc_id
|
232 |
+
merging_output["offset"] = window2.offset
|
233 |
+
|
234 |
+
m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
|
235 |
+
window1, window2
|
236 |
+
)
|
237 |
+
|
238 |
+
window_labels = None
|
239 |
+
if getattr(window1, "window_labels", None) is not None:
|
240 |
+
window_labels = self._merge_span_annotation(
|
241 |
+
window1.window_labels, window2.window_labels
|
242 |
+
)
|
243 |
+
(
|
244 |
+
predicted_window_labels_chars,
|
245 |
+
probs_window_labels_chars,
|
246 |
+
) = self._merge_predictions(
|
247 |
+
window1,
|
248 |
+
window2,
|
249 |
+
)
|
250 |
+
|
251 |
+
merging_output.update(
|
252 |
+
dict(
|
253 |
+
tokens=m_tokens,
|
254 |
+
token2char_start=m_token2char_start,
|
255 |
+
token2char_end=m_token2char_end,
|
256 |
+
window_labels=window_labels,
|
257 |
+
predicted_window_labels_chars=predicted_window_labels_chars,
|
258 |
+
probs_window_labels_chars=probs_window_labels_chars,
|
259 |
+
)
|
260 |
+
)
|
261 |
+
|
262 |
+
return RelikReaderSample(**merging_output)
|
relik/inference/gerbil.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import sys
|
6 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
|
9 |
+
from relik.inference.annotator import Relik
|
10 |
+
from relik.inference.data.objects import RelikOutput
|
11 |
+
|
12 |
+
# sys.path += ['../']
|
13 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
14 |
+
|
15 |
+
|
16 |
+
import logging
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class GerbilAlbyManager:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
annotator: Optional[Relik] = None,
|
25 |
+
response_logger_dir: Optional[str] = None,
|
26 |
+
) -> None:
|
27 |
+
self.annotator = annotator
|
28 |
+
self.response_logger_dir = response_logger_dir
|
29 |
+
self.predictions_counter = 0
|
30 |
+
self.labels_mapping = None
|
31 |
+
|
32 |
+
def annotate(self, document: str):
|
33 |
+
relik_output: RelikOutput = self.annotator(document)
|
34 |
+
annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels]
|
35 |
+
if self.labels_mapping is not None:
|
36 |
+
return [
|
37 |
+
(ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations
|
38 |
+
]
|
39 |
+
return annotations
|
40 |
+
|
41 |
+
def set_mapping_file(self, mapping_file_path: str):
|
42 |
+
with open(mapping_file_path) as f:
|
43 |
+
labels_mapping = json.load(f)
|
44 |
+
self.labels_mapping = {v: k for k, v in labels_mapping.items()}
|
45 |
+
|
46 |
+
def write_response_bundle(
|
47 |
+
self,
|
48 |
+
document: str,
|
49 |
+
new_document: str,
|
50 |
+
annotations: list,
|
51 |
+
mapped_annotations: list,
|
52 |
+
) -> None:
|
53 |
+
if self.response_logger_dir is None:
|
54 |
+
return
|
55 |
+
|
56 |
+
if not os.path.isdir(self.response_logger_dir):
|
57 |
+
os.mkdir(self.response_logger_dir)
|
58 |
+
|
59 |
+
with open(
|
60 |
+
f"{self.response_logger_dir}/{self.predictions_counter}.json", "w"
|
61 |
+
) as f:
|
62 |
+
out_json_obj = dict(
|
63 |
+
document=document,
|
64 |
+
new_document=new_document,
|
65 |
+
annotations=annotations,
|
66 |
+
mapped_annotations=mapped_annotations,
|
67 |
+
)
|
68 |
+
|
69 |
+
out_json_obj["span_annotations"] = [
|
70 |
+
(ss, se, document[ss:se], label) for (ss, se, label) in annotations
|
71 |
+
]
|
72 |
+
|
73 |
+
out_json_obj["span_mapped_annotations"] = [
|
74 |
+
(ss, se, new_document[ss:se], label)
|
75 |
+
for (ss, se, label) in mapped_annotations
|
76 |
+
]
|
77 |
+
|
78 |
+
json.dump(out_json_obj, f, indent=2)
|
79 |
+
|
80 |
+
self.predictions_counter += 1
|
81 |
+
|
82 |
+
|
83 |
+
manager = GerbilAlbyManager()
|
84 |
+
|
85 |
+
|
86 |
+
def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]:
|
87 |
+
pattern_subs = {
|
88 |
+
"-LPR- ": " (",
|
89 |
+
"-RPR-": ")",
|
90 |
+
"\n\n": "\n",
|
91 |
+
"-LRB-": "(",
|
92 |
+
"-RRB-": ")",
|
93 |
+
'","': ",",
|
94 |
+
}
|
95 |
+
|
96 |
+
document_acc = document
|
97 |
+
curr_offset = 0
|
98 |
+
char2offset = []
|
99 |
+
|
100 |
+
matchings = re.finditer("({})".format("|".join(pattern_subs)), document)
|
101 |
+
for span_matching in sorted(matchings, key=lambda x: x.span()[0]):
|
102 |
+
span_start, span_end = span_matching.span()
|
103 |
+
span_start -= curr_offset
|
104 |
+
span_end -= curr_offset
|
105 |
+
|
106 |
+
span_text = document_acc[span_start:span_end]
|
107 |
+
span_sub = pattern_subs[span_text]
|
108 |
+
document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:]
|
109 |
+
|
110 |
+
offset = len(span_text) - len(span_sub)
|
111 |
+
curr_offset += offset
|
112 |
+
|
113 |
+
char2offset.append((span_start + len(span_sub), curr_offset))
|
114 |
+
|
115 |
+
return document_acc, char2offset
|
116 |
+
|
117 |
+
|
118 |
+
def map_back_annotations(
|
119 |
+
annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]]
|
120 |
+
) -> Iterator[Tuple[int, int, str]]:
|
121 |
+
def map_char(char_idx: int) -> int:
|
122 |
+
current_offset = 0
|
123 |
+
for offset_idx, offset_value in char_mapping:
|
124 |
+
if char_idx >= offset_idx:
|
125 |
+
current_offset = offset_value
|
126 |
+
else:
|
127 |
+
break
|
128 |
+
return char_idx + current_offset
|
129 |
+
|
130 |
+
for ss, se, label in annotations:
|
131 |
+
yield map_char(ss), map_char(se), label
|
132 |
+
|
133 |
+
|
134 |
+
def annotate(document: str) -> List[Tuple[int, int, str]]:
|
135 |
+
new_document, mapping = preprocess_document(document)
|
136 |
+
logger.info("Mapping: " + str(mapping))
|
137 |
+
logger.info("Document: " + str(document))
|
138 |
+
annotations = [
|
139 |
+
(cs, ce, label.replace(" ", "_"))
|
140 |
+
for cs, ce, label in manager.annotate(new_document)
|
141 |
+
]
|
142 |
+
logger.info("New document: " + str(new_document))
|
143 |
+
mapped_annotations = (
|
144 |
+
list(map_back_annotations(annotations, mapping))
|
145 |
+
if len(mapping) > 0
|
146 |
+
else annotations
|
147 |
+
)
|
148 |
+
|
149 |
+
logger.info(
|
150 |
+
"Annotations: "
|
151 |
+
+ str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations])
|
152 |
+
)
|
153 |
+
|
154 |
+
manager.write_response_bundle(
|
155 |
+
document, new_document, mapped_annotations, annotations
|
156 |
+
)
|
157 |
+
|
158 |
+
if not all(
|
159 |
+
[
|
160 |
+
new_document[ss:se] == document[mss:mse]
|
161 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
162 |
+
]
|
163 |
+
):
|
164 |
+
diff_mappings = [
|
165 |
+
(new_document[ss:se], document[mss:mse])
|
166 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
167 |
+
]
|
168 |
+
return None
|
169 |
+
assert all(
|
170 |
+
[
|
171 |
+
document[mss:mse] == new_document[ss:se]
|
172 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
173 |
+
]
|
174 |
+
), (mapped_annotations, annotations)
|
175 |
+
|
176 |
+
return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations]
|
177 |
+
|
178 |
+
|
179 |
+
class GetHandler(BaseHTTPRequestHandler):
|
180 |
+
def do_POST(self):
|
181 |
+
content_length = int(self.headers["Content-Length"])
|
182 |
+
post_data = self.rfile.read(content_length)
|
183 |
+
self.send_response(200)
|
184 |
+
self.end_headers()
|
185 |
+
doc_text = read_json(post_data)
|
186 |
+
# try:
|
187 |
+
response = annotate(doc_text)
|
188 |
+
|
189 |
+
self.wfile.write(bytes(json.dumps(response), "utf-8"))
|
190 |
+
return
|
191 |
+
|
192 |
+
|
193 |
+
def read_json(post_data):
|
194 |
+
data = json.loads(post_data.decode("utf-8"))
|
195 |
+
# logger.info("received data:", data)
|
196 |
+
text = data["text"]
|
197 |
+
# spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
|
198 |
+
return text
|
199 |
+
|
200 |
+
|
201 |
+
def parse_args() -> argparse.Namespace:
|
202 |
+
parser = argparse.ArgumentParser()
|
203 |
+
parser.add_argument("--relik-model-name", required=True)
|
204 |
+
parser.add_argument("--responses-log-dir")
|
205 |
+
parser.add_argument("--log-file", default="logs/logging.txt")
|
206 |
+
parser.add_argument("--mapping-file")
|
207 |
+
return parser.parse_args()
|
208 |
+
|
209 |
+
|
210 |
+
def main():
|
211 |
+
args = parse_args()
|
212 |
+
|
213 |
+
# init manager
|
214 |
+
manager.response_logger_dir = args.responses_log_dir
|
215 |
+
# manager.annotator = Relik.from_pretrained(args.relik_model_name)
|
216 |
+
|
217 |
+
print("Debugging, not using you relik model but an hardcoded one.")
|
218 |
+
manager.annotator = Relik(
|
219 |
+
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
|
220 |
+
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
|
221 |
+
reader="relik/reader/models/relik-reader-deberta-base-new-data",
|
222 |
+
window_size=32,
|
223 |
+
window_stride=16,
|
224 |
+
candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()),
|
225 |
+
)
|
226 |
+
|
227 |
+
if args.mapping_file is not None:
|
228 |
+
manager.set_mapping_file(args.mapping_file)
|
229 |
+
|
230 |
+
port = 6654
|
231 |
+
server = HTTPServer(("localhost", port), GetHandler)
|
232 |
+
logger.info(f"Starting server at http://localhost:{port}")
|
233 |
+
|
234 |
+
# Create a file handler and set its level
|
235 |
+
file_handler = logging.FileHandler(args.log_file)
|
236 |
+
file_handler.setLevel(logging.DEBUG)
|
237 |
+
|
238 |
+
# Create a log formatter and set it on the handler
|
239 |
+
formatter = logging.Formatter(
|
240 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
241 |
+
)
|
242 |
+
file_handler.setFormatter(formatter)
|
243 |
+
|
244 |
+
# Add the file handler to the logger
|
245 |
+
logger.addHandler(file_handler)
|
246 |
+
|
247 |
+
try:
|
248 |
+
server.serve_forever()
|
249 |
+
except KeyboardInterrupt:
|
250 |
+
exit(0)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
main()
|
relik/inference/preprocessing.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def wikipedia_title_and_openings_preprocessing(
|
2 |
+
wikipedia_title_and_openings: str, sepator: str = " <def>"
|
3 |
+
):
|
4 |
+
return wikipedia_title_and_openings.split(sepator, 1)[0]
|
relik/inference/serve/__init__.py
ADDED
File without changes
|
relik/inference/serve/backend/__init__.py
ADDED
File without changes
|
relik/inference/serve/backend/relik.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
from relik.common.utils import is_package_available
|
6 |
+
from relik.inference.annotator import Relik
|
7 |
+
|
8 |
+
if not is_package_available("fastapi"):
|
9 |
+
raise ImportError(
|
10 |
+
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
|
11 |
+
)
|
12 |
+
from fastapi import FastAPI, HTTPException
|
13 |
+
|
14 |
+
if not is_package_available("ray"):
|
15 |
+
raise ImportError(
|
16 |
+
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
|
17 |
+
)
|
18 |
+
from ray import serve
|
19 |
+
|
20 |
+
from relik.common.log import get_logger
|
21 |
+
from relik.inference.serve.backend.utils import (
|
22 |
+
RayParameterManager,
|
23 |
+
ServerParameterManager,
|
24 |
+
)
|
25 |
+
from relik.retriever.data.utils import batch_generator
|
26 |
+
|
27 |
+
logger = get_logger(__name__, level=logging.INFO)
|
28 |
+
|
29 |
+
VERSION = {} # type: ignore
|
30 |
+
with open(
|
31 |
+
Path(__file__).parent.parent.parent.parent / "version.py", "r"
|
32 |
+
) as version_file:
|
33 |
+
exec(version_file.read(), VERSION)
|
34 |
+
|
35 |
+
# Env variables for server
|
36 |
+
SERVER_MANAGER = ServerParameterManager()
|
37 |
+
RAY_MANAGER = RayParameterManager()
|
38 |
+
|
39 |
+
app = FastAPI(
|
40 |
+
title="ReLiK",
|
41 |
+
version=VERSION["VERSION"],
|
42 |
+
description="ReLiK REST API",
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
@serve.deployment(
|
47 |
+
ray_actor_options={
|
48 |
+
"num_gpus": RAY_MANAGER.num_gpus
|
49 |
+
if (
|
50 |
+
SERVER_MANAGER.retriver_device == "cuda"
|
51 |
+
or SERVER_MANAGER.reader_device == "cuda"
|
52 |
+
)
|
53 |
+
else 0
|
54 |
+
},
|
55 |
+
autoscaling_config={
|
56 |
+
"min_replicas": RAY_MANAGER.min_replicas,
|
57 |
+
"max_replicas": RAY_MANAGER.max_replicas,
|
58 |
+
},
|
59 |
+
)
|
60 |
+
@serve.ingress(app)
|
61 |
+
class RelikServer:
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
question_encoder: str,
|
65 |
+
document_index: str,
|
66 |
+
passage_encoder: Optional[str] = None,
|
67 |
+
reader_encoder: Optional[str] = None,
|
68 |
+
top_k: int = 100,
|
69 |
+
retriver_device: str = "cpu",
|
70 |
+
reader_device: str = "cpu",
|
71 |
+
index_device: Optional[str] = None,
|
72 |
+
precision: int = 32,
|
73 |
+
index_precision: Optional[int] = None,
|
74 |
+
use_faiss: bool = False,
|
75 |
+
window_batch_size: int = 32,
|
76 |
+
window_size: int = 32,
|
77 |
+
window_stride: int = 16,
|
78 |
+
split_on_spaces: bool = False,
|
79 |
+
):
|
80 |
+
# parameters
|
81 |
+
self.question_encoder = question_encoder
|
82 |
+
self.passage_encoder = passage_encoder
|
83 |
+
self.reader_encoder = reader_encoder
|
84 |
+
self.document_index = document_index
|
85 |
+
self.top_k = top_k
|
86 |
+
self.retriver_device = retriver_device
|
87 |
+
self.index_device = index_device or retriver_device
|
88 |
+
self.reader_device = reader_device
|
89 |
+
self.precision = precision
|
90 |
+
self.index_precision = index_precision or precision
|
91 |
+
self.use_faiss = use_faiss
|
92 |
+
self.window_batch_size = window_batch_size
|
93 |
+
self.window_size = window_size
|
94 |
+
self.window_stride = window_stride
|
95 |
+
self.split_on_spaces = split_on_spaces
|
96 |
+
|
97 |
+
# log stuff for debugging
|
98 |
+
logger.info("Initializing RelikServer with parameters:")
|
99 |
+
logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
|
100 |
+
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
|
101 |
+
logger.info(f"READER_ENCODER: {self.reader_encoder}")
|
102 |
+
logger.info(f"DOCUMENT_INDEX: {self.document_index}")
|
103 |
+
logger.info(f"TOP_K: {self.top_k}")
|
104 |
+
logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}")
|
105 |
+
logger.info(f"READER_DEVICE: {self.reader_device}")
|
106 |
+
logger.info(f"INDEX_DEVICE: {self.index_device}")
|
107 |
+
logger.info(f"PRECISION: {self.precision}")
|
108 |
+
logger.info(f"INDEX_PRECISION: {self.index_precision}")
|
109 |
+
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
|
110 |
+
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
|
111 |
+
|
112 |
+
self.relik = Relik(
|
113 |
+
question_encoder=self.question_encoder,
|
114 |
+
passage_encoder=self.passage_encoder,
|
115 |
+
document_index=self.document_index,
|
116 |
+
reader=self.reader_encoder,
|
117 |
+
retriever_device=self.retriver_device,
|
118 |
+
document_index_device=self.index_device,
|
119 |
+
reader_device=self.reader_device,
|
120 |
+
retriever_precision=self.precision,
|
121 |
+
document_index_precision=self.index_precision,
|
122 |
+
reader_precision=self.precision,
|
123 |
+
)
|
124 |
+
|
125 |
+
# @serve.batch()
|
126 |
+
async def handle_batch(self, documents: List[str]) -> List:
|
127 |
+
return self.relik(
|
128 |
+
documents,
|
129 |
+
top_k=self.top_k,
|
130 |
+
window_size=self.window_size,
|
131 |
+
window_stride=self.window_stride,
|
132 |
+
batch_size=self.window_batch_size,
|
133 |
+
)
|
134 |
+
|
135 |
+
@app.post("/api/entities")
|
136 |
+
async def entities_endpoint(
|
137 |
+
self,
|
138 |
+
documents: Union[str, List[str]],
|
139 |
+
):
|
140 |
+
try:
|
141 |
+
# normalize input
|
142 |
+
if isinstance(documents, str):
|
143 |
+
documents = [documents]
|
144 |
+
if document_topics is not None:
|
145 |
+
if isinstance(document_topics, str):
|
146 |
+
document_topics = [document_topics]
|
147 |
+
assert len(documents) == len(document_topics)
|
148 |
+
# get predictions for the retriever
|
149 |
+
return await self.handle_batch(documents, document_topics)
|
150 |
+
except Exception as e:
|
151 |
+
# log the entire stack trace
|
152 |
+
logger.exception(e)
|
153 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
154 |
+
|
155 |
+
@app.post("/api/gerbil")
|
156 |
+
async def gerbil_endpoint(self, documents: Union[str, List[str]]):
|
157 |
+
try:
|
158 |
+
# normalize input
|
159 |
+
if isinstance(documents, str):
|
160 |
+
documents = [documents]
|
161 |
+
|
162 |
+
# output list
|
163 |
+
windows_passages = []
|
164 |
+
# split documents into windows
|
165 |
+
document_windows = [
|
166 |
+
window
|
167 |
+
for doc_id, document in enumerate(documents)
|
168 |
+
for window in self.window_manager(
|
169 |
+
self.tokenizer,
|
170 |
+
document,
|
171 |
+
window_size=self.window_size,
|
172 |
+
stride=self.window_stride,
|
173 |
+
doc_id=doc_id,
|
174 |
+
)
|
175 |
+
]
|
176 |
+
|
177 |
+
# get text and topic from document windows and create new list
|
178 |
+
model_inputs = [
|
179 |
+
(window.text, window.doc_topic) for window in document_windows
|
180 |
+
]
|
181 |
+
|
182 |
+
# batch generator
|
183 |
+
for batch in batch_generator(
|
184 |
+
model_inputs, batch_size=self.window_batch_size
|
185 |
+
):
|
186 |
+
text, text_pair = zip(*batch)
|
187 |
+
batch_predictions = await self.handle_batch_retriever(text, text_pair)
|
188 |
+
windows_passages.extend(
|
189 |
+
[
|
190 |
+
[p.label for p in predictions]
|
191 |
+
for predictions in batch_predictions
|
192 |
+
]
|
193 |
+
)
|
194 |
+
|
195 |
+
# add passage to document windows
|
196 |
+
for window, passages in zip(document_windows, windows_passages):
|
197 |
+
# clean up passages (remove everything after first <def> tag if present)
|
198 |
+
passages = [c.split(" <def>", 1)[0] for c in passages]
|
199 |
+
window.window_candidates = passages
|
200 |
+
|
201 |
+
# return document windows
|
202 |
+
return document_windows
|
203 |
+
|
204 |
+
except Exception as e:
|
205 |
+
# log the entire stack trace
|
206 |
+
logger.exception(e)
|
207 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
208 |
+
|
209 |
+
|
210 |
+
server = RelikServer.bind(**vars(SERVER_MANAGER))
|
relik/inference/serve/backend/retriever.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
from relik.common.utils import is_package_available
|
6 |
+
|
7 |
+
if not is_package_available("fastapi"):
|
8 |
+
raise ImportError(
|
9 |
+
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
|
10 |
+
)
|
11 |
+
from fastapi import FastAPI, HTTPException
|
12 |
+
|
13 |
+
if not is_package_available("ray"):
|
14 |
+
raise ImportError(
|
15 |
+
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
|
16 |
+
)
|
17 |
+
from ray import serve
|
18 |
+
|
19 |
+
from relik.common.log import get_logger
|
20 |
+
from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer
|
21 |
+
from relik.inference.data.window.manager import WindowManager
|
22 |
+
from relik.inference.serve.backend.utils import (
|
23 |
+
RayParameterManager,
|
24 |
+
ServerParameterManager,
|
25 |
+
)
|
26 |
+
from relik.retriever.data.utils import batch_generator
|
27 |
+
from relik.retriever.pytorch_modules import GoldenRetriever
|
28 |
+
|
29 |
+
logger = get_logger(__name__, level=logging.INFO)
|
30 |
+
|
31 |
+
VERSION = {} # type: ignore
|
32 |
+
with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file:
|
33 |
+
exec(version_file.read(), VERSION)
|
34 |
+
|
35 |
+
# Env variables for server
|
36 |
+
SERVER_MANAGER = ServerParameterManager()
|
37 |
+
RAY_MANAGER = RayParameterManager()
|
38 |
+
|
39 |
+
app = FastAPI(
|
40 |
+
title="Golden Retriever",
|
41 |
+
version=VERSION["VERSION"],
|
42 |
+
description="Golden Retriever REST API",
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
@serve.deployment(
|
47 |
+
ray_actor_options={
|
48 |
+
"num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0
|
49 |
+
},
|
50 |
+
autoscaling_config={
|
51 |
+
"min_replicas": RAY_MANAGER.min_replicas,
|
52 |
+
"max_replicas": RAY_MANAGER.max_replicas,
|
53 |
+
},
|
54 |
+
)
|
55 |
+
@serve.ingress(app)
|
56 |
+
class GoldenRetrieverServer:
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
question_encoder: str,
|
60 |
+
document_index: str,
|
61 |
+
passage_encoder: Optional[str] = None,
|
62 |
+
top_k: int = 100,
|
63 |
+
device: str = "cpu",
|
64 |
+
index_device: Optional[str] = None,
|
65 |
+
precision: int = 32,
|
66 |
+
index_precision: Optional[int] = None,
|
67 |
+
use_faiss: bool = False,
|
68 |
+
window_batch_size: int = 32,
|
69 |
+
window_size: int = 32,
|
70 |
+
window_stride: int = 16,
|
71 |
+
split_on_spaces: bool = False,
|
72 |
+
):
|
73 |
+
# parameters
|
74 |
+
self.question_encoder = question_encoder
|
75 |
+
self.passage_encoder = passage_encoder
|
76 |
+
self.document_index = document_index
|
77 |
+
self.top_k = top_k
|
78 |
+
self.device = device
|
79 |
+
self.index_device = index_device or device
|
80 |
+
self.precision = precision
|
81 |
+
self.index_precision = index_precision or precision
|
82 |
+
self.use_faiss = use_faiss
|
83 |
+
self.window_batch_size = window_batch_size
|
84 |
+
self.window_size = window_size
|
85 |
+
self.window_stride = window_stride
|
86 |
+
self.split_on_spaces = split_on_spaces
|
87 |
+
|
88 |
+
# log stuff for debugging
|
89 |
+
logger.info("Initializing GoldenRetrieverServer with parameters:")
|
90 |
+
logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
|
91 |
+
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
|
92 |
+
logger.info(f"DOCUMENT_INDEX: {self.document_index}")
|
93 |
+
logger.info(f"TOP_K: {self.top_k}")
|
94 |
+
logger.info(f"DEVICE: {self.device}")
|
95 |
+
logger.info(f"INDEX_DEVICE: {self.index_device}")
|
96 |
+
logger.info(f"PRECISION: {self.precision}")
|
97 |
+
logger.info(f"INDEX_PRECISION: {self.index_precision}")
|
98 |
+
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
|
99 |
+
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
|
100 |
+
|
101 |
+
self.retriever = GoldenRetriever(
|
102 |
+
question_encoder=self.question_encoder,
|
103 |
+
passage_encoder=self.passage_encoder,
|
104 |
+
document_index=self.document_index,
|
105 |
+
device=self.device,
|
106 |
+
index_device=self.index_device,
|
107 |
+
index_precision=self.index_precision,
|
108 |
+
)
|
109 |
+
self.retriever.eval()
|
110 |
+
|
111 |
+
if self.split_on_spaces:
|
112 |
+
logger.info("Using WhitespaceTokenizer")
|
113 |
+
self.tokenizer = WhitespaceTokenizer()
|
114 |
+
# logger.info("Using RegexTokenizer")
|
115 |
+
# self.tokenizer = RegexTokenizer()
|
116 |
+
else:
|
117 |
+
logger.info("Using SpacyTokenizer")
|
118 |
+
self.tokenizer = SpacyTokenizer(language="en")
|
119 |
+
|
120 |
+
self.window_manager = WindowManager(tokenizer=self.tokenizer)
|
121 |
+
|
122 |
+
# @serve.batch()
|
123 |
+
async def handle_batch(
|
124 |
+
self, documents: List[str], document_topics: List[str]
|
125 |
+
) -> List:
|
126 |
+
return self.retriever.retrieve(
|
127 |
+
documents, text_pair=document_topics, k=self.top_k, precision=self.precision
|
128 |
+
)
|
129 |
+
|
130 |
+
@app.post("/api/retrieve")
|
131 |
+
async def retrieve_endpoint(
|
132 |
+
self,
|
133 |
+
documents: Union[str, List[str]],
|
134 |
+
document_topics: Optional[Union[str, List[str]]] = None,
|
135 |
+
):
|
136 |
+
try:
|
137 |
+
# normalize input
|
138 |
+
if isinstance(documents, str):
|
139 |
+
documents = [documents]
|
140 |
+
if document_topics is not None:
|
141 |
+
if isinstance(document_topics, str):
|
142 |
+
document_topics = [document_topics]
|
143 |
+
assert len(documents) == len(document_topics)
|
144 |
+
# get predictions
|
145 |
+
return await self.handle_batch(documents, document_topics)
|
146 |
+
except Exception as e:
|
147 |
+
# log the entire stack trace
|
148 |
+
logger.exception(e)
|
149 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
150 |
+
|
151 |
+
@app.post("/api/gerbil")
|
152 |
+
async def gerbil_endpoint(self, documents: Union[str, List[str]]):
|
153 |
+
try:
|
154 |
+
# normalize input
|
155 |
+
if isinstance(documents, str):
|
156 |
+
documents = [documents]
|
157 |
+
|
158 |
+
# output list
|
159 |
+
windows_passages = []
|
160 |
+
# split documents into windows
|
161 |
+
document_windows = [
|
162 |
+
window
|
163 |
+
for doc_id, document in enumerate(documents)
|
164 |
+
for window in self.window_manager(
|
165 |
+
self.tokenizer,
|
166 |
+
document,
|
167 |
+
window_size=self.window_size,
|
168 |
+
stride=self.window_stride,
|
169 |
+
doc_id=doc_id,
|
170 |
+
)
|
171 |
+
]
|
172 |
+
|
173 |
+
# get text and topic from document windows and create new list
|
174 |
+
model_inputs = [
|
175 |
+
(window.text, window.doc_topic) for window in document_windows
|
176 |
+
]
|
177 |
+
|
178 |
+
# batch generator
|
179 |
+
for batch in batch_generator(
|
180 |
+
model_inputs, batch_size=self.window_batch_size
|
181 |
+
):
|
182 |
+
text, text_pair = zip(*batch)
|
183 |
+
batch_predictions = await self.handle_batch(text, text_pair)
|
184 |
+
windows_passages.extend(
|
185 |
+
[
|
186 |
+
[p.label for p in predictions]
|
187 |
+
for predictions in batch_predictions
|
188 |
+
]
|
189 |
+
)
|
190 |
+
|
191 |
+
# add passage to document windows
|
192 |
+
for window, passages in zip(document_windows, windows_passages):
|
193 |
+
# clean up passages (remove everything after first <def> tag if present)
|
194 |
+
passages = [c.split(" <def>", 1)[0] for c in passages]
|
195 |
+
window.window_candidates = passages
|
196 |
+
|
197 |
+
# return document windows
|
198 |
+
return document_windows
|
199 |
+
|
200 |
+
except Exception as e:
|
201 |
+
# log the entire stack trace
|
202 |
+
logger.exception(e)
|
203 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
204 |
+
|
205 |
+
|
206 |
+
server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER))
|
relik/inference/serve/backend/utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class ServerParameterManager:
|
8 |
+
retriver_device: str = os.environ.get("RETRIEVER_DEVICE", "cpu")
|
9 |
+
reader_device: str = os.environ.get("READER_DEVICE", "cpu")
|
10 |
+
index_device: str = os.environ.get("INDEX_DEVICE", retriver_device)
|
11 |
+
precision: Union[str, int] = os.environ.get("PRECISION", "fp32")
|
12 |
+
index_precision: Union[str, int] = os.environ.get("INDEX_PRECISION", precision)
|
13 |
+
question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
|
14 |
+
passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
|
15 |
+
document_index: str = os.environ.get("DOCUMENT_INDEX", None)
|
16 |
+
reader_encoder: str = os.environ.get("READER_ENCODER", None)
|
17 |
+
top_k: int = int(os.environ.get("TOP_K", 100))
|
18 |
+
use_faiss: bool = os.environ.get("USE_FAISS", False)
|
19 |
+
window_batch_size: int = int(os.environ.get("WINDOW_BATCH_SIZE", 32))
|
20 |
+
window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
|
21 |
+
window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
|
22 |
+
split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
|
23 |
+
|
24 |
+
|
25 |
+
class RayParameterManager:
|
26 |
+
def __init__(self) -> None:
|
27 |
+
self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
|
28 |
+
self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
|
29 |
+
self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
|
relik/inference/serve/frontend/__init__.py
ADDED
File without changes
|
relik/inference/serve/frontend/relik.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import streamlit as st
|
8 |
+
from spacy import displacy
|
9 |
+
from streamlit_extras.badges import badge
|
10 |
+
from streamlit_extras.stylable_container import stylable_container
|
11 |
+
|
12 |
+
RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
|
13 |
+
|
14 |
+
import random
|
15 |
+
|
16 |
+
|
17 |
+
def get_random_color(ents):
|
18 |
+
colors = {}
|
19 |
+
random_colors = generate_pastel_colors(len(ents))
|
20 |
+
for ent in ents:
|
21 |
+
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
|
22 |
+
return colors
|
23 |
+
|
24 |
+
|
25 |
+
def floatrange(start, stop, steps):
|
26 |
+
if int(steps) == 1:
|
27 |
+
return [stop]
|
28 |
+
return [
|
29 |
+
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
def hsl_to_rgb(h, s, l):
|
34 |
+
def hue_2_rgb(v1, v2, v_h):
|
35 |
+
while v_h < 0.0:
|
36 |
+
v_h += 1.0
|
37 |
+
while v_h > 1.0:
|
38 |
+
v_h -= 1.0
|
39 |
+
if 6 * v_h < 1.0:
|
40 |
+
return v1 + (v2 - v1) * 6.0 * v_h
|
41 |
+
if 2 * v_h < 1.0:
|
42 |
+
return v2
|
43 |
+
if 3 * v_h < 2.0:
|
44 |
+
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
|
45 |
+
return v1
|
46 |
+
|
47 |
+
# if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
|
48 |
+
# if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
|
49 |
+
|
50 |
+
r, b, g = (l * 255,) * 3
|
51 |
+
if s != 0.0:
|
52 |
+
if l < 0.5:
|
53 |
+
var_2 = l * (1.0 + s)
|
54 |
+
else:
|
55 |
+
var_2 = (l + s) - (s * l)
|
56 |
+
var_1 = 2.0 * l - var_2
|
57 |
+
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
|
58 |
+
g = 255 * hue_2_rgb(var_1, var_2, h)
|
59 |
+
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
|
60 |
+
|
61 |
+
return int(round(r)), int(round(g)), int(round(b))
|
62 |
+
|
63 |
+
|
64 |
+
def generate_pastel_colors(n):
|
65 |
+
"""Return different pastel colours.
|
66 |
+
|
67 |
+
Input:
|
68 |
+
n (integer) : The number of colors to return
|
69 |
+
|
70 |
+
Output:
|
71 |
+
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
|
72 |
+
|
73 |
+
Example:
|
74 |
+
>>> print generate_pastel_colors(5)
|
75 |
+
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
|
76 |
+
"""
|
77 |
+
if n == 0:
|
78 |
+
return []
|
79 |
+
|
80 |
+
# To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
|
81 |
+
start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
|
82 |
+
saturation = 1.0
|
83 |
+
lightness = 0.8
|
84 |
+
# We take points around the chromatic circle (hue):
|
85 |
+
# (Note: we generate n+1 colors, then drop the last one ([:-1]) because
|
86 |
+
# it equals the first one (hue 0 = hue 1))
|
87 |
+
return [
|
88 |
+
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
|
89 |
+
for hue in floatrange(start_hue, start_hue + 1, n + 1)
|
90 |
+
][:-1]
|
91 |
+
|
92 |
+
|
93 |
+
def set_sidebar(css):
|
94 |
+
white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
|
95 |
+
with st.sidebar:
|
96 |
+
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
97 |
+
st.image(
|
98 |
+
"http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
|
99 |
+
use_column_width=True,
|
100 |
+
)
|
101 |
+
st.markdown("## ReLiK")
|
102 |
+
st.write(
|
103 |
+
f"""
|
104 |
+
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")}
|
105 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")}
|
106 |
+
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")}
|
107 |
+
""",
|
108 |
+
unsafe_allow_html=True,
|
109 |
+
)
|
110 |
+
st.markdown("## Sapienza NLP")
|
111 |
+
st.write(
|
112 |
+
f"""
|
113 |
+
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")}
|
114 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")}
|
115 |
+
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")}
|
116 |
+
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")}
|
117 |
+
""",
|
118 |
+
unsafe_allow_html=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def get_el_annotations(response):
|
123 |
+
# swap labels key with ents
|
124 |
+
response["ents"] = response.pop("labels")
|
125 |
+
label_in_text = set(l["label"] for l in response["ents"])
|
126 |
+
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
|
127 |
+
return response, options
|
128 |
+
|
129 |
+
|
130 |
+
def set_intro(css):
|
131 |
+
# intro
|
132 |
+
st.markdown("# ReLik")
|
133 |
+
st.markdown(
|
134 |
+
"### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
|
135 |
+
)
|
136 |
+
# st.markdown(
|
137 |
+
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
138 |
+
# "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
|
139 |
+
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
140 |
+
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
141 |
+
# )
|
142 |
+
badge(type="github", name="sapienzanlp/relik")
|
143 |
+
badge(type="pypi", name="relik")
|
144 |
+
|
145 |
+
|
146 |
+
def run_client():
|
147 |
+
with open(Path(__file__).parent / "style.css") as f:
|
148 |
+
css = f.read()
|
149 |
+
|
150 |
+
st.set_page_config(
|
151 |
+
page_title="ReLik",
|
152 |
+
page_icon="🦮",
|
153 |
+
layout="wide",
|
154 |
+
)
|
155 |
+
set_sidebar(css)
|
156 |
+
set_intro(css)
|
157 |
+
|
158 |
+
# text input
|
159 |
+
text = st.text_area(
|
160 |
+
"Enter Text Below:",
|
161 |
+
value="Obama went to Rome for a quick vacation.",
|
162 |
+
height=200,
|
163 |
+
max_chars=500,
|
164 |
+
)
|
165 |
+
|
166 |
+
with stylable_container(
|
167 |
+
key="annotate_button",
|
168 |
+
css_styles="""
|
169 |
+
button {
|
170 |
+
background-color: #802433;
|
171 |
+
color: white;
|
172 |
+
border-radius: 25px;
|
173 |
+
}
|
174 |
+
""",
|
175 |
+
):
|
176 |
+
submit = st.button("Annotate")
|
177 |
+
# submit = st.button("Run")
|
178 |
+
|
179 |
+
# ReLik API call
|
180 |
+
if submit:
|
181 |
+
text = text.strip()
|
182 |
+
if text:
|
183 |
+
st.markdown("####")
|
184 |
+
st.markdown("#### Entity Linking")
|
185 |
+
with st.spinner(text="In progress"):
|
186 |
+
response = requests.post(RELIK, json=text)
|
187 |
+
if response.status_code != 200:
|
188 |
+
st.error("Error: {}".format(response.status_code))
|
189 |
+
else:
|
190 |
+
response = response.json()
|
191 |
+
|
192 |
+
# Entity Linking
|
193 |
+
# with stylable_container(
|
194 |
+
# key="container_with_border",
|
195 |
+
# css_styles="""
|
196 |
+
# {
|
197 |
+
# border: 1px solid rgba(49, 51, 63, 0.2);
|
198 |
+
# border-radius: 0.5rem;
|
199 |
+
# padding: 0.5rem;
|
200 |
+
# padding-bottom: 2rem;
|
201 |
+
# }
|
202 |
+
# """,
|
203 |
+
# ):
|
204 |
+
# st.markdown("##")
|
205 |
+
dict_of_ents, options = get_el_annotations(response=response)
|
206 |
+
display = displacy.render(
|
207 |
+
dict_of_ents, manual=True, style="ent", options=options
|
208 |
+
)
|
209 |
+
display = display.replace("\n", " ")
|
210 |
+
# wsd_display = re.sub(
|
211 |
+
# r"(wiki::\d+\w)",
|
212 |
+
# r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
|
213 |
+
# language.upper()
|
214 |
+
# ),
|
215 |
+
# wsd_display,
|
216 |
+
# )
|
217 |
+
with st.container():
|
218 |
+
st.write(display, unsafe_allow_html=True)
|
219 |
+
|
220 |
+
st.markdown("####")
|
221 |
+
st.markdown("#### Relation Extraction")
|
222 |
+
|
223 |
+
with st.container():
|
224 |
+
st.write("Coming :)", unsafe_allow_html=True)
|
225 |
+
|
226 |
+
else:
|
227 |
+
st.error("Please enter some text.")
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
run_client()
|
relik/inference/serve/frontend/style.css
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Sidebar */
|
2 |
+
.eczjsme11 {
|
3 |
+
background-color: #802433;
|
4 |
+
}
|
5 |
+
|
6 |
+
.st-emotion-cache-10oheav h2 {
|
7 |
+
color: white;
|
8 |
+
}
|
9 |
+
|
10 |
+
.st-emotion-cache-10oheav li {
|
11 |
+
color: white;
|
12 |
+
}
|
13 |
+
|
14 |
+
/* Main */
|
15 |
+
a:link {
|
16 |
+
text-decoration: none;
|
17 |
+
color: white;
|
18 |
+
}
|
19 |
+
|
20 |
+
a:visited {
|
21 |
+
text-decoration: none;
|
22 |
+
color: white;
|
23 |
+
}
|
24 |
+
|
25 |
+
a:hover {
|
26 |
+
text-decoration: none;
|
27 |
+
color: rgba(255, 255, 255, 0.871);
|
28 |
+
}
|
29 |
+
|
30 |
+
a:active {
|
31 |
+
text-decoration: none;
|
32 |
+
color: white;
|
33 |
+
}
|
relik/reader/__init__.py
ADDED
File without changes
|
relik/reader/conf/config.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Required to make the "experiments" dir the default one for the output of the models
|
2 |
+
hydra:
|
3 |
+
run:
|
4 |
+
dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
5 |
+
|
6 |
+
model_name: relik-reader-deberta-base # used to name the model in wandb and output dir
|
7 |
+
project_name: relik-reader # used to name the project in wandb
|
8 |
+
|
9 |
+
|
10 |
+
defaults:
|
11 |
+
- _self_
|
12 |
+
- training: base
|
13 |
+
- model: base
|
14 |
+
- data: base
|
relik/reader/conf/data/base.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_dataset_path: "relik/reader/data/train.jsonl"
|
2 |
+
val_dataset_path: "relik/reader/data/testa.jsonl"
|
3 |
+
|
4 |
+
train_dataset:
|
5 |
+
_target_: "relik.reader.relik_reader_data.RelikDataset"
|
6 |
+
transformer_model: "${model.model.transformer_model}"
|
7 |
+
materialize_samples: False
|
8 |
+
shuffle_candidates: 0.5
|
9 |
+
random_drop_gold_candidates: 0.05
|
10 |
+
noise_param: 0.0
|
11 |
+
for_inference: False
|
12 |
+
tokens_per_batch: 4096
|
13 |
+
special_symbols: null
|
14 |
+
|
15 |
+
val_dataset:
|
16 |
+
_target_: "relik.reader.relik_reader_data.RelikDataset"
|
17 |
+
transformer_model: "${model.model.transformer_model}"
|
18 |
+
materialize_samples: False
|
19 |
+
shuffle_candidates: False
|
20 |
+
for_inference: True
|
21 |
+
special_symbols: null
|
relik/reader/conf/data/re.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl"
|
2 |
+
val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl"
|
3 |
+
test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl"
|
4 |
+
|
5 |
+
relations_definitions:
|
6 |
+
/people/person/nationality: "nationality"
|
7 |
+
/sports/sports_team/location: "sports team location"
|
8 |
+
/location/country/administrative_divisions: "administrative divisions"
|
9 |
+
/business/company/major_shareholders: "shareholders"
|
10 |
+
/people/ethnicity/people: "ethnicity"
|
11 |
+
/people/ethnicity/geographic_distribution: "geographic distributi6on"
|
12 |
+
/business/company_shareholder/major_shareholder_of: "major shareholder"
|
13 |
+
/location/location/contains: "location"
|
14 |
+
/business/company/founders: "founders"
|
15 |
+
/business/person/company: "company"
|
16 |
+
/business/company/advisors: "advisor"
|
17 |
+
/people/deceased_person/place_of_death: "place of death"
|
18 |
+
/business/company/industry: "industry"
|
19 |
+
/people/person/ethnicity: "ethnic background"
|
20 |
+
/people/person/place_of_birth: "place of birth"
|
21 |
+
/location/administrative_division/country: "country of an administration division"
|
22 |
+
/people/person/place_lived: "place lived"
|
23 |
+
/sports/sports_team_location/teams: "sports team"
|
24 |
+
/people/person/children: "child"
|
25 |
+
/people/person/religion: "religion"
|
26 |
+
/location/neighborhood/neighborhood_of: "neighborhood"
|
27 |
+
/location/country/capital: "capital"
|
28 |
+
/business/company/place_founded: "company founded location"
|
29 |
+
/people/person/profession: "occupation"
|
30 |
+
|
31 |
+
train_dataset:
|
32 |
+
_target_: "relik.reader.relik_reader_re_data.RelikREDataset"
|
33 |
+
transformer_model: "${model.model.transformer_model}"
|
34 |
+
materialize_samples: False
|
35 |
+
shuffle_candidates: False
|
36 |
+
flip_candidates: 1.0
|
37 |
+
noise_param: 0.0
|
38 |
+
for_inference: False
|
39 |
+
tokens_per_batch: 4096
|
40 |
+
min_length: -1
|
41 |
+
special_symbols: null
|
42 |
+
relations_definitions: ${data.relations_definitions}
|
43 |
+
sorting_fields:
|
44 |
+
- "predictable_candidates"
|
45 |
+
val_dataset:
|
46 |
+
_target_: "relik.reader.relik_reader_re_data.RelikREDataset"
|
47 |
+
transformer_model: "${model.model.transformer_model}"
|
48 |
+
materialize_samples: False
|
49 |
+
shuffle_candidates: False
|
50 |
+
flip_candidates: False
|
51 |
+
for_inference: True
|
52 |
+
min_length: -1
|
53 |
+
special_symbols: null
|
54 |
+
relations_definitions: ${data.relations_definitions}
|
relik/reader/conf/training/base.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 94
|
2 |
+
|
3 |
+
trainer:
|
4 |
+
_target_: lightning.Trainer
|
5 |
+
devices:
|
6 |
+
- 0
|
7 |
+
precision: "16-mixed"
|
8 |
+
max_steps: 50000
|
9 |
+
val_check_interval: 1.0
|
10 |
+
num_sanity_val_steps: 0
|
11 |
+
limit_val_batches: 1
|
12 |
+
gradient_clip_val: 1.0
|
relik/reader/conf/training/re.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 15
|
2 |
+
|
3 |
+
trainer:
|
4 |
+
_target_: lightning.Trainer
|
5 |
+
devices:
|
6 |
+
- 0
|
7 |
+
precision: "16-mixed"
|
8 |
+
max_steps: 100000
|
9 |
+
val_check_interval: 1.0
|
10 |
+
num_sanity_val_steps: 0
|
11 |
+
limit_val_batches: 1
|
12 |
+
gradient_clip_val: 1.0
|
relik/reader/data/__init__.py
ADDED
File without changes
|
relik/reader/data/patches.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
4 |
+
from relik.reader.utils.special_symbols import NME_SYMBOL
|
5 |
+
|
6 |
+
|
7 |
+
def merge_patches_predictions(sample) -> None:
|
8 |
+
sample._d["predicted_window_labels"] = dict()
|
9 |
+
predicted_window_labels = sample._d["predicted_window_labels"]
|
10 |
+
|
11 |
+
sample._d["span_title_probabilities"] = dict()
|
12 |
+
span_title_probabilities = sample._d["span_title_probabilities"]
|
13 |
+
|
14 |
+
span2title = dict()
|
15 |
+
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
|
16 |
+
# selecting span predictions
|
17 |
+
for predicted_title, predicted_spans in patch_info[
|
18 |
+
"predicted_window_labels"
|
19 |
+
].items():
|
20 |
+
for pred_span in predicted_spans:
|
21 |
+
pred_span = tuple(pred_span)
|
22 |
+
curr_title = span2title.get(pred_span)
|
23 |
+
if curr_title is None or curr_title == NME_SYMBOL:
|
24 |
+
span2title[pred_span] = predicted_title
|
25 |
+
# else:
|
26 |
+
# print("Merging at patch level")
|
27 |
+
|
28 |
+
# selecting span predictions probability
|
29 |
+
for predicted_span, titles_probabilities in patch_info[
|
30 |
+
"span_title_probabilities"
|
31 |
+
].items():
|
32 |
+
if predicted_span not in span_title_probabilities:
|
33 |
+
span_title_probabilities[predicted_span] = titles_probabilities
|
34 |
+
|
35 |
+
for span, title in span2title.items():
|
36 |
+
if title not in predicted_window_labels:
|
37 |
+
predicted_window_labels[title] = list()
|
38 |
+
predicted_window_labels[title].append(span)
|
39 |
+
|
40 |
+
|
41 |
+
def remove_duplicate_samples(
|
42 |
+
samples: List[RelikReaderSample],
|
43 |
+
) -> List[RelikReaderSample]:
|
44 |
+
seen_sample = set()
|
45 |
+
samples_store = []
|
46 |
+
for sample in samples:
|
47 |
+
sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
|
48 |
+
if sample_id not in seen_sample:
|
49 |
+
seen_sample.add(sample_id)
|
50 |
+
samples_store.append(sample)
|
51 |
+
return samples_store
|
relik/reader/data/relik_reader_data.py
ADDED
@@ -0,0 +1,965 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import (
|
3 |
+
Any,
|
4 |
+
Callable,
|
5 |
+
Dict,
|
6 |
+
Generator,
|
7 |
+
Iterable,
|
8 |
+
Iterator,
|
9 |
+
List,
|
10 |
+
NamedTuple,
|
11 |
+
Optional,
|
12 |
+
Tuple,
|
13 |
+
Union,
|
14 |
+
)
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from torch.utils.data import IterableDataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer
|
21 |
+
|
22 |
+
from relik.reader.data.relik_reader_data_utils import (
|
23 |
+
add_noise_to_value,
|
24 |
+
batchify,
|
25 |
+
chunks,
|
26 |
+
flatten,
|
27 |
+
)
|
28 |
+
from relik.reader.data.relik_reader_sample import (
|
29 |
+
RelikReaderSample,
|
30 |
+
load_relik_reader_samples,
|
31 |
+
)
|
32 |
+
from relik.reader.utils.special_symbols import NME_SYMBOL
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
def preprocess_dataset(
|
38 |
+
input_dataset: Iterable[dict],
|
39 |
+
transformer_model: str,
|
40 |
+
add_topic: bool,
|
41 |
+
) -> Iterable[dict]:
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(transformer_model)
|
43 |
+
for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"):
|
44 |
+
if len(dataset_elem["tokens"]) == 0:
|
45 |
+
print(
|
46 |
+
f"Dataset element with doc id: {dataset_elem['doc_id']}",
|
47 |
+
f"and offset {dataset_elem['offset']} does not contain any token",
|
48 |
+
"Skipping it",
|
49 |
+
)
|
50 |
+
continue
|
51 |
+
|
52 |
+
new_dataset_elem = dict(
|
53 |
+
doc_id=dataset_elem["doc_id"],
|
54 |
+
offset=dataset_elem["offset"],
|
55 |
+
)
|
56 |
+
|
57 |
+
tokenization_out = tokenizer(
|
58 |
+
dataset_elem["tokens"],
|
59 |
+
return_offsets_mapping=True,
|
60 |
+
add_special_tokens=False,
|
61 |
+
)
|
62 |
+
|
63 |
+
window_tokens = tokenization_out.input_ids
|
64 |
+
window_tokens = flatten(window_tokens)
|
65 |
+
|
66 |
+
offsets_mapping = [
|
67 |
+
[
|
68 |
+
(
|
69 |
+
ss + dataset_elem["token2char_start"][str(i)],
|
70 |
+
se + dataset_elem["token2char_start"][str(i)],
|
71 |
+
)
|
72 |
+
for ss, se in tokenization_out.offset_mapping[i]
|
73 |
+
]
|
74 |
+
for i in range(len(dataset_elem["tokens"]))
|
75 |
+
]
|
76 |
+
|
77 |
+
offsets_mapping = flatten(offsets_mapping)
|
78 |
+
|
79 |
+
assert len(offsets_mapping) == len(window_tokens)
|
80 |
+
|
81 |
+
window_tokens = (
|
82 |
+
[tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
|
83 |
+
)
|
84 |
+
|
85 |
+
topic_offset = 0
|
86 |
+
if add_topic:
|
87 |
+
topic_tokens = tokenizer(
|
88 |
+
dataset_elem["doc_topic"], add_special_tokens=False
|
89 |
+
).input_ids
|
90 |
+
topic_offset = len(topic_tokens)
|
91 |
+
new_dataset_elem["topic_tokens"] = topic_offset
|
92 |
+
window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
|
93 |
+
|
94 |
+
new_dataset_elem.update(
|
95 |
+
dict(
|
96 |
+
tokens=window_tokens,
|
97 |
+
token2char_start={
|
98 |
+
str(i): s
|
99 |
+
for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
|
100 |
+
},
|
101 |
+
token2char_end={
|
102 |
+
str(i): e
|
103 |
+
for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
|
104 |
+
},
|
105 |
+
window_candidates=dataset_elem["window_candidates"],
|
106 |
+
window_candidates_scores=dataset_elem.get("window_candidates_scores"),
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
if "window_labels" in dataset_elem:
|
111 |
+
window_labels = [
|
112 |
+
(s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"]
|
113 |
+
]
|
114 |
+
|
115 |
+
new_dataset_elem["window_labels"] = window_labels
|
116 |
+
|
117 |
+
if not all(
|
118 |
+
[
|
119 |
+
s in new_dataset_elem["token2char_start"].values()
|
120 |
+
for s, _, _ in new_dataset_elem["window_labels"]
|
121 |
+
]
|
122 |
+
):
|
123 |
+
print(
|
124 |
+
"Mismatching token start char mapping with labels",
|
125 |
+
new_dataset_elem["token2char_start"],
|
126 |
+
new_dataset_elem["window_labels"],
|
127 |
+
dataset_elem["tokens"],
|
128 |
+
)
|
129 |
+
continue
|
130 |
+
|
131 |
+
if not all(
|
132 |
+
[
|
133 |
+
e in new_dataset_elem["token2char_end"].values()
|
134 |
+
for _, e, _ in new_dataset_elem["window_labels"]
|
135 |
+
]
|
136 |
+
):
|
137 |
+
print(
|
138 |
+
"Mismatching token end char mapping with labels",
|
139 |
+
new_dataset_elem["token2char_end"],
|
140 |
+
new_dataset_elem["window_labels"],
|
141 |
+
dataset_elem["tokens"],
|
142 |
+
)
|
143 |
+
continue
|
144 |
+
|
145 |
+
yield new_dataset_elem
|
146 |
+
|
147 |
+
|
148 |
+
def preprocess_sample(
|
149 |
+
relik_sample: RelikReaderSample,
|
150 |
+
tokenizer,
|
151 |
+
lowercase_policy: float,
|
152 |
+
add_topic: bool = False,
|
153 |
+
) -> None:
|
154 |
+
if len(relik_sample.tokens) == 0:
|
155 |
+
return
|
156 |
+
|
157 |
+
if lowercase_policy > 0:
|
158 |
+
lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy
|
159 |
+
relik_sample.tokens = [
|
160 |
+
t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens)
|
161 |
+
]
|
162 |
+
|
163 |
+
tokenization_out = tokenizer(
|
164 |
+
relik_sample.tokens,
|
165 |
+
return_offsets_mapping=True,
|
166 |
+
add_special_tokens=False,
|
167 |
+
)
|
168 |
+
|
169 |
+
window_tokens = tokenization_out.input_ids
|
170 |
+
window_tokens = flatten(window_tokens)
|
171 |
+
|
172 |
+
offsets_mapping = [
|
173 |
+
[
|
174 |
+
(
|
175 |
+
ss + relik_sample.token2char_start[str(i)],
|
176 |
+
se + relik_sample.token2char_start[str(i)],
|
177 |
+
)
|
178 |
+
for ss, se in tokenization_out.offset_mapping[i]
|
179 |
+
]
|
180 |
+
for i in range(len(relik_sample.tokens))
|
181 |
+
]
|
182 |
+
|
183 |
+
offsets_mapping = flatten(offsets_mapping)
|
184 |
+
|
185 |
+
assert len(offsets_mapping) == len(window_tokens)
|
186 |
+
|
187 |
+
window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
|
188 |
+
|
189 |
+
topic_offset = 0
|
190 |
+
if add_topic:
|
191 |
+
topic_tokens = tokenizer(
|
192 |
+
relik_sample.doc_topic, add_special_tokens=False
|
193 |
+
).input_ids
|
194 |
+
topic_offset = len(topic_tokens)
|
195 |
+
relik_sample.topic_tokens = topic_offset
|
196 |
+
window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
|
197 |
+
|
198 |
+
relik_sample._d.update(
|
199 |
+
dict(
|
200 |
+
tokens=window_tokens,
|
201 |
+
token2char_start={
|
202 |
+
str(i): s
|
203 |
+
for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
|
204 |
+
},
|
205 |
+
token2char_end={
|
206 |
+
str(i): e
|
207 |
+
for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
|
208 |
+
},
|
209 |
+
)
|
210 |
+
)
|
211 |
+
|
212 |
+
if "window_labels" in relik_sample._d:
|
213 |
+
relik_sample.window_labels = [
|
214 |
+
(s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels
|
215 |
+
]
|
216 |
+
|
217 |
+
|
218 |
+
class TokenizationOutput(NamedTuple):
|
219 |
+
input_ids: torch.Tensor
|
220 |
+
attention_mask: torch.Tensor
|
221 |
+
token_type_ids: torch.Tensor
|
222 |
+
prediction_mask: torch.Tensor
|
223 |
+
special_symbols_mask: torch.Tensor
|
224 |
+
|
225 |
+
|
226 |
+
class RelikDataset(IterableDataset):
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
dataset_path: Optional[str],
|
230 |
+
materialize_samples: bool,
|
231 |
+
transformer_model: Union[str, PreTrainedTokenizer],
|
232 |
+
special_symbols: List[str],
|
233 |
+
shuffle_candidates: Optional[Union[bool, float]] = False,
|
234 |
+
for_inference: bool = False,
|
235 |
+
noise_param: float = 0.1,
|
236 |
+
sorting_fields: Optional[str] = None,
|
237 |
+
tokens_per_batch: int = 2048,
|
238 |
+
batch_size: int = None,
|
239 |
+
max_batch_size: int = 128,
|
240 |
+
section_size: int = 50_000,
|
241 |
+
prebatch: bool = True,
|
242 |
+
random_drop_gold_candidates: float = 0.0,
|
243 |
+
use_nme: bool = True,
|
244 |
+
max_subwords_per_candidate: bool = 22,
|
245 |
+
mask_by_instances: bool = False,
|
246 |
+
min_length: int = 5,
|
247 |
+
max_length: int = 2048,
|
248 |
+
model_max_length: int = 1000,
|
249 |
+
split_on_cand_overload: bool = True,
|
250 |
+
skip_empty_training_samples: bool = False,
|
251 |
+
drop_last: bool = False,
|
252 |
+
samples: Optional[Iterator[RelikReaderSample]] = None,
|
253 |
+
lowercase_policy: float = 0.0,
|
254 |
+
**kwargs,
|
255 |
+
):
|
256 |
+
super().__init__(**kwargs)
|
257 |
+
self.dataset_path = dataset_path
|
258 |
+
self.materialize_samples = materialize_samples
|
259 |
+
self.samples: Optional[List[RelikReaderSample]] = None
|
260 |
+
if self.materialize_samples:
|
261 |
+
self.samples = list()
|
262 |
+
|
263 |
+
if isinstance(transformer_model, str):
|
264 |
+
self.tokenizer = self._build_tokenizer(transformer_model, special_symbols)
|
265 |
+
else:
|
266 |
+
self.tokenizer = transformer_model
|
267 |
+
self.special_symbols = special_symbols
|
268 |
+
self.shuffle_candidates = shuffle_candidates
|
269 |
+
self.for_inference = for_inference
|
270 |
+
self.noise_param = noise_param
|
271 |
+
self.batching_fields = ["input_ids"]
|
272 |
+
self.sorting_fields = (
|
273 |
+
sorting_fields if sorting_fields is not None else self.batching_fields
|
274 |
+
)
|
275 |
+
|
276 |
+
self.tokens_per_batch = tokens_per_batch
|
277 |
+
self.batch_size = batch_size
|
278 |
+
self.max_batch_size = max_batch_size
|
279 |
+
self.section_size = section_size
|
280 |
+
self.prebatch = prebatch
|
281 |
+
|
282 |
+
self.random_drop_gold_candidates = random_drop_gold_candidates
|
283 |
+
self.use_nme = use_nme
|
284 |
+
self.max_subwords_per_candidate = max_subwords_per_candidate
|
285 |
+
self.mask_by_instances = mask_by_instances
|
286 |
+
self.min_length = min_length
|
287 |
+
self.max_length = max_length
|
288 |
+
self.model_max_length = (
|
289 |
+
model_max_length
|
290 |
+
if model_max_length < self.tokenizer.model_max_length
|
291 |
+
else self.tokenizer.model_max_length
|
292 |
+
)
|
293 |
+
|
294 |
+
# retrocompatibility workaround
|
295 |
+
self.transformer_model = (
|
296 |
+
transformer_model
|
297 |
+
if isinstance(transformer_model, str)
|
298 |
+
else transformer_model.name_or_path
|
299 |
+
)
|
300 |
+
self.split_on_cand_overload = split_on_cand_overload
|
301 |
+
self.skip_empty_training_samples = skip_empty_training_samples
|
302 |
+
self.drop_last = drop_last
|
303 |
+
self.lowercase_policy = lowercase_policy
|
304 |
+
self.samples = samples
|
305 |
+
|
306 |
+
def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]):
|
307 |
+
return AutoTokenizer.from_pretrained(
|
308 |
+
transformer_model,
|
309 |
+
additional_special_tokens=[ss for ss in special_symbols],
|
310 |
+
add_prefix_space=True,
|
311 |
+
)
|
312 |
+
|
313 |
+
@property
|
314 |
+
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
|
315 |
+
fields_batchers = {
|
316 |
+
"input_ids": lambda x: batchify(
|
317 |
+
x, padding_value=self.tokenizer.pad_token_id
|
318 |
+
),
|
319 |
+
"attention_mask": lambda x: batchify(x, padding_value=0),
|
320 |
+
"token_type_ids": lambda x: batchify(x, padding_value=0),
|
321 |
+
"prediction_mask": lambda x: batchify(x, padding_value=1),
|
322 |
+
"global_attention": lambda x: batchify(x, padding_value=0),
|
323 |
+
"token2word": None,
|
324 |
+
"sample": None,
|
325 |
+
"special_symbols_mask": lambda x: batchify(x, padding_value=False),
|
326 |
+
"start_labels": lambda x: batchify(x, padding_value=-100),
|
327 |
+
"end_labels": lambda x: batchify(x, padding_value=-100),
|
328 |
+
"predictable_candidates_symbols": None,
|
329 |
+
"predictable_candidates": None,
|
330 |
+
"patch_offset": None,
|
331 |
+
"optimus_labels": None,
|
332 |
+
}
|
333 |
+
|
334 |
+
if "roberta" in self.transformer_model:
|
335 |
+
del fields_batchers["token_type_ids"]
|
336 |
+
|
337 |
+
return fields_batchers
|
338 |
+
|
339 |
+
def _build_input_ids(
|
340 |
+
self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]]
|
341 |
+
) -> List[int]:
|
342 |
+
return (
|
343 |
+
[self.tokenizer.cls_token_id]
|
344 |
+
+ sentence_input_ids
|
345 |
+
+ [self.tokenizer.sep_token_id]
|
346 |
+
+ flatten(candidates_input_ids)
|
347 |
+
+ [self.tokenizer.sep_token_id]
|
348 |
+
)
|
349 |
+
|
350 |
+
def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
351 |
+
special_symbols_mask = input_ids >= (
|
352 |
+
len(self.tokenizer) - len(self.special_symbols)
|
353 |
+
)
|
354 |
+
special_symbols_mask[0] = True
|
355 |
+
return special_symbols_mask
|
356 |
+
|
357 |
+
def _build_tokenizer_essentials(
|
358 |
+
self, input_ids, original_sequence, sample
|
359 |
+
) -> TokenizationOutput:
|
360 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
361 |
+
attention_mask = torch.ones_like(input_ids)
|
362 |
+
|
363 |
+
total_sequence_len = len(input_ids)
|
364 |
+
predictable_sentence_len = len(original_sequence)
|
365 |
+
|
366 |
+
# token type ids
|
367 |
+
token_type_ids = torch.cat(
|
368 |
+
[
|
369 |
+
input_ids.new_zeros(
|
370 |
+
predictable_sentence_len + 2
|
371 |
+
), # original sentence bpes + CLS and SEP
|
372 |
+
input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2),
|
373 |
+
]
|
374 |
+
)
|
375 |
+
|
376 |
+
# prediction mask -> boolean on tokens that are predictable
|
377 |
+
|
378 |
+
prediction_mask = torch.tensor(
|
379 |
+
[1]
|
380 |
+
+ ([0] * predictable_sentence_len)
|
381 |
+
+ ([1] * (total_sequence_len - predictable_sentence_len - 1))
|
382 |
+
)
|
383 |
+
|
384 |
+
# add topic tokens to the prediction mask so that they cannot be predicted
|
385 |
+
# or optimized during training
|
386 |
+
topic_tokens = getattr(sample, "topic_tokens", None)
|
387 |
+
if topic_tokens is not None:
|
388 |
+
prediction_mask[1 : 1 + topic_tokens] = 1
|
389 |
+
|
390 |
+
# If mask by instances is active the prediction mask is applied to everything
|
391 |
+
# that is not indicated as an instance in the training set.
|
392 |
+
if self.mask_by_instances:
|
393 |
+
char_start2token = {
|
394 |
+
cs: int(tok) for tok, cs in sample.token2char_start.items()
|
395 |
+
}
|
396 |
+
char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()}
|
397 |
+
instances_mask = torch.ones_like(prediction_mask)
|
398 |
+
for _, span_info in sample.instance_id2span_data.items():
|
399 |
+
span_info = span_info[0]
|
400 |
+
token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS
|
401 |
+
token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS
|
402 |
+
instances_mask[token_start : token_end + 1] = 0
|
403 |
+
|
404 |
+
prediction_mask += instances_mask
|
405 |
+
prediction_mask[prediction_mask > 1] = 1
|
406 |
+
|
407 |
+
assert len(prediction_mask) == len(input_ids)
|
408 |
+
|
409 |
+
# special symbols mask
|
410 |
+
special_symbols_mask = self._get_special_symbols_mask(input_ids)
|
411 |
+
|
412 |
+
return TokenizationOutput(
|
413 |
+
input_ids,
|
414 |
+
attention_mask,
|
415 |
+
token_type_ids,
|
416 |
+
prediction_mask,
|
417 |
+
special_symbols_mask,
|
418 |
+
)
|
419 |
+
|
420 |
+
def _build_labels(
|
421 |
+
self,
|
422 |
+
sample,
|
423 |
+
tokenization_output: TokenizationOutput,
|
424 |
+
predictable_candidates: List[str],
|
425 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
426 |
+
start_labels = [0] * len(tokenization_output.input_ids)
|
427 |
+
end_labels = [0] * len(tokenization_output.input_ids)
|
428 |
+
|
429 |
+
char_start2token = {v: int(k) for k, v in sample.token2char_start.items()}
|
430 |
+
char_end2token = {v: int(k) for k, v in sample.token2char_end.items()}
|
431 |
+
for cs, ce, gold_candidate_title in sample.window_labels:
|
432 |
+
if gold_candidate_title not in predictable_candidates:
|
433 |
+
if self.use_nme:
|
434 |
+
gold_candidate_title = NME_SYMBOL
|
435 |
+
else:
|
436 |
+
continue
|
437 |
+
# +1 is to account for the CLS token
|
438 |
+
start_bpe = char_start2token[cs] + 1
|
439 |
+
end_bpe = char_end2token[ce] + 1
|
440 |
+
class_index = predictable_candidates.index(gold_candidate_title)
|
441 |
+
if (
|
442 |
+
start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0
|
443 |
+
): # prevent from having entities that ends with the same label
|
444 |
+
start_labels[start_bpe] = class_index + 1 # +1 for the NONE class
|
445 |
+
end_labels[end_bpe] = class_index + 1 # +1 for the NONE class
|
446 |
+
else:
|
447 |
+
print(
|
448 |
+
"Found entity with the same last subword, it will not be included."
|
449 |
+
)
|
450 |
+
print(
|
451 |
+
cs,
|
452 |
+
ce,
|
453 |
+
gold_candidate_title,
|
454 |
+
start_labels,
|
455 |
+
end_labels,
|
456 |
+
sample.doc_id,
|
457 |
+
)
|
458 |
+
|
459 |
+
ignored_labels_indices = tokenization_output.prediction_mask == 1
|
460 |
+
|
461 |
+
start_labels = torch.tensor(start_labels, dtype=torch.long)
|
462 |
+
start_labels[ignored_labels_indices] = -100
|
463 |
+
|
464 |
+
end_labels = torch.tensor(end_labels, dtype=torch.long)
|
465 |
+
end_labels[ignored_labels_indices] = -100
|
466 |
+
|
467 |
+
return start_labels, end_labels
|
468 |
+
|
469 |
+
def produce_sample_bag(
|
470 |
+
self, sample, predictable_candidates: List[str], candidates_starting_offset: int
|
471 |
+
) -> Optional[Tuple[dict, list, int]]:
|
472 |
+
# input sentence tokenization
|
473 |
+
input_subwords = sample.tokens[1:-1] # removing special tokens
|
474 |
+
candidates_symbols = self.special_symbols[candidates_starting_offset:]
|
475 |
+
|
476 |
+
predictable_candidates = list(predictable_candidates)
|
477 |
+
original_predictable_candidates = list(predictable_candidates)
|
478 |
+
|
479 |
+
# add NME as a possible candidate
|
480 |
+
if self.use_nme:
|
481 |
+
predictable_candidates.insert(0, NME_SYMBOL)
|
482 |
+
|
483 |
+
# candidates encoding
|
484 |
+
candidates_symbols = candidates_symbols[: len(predictable_candidates)]
|
485 |
+
candidates_encoding_result = self.tokenizer.batch_encode_plus(
|
486 |
+
[
|
487 |
+
"{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL
|
488 |
+
for cs, ct in zip(candidates_symbols, predictable_candidates)
|
489 |
+
],
|
490 |
+
add_special_tokens=False,
|
491 |
+
).input_ids
|
492 |
+
|
493 |
+
if (
|
494 |
+
self.max_subwords_per_candidate is not None
|
495 |
+
and self.max_subwords_per_candidate > 0
|
496 |
+
):
|
497 |
+
candidates_encoding_result = [
|
498 |
+
cer[: self.max_subwords_per_candidate]
|
499 |
+
for cer in candidates_encoding_result
|
500 |
+
]
|
501 |
+
|
502 |
+
# drop candidates if the number of input tokens is too long for the model
|
503 |
+
if (
|
504 |
+
sum(map(len, candidates_encoding_result))
|
505 |
+
+ len(input_subwords)
|
506 |
+
+ 20 # + 20 special tokens
|
507 |
+
> self.model_max_length
|
508 |
+
):
|
509 |
+
acceptable_tokens_from_candidates = (
|
510 |
+
self.model_max_length - 20 - len(input_subwords)
|
511 |
+
)
|
512 |
+
i = 0
|
513 |
+
cum_len = 0
|
514 |
+
while (
|
515 |
+
cum_len + len(candidates_encoding_result[i])
|
516 |
+
< acceptable_tokens_from_candidates
|
517 |
+
):
|
518 |
+
cum_len += len(candidates_encoding_result[i])
|
519 |
+
i += 1
|
520 |
+
|
521 |
+
candidates_encoding_result = candidates_encoding_result[:i]
|
522 |
+
candidates_symbols = candidates_symbols[:i]
|
523 |
+
predictable_candidates = predictable_candidates[:i]
|
524 |
+
|
525 |
+
# final input_ids build
|
526 |
+
input_ids = self._build_input_ids(
|
527 |
+
sentence_input_ids=input_subwords,
|
528 |
+
candidates_input_ids=candidates_encoding_result,
|
529 |
+
)
|
530 |
+
|
531 |
+
# complete input building (e.g. attention / prediction mask)
|
532 |
+
tokenization_output = self._build_tokenizer_essentials(
|
533 |
+
input_ids, input_subwords, sample
|
534 |
+
)
|
535 |
+
|
536 |
+
output_dict = {
|
537 |
+
"input_ids": tokenization_output.input_ids,
|
538 |
+
"attention_mask": tokenization_output.attention_mask,
|
539 |
+
"token_type_ids": tokenization_output.token_type_ids,
|
540 |
+
"prediction_mask": tokenization_output.prediction_mask,
|
541 |
+
"special_symbols_mask": tokenization_output.special_symbols_mask,
|
542 |
+
"sample": sample,
|
543 |
+
"predictable_candidates_symbols": candidates_symbols,
|
544 |
+
"predictable_candidates": predictable_candidates,
|
545 |
+
}
|
546 |
+
|
547 |
+
# labels creation
|
548 |
+
if sample.window_labels is not None:
|
549 |
+
start_labels, end_labels = self._build_labels(
|
550 |
+
sample,
|
551 |
+
tokenization_output,
|
552 |
+
predictable_candidates,
|
553 |
+
)
|
554 |
+
output_dict.update(start_labels=start_labels, end_labels=end_labels)
|
555 |
+
|
556 |
+
if (
|
557 |
+
"roberta" in self.transformer_model
|
558 |
+
or "longformer" in self.transformer_model
|
559 |
+
):
|
560 |
+
del output_dict["token_type_ids"]
|
561 |
+
|
562 |
+
predictable_candidates_set = set(predictable_candidates)
|
563 |
+
remaining_candidates = [
|
564 |
+
candidate
|
565 |
+
for candidate in original_predictable_candidates
|
566 |
+
if candidate not in predictable_candidates_set
|
567 |
+
]
|
568 |
+
total_used_candidates = (
|
569 |
+
candidates_starting_offset
|
570 |
+
+ len(predictable_candidates)
|
571 |
+
- (1 if self.use_nme else 0)
|
572 |
+
)
|
573 |
+
|
574 |
+
if self.use_nme:
|
575 |
+
assert predictable_candidates[0] == NME_SYMBOL
|
576 |
+
|
577 |
+
return output_dict, remaining_candidates, total_used_candidates
|
578 |
+
|
579 |
+
def __iter__(self):
|
580 |
+
dataset_iterator = self.dataset_iterator_func()
|
581 |
+
|
582 |
+
current_dataset_elements = []
|
583 |
+
|
584 |
+
i = None
|
585 |
+
for i, dataset_elem in enumerate(dataset_iterator, start=1):
|
586 |
+
if (
|
587 |
+
self.section_size is not None
|
588 |
+
and len(current_dataset_elements) == self.section_size
|
589 |
+
):
|
590 |
+
for batch in self.materialize_batches(current_dataset_elements):
|
591 |
+
yield batch
|
592 |
+
current_dataset_elements = []
|
593 |
+
|
594 |
+
current_dataset_elements.append(dataset_elem)
|
595 |
+
|
596 |
+
if i % 50_000 == 0:
|
597 |
+
logger.info(f"Processed: {i} number of elements")
|
598 |
+
|
599 |
+
if len(current_dataset_elements) != 0:
|
600 |
+
for batch in self.materialize_batches(current_dataset_elements):
|
601 |
+
yield batch
|
602 |
+
|
603 |
+
if i is not None:
|
604 |
+
logger.info(f"Dataset finished: {i} number of elements processed")
|
605 |
+
else:
|
606 |
+
logger.warning("Dataset empty")
|
607 |
+
|
608 |
+
def dataset_iterator_func(self):
|
609 |
+
skipped_instances = 0
|
610 |
+
data_samples = (
|
611 |
+
load_relik_reader_samples(self.dataset_path)
|
612 |
+
if self.samples is None
|
613 |
+
else self.samples
|
614 |
+
)
|
615 |
+
for sample in data_samples:
|
616 |
+
preprocess_sample(
|
617 |
+
sample, self.tokenizer, lowercase_policy=self.lowercase_policy
|
618 |
+
)
|
619 |
+
current_patch = 0
|
620 |
+
sample_bag, used_candidates = None, None
|
621 |
+
remaining_candidates = list(sample.window_candidates)
|
622 |
+
|
623 |
+
if not self.for_inference:
|
624 |
+
# randomly drop gold candidates at training time
|
625 |
+
if (
|
626 |
+
self.random_drop_gold_candidates > 0.0
|
627 |
+
and np.random.uniform() < self.random_drop_gold_candidates
|
628 |
+
and len(set(ct for _, _, ct in sample.window_labels)) > 1
|
629 |
+
):
|
630 |
+
# selecting candidates to drop
|
631 |
+
np.random.shuffle(sample.window_labels)
|
632 |
+
n_dropped_candidates = np.random.randint(
|
633 |
+
0, len(sample.window_labels) - 1
|
634 |
+
)
|
635 |
+
dropped_candidates = [
|
636 |
+
label_elem[-1]
|
637 |
+
for label_elem in sample.window_labels[:n_dropped_candidates]
|
638 |
+
]
|
639 |
+
dropped_candidates = set(dropped_candidates)
|
640 |
+
|
641 |
+
# saving NMEs because they should not be dropped
|
642 |
+
if NME_SYMBOL in dropped_candidates:
|
643 |
+
dropped_candidates.remove(NME_SYMBOL)
|
644 |
+
|
645 |
+
# sample update
|
646 |
+
sample.window_labels = [
|
647 |
+
(s, e, _l)
|
648 |
+
if _l not in dropped_candidates
|
649 |
+
else (s, e, NME_SYMBOL)
|
650 |
+
for s, e, _l in sample.window_labels
|
651 |
+
]
|
652 |
+
remaining_candidates = [
|
653 |
+
wc
|
654 |
+
for wc in remaining_candidates
|
655 |
+
if wc not in dropped_candidates
|
656 |
+
]
|
657 |
+
|
658 |
+
# shuffle candidates
|
659 |
+
if (
|
660 |
+
isinstance(self.shuffle_candidates, bool)
|
661 |
+
and self.shuffle_candidates
|
662 |
+
) or (
|
663 |
+
isinstance(self.shuffle_candidates, float)
|
664 |
+
and np.random.uniform() < self.shuffle_candidates
|
665 |
+
):
|
666 |
+
np.random.shuffle(remaining_candidates)
|
667 |
+
|
668 |
+
while len(remaining_candidates) != 0:
|
669 |
+
sample_bag = self.produce_sample_bag(
|
670 |
+
sample,
|
671 |
+
predictable_candidates=remaining_candidates,
|
672 |
+
candidates_starting_offset=used_candidates
|
673 |
+
if used_candidates is not None
|
674 |
+
else 0,
|
675 |
+
)
|
676 |
+
if sample_bag is not None:
|
677 |
+
sample_bag, remaining_candidates, used_candidates = sample_bag
|
678 |
+
if (
|
679 |
+
self.for_inference
|
680 |
+
or not self.skip_empty_training_samples
|
681 |
+
or (
|
682 |
+
(
|
683 |
+
sample_bag.get("start_labels") is not None
|
684 |
+
and torch.any(sample_bag["start_labels"] > 1).item()
|
685 |
+
)
|
686 |
+
or (
|
687 |
+
sample_bag.get("optimus_labels") is not None
|
688 |
+
and len(sample_bag["optimus_labels"]) > 0
|
689 |
+
)
|
690 |
+
)
|
691 |
+
):
|
692 |
+
sample_bag["patch_offset"] = current_patch
|
693 |
+
current_patch += 1
|
694 |
+
yield sample_bag
|
695 |
+
else:
|
696 |
+
skipped_instances += 1
|
697 |
+
if skipped_instances % 1000 == 0 and skipped_instances != 0:
|
698 |
+
logger.info(
|
699 |
+
f"Skipped {skipped_instances} instances since they did not have any gold labels..."
|
700 |
+
)
|
701 |
+
|
702 |
+
# Just use the first fitting candidates if split on
|
703 |
+
# cand is not True
|
704 |
+
if not self.split_on_cand_overload:
|
705 |
+
break
|
706 |
+
|
707 |
+
def preshuffle_elements(self, dataset_elements: List):
|
708 |
+
# This shuffling is done so that when using the sorting function,
|
709 |
+
# if it is deterministic given a collection and its order, we will
|
710 |
+
# make the whole operation not deterministic anymore.
|
711 |
+
# Basically, the aim is not to build every time the same batches.
|
712 |
+
if not self.for_inference:
|
713 |
+
dataset_elements = np.random.permutation(dataset_elements)
|
714 |
+
|
715 |
+
sorting_fn = (
|
716 |
+
lambda elem: add_noise_to_value(
|
717 |
+
sum(len(elem[k]) for k in self.sorting_fields),
|
718 |
+
noise_param=self.noise_param,
|
719 |
+
)
|
720 |
+
if not self.for_inference
|
721 |
+
else sum(len(elem[k]) for k in self.sorting_fields)
|
722 |
+
)
|
723 |
+
|
724 |
+
dataset_elements = sorted(dataset_elements, key=sorting_fn)
|
725 |
+
|
726 |
+
if self.for_inference:
|
727 |
+
return dataset_elements
|
728 |
+
|
729 |
+
ds = list(chunks(dataset_elements, 64))
|
730 |
+
np.random.shuffle(ds)
|
731 |
+
return flatten(ds)
|
732 |
+
|
733 |
+
def materialize_batches(
|
734 |
+
self, dataset_elements: List[Dict[str, Any]]
|
735 |
+
) -> Generator[Dict[str, Any], None, None]:
|
736 |
+
if self.prebatch:
|
737 |
+
dataset_elements = self.preshuffle_elements(dataset_elements)
|
738 |
+
|
739 |
+
current_batch = []
|
740 |
+
|
741 |
+
# function that creates a batch from the 'current_batch' list
|
742 |
+
def output_batch() -> Dict[str, Any]:
|
743 |
+
assert (
|
744 |
+
len(
|
745 |
+
set([len(elem["predictable_candidates"]) for elem in current_batch])
|
746 |
+
)
|
747 |
+
== 1
|
748 |
+
), " ".join(
|
749 |
+
map(
|
750 |
+
str, [len(elem["predictable_candidates"]) for elem in current_batch]
|
751 |
+
)
|
752 |
+
)
|
753 |
+
|
754 |
+
batch_dict = dict()
|
755 |
+
|
756 |
+
de_values_by_field = {
|
757 |
+
fn: [de[fn] for de in current_batch if fn in de]
|
758 |
+
for fn in self.fields_batcher
|
759 |
+
}
|
760 |
+
|
761 |
+
# in case you provide fields batchers but in the batch
|
762 |
+
# there are no elements for that field
|
763 |
+
de_values_by_field = {
|
764 |
+
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
|
765 |
+
}
|
766 |
+
|
767 |
+
assert len(set([len(v) for v in de_values_by_field.values()]))
|
768 |
+
|
769 |
+
# todo: maybe we should report the user about possible
|
770 |
+
# fields filtering due to "None" instances
|
771 |
+
de_values_by_field = {
|
772 |
+
fn: fvs
|
773 |
+
for fn, fvs in de_values_by_field.items()
|
774 |
+
if all([fv is not None for fv in fvs])
|
775 |
+
}
|
776 |
+
|
777 |
+
for field_name, field_values in de_values_by_field.items():
|
778 |
+
field_batch = (
|
779 |
+
self.fields_batcher[field_name](field_values)
|
780 |
+
if self.fields_batcher[field_name] is not None
|
781 |
+
else field_values
|
782 |
+
)
|
783 |
+
|
784 |
+
batch_dict[field_name] = field_batch
|
785 |
+
|
786 |
+
return batch_dict
|
787 |
+
|
788 |
+
max_len_discards, min_len_discards = 0, 0
|
789 |
+
|
790 |
+
should_token_batch = self.batch_size is None
|
791 |
+
|
792 |
+
curr_pred_elements = -1
|
793 |
+
for de in dataset_elements:
|
794 |
+
if (
|
795 |
+
should_token_batch
|
796 |
+
and self.max_batch_size != -1
|
797 |
+
and len(current_batch) == self.max_batch_size
|
798 |
+
) or (not should_token_batch and len(current_batch) == self.batch_size):
|
799 |
+
yield output_batch()
|
800 |
+
current_batch = []
|
801 |
+
curr_pred_elements = -1
|
802 |
+
|
803 |
+
too_long_fields = [
|
804 |
+
k
|
805 |
+
for k in de
|
806 |
+
if self.max_length != -1
|
807 |
+
and torch.is_tensor(de[k])
|
808 |
+
and len(de[k]) > self.max_length
|
809 |
+
]
|
810 |
+
if len(too_long_fields) > 0:
|
811 |
+
max_len_discards += 1
|
812 |
+
continue
|
813 |
+
|
814 |
+
too_short_fields = [
|
815 |
+
k
|
816 |
+
for k in de
|
817 |
+
if self.min_length != -1
|
818 |
+
and torch.is_tensor(de[k])
|
819 |
+
and len(de[k]) < self.min_length
|
820 |
+
]
|
821 |
+
if len(too_short_fields) > 0:
|
822 |
+
min_len_discards += 1
|
823 |
+
continue
|
824 |
+
|
825 |
+
if should_token_batch:
|
826 |
+
de_len = sum(len(de[k]) for k in self.batching_fields)
|
827 |
+
|
828 |
+
future_max_len = max(
|
829 |
+
de_len,
|
830 |
+
max(
|
831 |
+
[
|
832 |
+
sum(len(bde[k]) for k in self.batching_fields)
|
833 |
+
for bde in current_batch
|
834 |
+
],
|
835 |
+
default=0,
|
836 |
+
),
|
837 |
+
)
|
838 |
+
|
839 |
+
future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
|
840 |
+
|
841 |
+
num_predictable_candidates = len(de["predictable_candidates"])
|
842 |
+
|
843 |
+
if len(current_batch) > 0 and (
|
844 |
+
future_tokens_per_batch >= self.tokens_per_batch
|
845 |
+
or (
|
846 |
+
num_predictable_candidates != curr_pred_elements
|
847 |
+
and curr_pred_elements != -1
|
848 |
+
)
|
849 |
+
):
|
850 |
+
yield output_batch()
|
851 |
+
current_batch = []
|
852 |
+
|
853 |
+
current_batch.append(de)
|
854 |
+
curr_pred_elements = len(de["predictable_candidates"])
|
855 |
+
|
856 |
+
if len(current_batch) != 0 and not self.drop_last:
|
857 |
+
yield output_batch()
|
858 |
+
|
859 |
+
if max_len_discards > 0:
|
860 |
+
if self.for_inference:
|
861 |
+
logger.warning(
|
862 |
+
f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were "
|
863 |
+
f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
|
864 |
+
f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the "
|
865 |
+
f"sample length exceeds the maximum length supported by the current model."
|
866 |
+
)
|
867 |
+
else:
|
868 |
+
logger.warning(
|
869 |
+
f"During iteration, {max_len_discards} elements were "
|
870 |
+
f"discarded since longer than max length {self.max_length}"
|
871 |
+
)
|
872 |
+
|
873 |
+
if min_len_discards > 0:
|
874 |
+
if self.for_inference:
|
875 |
+
logger.warning(
|
876 |
+
f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were "
|
877 |
+
f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
|
878 |
+
f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the "
|
879 |
+
f"sample length is shorter than the minimum length supported by the current model."
|
880 |
+
)
|
881 |
+
else:
|
882 |
+
logger.warning(
|
883 |
+
f"During iteration, {min_len_discards} elements were "
|
884 |
+
f"discarded since shorter than min length {self.min_length}"
|
885 |
+
)
|
886 |
+
|
887 |
+
@staticmethod
|
888 |
+
def convert_tokens_to_char_annotations(
|
889 |
+
sample: RelikReaderSample,
|
890 |
+
remove_nmes: bool = True,
|
891 |
+
) -> RelikReaderSample:
|
892 |
+
"""
|
893 |
+
Converts the token annotations to char annotations.
|
894 |
+
|
895 |
+
Args:
|
896 |
+
sample (:obj:`RelikReaderSample`):
|
897 |
+
The sample to convert.
|
898 |
+
remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
899 |
+
Whether to remove the NMEs from the annotations.
|
900 |
+
Returns:
|
901 |
+
:obj:`RelikReaderSample`: The converted sample.
|
902 |
+
"""
|
903 |
+
char_annotations = set()
|
904 |
+
for (
|
905 |
+
predicted_entity,
|
906 |
+
predicted_spans,
|
907 |
+
) in sample.predicted_window_labels.items():
|
908 |
+
if predicted_entity == NME_SYMBOL and remove_nmes:
|
909 |
+
continue
|
910 |
+
|
911 |
+
for span_start, span_end in predicted_spans:
|
912 |
+
span_start = sample.token2char_start[str(span_start)]
|
913 |
+
span_end = sample.token2char_end[str(span_end)]
|
914 |
+
|
915 |
+
char_annotations.add((span_start, span_end, predicted_entity))
|
916 |
+
|
917 |
+
char_probs_annotations = dict()
|
918 |
+
for (
|
919 |
+
span_start,
|
920 |
+
span_end,
|
921 |
+
), candidates_probs in sample.span_title_probabilities.items():
|
922 |
+
span_start = sample.token2char_start[str(span_start)]
|
923 |
+
span_end = sample.token2char_end[str(span_end)]
|
924 |
+
char_probs_annotations[(span_start, span_end)] = {
|
925 |
+
title for title, _ in candidates_probs
|
926 |
+
}
|
927 |
+
|
928 |
+
sample.predicted_window_labels_chars = char_annotations
|
929 |
+
sample.probs_window_labels_chars = char_probs_annotations
|
930 |
+
|
931 |
+
return sample
|
932 |
+
|
933 |
+
@staticmethod
|
934 |
+
def merge_patches_predictions(sample) -> None:
|
935 |
+
sample._d["predicted_window_labels"] = dict()
|
936 |
+
predicted_window_labels = sample._d["predicted_window_labels"]
|
937 |
+
|
938 |
+
sample._d["span_title_probabilities"] = dict()
|
939 |
+
span_title_probabilities = sample._d["span_title_probabilities"]
|
940 |
+
|
941 |
+
span2title = dict()
|
942 |
+
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
|
943 |
+
# selecting span predictions
|
944 |
+
for predicted_title, predicted_spans in patch_info[
|
945 |
+
"predicted_window_labels"
|
946 |
+
].items():
|
947 |
+
for pred_span in predicted_spans:
|
948 |
+
pred_span = tuple(pred_span)
|
949 |
+
curr_title = span2title.get(pred_span)
|
950 |
+
if curr_title is None or curr_title == NME_SYMBOL:
|
951 |
+
span2title[pred_span] = predicted_title
|
952 |
+
# else:
|
953 |
+
# print("Merging at patch level")
|
954 |
+
|
955 |
+
# selecting span predictions probability
|
956 |
+
for predicted_span, titles_probabilities in patch_info[
|
957 |
+
"span_title_probabilities"
|
958 |
+
].items():
|
959 |
+
if predicted_span not in span_title_probabilities:
|
960 |
+
span_title_probabilities[predicted_span] = titles_probabilities
|
961 |
+
|
962 |
+
for span, title in span2title.items():
|
963 |
+
if title not in predicted_window_labels:
|
964 |
+
predicted_window_labels[title] = list()
|
965 |
+
predicted_window_labels[title].append(span)
|
relik/reader/data/relik_reader_data_utils.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def flatten(lsts: List[list]) -> list:
|
8 |
+
acc_lst = list()
|
9 |
+
for lst in lsts:
|
10 |
+
acc_lst.extend(lst)
|
11 |
+
return acc_lst
|
12 |
+
|
13 |
+
|
14 |
+
def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor:
|
15 |
+
return torch.nn.utils.rnn.pad_sequence(
|
16 |
+
tensors, batch_first=True, padding_value=padding_value
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
|
21 |
+
x = max([t.shape[0] for t in tensors])
|
22 |
+
y = max([t.shape[1] for t in tensors])
|
23 |
+
out_matrix = torch.zeros((len(tensors), x, y))
|
24 |
+
out_matrix += padding_value
|
25 |
+
for i, tensor in enumerate(tensors):
|
26 |
+
out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor
|
27 |
+
return out_matrix
|
28 |
+
|
29 |
+
|
30 |
+
def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
|
31 |
+
x = max([t.shape[0] for t in tensors])
|
32 |
+
y = max([t.shape[1] for t in tensors])
|
33 |
+
rest = tensors[0].shape[2]
|
34 |
+
out_matrix = torch.zeros((len(tensors), x, y, rest))
|
35 |
+
out_matrix += padding_value
|
36 |
+
for i, tensor in enumerate(tensors):
|
37 |
+
out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor
|
38 |
+
return out_matrix
|
39 |
+
|
40 |
+
|
41 |
+
def chunks(lst: list, chunk_size: int) -> List[list]:
|
42 |
+
chunks_acc = list()
|
43 |
+
for i in range(0, len(lst), chunk_size):
|
44 |
+
chunks_acc.append(lst[i : i + chunk_size])
|
45 |
+
return chunks_acc
|
46 |
+
|
47 |
+
|
48 |
+
def add_noise_to_value(value: int, noise_param: float):
|
49 |
+
noise_value = value * noise_param
|
50 |
+
noise = np.random.uniform(-noise_value, noise_value)
|
51 |
+
return max(1, value + noise)
|
relik/reader/data/relik_reader_sample.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Iterable
|
3 |
+
|
4 |
+
|
5 |
+
class RelikReaderSample:
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
super().__setattr__("_d", {})
|
8 |
+
self._d = kwargs
|
9 |
+
|
10 |
+
def __getattribute__(self, item):
|
11 |
+
return super(RelikReaderSample, self).__getattribute__(item)
|
12 |
+
|
13 |
+
def __getattr__(self, item):
|
14 |
+
if item.startswith("__") and item.endswith("__"):
|
15 |
+
# this is likely some python library-specific variable (such as __deepcopy__ for copy)
|
16 |
+
# better follow standard behavior here
|
17 |
+
raise AttributeError(item)
|
18 |
+
elif item in self._d:
|
19 |
+
return self._d[item]
|
20 |
+
else:
|
21 |
+
return None
|
22 |
+
|
23 |
+
def __setattr__(self, key, value):
|
24 |
+
if key in self._d:
|
25 |
+
self._d[key] = value
|
26 |
+
else:
|
27 |
+
super().__setattr__(key, value)
|
28 |
+
|
29 |
+
def to_jsons(self) -> str:
|
30 |
+
if "predicted_window_labels" in self._d:
|
31 |
+
new_obj = {
|
32 |
+
k: v
|
33 |
+
for k, v in self._d.items()
|
34 |
+
if k != "predicted_window_labels" and k != "span_title_probabilities"
|
35 |
+
}
|
36 |
+
new_obj["predicted_window_labels"] = [
|
37 |
+
[ss, se, pred_title]
|
38 |
+
for (ss, se), pred_title in self.predicted_window_labels_chars
|
39 |
+
]
|
40 |
+
else:
|
41 |
+
return json.dumps(self._d)
|
42 |
+
|
43 |
+
|
44 |
+
def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]:
|
45 |
+
with open(path) as f:
|
46 |
+
for line in f:
|
47 |
+
jsonl_line = json.loads(line.strip())
|
48 |
+
relik_reader_sample = RelikReaderSample(**jsonl_line)
|
49 |
+
yield relik_reader_sample
|
relik/reader/lightning_modules/__init__.py
ADDED
File without changes
|
relik/reader/lightning_modules/relik_reader_pl_module.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional
|
2 |
+
|
3 |
+
import lightning
|
4 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
|
5 |
+
|
6 |
+
from relik.reader.relik_reader_core import RelikReaderCoreModel
|
7 |
+
|
8 |
+
|
9 |
+
class RelikReaderPLModule(lightning.LightningModule):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
cfg: dict,
|
13 |
+
transformer_model: str,
|
14 |
+
additional_special_symbols: int,
|
15 |
+
num_layers: Optional[int] = None,
|
16 |
+
activation: str = "gelu",
|
17 |
+
linears_hidden_size: Optional[int] = 512,
|
18 |
+
use_last_k_layers: int = 1,
|
19 |
+
training: bool = False,
|
20 |
+
*args: Any,
|
21 |
+
**kwargs: Any
|
22 |
+
):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self.save_hyperparameters()
|
25 |
+
self.relik_reader_core_model = RelikReaderCoreModel(
|
26 |
+
transformer_model,
|
27 |
+
additional_special_symbols,
|
28 |
+
num_layers,
|
29 |
+
activation,
|
30 |
+
linears_hidden_size,
|
31 |
+
use_last_k_layers,
|
32 |
+
training=training,
|
33 |
+
)
|
34 |
+
self.optimizer_factory = None
|
35 |
+
|
36 |
+
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
37 |
+
relik_output = self.relik_reader_core_model(**batch)
|
38 |
+
self.log("train-loss", relik_output["loss"])
|
39 |
+
return relik_output["loss"]
|
40 |
+
|
41 |
+
def validation_step(
|
42 |
+
self, batch: dict, *args: Any, **kwargs: Any
|
43 |
+
) -> Optional[STEP_OUTPUT]:
|
44 |
+
return
|
45 |
+
|
46 |
+
def set_optimizer_factory(self, optimizer_factory) -> None:
|
47 |
+
self.optimizer_factory = optimizer_factory
|
48 |
+
|
49 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
50 |
+
return self.optimizer_factory(self.relik_reader_core_model)
|
relik/reader/lightning_modules/relik_reader_re_pl_module.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional
|
2 |
+
|
3 |
+
import lightning
|
4 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
|
5 |
+
|
6 |
+
from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
|
7 |
+
|
8 |
+
|
9 |
+
class RelikReaderREPLModule(lightning.LightningModule):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
cfg: dict,
|
13 |
+
transformer_model: str,
|
14 |
+
additional_special_symbols: int,
|
15 |
+
num_layers: Optional[int] = None,
|
16 |
+
activation: str = "gelu",
|
17 |
+
linears_hidden_size: Optional[int] = 512,
|
18 |
+
use_last_k_layers: int = 1,
|
19 |
+
training: bool = False,
|
20 |
+
*args: Any,
|
21 |
+
**kwargs: Any
|
22 |
+
):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self.save_hyperparameters()
|
25 |
+
|
26 |
+
self.relik_reader_re_model = RelikReaderForTripletExtraction(
|
27 |
+
transformer_model,
|
28 |
+
additional_special_symbols,
|
29 |
+
num_layers,
|
30 |
+
activation,
|
31 |
+
linears_hidden_size,
|
32 |
+
use_last_k_layers,
|
33 |
+
training=training,
|
34 |
+
)
|
35 |
+
self.optimizer_factory = None
|
36 |
+
|
37 |
+
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
38 |
+
relik_output = self.relik_reader_re_model(**batch)
|
39 |
+
self.log("train-loss", relik_output["loss"])
|
40 |
+
self.log("train-start_loss", relik_output["ned_start_loss"])
|
41 |
+
self.log("train-end_loss", relik_output["ned_end_loss"])
|
42 |
+
self.log("train-relation_loss", relik_output["re_loss"])
|
43 |
+
return relik_output["loss"]
|
44 |
+
|
45 |
+
def validation_step(
|
46 |
+
self, batch: dict, *args: Any, **kwargs: Any
|
47 |
+
) -> Optional[STEP_OUTPUT]:
|
48 |
+
return
|
49 |
+
|
50 |
+
def set_optimizer_factory(self, optimizer_factory) -> None:
|
51 |
+
self.optimizer_factory = optimizer_factory
|
52 |
+
|
53 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
54 |
+
return self.optimizer_factory(self.relik_reader_re_model)
|
relik/reader/pytorch_modules/__init__.py
ADDED
File without changes
|