up
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +9 -0
- .gitattributes +2 -0
- .gitignore +41 -0
- CHANGELOG.md +111 -0
- CHANGELOG_CN.md +111 -0
- Dockerfile +65 -0
- LICENSE +674 -0
- Makefile +13 -0
- README_CN.md +413 -0
- devscripts/make_readme.py +98 -0
- devscripts/utils.py +42 -0
- docker_prepare.py +28 -0
- fonts/Arial-Unicode-Regular.ttf +3 -0
- fonts/anime_ace.ttf +3 -0
- fonts/anime_ace_3.ttf +3 -0
- fonts/comic shanns 2.ttf +3 -0
- fonts/msgothic.ttc +3 -0
- fonts/msyh.ttc +3 -0
- manga_translator/__init__.py +7 -0
- manga_translator/__main__.py +79 -0
- manga_translator/args.py +182 -0
- manga_translator/colorization/__init__.py +28 -0
- manga_translator/colorization/common.py +24 -0
- manga_translator/colorization/manga_colorization_v2.py +74 -0
- manga_translator/colorization/manga_colorization_v2_utils/denoising/denoiser.py +118 -0
- manga_translator/colorization/manga_colorization_v2_utils/denoising/functions.py +102 -0
- manga_translator/colorization/manga_colorization_v2_utils/denoising/models.py +100 -0
- manga_translator/colorization/manga_colorization_v2_utils/denoising/utils.py +66 -0
- manga_translator/colorization/manga_colorization_v2_utils/networks/extractor.py +127 -0
- manga_translator/colorization/manga_colorization_v2_utils/networks/models.py +319 -0
- manga_translator/colorization/manga_colorization_v2_utils/utils/utils.py +44 -0
- manga_translator/detection/__init__.py +37 -0
- manga_translator/detection/common.py +146 -0
- manga_translator/detection/craft.py +200 -0
- manga_translator/detection/craft_utils/refiner.py +65 -0
- manga_translator/detection/craft_utils/vgg16_bn.py +71 -0
- manga_translator/detection/ctd.py +186 -0
- manga_translator/detection/ctd_utils/__init__.py +5 -0
- manga_translator/detection/ctd_utils/basemodel.py +250 -0
- manga_translator/detection/ctd_utils/textmask.py +174 -0
- manga_translator/detection/ctd_utils/utils/db_utils.py +706 -0
- manga_translator/detection/ctd_utils/utils/imgproc_utils.py +180 -0
- manga_translator/detection/ctd_utils/utils/io_utils.py +54 -0
- manga_translator/detection/ctd_utils/utils/weight_init.py +103 -0
- manga_translator/detection/ctd_utils/utils/yolov5_utils.py +243 -0
- manga_translator/detection/ctd_utils/yolov5/common.py +289 -0
- manga_translator/detection/ctd_utils/yolov5/yolo.py +311 -0
- manga_translator/detection/dbnet_convnext.py +596 -0
- manga_translator/detection/default.py +103 -0
- manga_translator/detection/default_utils/CRAFT_resnet34.py +153 -0
.dockerignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
result
|
2 |
+
*.ckpt
|
3 |
+
*.pt
|
4 |
+
.vscode
|
5 |
+
*.onnx
|
6 |
+
__pycache__
|
7 |
+
ocrs
|
8 |
+
models/*
|
9 |
+
test/testdata/bboxes
|
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.ttc filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.ttf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
result
|
2 |
+
*.ckpt
|
3 |
+
*.pt
|
4 |
+
.vscode
|
5 |
+
*.onnx
|
6 |
+
__pycache__
|
7 |
+
ocrs
|
8 |
+
Manga
|
9 |
+
Manga-translated
|
10 |
+
/models
|
11 |
+
.env
|
12 |
+
*.local
|
13 |
+
*.local.*
|
14 |
+
test/testdata
|
15 |
+
.idea
|
16 |
+
pyvenv.cfg
|
17 |
+
Scripts
|
18 |
+
Lib
|
19 |
+
include
|
20 |
+
share
|
21 |
+
|
22 |
+
# Distribution / packaging
|
23 |
+
.Python
|
24 |
+
build/
|
25 |
+
develop-eggs/
|
26 |
+
dist/
|
27 |
+
downloads/
|
28 |
+
eggs/
|
29 |
+
.eggs/
|
30 |
+
lib/
|
31 |
+
lib64/
|
32 |
+
parts/
|
33 |
+
sdist/
|
34 |
+
var/
|
35 |
+
wheels/
|
36 |
+
share/python-wheels/
|
37 |
+
*.egg-info/
|
38 |
+
.installed.cfg
|
39 |
+
*.egg
|
40 |
+
MANIFEST
|
41 |
+
.history
|
CHANGELOG.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelogs
|
2 |
+
|
3 |
+
### 2023-11-11
|
4 |
+
|
5 |
+
1. Added new OCR model `48px`
|
6 |
+
|
7 |
+
### 2023-05-08
|
8 |
+
|
9 |
+
1. Added [4x-UltraSharp](https://mega.nz/folder/qZRBmaIY#nIG8KyWFcGNTuMX_XNbJ_g) upscaler
|
10 |
+
|
11 |
+
### 2023-04-30
|
12 |
+
|
13 |
+
1. Countless bug fixes and refactor
|
14 |
+
2. Add [CRAFT](https://github.com/clovaai/CRAFT-pytorch) detector, enable by `--detector craft`
|
15 |
+
|
16 |
+
### 2022-06-15
|
17 |
+
|
18 |
+
1. Added New inpainting model LaMa MPE by [dmMaze](https://github.com/dmMaze) and set as default
|
19 |
+
|
20 |
+
### 2022-04-23
|
21 |
+
|
22 |
+
Project version is now at beta-0.3
|
23 |
+
|
24 |
+
1. Added English text renderer by [dmMaze](https://github.com/dmMaze)
|
25 |
+
2. Added new CTC based OCR engine, significant speed improvement
|
26 |
+
3. The new OCR model now support Korean
|
27 |
+
|
28 |
+
### 2022-03-19
|
29 |
+
|
30 |
+
1. Use new font rendering method by [pokedexter](https://github.com/pokedexter)
|
31 |
+
2. Added manual translation UI by [rspreet92](https://github.com/rspreet92)
|
32 |
+
|
33 |
+
### 2022-01-24
|
34 |
+
|
35 |
+
1. Added text detection model by [dmMaze](https://github.com/dmMaze)
|
36 |
+
|
37 |
+
### 2021-08-21
|
38 |
+
|
39 |
+
1. New MST based text region merge algorithm, huge text region merge improvement
|
40 |
+
2. Add baidu translator in demo mode
|
41 |
+
3. Add google translator in demo mode
|
42 |
+
4. Various bugfixes
|
43 |
+
|
44 |
+
### 2021-07-29
|
45 |
+
|
46 |
+
1. Web demo adds translator, detection resolution and target language option
|
47 |
+
2. Slight text color extraction improvement
|
48 |
+
|
49 |
+
### 2021-07-26
|
50 |
+
|
51 |
+
Major upgrades for all components, now we are on beta! \
|
52 |
+
Note in this version all English texts are detected as capital letters, \
|
53 |
+
You need Python >= 3.8 for `cached_property` to work
|
54 |
+
|
55 |
+
1. Detection model upgrade
|
56 |
+
2. OCR model upgrade, better at text color extraction
|
57 |
+
3. Inpainting model upgrade
|
58 |
+
4. Major text rendering improvement, faster rendering and higher quality text with shadow
|
59 |
+
5. Slight mask generation improvement
|
60 |
+
6. Various bugfixes
|
61 |
+
7. Default detection resolution has been dialed back to 1536 from 2048
|
62 |
+
|
63 |
+
### 2021-07-09
|
64 |
+
|
65 |
+
1. Fix erroneous image rendering when inpainting is not used
|
66 |
+
|
67 |
+
### 2021-06-18
|
68 |
+
|
69 |
+
1. Support manual translation
|
70 |
+
2. Support detection and rendering of angled texts
|
71 |
+
|
72 |
+
### 2021-06-13
|
73 |
+
|
74 |
+
1. Text mask completion is now based on CRF, mask quality is drastically improved
|
75 |
+
|
76 |
+
### 2021-06-10
|
77 |
+
|
78 |
+
1. Improve text rendering
|
79 |
+
|
80 |
+
### 2021-06-09
|
81 |
+
|
82 |
+
1. New text region based text direction detection method
|
83 |
+
2. Support running demo as web service
|
84 |
+
|
85 |
+
### 2021-05-20
|
86 |
+
|
87 |
+
1. Text detection model is now based on DBNet with ResNet34 backbone
|
88 |
+
2. OCR model is now trained with more English sentences
|
89 |
+
3. Inpaint model is now based on [AOT](https://arxiv.org/abs/2104.01431) which requires far less memory
|
90 |
+
4. Default inpainting resolution is now increased to 2048, thanks to the new inpainting model
|
91 |
+
5. Support merging hyphenated English words
|
92 |
+
|
93 |
+
### 2021-05-11
|
94 |
+
|
95 |
+
1. Add youdao translate and set as default translator
|
96 |
+
|
97 |
+
### 2021-05-06
|
98 |
+
|
99 |
+
1. Text detection model is now based on DBNet with ResNet101 backbone
|
100 |
+
2. OCR model is now deeper
|
101 |
+
3. Default detection resolution has been increased to 2048 from 1536
|
102 |
+
|
103 |
+
Note this version is slightly better at handling English texts, other than that it is worse in every other ways
|
104 |
+
|
105 |
+
### 2021-03-04
|
106 |
+
|
107 |
+
1. Added inpainting model
|
108 |
+
|
109 |
+
### 2021-02-17
|
110 |
+
|
111 |
+
1. First version launched
|
CHANGELOG_CN.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 更新日志 (中文)
|
2 |
+
|
3 |
+
### 2023-11-11
|
4 |
+
|
5 |
+
1. 添加了新的OCR模型`48px`
|
6 |
+
|
7 |
+
### 2023-05-08
|
8 |
+
|
9 |
+
1. 添加了[4x-UltraSharp](https://mega.nz/folder/qZRBmaIY#nIG8KyWFcGNTuMX_XNbJ_g)超分辨率
|
10 |
+
|
11 |
+
### 2023-04-30
|
12 |
+
|
13 |
+
1. 无数bug修复和重构
|
14 |
+
2. 添加了[CRAFT](https://github.com/clovaai/CRAFT-pytorch)文本检测器,使用`--detector craft`启用
|
15 |
+
|
16 |
+
### 2022-06-15
|
17 |
+
|
18 |
+
1. 增加了来自[dmMaze](https://github.com/dmMaze)的LaMa MPE图像修补模型
|
19 |
+
|
20 |
+
### 2022-04-23
|
21 |
+
|
22 |
+
版本更新为beta-0.3
|
23 |
+
|
24 |
+
1. 增加了来自[dmMaze](https://github.com/dmMaze)的英语文本渲染器
|
25 |
+
2. 增加了基于CTC的OCR模型,识别速度大幅提升
|
26 |
+
3. 新OCR模型增加韩语识别支持
|
27 |
+
|
28 |
+
### 2022-03-19
|
29 |
+
|
30 |
+
1. 增加了来自[pokedexter](https://github.com/pokedexter)的新文本渲染器
|
31 |
+
2. 增加了来自[rspreet92](https://github.com/rspreet92)的人工翻译页面
|
32 |
+
|
33 |
+
### 2022-01-24
|
34 |
+
|
35 |
+
1. 增加了来自[dmMaze](https://github.com/dmMaze)的文本检测模型
|
36 |
+
|
37 |
+
### 2021-08-21
|
38 |
+
|
39 |
+
1. 文本区域合并算法更新,先已经实现几乎完美文本行合并
|
40 |
+
2. 增加演示模式百度翻译支持
|
41 |
+
3. 增加演示模式谷歌翻译支持
|
42 |
+
4. 各类 bug 修复
|
43 |
+
|
44 |
+
### 2021-07-29
|
45 |
+
|
46 |
+
1. 网页版增加翻译器、分辨率和目标语言选项
|
47 |
+
2. 文本颜色提取小腹提升
|
48 |
+
|
49 |
+
### 2021-07-26
|
50 |
+
|
51 |
+
程序所有组件都大幅升级,本程序现已进入 beta 版本! \
|
52 |
+
注意:该版本所有英文检测只会输出大写字母。\
|
53 |
+
你需要 Python>=3.8 版本才能运行
|
54 |
+
|
55 |
+
1. 检测模型升级
|
56 |
+
2. OCR 模型升级,文本颜色抽取质量大幅提升
|
57 |
+
3. 图像修补模型升级
|
58 |
+
4. 文本渲染升级,渲染更快,并支持更高质量的文本和文本阴影渲染
|
59 |
+
5. 文字掩膜补全算法小幅提升
|
60 |
+
6. 各类 BUG 修复
|
61 |
+
7. 默认检测分辨率为 1536
|
62 |
+
|
63 |
+
### 2021-07-09
|
64 |
+
|
65 |
+
1. 修复不使用 inpainting 时图片错误
|
66 |
+
|
67 |
+
### 2021-06-18
|
68 |
+
|
69 |
+
1. 增加手动翻译选项
|
70 |
+
2. 支持倾斜文本的识别和渲染
|
71 |
+
|
72 |
+
### 2021-06-13
|
73 |
+
|
74 |
+
1. 文字掩膜补全算法更新为基于 CRF 算法,补全质量大幅提升
|
75 |
+
|
76 |
+
### 2021-06-10
|
77 |
+
|
78 |
+
1. 完善文本渲染
|
79 |
+
|
80 |
+
### 2021-06-09
|
81 |
+
|
82 |
+
1. 使用基于区域的文本方向检测,文本方向检测效果大幅提升
|
83 |
+
2. 增加 web 服务功能
|
84 |
+
|
85 |
+
### 2021-05-20
|
86 |
+
|
87 |
+
1. 检测模型更新为基于 ResNet34 的 DBNet
|
88 |
+
2. OCR 模型更新增加更多英语预料训练
|
89 |
+
3. 图像修补模型升级到基于[AOT](https://arxiv.org/abs/2104.01431)的模型,占用更少显存
|
90 |
+
4. 图像修补默认分辨率增加到 2048
|
91 |
+
5. 支持多行英语单词合并
|
92 |
+
|
93 |
+
### 2021-05-11
|
94 |
+
|
95 |
+
1. 增加并默认使用有道翻译
|
96 |
+
|
97 |
+
### 2021-05-06
|
98 |
+
|
99 |
+
1. 检测模型更新为基于 ResNet101 的 DBNet
|
100 |
+
2. OCR 模型更新更深
|
101 |
+
3. 默认检测分辨率增加到 2048
|
102 |
+
|
103 |
+
注意这个版本除了英文检测稍微好一些,其他方面都不如之前版本
|
104 |
+
|
105 |
+
### 2021-03-04
|
106 |
+
|
107 |
+
1. 添加图片修补模型
|
108 |
+
|
109 |
+
### 2021-02-17
|
110 |
+
|
111 |
+
1. 初步版本发布
|
Dockerfile
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:latest
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
RUN apt-get update
|
8 |
+
RUN DEBIAN_FRONTEND=noninteractive TZ=asia/shanghai apt-get -y install tzdata
|
9 |
+
# 设置缓存环境变量
|
10 |
+
ENV TRANSFORMERS_CACHE=/app/cache
|
11 |
+
ENV DEEPL_AUTH_KEY="6e4907cd-8926-42e7-aa5d-7561363c82b1:fx"
|
12 |
+
ENV OPENAI_API_KEY="sk-yuBWvBk2lTQoJFYP24A03515D46041429f907dE81cC3F04e"
|
13 |
+
ENV OPENAI_HTTP_PROXY="https://www.ygxdapi.top"
|
14 |
+
RUN mkdir -p /app/cache
|
15 |
+
# Assume root to install required dependencies
|
16 |
+
RUN apt-get install -y git g++ ffmpeg libsm6 libxext6 libvulkan-dev
|
17 |
+
|
18 |
+
|
19 |
+
# Install pip dependencies
|
20 |
+
|
21 |
+
COPY --chown=user requirements.txt /app/requirements.txt
|
22 |
+
|
23 |
+
RUN pip install -r /app/requirements.txt
|
24 |
+
RUN pip install torchvision --force-reinstall
|
25 |
+
RUN pip install "numpy<2.0"
|
26 |
+
# RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
27 |
+
|
28 |
+
RUN apt-get remove -y g++ && \
|
29 |
+
apt-get autoremove -y
|
30 |
+
|
31 |
+
# Copy app
|
32 |
+
COPY --chown=user . /app
|
33 |
+
|
34 |
+
# Prepare models
|
35 |
+
RUN python -u docker_prepare.py
|
36 |
+
|
37 |
+
RUN rm -rf /tmp
|
38 |
+
|
39 |
+
# Add /app to Python module path
|
40 |
+
ENV PYTHONPATH="${PYTHONPATH}:/app"
|
41 |
+
|
42 |
+
WORKDIR /app
|
43 |
+
RUN mkdir -p /app/result && chmod 777 /app/result
|
44 |
+
RUN mkdir -p /app/models/translators && chmod 777 /app/models/translators
|
45 |
+
RUN mkdir -p /app/models/upscaling && chmod 777 /app/models/upscaling
|
46 |
+
RUN mkdir -p /app/cache/models && chmod 777 /app/cache/models
|
47 |
+
RUN mkdir -p /app/cache/.locks && chmod 777 /app/cache/.locks
|
48 |
+
RUN mkdir -p /app/cache/models--kha-white--manga-ocr-base && chmod 777 /app/cache/models--kha-white--manga-ocr-base
|
49 |
+
RUN mkdir -p /app && chmod 777 /app
|
50 |
+
|
51 |
+
ENTRYPOINT ["python", "-m", "manga_translator", "-v", "--mode", "web", "--host", "0.0.0.0", "--port", "7860", "--font-size", "28", "--font-size-offset", "5", "--unclip-ratio", "1.1", "--det-invert"]
|
52 |
+
# # ENTRYPOINT ["python", "-m", "manga_translator", "-v", "--mode", "web", "--host", "0.0.0.0", "--port", "7860", "--use-cuda", "--use-inpainting"]
|
53 |
+
|
54 |
+
|
55 |
+
# 使用指定的基础镜像
|
56 |
+
# FROM zyddnys/manga-image-translator:main
|
57 |
+
|
58 |
+
# 复制需要的文件到容器中
|
59 |
+
# COPY ./../../translate_demo.py /app/translate_demo.py
|
60 |
+
|
61 |
+
# # 暴露端口
|
62 |
+
# EXPOSE 7860
|
63 |
+
|
64 |
+
# # 运行命令
|
65 |
+
# CMD ["--verbose", "--log-web", "--mode", "web", "--use-inpainting", "--use-cuda", "--host=0.0.0.0", "--port=7860"]
|
LICENSE
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
<program> Copyright (C) <year> <name of author>
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<https://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
Makefile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build-image:
|
2 |
+
docker rmi manga-image-translator || true
|
3 |
+
docker build . --tag=manga-image-translator
|
4 |
+
|
5 |
+
run-web-server:
|
6 |
+
docker run --gpus all -p 5003:5003 --ipc=host --rm zyddnys/manga-image-translator:main \
|
7 |
+
--target-lang=ENG \
|
8 |
+
--manga2eng \
|
9 |
+
--verbose \
|
10 |
+
--mode=web \
|
11 |
+
--use-gpu \
|
12 |
+
--host=0.0.0.0 \
|
13 |
+
--port=5003
|
README_CN.md
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 漫画图片翻译器 (中文说明)
|
2 |
+
|
3 |
+
> 一键翻译各类图片内文字\
|
4 |
+
> [English](README.md) | [更新日志](CHANGELOG_CN.md) \
|
5 |
+
> 欢迎加入我们的 Discord <https://discord.gg/Ak8APNy4vb>
|
6 |
+
|
7 |
+
针对群内、各个图站上大量不太可能会有人去翻译的图片设计,让我这种日语小白能够勉强看懂图片\
|
8 |
+
主要支持日语,汉语、英文和韩语\
|
9 |
+
支持图片修补和嵌字\
|
10 |
+
该项目是[求闻转译志](https://github.com/PatchyVideo/MMDOCR-HighPerformance)的 v2 版本
|
11 |
+
|
12 |
+
**只是初步版本,我们需要您的帮助完善**\
|
13 |
+
这个项目目前只完成了简单的 demo,依旧存在大量不完善的地方,我们需要您的帮助完善这个项目!
|
14 |
+
|
15 |
+
## 支持我们
|
16 |
+
|
17 |
+
请支持我们使用 GPU 服务器,谢谢!
|
18 |
+
|
19 |
+
- Ko-fi: <https://ko-fi.com/voilelabs>
|
20 |
+
- Patreon: <https://www.patreon.com/voilelabs>
|
21 |
+
- 爱发电: <https://afdian.net/@voilelabs>
|
22 |
+
|
23 |
+
## 在线版
|
24 |
+
|
25 |
+
官方演示站 (由 zyddnys 维护): <https://cotrans.touhou.ai/>\
|
26 |
+
镜像站 (由 Eidenz 维护): <https://manga.eidenz.com/>\
|
27 |
+
浏览器脚本 (由 QiroNT 维护): <https://greasyfork.org/scripts/437569>
|
28 |
+
|
29 |
+
- 注意如果在线版无法访问说明 Google GCP 又在重启我的服务器,此时请等待我重新开启服务。
|
30 |
+
- 在线版使用的是目前 main 分支最新版本。
|
31 |
+
|
32 |
+
## 使用说明
|
33 |
+
|
34 |
+
```bash
|
35 |
+
# 首先,确信你的机器安装了 Python 3.8 及以上版本,和相应的编译工具
|
36 |
+
$ python --version
|
37 |
+
Python 3.8.13
|
38 |
+
|
39 |
+
# 拉取仓库
|
40 |
+
$ git clone https://github.com/zyddnys/manga-image-translator.git
|
41 |
+
|
42 |
+
# 安装依赖
|
43 |
+
$ pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
|
46 |
+
注意:`pydensecrf` 和其他pip包可能需要操作系统的相应编译工具(如Debian的build-essential)。
|
47 |
+
|
48 |
+
[使用谷歌翻译时可选]\
|
49 |
+
申请有道翻译或者 DeepL 的 API,把你的 `APP_KEY` 和 `APP_SECRET` 或 `AUTH_KEY` 写入 `translators/key.py` 中。
|
50 |
+
|
51 |
+
### 翻译器列表
|
52 |
+
|
53 |
+
| 名称 | 是否需要 API Key | 是否离线可用 | 其他说明 |
|
54 |
+
| -------------- | ------- | ------- | ----------------------------------------------------- |
|
55 |
+
| google | | | |
|
56 |
+
| youdao | ✔️ | | 需要 `YOUDAO_APP_KEY` 和 `YOUDAO_SECRET_KEY` |
|
57 |
+
| baidu | ✔️ | | 需要 `BAIDU_APP_ID` 和 `BAIDU_SECRET_KEY` |
|
58 |
+
| deepl | ✔️ | | 需要 `DEEPL_AUTH_KEY` |
|
59 |
+
| caiyun | ✔️ | | 需要 `CAIYUN_TOKEN` |
|
60 |
+
| gpt3 | ✔️ | | Implements text-davinci-003. Requires `OPENAI_API_KEY`|
|
61 |
+
| gpt3.5 | ✔️ | | Implements gpt-3.5-turbo. Requires `OPENAI_API_KEY` |
|
62 |
+
| gpt4 | ✔️ | | Implements gpt-4. Requires `OPENAI_API_KEY` |
|
63 |
+
| papago | | | |
|
64 |
+
| sakura | | |需要`SAKURA_API_BASE`|
|
65 |
+
| offline | | ✔️ | 自动选择可用的离线模型,只是选择器 |
|
66 |
+
| sugoi | | ✔️ | 只能翻译英文 |
|
67 |
+
| m2m100 | | ✔️ | 可以翻译所有语言 |
|
68 |
+
| m2m100_big | | ✔️ | 带big的是完整尺寸,不带是精简版 |
|
69 |
+
| none | | ✔️ | 翻译成空白文本 |
|
70 |
+
| mbart50 | | ✔️ | |
|
71 |
+
| original | | ✔️ | 翻译成源文本 |
|
72 |
+
|
73 |
+
### 语言代码列表
|
74 |
+
|
75 |
+
可以填入 `--target-lang` 参数
|
76 |
+
|
77 |
+
```yaml
|
78 |
+
CHS: Chinese (Simplified)
|
79 |
+
CHT: Chinese (Traditional)
|
80 |
+
CSY: Czech
|
81 |
+
NLD: Dutch
|
82 |
+
ENG: English
|
83 |
+
FRA: French
|
84 |
+
DEU: German
|
85 |
+
HUN: Hungarian
|
86 |
+
ITA: Italian
|
87 |
+
JPN: Japanese
|
88 |
+
KOR: Korean
|
89 |
+
PLK: Polish
|
90 |
+
PTB: Portuguese (Brazil)
|
91 |
+
ROM: Romanian
|
92 |
+
RUS: Russian
|
93 |
+
ESP: Spanish
|
94 |
+
TRK: Turkish
|
95 |
+
VIN: Vietnames
|
96 |
+
ARA: Arabic
|
97 |
+
SRP: Serbian
|
98 |
+
HRV: Croatian
|
99 |
+
THA: Thai
|
100 |
+
IND: Indonesian
|
101 |
+
FIL: Filipino (Tagalog)
|
102 |
+
```
|
103 |
+
|
104 |
+
<!-- Auto generated start (See devscripts/make_readme.py) -->
|
105 |
+
## 选项
|
106 |
+
|
107 |
+
```text
|
108 |
+
-h, --help show this help message and exit
|
109 |
+
-m, --mode {demo,batch,web,web_client,ws,api}
|
110 |
+
Run demo in single image demo mode (demo), batch
|
111 |
+
translation mode (batch), web service mode (web)
|
112 |
+
-i, --input INPUT [INPUT ...] Path to an image file if using demo mode, or path to an
|
113 |
+
image folder if using batch mode
|
114 |
+
-o, --dest DEST Path to the destination folder for translated images in
|
115 |
+
batch mode
|
116 |
+
-l, --target-lang {CHS,CHT,CSY,NLD,ENG,FRA,DEU,HUN,ITA,JPN,KOR,PLK,PTB,ROM,RUS,ESP,TRK,UKR,VIN,ARA,CNR,SRP,HRV,THA,IND,FIL}
|
117 |
+
Destination language
|
118 |
+
-v, --verbose Print debug info and save intermediate images in result
|
119 |
+
folder
|
120 |
+
-f, --format {png,webp,jpg,xcf,psd,pdf} Output format of the translation.
|
121 |
+
--attempts ATTEMPTS Retry attempts on encountered error. -1 means infinite
|
122 |
+
times.
|
123 |
+
--ignore-errors Skip image on encountered error.
|
124 |
+
--overwrite Overwrite already translated images in batch mode.
|
125 |
+
--skip-no-text Skip image without text (Will not be saved).
|
126 |
+
--model-dir MODEL_DIR Model directory (by default ./models in project root)
|
127 |
+
--use-gpu Turn on/off gpu (automatic selection between mps or cuda)
|
128 |
+
--use-gpu-limited Turn on/off gpu (excluding offline translator)
|
129 |
+
--detector {default,ctd,craft,none} Text detector used for creating a text mask from an
|
130 |
+
image, DO NOT use craft for manga, it's not designed
|
131 |
+
for it
|
132 |
+
--ocr {32px,48px,48px_ctc,mocr} Optical character recognition (OCR) model to use
|
133 |
+
--use-mocr-merge Use bbox merge when Manga OCR inference.
|
134 |
+
--inpainter {default,lama_large,lama_mpe,sd,none,original}
|
135 |
+
Inpainting model to use
|
136 |
+
--upscaler {waifu2x,esrgan,4xultrasharp} Upscaler to use. --upscale-ratio has to be set for it
|
137 |
+
to take effect
|
138 |
+
--upscale-ratio UPSCALE_RATIO Image upscale ratio applied before detection. Can
|
139 |
+
improve text detection.
|
140 |
+
--colorizer {mc2} Colorization model to use.
|
141 |
+
--translator {google,youdao,baidu,deepl,papago,caiyun,gpt3,gpt3.5,gpt4,none,original,offline,nllb,nllb_big,sugoi,jparacrawl,jparacrawl_big,m2m100,sakura}
|
142 |
+
Language translator to use
|
143 |
+
--translator-chain TRANSLATOR_CHAIN Output of one translator goes in another. Example:
|
144 |
+
--translator-chain "google:JPN;sugoi:ENG".
|
145 |
+
--selective-translation SELECTIVE_TRANSLATION
|
146 |
+
Select a translator based on detected language in
|
147 |
+
image. Note the first translation service acts as
|
148 |
+
default if the language isn't defined. Example:
|
149 |
+
--translator-chain "google:JPN;sugoi:ENG".
|
150 |
+
--revert-upscaling Downscales the previously upscaled image after
|
151 |
+
translation back to original size (Use with --upscale-
|
152 |
+
ratio).
|
153 |
+
--detection-size DETECTION_SIZE Size of image used for detection
|
154 |
+
--det-rotate Rotate the image for detection. Might improve
|
155 |
+
detection.
|
156 |
+
--det-auto-rotate Rotate the image for detection to prefer vertical
|
157 |
+
textlines. Might improve detection.
|
158 |
+
--det-invert Invert the image colors for detection. Might improve
|
159 |
+
detection.
|
160 |
+
--det-gamma-correct Applies gamma correction for detection. Might improve
|
161 |
+
detection.
|
162 |
+
--unclip-ratio UNCLIP_RATIO How much to extend text skeleton to form bounding box
|
163 |
+
--box-threshold BOX_THRESHOLD Threshold for bbox generation
|
164 |
+
--text-threshold TEXT_THRESHOLD Threshold for text detection
|
165 |
+
--min-text-length MIN_TEXT_LENGTH Minimum text length of a text region
|
166 |
+
--no-text-lang-skip Dont skip text that is seemingly already in the target
|
167 |
+
language.
|
168 |
+
--inpainting-size INPAINTING_SIZE Size of image used for inpainting (too large will
|
169 |
+
result in OOM)
|
170 |
+
--inpainting-precision {fp32,fp16,bf16} Inpainting precision for lama, use bf16 while you can.
|
171 |
+
--colorization-size COLORIZATION_SIZE Size of image used for colorization. Set to -1 to use
|
172 |
+
full image size
|
173 |
+
--denoise-sigma DENOISE_SIGMA Used by colorizer and affects color strength, range
|
174 |
+
from 0 to 255 (default 30). -1 turns it off.
|
175 |
+
--mask-dilation-offset MASK_DILATION_OFFSET By how much to extend the text mask to remove left-over
|
176 |
+
text pixels of the original image.
|
177 |
+
--font-size FONT_SIZE Use fixed font size for rendering
|
178 |
+
--font-size-offset FONT_SIZE_OFFSET Offset font size by a given amount, positive number
|
179 |
+
increase font size and vice versa
|
180 |
+
--font-size-minimum FONT_SIZE_MINIMUM Minimum output font size. Default is
|
181 |
+
image_sides_sum/200
|
182 |
+
--font-color FONT_COLOR Overwrite the text fg/bg color detected by the OCR
|
183 |
+
model. Use hex string without the "#" such as FFFFFF
|
184 |
+
for a white foreground or FFFFFF:000000 to also have a
|
185 |
+
black background around the text.
|
186 |
+
--line-spacing LINE_SPACING Line spacing is font_size * this value. Default is 0.01
|
187 |
+
for horizontal text and 0.2 for vertical.
|
188 |
+
--force-horizontal Force text to be rendered horizontally
|
189 |
+
--force-vertical Force text to be rendered vertically
|
190 |
+
--align-left Align rendered text left
|
191 |
+
--align-center Align rendered text centered
|
192 |
+
--align-right Align rendered text right
|
193 |
+
--uppercase Change text to uppercase
|
194 |
+
--lowercase Change text to lowercase
|
195 |
+
--no-hyphenation If renderer should be splitting up words using a hyphen
|
196 |
+
character (-)
|
197 |
+
--manga2eng Render english text translated from manga with some
|
198 |
+
additional typesetting. Ignores some other argument
|
199 |
+
options
|
200 |
+
--gpt-config GPT_CONFIG Path to GPT config file, more info in README
|
201 |
+
--use-mtpe Turn on/off machine translation post editing (MTPE) on
|
202 |
+
the command line (works only on linux right now)
|
203 |
+
--save-text Save extracted text and translations into a text file.
|
204 |
+
--save-text-file SAVE_TEXT_FILE Like --save-text but with a specified file path.
|
205 |
+
--filter-text FILTER_TEXT Filter regions by their text with a regex. Example
|
206 |
+
usage: --text-filter ".*badtext.*"
|
207 |
+
--skip-lang Skip translation if source image is one of the provide languages,
|
208 |
+
use comma to separate multiple languages. Example: JPN,ENG
|
209 |
+
--prep-manual Prepare for manual typesetting by outputting blank,
|
210 |
+
inpainted images, plus copies of the original for
|
211 |
+
reference
|
212 |
+
--font-path FONT_PATH Path to font file
|
213 |
+
--gimp-font GIMP_FONT Font family to use for gimp rendering.
|
214 |
+
--host HOST Used by web module to decide which host to attach to
|
215 |
+
--port PORT Used by web module to decide which port to attach to
|
216 |
+
--nonce NONCE Used by web module as secret for securing internal web
|
217 |
+
server communication
|
218 |
+
--ws-url WS_URL Server URL for WebSocket mode
|
219 |
+
--save-quality SAVE_QUALITY Quality of saved JPEG image, range from 0 to 100 with
|
220 |
+
100 being best
|
221 |
+
--ignore-bubble IGNORE_BUBBLE The threshold for ignoring text in non bubble areas,
|
222 |
+
with valid values ranging from 1 to 50, does not ignore
|
223 |
+
others. Recommendation 5 to 10. If it is too low,
|
224 |
+
normal bubble areas may be ignored, and if it is too
|
225 |
+
large, non bubble areas may be considered normal
|
226 |
+
bubbles
|
227 |
+
```
|
228 |
+
|
229 |
+
<!-- Auto generated end -->
|
230 |
+
|
231 |
+
### 使用命令行执行
|
232 |
+
|
233 |
+
```bash
|
234 |
+
# 如果机器有支持 CUDA 的 NVIDIA GPU,可以添加 `--use-gpu` 参数
|
235 |
+
# 使用 `--use-gpu-limited` 将需要使用大量显存的翻译交由CPU执行,这样可以减少显存占用
|
236 |
+
# 使用 `--translator=<翻译器名称>` 来指定翻译器
|
237 |
+
# 使用 `--target-lang=<语言代码>` 来指定目标语言
|
238 |
+
# 将 <图片文件路径> 替换为图片的路径
|
239 |
+
# 如果你要翻译的图片比较小或者模糊,可以使用upscaler提升图像大小与质量,从而提升检测翻译效果
|
240 |
+
$ python -m manga_translator --verbose --use-gpu --translator=google --target-lang=CHS -i <path_to_image_file>
|
241 |
+
# 结果会存放到 result 文件夹里
|
242 |
+
```
|
243 |
+
|
244 |
+
#### 使用命令行批量翻译
|
245 |
+
|
246 |
+
```bash
|
247 |
+
# 其它参数如上
|
248 |
+
# 使用 `--mode batch` 开启批量翻译模式
|
249 |
+
# 将 <图片文件夹路径> 替换为图片文件夹的路径
|
250 |
+
$ python -m manga_translator --verbose --mode batch --use-gpu --translator=google --target-lang=CHS -i <图片文件夹路径>
|
251 |
+
# 结果会存放到 `<图片文件夹路径>-translated` 文件夹里
|
252 |
+
```
|
253 |
+
|
254 |
+
### 使用浏览器 (Web 服务器)
|
255 |
+
|
256 |
+
```bash
|
257 |
+
# 其它参数如上
|
258 |
+
# 使用 `--mode web` 开启 Web 服务器模式
|
259 |
+
$ python -m manga_translator --verbose --mode web --use-gpu
|
260 |
+
# 程序服务会开启在 http://127.0.0.1:5003
|
261 |
+
```
|
262 |
+
|
263 |
+
程序提供两个请求模式:同步模式和异步模式。\
|
264 |
+
同步模式下你的 HTTP POST 请求会一直等待直到翻译完成。\
|
265 |
+
异步模式下你的 HTTP POST 会立刻返回一个 `task_id`,你可以使用这个 `task_id` 去定期轮询得到翻译的状态。
|
266 |
+
|
267 |
+
#### 同步模式
|
268 |
+
|
269 |
+
1. POST 提交一个带图片,名字是 file 的 form 到 <http://127.0.0.1:5003/run>
|
270 |
+
2. 等待返回
|
271 |
+
3. 从得到的 `task_id` 去 result 文件夹里取结果,例如通过 Nginx 暴露 result 下的内容
|
272 |
+
|
273 |
+
#### 异步模式
|
274 |
+
|
275 |
+
1. POST 提交一个带图片,名字是 file 的 form 到<http://127.0.0.1:5003/submit>
|
276 |
+
2. 你会得到一个 `task_id`
|
277 |
+
3. 通过这个 `task_id` 你可以定期发送 POST 轮询请求 JSON `{"taskid": <task_id>}` 到 <http://127.0.0.1:5003/task-state>
|
278 |
+
4. 当返回的状态是 `finished`、`error` 或 `error-lang` 时代表翻译完成
|
279 |
+
5. 去 result 文件夹里取结果,例如通过 Nginx 暴露 result 下的内容
|
280 |
+
|
281 |
+
#### 人工翻译
|
282 |
+
|
283 |
+
人工翻译允许代替机翻手动填入翻译后文本
|
284 |
+
|
285 |
+
POST 提交一个带图片,名字是 file 的 form 到 <http://127.0.0.1:5003/manual-translate>,并等待返回
|
286 |
+
|
287 |
+
你会得到一个 JSON 数组,例如:
|
288 |
+
|
289 |
+
```json
|
290 |
+
{
|
291 |
+
"task_id": "12c779c9431f954971cae720eb104499",
|
292 |
+
"status": "pending",
|
293 |
+
"trans_result": [
|
294 |
+
{
|
295 |
+
"s": "☆上司来ちゃった……",
|
296 |
+
"t": ""
|
297 |
+
}
|
298 |
+
]
|
299 |
+
}
|
300 |
+
```
|
301 |
+
|
302 |
+
将翻译后内容填入 t 字符串:
|
303 |
+
|
304 |
+
```json
|
305 |
+
{
|
306 |
+
"task_id": "12c779c9431f954971cae720eb104499",
|
307 |
+
"status": "pending",
|
308 |
+
"trans_result": [
|
309 |
+
{
|
310 |
+
"s": "☆上司来ちゃった……",
|
311 |
+
"t": "☆上司来了..."
|
312 |
+
}
|
313 |
+
]
|
314 |
+
}
|
315 |
+
```
|
316 |
+
|
317 |
+
将该 JSON 发送到 <http://127.0.0.1:5003/post-manual-result>,并等待返回\
|
318 |
+
之后就可以从得到的 `task_id` 去 result 文件夹里取结果,例如通过 Nginx 暴露 result 下的内容
|
319 |
+
|
320 |
+
## 下一步
|
321 |
+
|
322 |
+
列一下以后完善这个项目需要做的事,欢迎贡献!
|
323 |
+
|
324 |
+
1. 使用基于扩散模型的图像修补算法,不过这样图像修补会慢很多
|
325 |
+
2. ~~【重要,请求帮助】目前的文字渲染引擎只能勉强看,和 Adobe 的渲染引擎差距明显,我们需要您的帮助完善文本渲染!~~
|
326 |
+
3. ~~我尝试了在 OCR 模型里提取文字颜色,均以失败告终,现在只能用 DPGMM 凑活提取文字颜色,但是效果欠佳,我会尽量完善文字颜色提取,如果您有好的建议请尽管提 issue~~
|
327 |
+
4. ~~文本检测目前不能很好处理英语和韩语,等图片修补模型训练好了我就会训练新版的文字检测模型。~~ ~~韩语支持在做了~~
|
328 |
+
5. 文本渲染区域是根据检测到的文本,而不是汽包决定的,这样可以处理没有汽包的图片但是不能很好进行英语嵌字,目前没有想到好的解决方案。
|
329 |
+
6. [Ryota et al.](https://arxiv.org/abs/2012.14271) 提出了获取配对漫画作为训练数据,训练可以结合图片内容进行翻译的模型,未来可以考虑把大量图片 VQVAE 化,输入 nmt 的 encoder 辅助翻译,而不是分框提取 tag 辅助翻译,这样可以处理范围更广的图片。这需要我们也获取大量配对翻译漫画/图片数据,以及训练 VQVAE 模型。
|
330 |
+
7. 求闻转译志针对视频设计,未来这个项目要能优化到可以处理视频,提取文本颜色用于生成 ass 字幕,进一步辅助东方视频字幕组工作。甚至可以涂改视频内容,去掉视频内字幕。
|
331 |
+
8. ~~结合传统算法的 mask 生成优化,目前在测试 CRF 相关算法。~~
|
332 |
+
9. ~~尚不支持倾斜文本区域合并~~
|
333 |
+
|
334 |
+
## 效果图
|
335 |
+
|
336 |
+
以下样例可能并未经常更新,可能不能代表当前主分支版本的效果。
|
337 |
+
|
338 |
+
<table>
|
339 |
+
<thead>
|
340 |
+
<tr>
|
341 |
+
<th align="center" width="50%">原始图片</th>
|
342 |
+
<th align="center" width="50%">翻译后图片</th>
|
343 |
+
</tr>
|
344 |
+
</thead>
|
345 |
+
<tbody>
|
346 |
+
<tr>
|
347 |
+
<td align="center" width="50%">
|
348 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265329-6a560438-e887-4f7f-b6a1-a61b8648f781.png">
|
349 |
+
<img alt="佐藤さんは知っていた - 猫麦" src="https://user-images.githubusercontent.com/31543482/232265329-6a560438-e887-4f7f-b6a1-a61b8648f781.png" />
|
350 |
+
</a>
|
351 |
+
<br />
|
352 |
+
<a href="https://twitter.com/09ra_19ra/status/1647079591109103617/photo/1">(Source @09ra_19ra)</a>
|
353 |
+
</td>
|
354 |
+
<td align="center" width="50%">
|
355 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265339-514c843a-0541-4a24-b3bc-1efa6915f757.png">
|
356 |
+
<img alt="Output" src="https://user-images.githubusercontent.com/31543482/232265339-514c843a-0541-4a24-b3bc-1efa6915f757.png" />
|
357 |
+
</a>
|
358 |
+
<br />
|
359 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265376-01a4557d-8120-4b6b-b062-f271df177770.png">(Mask)</a>
|
360 |
+
</td>
|
361 |
+
</tr>
|
362 |
+
<tr>
|
363 |
+
<td align="center" width="50%">
|
364 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265479-a15c43b5-0f00-489c-9b04-5dfbcd48c432.png">
|
365 |
+
<img alt="Gris finds out she's of royal blood - VERTI" src="https://user-images.githubusercontent.com/31543482/232265479-a15c43b5-0f00-489c-9b04-5dfbcd48c432.png" />
|
366 |
+
</a>
|
367 |
+
<br />
|
368 |
+
<a href="https://twitter.com/VERTIGRIS_ART/status/1644365184142647300/photo/1">(Source @VERTIGRIS_ART)</a>
|
369 |
+
</td>
|
370 |
+
<td align="center" width="50%">
|
371 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265480-f8ba7a28-846f-46e7-8041-3dcb1afe3f67.png">
|
372 |
+
<img alt="Output" src="https://user-images.githubusercontent.com/31543482/232265480-f8ba7a28-846f-46e7-8041-3dcb1afe3f67.png" />
|
373 |
+
</a>
|
374 |
+
<br />
|
375 |
+
<code>--detector ctd</code>
|
376 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265483-99ad20af-dca8-4b78-90f9-a6599eb0e70b.png">(Mask)</a>
|
377 |
+
</td>
|
378 |
+
</tr>
|
379 |
+
<tr>
|
380 |
+
<td align="center" width="50%">
|
381 |
+
<a href="https://user-images.githubusercontent.com/31543482/232264684-5a7bcf8e-707b-4925-86b0-4212382f1680.png">
|
382 |
+
<img alt="陰キャお嬢様の新学期🏫📔🌸 (#3) - ひづき夜宵🎀💜" src="https://user-images.githubusercontent.com/31543482/232264684-5a7bcf8e-707b-4925-86b0-4212382f1680.png" />
|
383 |
+
</a>
|
384 |
+
<br />
|
385 |
+
<a href="https://twitter.com/hiduki_yayoi/status/1645186427712573440/photo/2">(Source @hiduki_yayoi)</a>
|
386 |
+
</td>
|
387 |
+
<td align="center" width="50%">
|
388 |
+
<a href="https://user-images.githubusercontent.com/31543482/232264644-39db36c8-a8d9-4009-823d-bf85ca0609bf.png">
|
389 |
+
<img alt="Output" src="https://user-images.githubusercontent.com/31543482/232264644-39db36c8-a8d9-4009-823d-bf85ca0609bf.png" />
|
390 |
+
</a>
|
391 |
+
<br />
|
392 |
+
<code>--translator none</code>
|
393 |
+
<a href="https://user-images.githubusercontent.com/31543482/232264671-bc8dd9d0-8675-4c6d-8f86-0d5b7a342233.png">(Mask)</a>
|
394 |
+
</td>
|
395 |
+
</tr>
|
396 |
+
<tr>
|
397 |
+
<td align="center" width="50%">
|
398 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265794-5ea8a0cb-42fe-4438-80b7-3bf7eaf0ff2c.png">
|
399 |
+
<img alt="幼なじみの高校デビューの癖がすごい (#1) - 神吉李花☪️🐧" src="https://user-images.githubusercontent.com/31543482/232265794-5ea8a0cb-42fe-4438-80b7-3bf7eaf0ff2c.png" />
|
400 |
+
</a>
|
401 |
+
<br />
|
402 |
+
<a href="https://twitter.com/rikak/status/1642727617886556160/photo/1">(Source @rikak)</a>
|
403 |
+
</td>
|
404 |
+
<td align="center" width="50%">
|
405 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265795-4bc47589-fd97-4073-8cf4-82ae216a88bc.png">
|
406 |
+
<img alt="Output" src="https://user-images.githubusercontent.com/31543482/232265795-4bc47589-fd97-4073-8cf4-82ae216a88bc.png" />
|
407 |
+
</a>
|
408 |
+
<br />
|
409 |
+
<a href="https://user-images.githubusercontent.com/31543482/232265800-6bdc7973-41fe-4d7e-a554-98ea7ca7a137.png">(Mask)</a>
|
410 |
+
</td>
|
411 |
+
</tr>
|
412 |
+
</tbody>
|
413 |
+
</table>
|
devscripts/make_readme.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Adapted from https://github.com/yt-dlp/yt-dlp/tree/master/devscripts
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
9 |
+
|
10 |
+
import functools
|
11 |
+
import re
|
12 |
+
|
13 |
+
from devscripts.utils import read_file, write_file
|
14 |
+
from manga_translator.args import HelpFormatter, parser
|
15 |
+
|
16 |
+
READMES = (
|
17 |
+
[
|
18 |
+
'README.md',
|
19 |
+
'## Options',
|
20 |
+
'<!-- Auto generated end -->',
|
21 |
+
],
|
22 |
+
[
|
23 |
+
'README_CN.md',
|
24 |
+
'## 选项',
|
25 |
+
'<!-- Auto generated end -->',
|
26 |
+
],
|
27 |
+
)
|
28 |
+
|
29 |
+
ALLOWED_OVERSHOOT = 2
|
30 |
+
DISABLE_PATCH = object()
|
31 |
+
|
32 |
+
HelpFormatter.INDENT_INCREMENT = 0
|
33 |
+
HelpFormatter.MAX_HELP_POSITION = 45
|
34 |
+
HelpFormatter.WIDTH = 100
|
35 |
+
|
36 |
+
def take_section(text, start=None, end=None, *, shift=0):
|
37 |
+
return text[
|
38 |
+
text.index(start) + shift if start else None:
|
39 |
+
text.index(end) + shift if end else None
|
40 |
+
]
|
41 |
+
|
42 |
+
|
43 |
+
def apply_patch(text, patch):
|
44 |
+
return text if patch[0] is DISABLE_PATCH else re.sub(*patch, text)
|
45 |
+
|
46 |
+
|
47 |
+
options = take_section(parser.format_help(), '\noptions:', shift=len('\noptions:'))
|
48 |
+
|
49 |
+
max_width = max(map(len, options.split('\n')))
|
50 |
+
switch_col_width = len(re.search(r'(?m)^\s{5,}', options).group())
|
51 |
+
delim = f'\n{" " * switch_col_width}'
|
52 |
+
|
53 |
+
PATCHES = (
|
54 |
+
# ( # Headings
|
55 |
+
# r'(?m)^ (\w.+\n)( (?=\w))?',
|
56 |
+
# r'## \1'
|
57 |
+
# ),
|
58 |
+
( # Fixup `--date` formatting
|
59 |
+
rf'(?m)( --date DATE.+({delim}[^\[]+)*)\[.+({delim}.+)*$',
|
60 |
+
(rf'\1[now|today|yesterday][-N[day|week|month|year]].{delim}'
|
61 |
+
f'E.g. "--date today-2weeks" downloads only{delim}'
|
62 |
+
'videos uploaded on the same day two weeks ago'),
|
63 |
+
),
|
64 |
+
( # Do not split URLs
|
65 |
+
rf'({delim[:-1]})? (?P<label>\[\S+\] )?(?P<url>https?({delim})?:({delim})?/({delim})?/(({delim})?\S+)+)\s',
|
66 |
+
lambda mobj: ''.join((delim, mobj.group('label') or '', re.sub(r'\s+', '', mobj.group('url')), '\n'))
|
67 |
+
),
|
68 |
+
( # Do not split "words"
|
69 |
+
rf'(?m)({delim}\S+)+$',
|
70 |
+
lambda mobj: ''.join((delim, mobj.group(0).replace(delim, '')))
|
71 |
+
),
|
72 |
+
# ( # Allow overshooting last line
|
73 |
+
# rf'(?m)^(?P<prev>.+)${delim}(?P<current>.+)$(?!{delim})',
|
74 |
+
# lambda mobj: (mobj.group().replace(delim, ' ')
|
75 |
+
# if len(mobj.group()) - len(delim) + 1 <= max_width + ALLOWED_OVERSHOOT
|
76 |
+
# else mobj.group())
|
77 |
+
# ),
|
78 |
+
# ( # Avoid newline when a space is available b/w switch and description
|
79 |
+
# DISABLE_PATCH, # This creates issues with prepare_manpage
|
80 |
+
# r'(?m)^(\s{4}-.{%d})(%s)' % (switch_col_width - 6, delim),
|
81 |
+
# r'\1 '
|
82 |
+
# ),
|
83 |
+
# ( # Replace brackets with a Markdown link
|
84 |
+
# r'SponsorBlock API \((http.+)\)',
|
85 |
+
# r'[SponsorBlock API](\1)'
|
86 |
+
# ),
|
87 |
+
)
|
88 |
+
|
89 |
+
for file, options_start, options_end in READMES:
|
90 |
+
readme = read_file(file)
|
91 |
+
|
92 |
+
write_file(file, ''.join((
|
93 |
+
take_section(readme, end=options_start, shift=len(options_start)),
|
94 |
+
'\n\n```text',
|
95 |
+
functools.reduce(apply_patch, PATCHES, options),
|
96 |
+
'```\n\n',
|
97 |
+
take_section(readme, options_end),
|
98 |
+
)))
|
devscripts/utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from https://github.com/yt-dlp/yt-dlp/tree/master/devscripts
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import functools
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
|
8 |
+
def read_file(fname):
|
9 |
+
with open(fname, encoding='utf-8') as f:
|
10 |
+
return f.read()
|
11 |
+
|
12 |
+
|
13 |
+
def write_file(fname, content, mode='w'):
|
14 |
+
with open(fname, mode, encoding='utf-8') as f:
|
15 |
+
return f.write(content)
|
16 |
+
|
17 |
+
|
18 |
+
def get_filename_args(has_infile=False, default_outfile=None):
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
if has_infile:
|
21 |
+
parser.add_argument('infile', help='Input file')
|
22 |
+
kwargs = {'nargs': '?', 'default': default_outfile} if default_outfile else {}
|
23 |
+
parser.add_argument('outfile', **kwargs, help='Output file')
|
24 |
+
|
25 |
+
opts = parser.parse_args()
|
26 |
+
if has_infile:
|
27 |
+
return opts.infile, opts.outfile
|
28 |
+
return opts.outfile
|
29 |
+
|
30 |
+
|
31 |
+
def compose_functions(*functions):
|
32 |
+
return lambda x: functools.reduce(lambda y, f: f(y), functions, x)
|
33 |
+
|
34 |
+
|
35 |
+
def run_process(*args, **kwargs):
|
36 |
+
kwargs.setdefault('text', True)
|
37 |
+
kwargs.setdefault('check', True)
|
38 |
+
kwargs.setdefault('capture_output', True)
|
39 |
+
if kwargs['text']:
|
40 |
+
kwargs.setdefault('encoding', 'utf-8')
|
41 |
+
kwargs.setdefault('errors', 'replace')
|
42 |
+
return subprocess.run(args, **kwargs)
|
docker_prepare.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from manga_translator.utils import ModelWrapper
|
4 |
+
from manga_translator.detection import DETECTORS
|
5 |
+
from manga_translator.ocr import OCRS
|
6 |
+
from manga_translator.inpainting import INPAINTERS
|
7 |
+
|
8 |
+
async def download(dict):
|
9 |
+
for key, value in dict.items():
|
10 |
+
if issubclass(value, ModelWrapper):
|
11 |
+
print(' -- Downloading', key)
|
12 |
+
try:
|
13 |
+
inst = value()
|
14 |
+
await inst.download()
|
15 |
+
except Exception as e:
|
16 |
+
print('Failed to download', key, value)
|
17 |
+
print(e)
|
18 |
+
|
19 |
+
async def main():
|
20 |
+
await download(DETECTORS)
|
21 |
+
await download(OCRS)
|
22 |
+
await download({
|
23 |
+
k: v for k, v in INPAINTERS.items()
|
24 |
+
if k not in ['sd']
|
25 |
+
})
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
asyncio.run(main())
|
fonts/Arial-Unicode-Regular.ttf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14f28249244f00c13348cb211c8a83c3e6e44dcf1874ebcb083efbfc0b9d5387
|
3 |
+
size 23892708
|
fonts/anime_ace.ttf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e3e311d48c305e79757cc0051aca591b735eb57002f78035969cbfc5ca4a5125
|
3 |
+
size 108036
|
fonts/anime_ace_3.ttf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9b7c40b5389c511a950234fe0add8a11da9563b468e0e8a88219ccbf2257f83
|
3 |
+
size 58236
|
fonts/comic shanns 2.ttf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64590b794cab741937889d379b205ae126ca4f3ed5cbe4f19839d2bfac246da6
|
3 |
+
size 73988
|
fonts/msgothic.ttc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef9044f54896c6d045a425e62e38b3232d49facc5549a12837d077ff0bf74298
|
3 |
+
size 9176636
|
fonts/msyh.ttc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e4b3b9d058750fb80899c24f68e35beda606ca92694eff0e9f7f91eec7a846aa
|
3 |
+
size 19647736
|
manga_translator/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import colorama
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
colorama.init(autoreset=True)
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
from .manga_translator import *
|
manga_translator/__main__.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import asyncio
|
3 |
+
import logging
|
4 |
+
from argparse import Namespace
|
5 |
+
|
6 |
+
from .manga_translator import (
|
7 |
+
MangaTranslator,
|
8 |
+
MangaTranslatorWeb,
|
9 |
+
MangaTranslatorWS,
|
10 |
+
MangaTranslatorAPI,
|
11 |
+
set_main_logger,
|
12 |
+
)
|
13 |
+
from .args import parser
|
14 |
+
from .utils import (
|
15 |
+
BASE_PATH,
|
16 |
+
init_logging,
|
17 |
+
get_logger,
|
18 |
+
set_log_level,
|
19 |
+
natural_sort,
|
20 |
+
)
|
21 |
+
|
22 |
+
# TODO: Dynamic imports to reduce ram usage in web(-server) mode. Will require dealing with args.py imports.
|
23 |
+
|
24 |
+
async def dispatch(args: Namespace):
|
25 |
+
args_dict = vars(args)
|
26 |
+
|
27 |
+
logger.info(f'Running in {args.mode} mode')
|
28 |
+
|
29 |
+
if args.mode in ('demo', 'batch'):
|
30 |
+
if not args.input:
|
31 |
+
raise Exception('No input image was supplied. Use -i <image_path>')
|
32 |
+
translator = MangaTranslator(args_dict)
|
33 |
+
if args.mode == 'demo':
|
34 |
+
if len(args.input) != 1 or not os.path.isfile(args.input[0]):
|
35 |
+
raise FileNotFoundError(f'Invalid single image file path for demo mode: "{" ".join(args.input)}". Use `-m batch`.')
|
36 |
+
dest = os.path.join(BASE_PATH, 'result/final.png')
|
37 |
+
args.overwrite = True # Do overwrite result/final.png file
|
38 |
+
await translator.translate_path(args.input[0], dest, args_dict)
|
39 |
+
else: # batch
|
40 |
+
dest = args.dest
|
41 |
+
for path in natural_sort(args.input):
|
42 |
+
await translator.translate_path(path, dest, args_dict)
|
43 |
+
|
44 |
+
elif args.mode == 'web':
|
45 |
+
from .server.web_main import dispatch
|
46 |
+
await dispatch(args.host, args.port, translation_params=args_dict)
|
47 |
+
|
48 |
+
elif args.mode == 'web_client':
|
49 |
+
translator = MangaTranslatorWeb(args_dict)
|
50 |
+
await translator.listen(args_dict)
|
51 |
+
|
52 |
+
elif args.mode == 'ws':
|
53 |
+
translator = MangaTranslatorWS(args_dict)
|
54 |
+
await translator.listen(args_dict)
|
55 |
+
|
56 |
+
elif args.mode == 'api':
|
57 |
+
translator = MangaTranslatorAPI(args_dict)
|
58 |
+
await translator.listen(args_dict)
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
args = None
|
62 |
+
init_logging()
|
63 |
+
try:
|
64 |
+
args = parser.parse_args()
|
65 |
+
set_log_level(level=logging.DEBUG if args.verbose else logging.INFO)
|
66 |
+
logger = get_logger(args.mode)
|
67 |
+
set_main_logger(logger)
|
68 |
+
if args.mode != 'web':
|
69 |
+
logger.debug(args)
|
70 |
+
|
71 |
+
loop = asyncio.new_event_loop()
|
72 |
+
asyncio.set_event_loop(loop)
|
73 |
+
loop.run_until_complete(dispatch(args))
|
74 |
+
except KeyboardInterrupt:
|
75 |
+
if not args or args.mode != 'web':
|
76 |
+
print()
|
77 |
+
except Exception as e:
|
78 |
+
logger.error(f'{e.__class__.__name__}: {e}',
|
79 |
+
exc_info=e if args and args.verbose else None)
|
manga_translator/args.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from urllib.parse import unquote
|
4 |
+
|
5 |
+
from .detection import DETECTORS
|
6 |
+
from .ocr import OCRS
|
7 |
+
from .inpainting import INPAINTERS
|
8 |
+
from .translators import VALID_LANGUAGES, TRANSLATORS, TranslatorChain
|
9 |
+
from .upscaling import UPSCALERS
|
10 |
+
from .colorization import COLORIZERS
|
11 |
+
from .save import OUTPUT_FORMATS
|
12 |
+
|
13 |
+
def url_decode(s):
|
14 |
+
s = unquote(s)
|
15 |
+
if s.startswith('file:///'):
|
16 |
+
s = s[len('file://'):]
|
17 |
+
return s
|
18 |
+
|
19 |
+
# Additional argparse types
|
20 |
+
def path(string):
|
21 |
+
if not string:
|
22 |
+
return ''
|
23 |
+
s = url_decode(os.path.expanduser(string))
|
24 |
+
if not os.path.exists(s):
|
25 |
+
raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"')
|
26 |
+
return s
|
27 |
+
|
28 |
+
def file_path(string):
|
29 |
+
if not string:
|
30 |
+
return ''
|
31 |
+
s = url_decode(os.path.expanduser(string))
|
32 |
+
if not os.path.exists(s):
|
33 |
+
raise argparse.ArgumentTypeError(f'No such file: "{string}"')
|
34 |
+
return s
|
35 |
+
|
36 |
+
def dir_path(string):
|
37 |
+
if not string:
|
38 |
+
return ''
|
39 |
+
s = url_decode(os.path.expanduser(string))
|
40 |
+
if not os.path.exists(s):
|
41 |
+
raise argparse.ArgumentTypeError(f'No such directory: "{string}"')
|
42 |
+
return s
|
43 |
+
|
44 |
+
# def choice_chain(choices):
|
45 |
+
# """Argument type for string chains from choices separated by ':'. Example: 'choice1:choice2:choice3'"""
|
46 |
+
# def _func(string):
|
47 |
+
# if choices is not None:
|
48 |
+
# for s in string.split(':') or ['']:
|
49 |
+
# if s not in choices:
|
50 |
+
# raise argparse.ArgumentTypeError(f'Invalid choice: %s (choose from %s)' % (s, ', '.join(map(repr, choices))))
|
51 |
+
# return string
|
52 |
+
# return _func
|
53 |
+
|
54 |
+
def translator_chain(string):
|
55 |
+
try:
|
56 |
+
return TranslatorChain(string)
|
57 |
+
except ValueError as e:
|
58 |
+
raise argparse.ArgumentTypeError(e)
|
59 |
+
except Exception:
|
60 |
+
raise argparse.ArgumentTypeError(f'Invalid translator_chain value: "{string}". Example usage: --translator "google:sugoi" -l "JPN:ENG"')
|
61 |
+
|
62 |
+
|
63 |
+
class HelpFormatter(argparse.HelpFormatter):
|
64 |
+
INDENT_INCREMENT = 2
|
65 |
+
MAX_HELP_POSITION = 24
|
66 |
+
WIDTH = None
|
67 |
+
|
68 |
+
def __init__(self, prog: str, indent_increment: int = 2, max_help_position: int = 24, width: int = None):
|
69 |
+
super().__init__(prog, self.INDENT_INCREMENT, self.MAX_HELP_POSITION, self.WIDTH)
|
70 |
+
|
71 |
+
def _format_action_invocation(self, action: argparse.Action) -> str:
|
72 |
+
if action.option_strings:
|
73 |
+
|
74 |
+
# if the Optional doesn't take a value, format is:
|
75 |
+
# -s, --long
|
76 |
+
if action.nargs == 0:
|
77 |
+
return ', '.join(action.option_strings)
|
78 |
+
|
79 |
+
# if the Optional takes a value, format is:
|
80 |
+
# -s, --long ARGS
|
81 |
+
else:
|
82 |
+
default = self._get_default_metavar_for_optional(action)
|
83 |
+
args_string = self._format_args(action, default)
|
84 |
+
return ', '.join(action.option_strings) + ' ' + args_string
|
85 |
+
else:
|
86 |
+
return super()._format_action_invocation(action)
|
87 |
+
|
88 |
+
|
89 |
+
parser = argparse.ArgumentParser(prog='manga_translator', description='Seamlessly translate mangas into a chosen language', formatter_class=HelpFormatter)
|
90 |
+
parser.add_argument('-m', '--mode', default='batch', type=str, choices=['demo', 'batch', 'web', 'web_client', 'ws', 'api'], help='Run demo in single image demo mode (demo), batch translation mode (batch), web service mode (web)')
|
91 |
+
parser.add_argument('-i', '--input', default=None, type=path, nargs='+', help='Path to an image file if using demo mode, or path to an image folder if using batch mode')
|
92 |
+
parser.add_argument('-o', '--dest', default='', type=str, help='Path to the destination folder for translated images in batch mode')
|
93 |
+
parser.add_argument('-l', '--target-lang', default='CHS', type=str, choices=VALID_LANGUAGES, help='Destination language')
|
94 |
+
parser.add_argument('-v', '--verbose', action='store_true', help='Print debug info and save intermediate images in result folder')
|
95 |
+
parser.add_argument('-f', '--format', default=None, choices=OUTPUT_FORMATS, help='Output format of the translation.')
|
96 |
+
parser.add_argument('--attempts', default=0, type=int, help='Retry attempts on encountered error. -1 means infinite times.')
|
97 |
+
parser.add_argument('--ignore-errors', action='store_true', help='Skip image on encountered error.')
|
98 |
+
parser.add_argument('--overwrite', action='store_true', help='Overwrite already translated images in batch mode.')
|
99 |
+
parser.add_argument('--skip-no-text', action='store_true', help='Skip image without text (Will not be saved).')
|
100 |
+
parser.add_argument('--model-dir', default=None, type=dir_path, help='Model directory (by default ./models in project root)')
|
101 |
+
parser.add_argument('--skip-lang', default=None, type=str, help='Skip translation if source image is one of the provide languages, use comma to separate multiple languages. Example: JPN,ENG')
|
102 |
+
|
103 |
+
g = parser.add_mutually_exclusive_group()
|
104 |
+
g.add_argument('--use-gpu', action='store_true', help='Turn on/off gpu (auto switch between mps and cuda)')
|
105 |
+
g.add_argument('--use-gpu-limited', action='store_true', help='Turn on/off gpu (excluding offline translator)')
|
106 |
+
|
107 |
+
parser.add_argument('--detector', default='default', type=str, choices=DETECTORS, help='Text detector used for creating a text mask from an image, DO NOT use craft for manga, it\'s not designed for it')
|
108 |
+
parser.add_argument('--ocr', default='48px', type=str, choices=OCRS, help='Optical character recognition (OCR) model to use')
|
109 |
+
parser.add_argument('--use-mocr-merge', action='store_true', help='Use bbox merge when Manga OCR inference.')
|
110 |
+
parser.add_argument('--inpainter', default='lama_large', type=str, choices=INPAINTERS, help='Inpainting model to use')
|
111 |
+
parser.add_argument('--upscaler', default='esrgan', type=str, choices=UPSCALERS, help='Upscaler to use. --upscale-ratio has to be set for it to take effect')
|
112 |
+
parser.add_argument('--upscale-ratio', default=None, type=float, help='Image upscale ratio applied before detection. Can improve text detection.')
|
113 |
+
parser.add_argument('--colorizer', default=None, type=str, choices=COLORIZERS, help='Colorization model to use.')
|
114 |
+
|
115 |
+
g = parser.add_mutually_exclusive_group()
|
116 |
+
g.add_argument('--translator', default='google', type=str, choices=TRANSLATORS, help='Language translator to use')
|
117 |
+
g.add_argument('--translator-chain', default=None, type=translator_chain, help='Output of one translator goes in another. Example: --translator-chain "google:JPN;sugoi:ENG".')
|
118 |
+
g.add_argument('--selective-translation', default=None, type=translator_chain, help='Select a translator based on detected language in image. Note the first translation service acts as default if the language isn\'t defined. Example: --translator-chain "google:JPN;sugoi:ENG".')
|
119 |
+
|
120 |
+
parser.add_argument('--revert-upscaling', action='store_true', help='Downscales the previously upscaled image after translation back to original size (Use with --upscale-ratio).')
|
121 |
+
parser.add_argument('--detection-size', default=1536, type=int, help='Size of image used for detection')
|
122 |
+
parser.add_argument('--det-rotate', action='store_true', help='Rotate the image for detection. Might improve detection.')
|
123 |
+
parser.add_argument('--det-auto-rotate', action='store_true', help='Rotate the image for detection to prefer vertical textlines. Might improve detection.')
|
124 |
+
parser.add_argument('--det-invert', action='store_true', help='Invert the image colors for detection. Might improve detection.')
|
125 |
+
parser.add_argument('--det-gamma-correct', action='store_true', help='Applies gamma correction for detection. Might improve detection.')
|
126 |
+
parser.add_argument('--unclip-ratio', default=2.3, type=float, help='How much to extend text skeleton to form bounding box')
|
127 |
+
parser.add_argument('--box-threshold', default=0.7, type=float, help='Threshold for bbox generation')
|
128 |
+
parser.add_argument('--text-threshold', default=0.5, type=float, help='Threshold for text detection')
|
129 |
+
parser.add_argument('--min-text-length', default=0, type=int, help='Minimum text length of a text region')
|
130 |
+
parser.add_argument('--no-text-lang-skip', action='store_true', help='Dont skip text that is seemingly already in the target language.')
|
131 |
+
parser.add_argument('--inpainting-size', default=2048, type=int, help='Size of image used for inpainting (too large will result in OOM)')
|
132 |
+
parser.add_argument('--inpainting-precision', default='fp32', type=str, help='Inpainting precision for lama, use bf16 while you can.', choices=['fp32', 'fp16', 'bf16'])
|
133 |
+
parser.add_argument('--colorization-size', default=576, type=int, help='Size of image used for colorization. Set to -1 to use full image size')
|
134 |
+
parser.add_argument('--denoise-sigma', default=30, type=int, help='Used by colorizer and affects color strength, range from 0 to 255 (default 30). -1 turns it off.')
|
135 |
+
parser.add_argument('--mask-dilation-offset', default=0, type=int, help='By how much to extend the text mask to remove left-over text pixels of the original image.')
|
136 |
+
|
137 |
+
parser.add_argument('--disable-font-border', action='store_true', help='Disable font border')
|
138 |
+
parser.add_argument('--font-size', default=None, type=int, help='Use fixed font size for rendering')
|
139 |
+
parser.add_argument('--font-size-offset', default=0, type=int, help='Offset font size by a given amount, positive number increase font size and vice versa')
|
140 |
+
parser.add_argument('--font-size-minimum', default=-1, type=int, help='Minimum output font size. Default is image_sides_sum/200')
|
141 |
+
parser.add_argument('--font-color', default=None, type=str, help='Overwrite the text fg/bg color detected by the OCR model. Use hex string without the "#" such as FFFFFF for a white foreground or FFFFFF:000000 to also have a black background around the text.')
|
142 |
+
parser.add_argument('--line-spacing', default=None, type=float, help='Line spacing is font_size * this value. Default is 0.01 for horizontal text and 0.2 for vertical.')
|
143 |
+
|
144 |
+
g = parser.add_mutually_exclusive_group()
|
145 |
+
g.add_argument('--force-horizontal', action='store_true', help='Force text to be rendered horizontally')
|
146 |
+
g.add_argument('--force-vertical', action='store_true', help='Force text to be rendered vertically')
|
147 |
+
|
148 |
+
g = parser.add_mutually_exclusive_group()
|
149 |
+
g.add_argument('--align-left', action='store_true', help='Align rendered text left')
|
150 |
+
g.add_argument('--align-center', action='store_true', help='Align rendered text centered')
|
151 |
+
g.add_argument('--align-right', action='store_true', help='Align rendered text right')
|
152 |
+
|
153 |
+
g = parser.add_mutually_exclusive_group()
|
154 |
+
g.add_argument('--uppercase', action='store_true', help='Change text to uppercase')
|
155 |
+
g.add_argument('--lowercase', action='store_true', help='Change text to lowercase')
|
156 |
+
|
157 |
+
parser.add_argument('--no-hyphenation', action='store_true', help='If renderer should be splitting up words using a hyphen character (-)')
|
158 |
+
parser.add_argument('--manga2eng', action='store_true', help='Render english text translated from manga with some additional typesetting. Ignores some other argument options')
|
159 |
+
parser.add_argument('--gpt-config', type=file_path, help='Path to GPT config file, more info in README')
|
160 |
+
parser.add_argument('--use-mtpe', action='store_true', help='Turn on/off machine translation post editing (MTPE) on the command line (works only on linux right now)')
|
161 |
+
|
162 |
+
g = parser.add_mutually_exclusive_group()
|
163 |
+
g.add_argument('--save-text', action='store_true', help='Save extracted text and translations into a text file.')
|
164 |
+
g.add_argument('--save-text-file', default='', type=str, help='Like --save-text but with a specified file path.')
|
165 |
+
|
166 |
+
parser.add_argument('--filter-text', default=None, type=str, help='Filter regions by their text with a regex. Example usage: --text-filter ".*badtext.*"')
|
167 |
+
parser.add_argument('--prep-manual', action='store_true', help='Prepare for manual typesetting by outputting blank, inpainted images, plus copies of the original for reference')
|
168 |
+
parser.add_argument('--font-path', default='', type=file_path, help='Path to font file')
|
169 |
+
parser.add_argument('--gimp-font', default='Sans-serif', type=str, help='Font family to use for gimp rendering.')
|
170 |
+
parser.add_argument('--host', default='127.0.0.1', type=str, help='Used by web module to decide which host to attach to')
|
171 |
+
parser.add_argument('--port', default=5003, type=int, help='Used by web module to decide which port to attach to')
|
172 |
+
parser.add_argument('--nonce', default=os.getenv('MT_WEB_NONCE', ''), type=str, help='Used by web module as secret for securing internal web server communication')
|
173 |
+
# parser.add_argument('--log-web', action='store_true', help='Used by web module to decide if web logs should be surfaced')
|
174 |
+
parser.add_argument('--ws-url', default='ws://localhost:5000', type=str, help='Server URL for WebSocket mode')
|
175 |
+
parser.add_argument('--save-quality', default=100, type=int, help='Quality of saved JPEG image, range from 0 to 100 with 100 being best')
|
176 |
+
parser.add_argument('--ignore-bubble', default=0, type=int, help='The threshold for ignoring text in non bubble areas, with valid values ranging from 1 to 50, does not ignore others. Recommendation 5 to 10. If it is too low, normal bubble areas may be ignored, and if it is too large, non bubble areas may be considered normal bubbles')
|
177 |
+
|
178 |
+
parser.add_argument('--kernel-size', default=3, type=int, help='Set the convolution kernel size of the text erasure area to completely clean up text residues')
|
179 |
+
|
180 |
+
|
181 |
+
# Generares dict with a default value for each argument
|
182 |
+
DEFAULT_ARGS = vars(parser.parse_args([]))
|
manga_translator/colorization/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
from .common import CommonColorizer, OfflineColorizer
|
4 |
+
from .manga_colorization_v2 import MangaColorizationV2
|
5 |
+
|
6 |
+
COLORIZERS = {
|
7 |
+
'mc2': MangaColorizationV2,
|
8 |
+
}
|
9 |
+
colorizer_cache = {}
|
10 |
+
|
11 |
+
def get_colorizer(key: str, *args, **kwargs) -> CommonColorizer:
|
12 |
+
if key not in COLORIZERS:
|
13 |
+
raise ValueError(f'Could not find colorizer for: "{key}". Choose from the following: %s' % ','.join(COLORIZERS))
|
14 |
+
if not colorizer_cache.get(key):
|
15 |
+
upscaler = COLORIZERS[key]
|
16 |
+
colorizer_cache[key] = upscaler(*args, **kwargs)
|
17 |
+
return colorizer_cache[key]
|
18 |
+
|
19 |
+
async def prepare(key: str):
|
20 |
+
upscaler = get_colorizer(key)
|
21 |
+
if isinstance(upscaler, OfflineColorizer):
|
22 |
+
await upscaler.download()
|
23 |
+
|
24 |
+
async def dispatch(key: str, device: str = 'cpu', **kwargs) -> Image.Image:
|
25 |
+
colorizer = get_colorizer(key)
|
26 |
+
if isinstance(colorizer, OfflineColorizer):
|
27 |
+
await colorizer.load(device)
|
28 |
+
return await colorizer.colorize(**kwargs)
|
manga_translator/colorization/common.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from abc import abstractmethod
|
3 |
+
|
4 |
+
from ..utils import InfererModule, ModelWrapper
|
5 |
+
|
6 |
+
class CommonColorizer(InfererModule):
|
7 |
+
_VALID_UPSCALE_RATIOS = None
|
8 |
+
|
9 |
+
async def colorize(self, image: Image.Image, colorization_size: int, **kwargs) -> Image.Image:
|
10 |
+
return await self._colorize(image, colorization_size, **kwargs)
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
async def _colorize(self, image: Image.Image, colorization_size: int, **kwargs) -> Image.Image:
|
14 |
+
pass
|
15 |
+
|
16 |
+
class OfflineColorizer(CommonColorizer, ModelWrapper):
|
17 |
+
_MODEL_SUB_DIR = 'colorization'
|
18 |
+
|
19 |
+
async def _colorize(self, *args, **kwargs):
|
20 |
+
return await self.infer(*args, **kwargs)
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
async def _infer(self, image: Image.Image, colorization_size: int, **kwargs) -> Image.Image:
|
24 |
+
pass
|
manga_translator/colorization/manga_colorization_v2.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision.transforms import ToTensor
|
6 |
+
|
7 |
+
from .common import OfflineColorizer
|
8 |
+
from .manga_colorization_v2_utils.networks.models import Colorizer
|
9 |
+
from .manga_colorization_v2_utils.denoising.denoiser import FFDNetDenoiser
|
10 |
+
from .manga_colorization_v2_utils.utils.utils import resize_pad
|
11 |
+
|
12 |
+
|
13 |
+
# https://github.com/qweasdd/manga-colorization-v2
|
14 |
+
class MangaColorizationV2(OfflineColorizer):
|
15 |
+
_MODEL_SUB_DIR = os.path.join(OfflineColorizer._MODEL_SUB_DIR, 'manga-colorization-v2')
|
16 |
+
_MODEL_MAPPING = {
|
17 |
+
# Models were in google drive so had to upload to github
|
18 |
+
'generator': {
|
19 |
+
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/manga-colorization-v2-generator.zip',
|
20 |
+
'file': 'generator.zip',
|
21 |
+
'hash': '087e6a0bc02770e732a52f33878b71a272a6123c9ac649e9b5bfb75e39e5c1d5',
|
22 |
+
},
|
23 |
+
'denoiser': {
|
24 |
+
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/manga-colorization-v2-net_rgb.pth',
|
25 |
+
'file': 'net_rgb.pth',
|
26 |
+
'hash': '0fe98bfd2ac870b15f360661b1c4789eecefc6dc2e4462842a0dd15e149a0433',
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
async def _load(self, device: str):
|
31 |
+
self.device = device
|
32 |
+
self.colorizer = Colorizer().to(device)
|
33 |
+
self.colorizer.generator.load_state_dict(
|
34 |
+
torch.load(self._get_file_path('generator.zip'), map_location=self.device))
|
35 |
+
self.colorizer = self.colorizer.eval()
|
36 |
+
self.denoiser = FFDNetDenoiser(device, _weights_dir=self.model_dir)
|
37 |
+
|
38 |
+
async def _unload(self):
|
39 |
+
del self.colorizer
|
40 |
+
del self.denoiser
|
41 |
+
|
42 |
+
async def _infer(self, image: Image.Image, colorization_size: int, denoise_sigma=25, **kwargs) -> Image.Image:
|
43 |
+
# Size has to be multiple of 32
|
44 |
+
img = np.array(image.convert('RGBA'))
|
45 |
+
max_size = min(*img.shape[:2])
|
46 |
+
max_size -= max_size % 32
|
47 |
+
if colorization_size > 0:
|
48 |
+
size = min(max_size, colorization_size - (colorization_size % 32))
|
49 |
+
else:
|
50 |
+
# size<=576 gives best results
|
51 |
+
size = min(max_size, 576)
|
52 |
+
|
53 |
+
if 0 <= denoise_sigma and denoise_sigma <= 255:
|
54 |
+
img = self.denoiser.get_denoised_image(img, sigma=denoise_sigma)
|
55 |
+
|
56 |
+
img, current_pad = resize_pad(img, size)
|
57 |
+
|
58 |
+
transform = ToTensor()
|
59 |
+
current_image = transform(img).unsqueeze(0).to(self.device)
|
60 |
+
current_hint = torch.zeros(1, 4, current_image.shape[2], current_image.shape[3]).float().to(self.device)
|
61 |
+
|
62 |
+
with torch.no_grad():
|
63 |
+
fake_color, _ = self.colorizer(torch.cat([current_image, current_hint], 1))
|
64 |
+
fake_color = fake_color.detach()
|
65 |
+
|
66 |
+
result = fake_color[0].detach().cpu().permute(1, 2, 0) * 0.5 + 0.5
|
67 |
+
|
68 |
+
if current_pad[0] != 0:
|
69 |
+
result = result[:-current_pad[0]]
|
70 |
+
if current_pad[1] != 0:
|
71 |
+
result = result[:, :-current_pad[1]]
|
72 |
+
|
73 |
+
colored_image = result.numpy() * 255
|
74 |
+
return Image.fromarray(colored_image.astype(np.uint8))
|
manga_translator/colorization/manga_colorization_v2_utils/denoising/denoiser.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Denoise an image with the FFDNet denoising method
|
3 |
+
|
4 |
+
Copyright (C) 2018, Matias Tassano <[email protected]>
|
5 |
+
|
6 |
+
This program is free software: you can use, modify and/or
|
7 |
+
redistribute it under the terms of the GNU General Public
|
8 |
+
License as published by the Free Software Foundation, either
|
9 |
+
version 3 of the License, or (at your option) any later
|
10 |
+
version. You should have received a copy of this license along
|
11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
12 |
+
"""
|
13 |
+
import os
|
14 |
+
import argparse
|
15 |
+
import time
|
16 |
+
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import cv2
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
from torch.autograd import Variable
|
23 |
+
from .models import FFDNet
|
24 |
+
from .utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
|
25 |
+
|
26 |
+
class FFDNetDenoiser:
|
27 |
+
def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
|
28 |
+
self.sigma = _sigma / 255
|
29 |
+
self.weights_dir = _weights_dir
|
30 |
+
self.channels = _in_ch
|
31 |
+
self.device = _device
|
32 |
+
self.model = FFDNet(num_input_channels = _in_ch)
|
33 |
+
self.load_weights()
|
34 |
+
self.model.eval()
|
35 |
+
|
36 |
+
|
37 |
+
def load_weights(self):
|
38 |
+
weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
|
39 |
+
weights_path = os.path.join(self.weights_dir, weights_name)
|
40 |
+
if self.device == 'cuda':
|
41 |
+
# data paralles only for cuda , no need for mps devices
|
42 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
43 |
+
self.model = nn.DataParallel(self.model,device_ids = [0]).to(self.device)
|
44 |
+
else:
|
45 |
+
# MPS devices don't support DataParallel
|
46 |
+
state_dict = torch.load(weights_path, map_location=self.device)
|
47 |
+
# CPU mode: remove the DataParallel wrapper
|
48 |
+
state_dict = remove_dataparallel_wrapper(state_dict)
|
49 |
+
self.model.load_state_dict(state_dict)
|
50 |
+
|
51 |
+
def get_denoised_image(self, imorig, sigma = None):
|
52 |
+
|
53 |
+
if sigma is not None:
|
54 |
+
cur_sigma = sigma / 255
|
55 |
+
else:
|
56 |
+
cur_sigma = self.sigma
|
57 |
+
|
58 |
+
if len(imorig.shape) < 3 or imorig.shape[2] == 1:
|
59 |
+
imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
|
60 |
+
|
61 |
+
imorig = imorig[..., :3]
|
62 |
+
|
63 |
+
if (max(imorig.shape[0], imorig.shape[1]) > 1200):
|
64 |
+
ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
|
65 |
+
imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
|
66 |
+
|
67 |
+
imorig = imorig.transpose(2, 0, 1)
|
68 |
+
|
69 |
+
if (imorig.max() > 1.2):
|
70 |
+
imorig = normalize(imorig)
|
71 |
+
imorig = np.expand_dims(imorig, 0)
|
72 |
+
|
73 |
+
# Handle odd sizes
|
74 |
+
expanded_h = False
|
75 |
+
expanded_w = False
|
76 |
+
sh_im = imorig.shape
|
77 |
+
if sh_im[2]%2 == 1:
|
78 |
+
expanded_h = True
|
79 |
+
imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
|
80 |
+
|
81 |
+
if sh_im[3]%2 == 1:
|
82 |
+
expanded_w = True
|
83 |
+
imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
|
84 |
+
|
85 |
+
|
86 |
+
imorig = torch.Tensor(imorig)
|
87 |
+
|
88 |
+
|
89 |
+
# Sets data type according to CPU or GPU modes
|
90 |
+
if self.device == 'cuda':
|
91 |
+
dtype = torch.cuda.FloatTensor
|
92 |
+
else:
|
93 |
+
# for mps devices is still floatTensor
|
94 |
+
dtype = torch.FloatTensor
|
95 |
+
|
96 |
+
imnoisy = imorig#.clone()
|
97 |
+
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
|
101 |
+
nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
|
102 |
+
|
103 |
+
|
104 |
+
# Estimate noise and subtract it from the input image
|
105 |
+
im_noise_estim = self.model(imnoisy, nsigma)
|
106 |
+
outim = torch.clamp(imnoisy - im_noise_estim, 0., 1.)
|
107 |
+
|
108 |
+
if expanded_h:
|
109 |
+
# imorig = imorig[:, :, :-1, :]
|
110 |
+
outim = outim[:, :, :-1, :]
|
111 |
+
# imnoisy = imnoisy[:, :, :-1, :]
|
112 |
+
|
113 |
+
if expanded_w:
|
114 |
+
# imorig = imorig[:, :, :, :-1]
|
115 |
+
outim = outim[:, :, :, :-1]
|
116 |
+
# imnoisy = imnoisy[:, :, :, :-1]
|
117 |
+
|
118 |
+
return variable_to_cv2_image(outim)
|
manga_translator/colorization/manga_colorization_v2_utils/denoising/functions.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Functions implementing custom NN layers
|
3 |
+
|
4 |
+
Copyright (C) 2018, Matias Tassano <[email protected]>
|
5 |
+
|
6 |
+
This program is free software: you can use, modify and/or
|
7 |
+
redistribute it under the terms of the GNU General Public
|
8 |
+
License as published by the Free Software Foundation, either
|
9 |
+
version 3 of the License, or (at your option) any later
|
10 |
+
version. You should have received a copy of this license along
|
11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
12 |
+
"""
|
13 |
+
import torch
|
14 |
+
from torch.autograd import Function, Variable
|
15 |
+
|
16 |
+
def concatenate_input_noise_map(input, noise_sigma):
|
17 |
+
r"""Implements the first layer of FFDNet. This function returns a
|
18 |
+
torch.autograd.Variable composed of the concatenation of the downsampled
|
19 |
+
input image and the noise map. Each image of the batch of size CxHxW gets
|
20 |
+
converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
|
21 |
+
non-overlapped 2x2 patches of the input image are placed in the new array
|
22 |
+
along the first dimension.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
input: batch containing CxHxW images
|
26 |
+
noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
|
27 |
+
"""
|
28 |
+
# noise_sigma is a list of length batch_size
|
29 |
+
N, C, H, W = input.size()
|
30 |
+
dtype = input.type()
|
31 |
+
sca = 2
|
32 |
+
sca2 = sca*sca
|
33 |
+
Cout = sca2*C
|
34 |
+
Hout = H//sca
|
35 |
+
Wout = W//sca
|
36 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
37 |
+
|
38 |
+
# Fill the downsampled image with zeros
|
39 |
+
if 'cuda' in dtype:
|
40 |
+
downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
41 |
+
else:
|
42 |
+
# cpu and mps are the same
|
43 |
+
downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
44 |
+
|
45 |
+
# Build the CxH/2xW/2 noise map
|
46 |
+
noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
|
47 |
+
|
48 |
+
# Populate output
|
49 |
+
for idx in range(sca2):
|
50 |
+
downsampledfeatures[:, idx:Cout:sca2, :, :] = \
|
51 |
+
input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
52 |
+
|
53 |
+
# concatenate de-interleaved mosaic with noise map
|
54 |
+
return torch.cat((noise_map, downsampledfeatures), 1)
|
55 |
+
|
56 |
+
class UpSampleFeaturesFunction(Function):
|
57 |
+
r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
|
58 |
+
This class implements the forward and backward methods of the last layer
|
59 |
+
of FFDNet. It basically performs the inverse of
|
60 |
+
concatenate_input_noise_map(): it converts each of the images of a
|
61 |
+
batch of size CxH/2xW/2 to images of size C/4xHxW
|
62 |
+
"""
|
63 |
+
@staticmethod
|
64 |
+
def forward(ctx, input):
|
65 |
+
N, Cin, Hin, Win = input.size()
|
66 |
+
dtype = input.type()
|
67 |
+
sca = 2
|
68 |
+
sca2 = sca*sca
|
69 |
+
Cout = Cin//sca2
|
70 |
+
Hout = Hin*sca
|
71 |
+
Wout = Win*sca
|
72 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
73 |
+
|
74 |
+
assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
|
75 |
+
|
76 |
+
result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
|
77 |
+
for idx in range(sca2):
|
78 |
+
result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
|
79 |
+
|
80 |
+
return result
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def backward(ctx, grad_output):
|
84 |
+
N, Cg_out, Hg_out, Wg_out = grad_output.size()
|
85 |
+
dtype = grad_output.data.type()
|
86 |
+
sca = 2
|
87 |
+
sca2 = sca*sca
|
88 |
+
Cg_in = sca2*Cg_out
|
89 |
+
Hg_in = Hg_out//sca
|
90 |
+
Wg_in = Wg_out//sca
|
91 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
92 |
+
|
93 |
+
# Build output
|
94 |
+
grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
|
95 |
+
# Populate output
|
96 |
+
for idx in range(sca2):
|
97 |
+
grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
98 |
+
|
99 |
+
return Variable(grad_input)
|
100 |
+
|
101 |
+
# Alias functions
|
102 |
+
upsamplefeatures = UpSampleFeaturesFunction.apply
|
manga_translator/colorization/manga_colorization_v2_utils/denoising/models.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Definition of the FFDNet model and its custom layers
|
3 |
+
|
4 |
+
Copyright (C) 2018, Matias Tassano <[email protected]>
|
5 |
+
|
6 |
+
This program is free software: you can use, modify and/or
|
7 |
+
redistribute it under the terms of the GNU General Public
|
8 |
+
License as published by the Free Software Foundation, either
|
9 |
+
version 3 of the License, or (at your option) any later
|
10 |
+
version. You should have received a copy of this license along
|
11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
12 |
+
"""
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.autograd import Variable
|
15 |
+
from . import functions
|
16 |
+
|
17 |
+
class UpSampleFeatures(nn.Module):
|
18 |
+
r"""Implements the last layer of FFDNet
|
19 |
+
"""
|
20 |
+
def __init__(self):
|
21 |
+
super(UpSampleFeatures, self).__init__()
|
22 |
+
def forward(self, x):
|
23 |
+
return functions.upsamplefeatures(x)
|
24 |
+
|
25 |
+
class IntermediateDnCNN(nn.Module):
|
26 |
+
r"""Implements the middel part of the FFDNet architecture, which
|
27 |
+
is basically a DnCNN net
|
28 |
+
"""
|
29 |
+
def __init__(self, input_features, middle_features, num_conv_layers):
|
30 |
+
super(IntermediateDnCNN, self).__init__()
|
31 |
+
self.kernel_size = 3
|
32 |
+
self.padding = 1
|
33 |
+
self.input_features = input_features
|
34 |
+
self.num_conv_layers = num_conv_layers
|
35 |
+
self.middle_features = middle_features
|
36 |
+
if self.input_features == 5:
|
37 |
+
self.output_features = 4 #Grayscale image
|
38 |
+
elif self.input_features == 15:
|
39 |
+
self.output_features = 12 #RGB image
|
40 |
+
else:
|
41 |
+
raise Exception('Invalid number of input features')
|
42 |
+
|
43 |
+
layers = []
|
44 |
+
layers.append(nn.Conv2d(in_channels=self.input_features,\
|
45 |
+
out_channels=self.middle_features,\
|
46 |
+
kernel_size=self.kernel_size,\
|
47 |
+
padding=self.padding,\
|
48 |
+
bias=False))
|
49 |
+
layers.append(nn.ReLU(inplace=True))
|
50 |
+
for _ in range(self.num_conv_layers-2):
|
51 |
+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
|
52 |
+
out_channels=self.middle_features,\
|
53 |
+
kernel_size=self.kernel_size,\
|
54 |
+
padding=self.padding,\
|
55 |
+
bias=False))
|
56 |
+
layers.append(nn.BatchNorm2d(self.middle_features))
|
57 |
+
layers.append(nn.ReLU(inplace=True))
|
58 |
+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
|
59 |
+
out_channels=self.output_features,\
|
60 |
+
kernel_size=self.kernel_size,\
|
61 |
+
padding=self.padding,\
|
62 |
+
bias=False))
|
63 |
+
self.itermediate_dncnn = nn.Sequential(*layers)
|
64 |
+
def forward(self, x):
|
65 |
+
out = self.itermediate_dncnn(x)
|
66 |
+
return out
|
67 |
+
|
68 |
+
class FFDNet(nn.Module):
|
69 |
+
r"""Implements the FFDNet architecture
|
70 |
+
"""
|
71 |
+
def __init__(self, num_input_channels):
|
72 |
+
super(FFDNet, self).__init__()
|
73 |
+
self.num_input_channels = num_input_channels
|
74 |
+
if self.num_input_channels == 1:
|
75 |
+
# Grayscale image
|
76 |
+
self.num_feature_maps = 64
|
77 |
+
self.num_conv_layers = 15
|
78 |
+
self.downsampled_channels = 5
|
79 |
+
self.output_features = 4
|
80 |
+
elif self.num_input_channels == 3:
|
81 |
+
# RGB image
|
82 |
+
self.num_feature_maps = 96
|
83 |
+
self.num_conv_layers = 12
|
84 |
+
self.downsampled_channels = 15
|
85 |
+
self.output_features = 12
|
86 |
+
else:
|
87 |
+
raise Exception('Invalid number of input features')
|
88 |
+
|
89 |
+
self.intermediate_dncnn = IntermediateDnCNN(\
|
90 |
+
input_features=self.downsampled_channels,\
|
91 |
+
middle_features=self.num_feature_maps,\
|
92 |
+
num_conv_layers=self.num_conv_layers)
|
93 |
+
self.upsamplefeatures = UpSampleFeatures()
|
94 |
+
|
95 |
+
def forward(self, x, noise_sigma):
|
96 |
+
concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data)
|
97 |
+
concat_noise_x = Variable(concat_noise_x)
|
98 |
+
h_dncnn = self.intermediate_dncnn(concat_noise_x)
|
99 |
+
pred_noise = self.upsamplefeatures(h_dncnn)
|
100 |
+
return pred_noise
|
manga_translator/colorization/manga_colorization_v2_utils/denoising/utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Different utilities such as orthogonalization of weights, initialization of
|
3 |
+
loggers, etc
|
4 |
+
|
5 |
+
Copyright (C) 2018, Matias Tassano <[email protected]>
|
6 |
+
|
7 |
+
This program is free software: you can use, modify and/or
|
8 |
+
redistribute it under the terms of the GNU General Public
|
9 |
+
License as published by the Free Software Foundation, either
|
10 |
+
version 3 of the License, or (at your option) any later
|
11 |
+
version. You should have received a copy of this license along
|
12 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
13 |
+
"""
|
14 |
+
import numpy as np
|
15 |
+
import cv2
|
16 |
+
|
17 |
+
|
18 |
+
def variable_to_cv2_image(varim):
|
19 |
+
r"""Converts a torch.autograd.Variable to an OpenCV image
|
20 |
+
|
21 |
+
Args:
|
22 |
+
varim: a torch.autograd.Variable
|
23 |
+
"""
|
24 |
+
nchannels = varim.size()[1]
|
25 |
+
if nchannels == 1:
|
26 |
+
res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
|
27 |
+
elif nchannels == 3:
|
28 |
+
res = varim.data.cpu().numpy()[0]
|
29 |
+
res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
|
30 |
+
res = (res*255.).clip(0, 255).astype(np.uint8)
|
31 |
+
else:
|
32 |
+
raise Exception('Number of color channels not supported')
|
33 |
+
return res
|
34 |
+
|
35 |
+
|
36 |
+
def normalize(data):
|
37 |
+
return np.float32(data/255.)
|
38 |
+
|
39 |
+
def remove_dataparallel_wrapper(state_dict):
|
40 |
+
r"""Converts a DataParallel model to a normal one by removing the "module."
|
41 |
+
wrapper in the module dictionary
|
42 |
+
|
43 |
+
Args:
|
44 |
+
state_dict: a torch.nn.DataParallel state dictionary
|
45 |
+
"""
|
46 |
+
from collections import OrderedDict
|
47 |
+
|
48 |
+
new_state_dict = OrderedDict()
|
49 |
+
for k, vl in state_dict.items():
|
50 |
+
name = k[7:] # remove 'module.' of DataParallel
|
51 |
+
new_state_dict[name] = vl
|
52 |
+
|
53 |
+
return new_state_dict
|
54 |
+
|
55 |
+
def is_rgb(im_path):
|
56 |
+
r""" Returns True if the image in im_path is an RGB image
|
57 |
+
"""
|
58 |
+
from skimage.io import imread
|
59 |
+
rgb = False
|
60 |
+
im = imread(im_path)
|
61 |
+
if (len(im.shape) == 3):
|
62 |
+
if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
|
63 |
+
rgb = True
|
64 |
+
print("rgb: {}".format(rgb))
|
65 |
+
print("im shape: {}".format(im.shape))
|
66 |
+
return rgb
|
manga_translator/colorization/manga_colorization_v2_utils/networks/extractor.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
'''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
|
6 |
+
|
7 |
+
# Pretrained version
|
8 |
+
class Selayer(nn.Module):
|
9 |
+
def __init__(self, inplanes):
|
10 |
+
super(Selayer, self).__init__()
|
11 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
12 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
13 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
14 |
+
self.relu = nn.ReLU(inplace=True)
|
15 |
+
self.sigmoid = nn.Sigmoid()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
out = self.global_avgpool(x)
|
19 |
+
out = self.conv1(out)
|
20 |
+
out = self.relu(out)
|
21 |
+
out = self.conv2(out)
|
22 |
+
out = self.sigmoid(out)
|
23 |
+
|
24 |
+
return x * out
|
25 |
+
|
26 |
+
|
27 |
+
class BottleneckX_Origin(nn.Module):
|
28 |
+
expansion = 4
|
29 |
+
|
30 |
+
def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
|
31 |
+
super(BottleneckX_Origin, self).__init__()
|
32 |
+
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
|
33 |
+
self.bn1 = nn.BatchNorm2d(planes * 2)
|
34 |
+
|
35 |
+
self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
|
36 |
+
padding=1, groups=cardinality, bias=False)
|
37 |
+
self.bn2 = nn.BatchNorm2d(planes * 2)
|
38 |
+
|
39 |
+
self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
|
40 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
41 |
+
|
42 |
+
self.selayer = Selayer(planes * 4)
|
43 |
+
|
44 |
+
self.relu = nn.ReLU(inplace=True)
|
45 |
+
self.downsample = downsample
|
46 |
+
self.stride = stride
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
residual = x
|
50 |
+
|
51 |
+
out = self.conv1(x)
|
52 |
+
out = self.bn1(out)
|
53 |
+
out = self.relu(out)
|
54 |
+
|
55 |
+
out = self.conv2(out)
|
56 |
+
out = self.bn2(out)
|
57 |
+
out = self.relu(out)
|
58 |
+
|
59 |
+
out = self.conv3(out)
|
60 |
+
out = self.bn3(out)
|
61 |
+
|
62 |
+
out = self.selayer(out)
|
63 |
+
|
64 |
+
if self.downsample is not None:
|
65 |
+
residual = self.downsample(x)
|
66 |
+
|
67 |
+
out += residual
|
68 |
+
out = self.relu(out)
|
69 |
+
|
70 |
+
return out
|
71 |
+
|
72 |
+
class SEResNeXt_Origin(nn.Module):
|
73 |
+
def __init__(self, block, layers, input_channels=3, cardinality=32, num_classes=1000):
|
74 |
+
super(SEResNeXt_Origin, self).__init__()
|
75 |
+
self.cardinality = cardinality
|
76 |
+
self.inplanes = 64
|
77 |
+
self.input_channels = input_channels
|
78 |
+
|
79 |
+
self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
|
80 |
+
bias=False)
|
81 |
+
self.bn1 = nn.BatchNorm2d(64)
|
82 |
+
self.relu = nn.ReLU(inplace=True)
|
83 |
+
|
84 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
85 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
86 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
87 |
+
|
88 |
+
for m in self.modules():
|
89 |
+
if isinstance(m, nn.Conv2d):
|
90 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
91 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
92 |
+
if m.bias is not None:
|
93 |
+
m.bias.data.zero_()
|
94 |
+
elif isinstance(m, nn.BatchNorm2d):
|
95 |
+
m.weight.data.fill_(1)
|
96 |
+
m.bias.data.zero_()
|
97 |
+
|
98 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
99 |
+
downsample = None
|
100 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
101 |
+
downsample = nn.Sequential(
|
102 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
103 |
+
kernel_size=1, stride=stride, bias=False),
|
104 |
+
nn.BatchNorm2d(planes * block.expansion),
|
105 |
+
)
|
106 |
+
|
107 |
+
layers = []
|
108 |
+
layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
|
109 |
+
self.inplanes = planes * block.expansion
|
110 |
+
for i in range(1, blocks):
|
111 |
+
layers.append(block(self.inplanes, planes, self.cardinality))
|
112 |
+
|
113 |
+
return nn.Sequential(*layers)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
|
117 |
+
x = self.conv1(x)
|
118 |
+
x = self.bn1(x)
|
119 |
+
x1 = self.relu(x)
|
120 |
+
|
121 |
+
x2 = self.layer1(x1)
|
122 |
+
|
123 |
+
x3 = self.layer2(x2)
|
124 |
+
|
125 |
+
x4 = self.layer3(x3)
|
126 |
+
|
127 |
+
return x1, x2, x3, x4
|
manga_translator/colorization/manga_colorization_v2_utils/networks/models.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as M
|
5 |
+
import math
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn import Parameter
|
8 |
+
|
9 |
+
from .extractor import SEResNeXt_Origin, BottleneckX_Origin
|
10 |
+
|
11 |
+
'''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
|
12 |
+
|
13 |
+
def l2normalize(v, eps=1e-12):
|
14 |
+
return v / (v.norm() + eps)
|
15 |
+
|
16 |
+
|
17 |
+
class SpectralNorm(nn.Module):
|
18 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
19 |
+
super(SpectralNorm, self).__init__()
|
20 |
+
self.module = module
|
21 |
+
self.name = name
|
22 |
+
self.power_iterations = power_iterations
|
23 |
+
if not self._made_params():
|
24 |
+
self._make_params()
|
25 |
+
|
26 |
+
def _update_u_v(self):
|
27 |
+
u = getattr(self.module, self.name + "_u")
|
28 |
+
v = getattr(self.module, self.name + "_v")
|
29 |
+
w = getattr(self.module, self.name + "_bar")
|
30 |
+
|
31 |
+
height = w.data.shape[0]
|
32 |
+
for _ in range(self.power_iterations):
|
33 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
34 |
+
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
35 |
+
|
36 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
37 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
38 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
39 |
+
|
40 |
+
def _made_params(self):
|
41 |
+
try:
|
42 |
+
u = getattr(self.module, self.name + "_u")
|
43 |
+
v = getattr(self.module, self.name + "_v")
|
44 |
+
w = getattr(self.module, self.name + "_bar")
|
45 |
+
return True
|
46 |
+
except AttributeError:
|
47 |
+
return False
|
48 |
+
|
49 |
+
|
50 |
+
def _make_params(self):
|
51 |
+
w = getattr(self.module, self.name)
|
52 |
+
height = w.data.shape[0]
|
53 |
+
width = w.view(height, -1).data.shape[1]
|
54 |
+
|
55 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
56 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
57 |
+
u.data = l2normalize(u.data)
|
58 |
+
v.data = l2normalize(v.data)
|
59 |
+
w_bar = Parameter(w.data)
|
60 |
+
|
61 |
+
del self.module._parameters[self.name]
|
62 |
+
|
63 |
+
self.module.register_parameter(self.name + "_u", u)
|
64 |
+
self.module.register_parameter(self.name + "_v", v)
|
65 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
66 |
+
|
67 |
+
|
68 |
+
def forward(self, *args):
|
69 |
+
self._update_u_v()
|
70 |
+
return self.module.forward(*args)
|
71 |
+
|
72 |
+
class Selayer(nn.Module):
|
73 |
+
def __init__(self, inplanes):
|
74 |
+
super(Selayer, self).__init__()
|
75 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
76 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
77 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
78 |
+
self.relu = nn.ReLU(inplace=True)
|
79 |
+
self.sigmoid = nn.Sigmoid()
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
out = self.global_avgpool(x)
|
83 |
+
out = self.conv1(out)
|
84 |
+
out = self.relu(out)
|
85 |
+
out = self.conv2(out)
|
86 |
+
out = self.sigmoid(out)
|
87 |
+
|
88 |
+
return x * out
|
89 |
+
|
90 |
+
class SelayerSpectr(nn.Module):
|
91 |
+
def __init__(self, inplanes):
|
92 |
+
super(SelayerSpectr, self).__init__()
|
93 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
94 |
+
self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
|
95 |
+
self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
|
96 |
+
self.relu = nn.ReLU(inplace=True)
|
97 |
+
self.sigmoid = nn.Sigmoid()
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
out = self.global_avgpool(x)
|
101 |
+
out = self.conv1(out)
|
102 |
+
out = self.relu(out)
|
103 |
+
out = self.conv2(out)
|
104 |
+
out = self.sigmoid(out)
|
105 |
+
|
106 |
+
return x * out
|
107 |
+
|
108 |
+
class ResNeXtBottleneck(nn.Module):
|
109 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
110 |
+
super(ResNeXtBottleneck, self).__init__()
|
111 |
+
D = out_channels // 2
|
112 |
+
self.out_channels = out_channels
|
113 |
+
self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
|
114 |
+
self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
115 |
+
groups=cardinality,
|
116 |
+
bias=False)
|
117 |
+
self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
118 |
+
self.shortcut = nn.Sequential()
|
119 |
+
if stride != 1:
|
120 |
+
self.shortcut.add_module('shortcut',
|
121 |
+
nn.AvgPool2d(2, stride=2))
|
122 |
+
|
123 |
+
self.selayer = Selayer(out_channels)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
bottleneck = self.conv_reduce.forward(x)
|
127 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
128 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
129 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
130 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
131 |
+
bottleneck = self.selayer(bottleneck)
|
132 |
+
|
133 |
+
x = self.shortcut.forward(x)
|
134 |
+
return x + bottleneck
|
135 |
+
|
136 |
+
class SpectrResNeXtBottleneck(nn.Module):
|
137 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
138 |
+
super(SpectrResNeXtBottleneck, self).__init__()
|
139 |
+
D = out_channels // 2
|
140 |
+
self.out_channels = out_channels
|
141 |
+
self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
|
142 |
+
self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
143 |
+
groups=cardinality,
|
144 |
+
bias=False))
|
145 |
+
self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
|
146 |
+
self.shortcut = nn.Sequential()
|
147 |
+
if stride != 1:
|
148 |
+
self.shortcut.add_module('shortcut',
|
149 |
+
nn.AvgPool2d(2, stride=2))
|
150 |
+
|
151 |
+
self.selayer = SelayerSpectr(out_channels)
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
bottleneck = self.conv_reduce.forward(x)
|
155 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
156 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
157 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
158 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
159 |
+
bottleneck = self.selayer(bottleneck)
|
160 |
+
|
161 |
+
x = self.shortcut.forward(x)
|
162 |
+
return x + bottleneck
|
163 |
+
|
164 |
+
class FeatureConv(nn.Module):
|
165 |
+
def __init__(self, input_dim=512, output_dim=512):
|
166 |
+
super(FeatureConv, self).__init__()
|
167 |
+
|
168 |
+
no_bn = True
|
169 |
+
|
170 |
+
seq = []
|
171 |
+
seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
172 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
173 |
+
seq.append(nn.ReLU(inplace=True))
|
174 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
175 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
176 |
+
seq.append(nn.ReLU(inplace=True))
|
177 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
178 |
+
seq.append(nn.ReLU(inplace=True))
|
179 |
+
|
180 |
+
self.network = nn.Sequential(*seq)
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
return self.network(x)
|
184 |
+
|
185 |
+
class Generator(nn.Module):
|
186 |
+
def __init__(self, ngf=64):
|
187 |
+
super(Generator, self).__init__()
|
188 |
+
|
189 |
+
self.encoder = SEResNeXt_Origin(BottleneckX_Origin, [3, 4, 6, 3], num_classes= 370, input_channels=1)
|
190 |
+
|
191 |
+
self.to0 = self._make_encoder_block_first(5, 32)
|
192 |
+
self.to1 = self._make_encoder_block(32, 64)
|
193 |
+
self.to2 = self._make_encoder_block(64, 92)
|
194 |
+
self.to3 = self._make_encoder_block(92, 128)
|
195 |
+
self.to4 = self._make_encoder_block(128, 256)
|
196 |
+
|
197 |
+
self.deconv_for_decoder = nn.Sequential(
|
198 |
+
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
|
199 |
+
nn.LeakyReLU(0.2),
|
200 |
+
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
|
201 |
+
nn.LeakyReLU(0.2),
|
202 |
+
nn.ConvTranspose2d(64, 32, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
|
203 |
+
nn.LeakyReLU(0.2),
|
204 |
+
nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
|
205 |
+
nn.Tanh(),
|
206 |
+
)
|
207 |
+
|
208 |
+
tunnel4 = nn.Sequential(*[ResNeXtBottleneck(512, 512, cardinality=32, dilate=1) for _ in range(20)])
|
209 |
+
|
210 |
+
|
211 |
+
self.tunnel4 = nn.Sequential(nn.Conv2d(1024 + 128, 512, kernel_size=3, stride=1, padding=1),
|
212 |
+
nn.LeakyReLU(0.2, True),
|
213 |
+
tunnel4,
|
214 |
+
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
|
215 |
+
nn.PixelShuffle(2),
|
216 |
+
nn.LeakyReLU(0.2, True)
|
217 |
+
) # 64
|
218 |
+
|
219 |
+
depth = 2
|
220 |
+
tunnel = [ResNeXtBottleneck(256, 256, cardinality=32, dilate=1) for _ in range(depth)]
|
221 |
+
tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=2) for _ in range(depth)]
|
222 |
+
tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=4) for _ in range(depth)]
|
223 |
+
tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=2),
|
224 |
+
ResNeXtBottleneck(256, 256, cardinality=32, dilate=1)]
|
225 |
+
tunnel3 = nn.Sequential(*tunnel)
|
226 |
+
|
227 |
+
self.tunnel3 = nn.Sequential(nn.Conv2d(512 + 256, 256, kernel_size=3, stride=1, padding=1),
|
228 |
+
nn.LeakyReLU(0.2, True),
|
229 |
+
tunnel3,
|
230 |
+
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
|
231 |
+
nn.PixelShuffle(2),
|
232 |
+
nn.LeakyReLU(0.2, True)
|
233 |
+
) # 128
|
234 |
+
|
235 |
+
tunnel = [ResNeXtBottleneck(128, 128, cardinality=32, dilate=1) for _ in range(depth)]
|
236 |
+
tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=2) for _ in range(depth)]
|
237 |
+
tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=4) for _ in range(depth)]
|
238 |
+
tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=2),
|
239 |
+
ResNeXtBottleneck(128, 128, cardinality=32, dilate=1)]
|
240 |
+
tunnel2 = nn.Sequential(*tunnel)
|
241 |
+
|
242 |
+
self.tunnel2 = nn.Sequential(nn.Conv2d(128 + 256 + 64, 128, kernel_size=3, stride=1, padding=1),
|
243 |
+
nn.LeakyReLU(0.2, True),
|
244 |
+
tunnel2,
|
245 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
246 |
+
nn.PixelShuffle(2),
|
247 |
+
nn.LeakyReLU(0.2, True)
|
248 |
+
)
|
249 |
+
|
250 |
+
tunnel = [ResNeXtBottleneck(64, 64, cardinality=16, dilate=1)]
|
251 |
+
tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=2)]
|
252 |
+
tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=4)]
|
253 |
+
tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=2),
|
254 |
+
ResNeXtBottleneck(64, 64, cardinality=16, dilate=1)]
|
255 |
+
tunnel1 = nn.Sequential(*tunnel)
|
256 |
+
|
257 |
+
self.tunnel1 = nn.Sequential(nn.Conv2d(64 + 32, 64, kernel_size=3, stride=1, padding=1),
|
258 |
+
nn.LeakyReLU(0.2, True),
|
259 |
+
tunnel1,
|
260 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
261 |
+
nn.PixelShuffle(2),
|
262 |
+
nn.LeakyReLU(0.2, True)
|
263 |
+
)
|
264 |
+
|
265 |
+
self.exit = nn.Sequential(nn.Conv2d(64 + 32, 32, kernel_size=3, stride=1, padding=1),
|
266 |
+
nn.LeakyReLU(0.2, True),
|
267 |
+
nn.Conv2d(32, 3, kernel_size= 1, stride = 1, padding = 0))
|
268 |
+
|
269 |
+
|
270 |
+
def _make_encoder_block(self, inplanes, planes):
|
271 |
+
return nn.Sequential(
|
272 |
+
nn.Conv2d(inplanes, planes, 3, 2, 1),
|
273 |
+
nn.LeakyReLU(0.2),
|
274 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
275 |
+
nn.LeakyReLU(0.2),
|
276 |
+
)
|
277 |
+
|
278 |
+
def _make_encoder_block_first(self, inplanes, planes):
|
279 |
+
return nn.Sequential(
|
280 |
+
nn.Conv2d(inplanes, planes, 3, 1, 1),
|
281 |
+
nn.LeakyReLU(0.2),
|
282 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
283 |
+
nn.LeakyReLU(0.2),
|
284 |
+
)
|
285 |
+
|
286 |
+
def forward(self, sketch):
|
287 |
+
|
288 |
+
x0 = self.to0(sketch)
|
289 |
+
aux_out = self.to1(x0)
|
290 |
+
aux_out = self.to2(aux_out)
|
291 |
+
aux_out = self.to3(aux_out)
|
292 |
+
|
293 |
+
x1, x2, x3, x4 = self.encoder(sketch[:, 0:1])
|
294 |
+
|
295 |
+
out = self.tunnel4(torch.cat([x4, aux_out], 1))
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
x = self.tunnel3(torch.cat([out, x3], 1))
|
300 |
+
|
301 |
+
x = self.tunnel2(torch.cat([x, x2, x1], 1))
|
302 |
+
|
303 |
+
|
304 |
+
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
|
305 |
+
|
306 |
+
decoder_output = self.deconv_for_decoder(out)
|
307 |
+
|
308 |
+
return x, decoder_output
|
309 |
+
|
310 |
+
|
311 |
+
class Colorizer(nn.Module):
|
312 |
+
def __init__(self):
|
313 |
+
super(Colorizer, self).__init__()
|
314 |
+
|
315 |
+
self.generator = Generator()
|
316 |
+
|
317 |
+
def forward(self, x, extractor_grad = False):
|
318 |
+
fake, guide = self.generator(x)
|
319 |
+
return fake, guide
|
manga_translator/colorization/manga_colorization_v2_utils/utils/utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
def resize_pad(img, size = 256):
|
5 |
+
|
6 |
+
if len(img.shape) == 2:
|
7 |
+
img = np.expand_dims(img, 2)
|
8 |
+
|
9 |
+
if img.shape[2] == 1:
|
10 |
+
img = np.repeat(img, 3, 2)
|
11 |
+
|
12 |
+
if img.shape[2] == 4:
|
13 |
+
img = img[:, :, :3]
|
14 |
+
|
15 |
+
pad = None
|
16 |
+
|
17 |
+
if (img.shape[0] < img.shape[1]):
|
18 |
+
height = img.shape[0]
|
19 |
+
ratio = height / (size * 1.5)
|
20 |
+
width = int(np.ceil(img.shape[1] / ratio))
|
21 |
+
img = cv2.resize(img, (width, int(size * 1.5)), interpolation = cv2.INTER_AREA)
|
22 |
+
|
23 |
+
|
24 |
+
new_width = width + (32 - width % 32)
|
25 |
+
|
26 |
+
pad = (0, new_width - width)
|
27 |
+
|
28 |
+
img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
|
29 |
+
else:
|
30 |
+
width = img.shape[1]
|
31 |
+
ratio = width / size
|
32 |
+
height = int(np.ceil(img.shape[0] / ratio))
|
33 |
+
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
|
34 |
+
|
35 |
+
new_height = height + (32 - height % 32)
|
36 |
+
|
37 |
+
pad = (new_height - height, 0)
|
38 |
+
|
39 |
+
img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
|
40 |
+
|
41 |
+
if (img.dtype == 'float32'):
|
42 |
+
np.clip(img, 0, 1, out = img)
|
43 |
+
|
44 |
+
return img[:, :, :1], pad
|
manga_translator/detection/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from .default import DefaultDetector
|
4 |
+
from .dbnet_convnext import DBConvNextDetector
|
5 |
+
from .ctd import ComicTextDetector
|
6 |
+
from .craft import CRAFTDetector
|
7 |
+
from .none import NoneDetector
|
8 |
+
from .common import CommonDetector, OfflineDetector
|
9 |
+
|
10 |
+
DETECTORS = {
|
11 |
+
'default': DefaultDetector,
|
12 |
+
'dbconvnext': DBConvNextDetector,
|
13 |
+
'ctd': ComicTextDetector,
|
14 |
+
'craft': CRAFTDetector,
|
15 |
+
'none': NoneDetector,
|
16 |
+
}
|
17 |
+
detector_cache = {}
|
18 |
+
|
19 |
+
def get_detector(key: str, *args, **kwargs) -> CommonDetector:
|
20 |
+
if key not in DETECTORS:
|
21 |
+
raise ValueError(f'Could not find detector for: "{key}". Choose from the following: %s' % ','.join(DETECTORS))
|
22 |
+
if not detector_cache.get(key):
|
23 |
+
detector = DETECTORS[key]
|
24 |
+
detector_cache[key] = detector(*args, **kwargs)
|
25 |
+
return detector_cache[key]
|
26 |
+
|
27 |
+
async def prepare(detector_key: str):
|
28 |
+
detector = get_detector(detector_key)
|
29 |
+
if isinstance(detector, OfflineDetector):
|
30 |
+
await detector.download()
|
31 |
+
|
32 |
+
async def dispatch(detector_key: str, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float, unclip_ratio: float,
|
33 |
+
invert: bool, gamma_correct: bool, rotate: bool, auto_rotate: bool = False, device: str = 'cpu', verbose: bool = False):
|
34 |
+
detector = get_detector(detector_key)
|
35 |
+
if isinstance(detector, OfflineDetector):
|
36 |
+
await detector.load(device)
|
37 |
+
return await detector.detect(image, detect_size, text_threshold, box_threshold, unclip_ratio, invert, gamma_correct, rotate, auto_rotate, verbose)
|
manga_translator/detection/common.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import List, Tuple
|
3 |
+
from collections import Counter
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
from ..utils import InfererModule, ModelWrapper, Quadrilateral
|
8 |
+
|
9 |
+
|
10 |
+
class CommonDetector(InfererModule):
|
11 |
+
|
12 |
+
async def detect(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float, unclip_ratio: float,
|
13 |
+
invert: bool, gamma_correct: bool, rotate: bool, auto_rotate: bool = False, verbose: bool = False):
|
14 |
+
'''
|
15 |
+
Returns textblock list and text mask.
|
16 |
+
'''
|
17 |
+
|
18 |
+
# Apply filters
|
19 |
+
img_h, img_w = image.shape[:2]
|
20 |
+
orig_image = image.copy()
|
21 |
+
minimum_image_size = 400
|
22 |
+
# Automatically add border if image too small (instead of simply resizing due to them more likely containing large fonts)
|
23 |
+
add_border = min(img_w, img_h) < minimum_image_size
|
24 |
+
if rotate:
|
25 |
+
self.logger.debug('Adding rotation')
|
26 |
+
image = self._add_rotation(image)
|
27 |
+
if add_border:
|
28 |
+
self.logger.debug('Adding border')
|
29 |
+
image = self._add_border(image, minimum_image_size)
|
30 |
+
if invert:
|
31 |
+
self.logger.debug('Adding inversion')
|
32 |
+
image = self._add_inversion(image)
|
33 |
+
if gamma_correct:
|
34 |
+
self.logger.debug('Adding gamma correction')
|
35 |
+
image = self._add_gamma_correction(image)
|
36 |
+
# if True:
|
37 |
+
# self.logger.debug('Adding histogram equalization')
|
38 |
+
# image = self._add_histogram_equalization(image)
|
39 |
+
|
40 |
+
# cv2.imwrite('histogram.png', image)
|
41 |
+
# cv2.waitKey(0)
|
42 |
+
|
43 |
+
# Run detection
|
44 |
+
textlines, raw_mask, mask = await self._detect(image, detect_size, text_threshold, box_threshold, unclip_ratio, verbose)
|
45 |
+
textlines = list(filter(lambda x: x.area > 1, textlines))
|
46 |
+
|
47 |
+
# Remove filters
|
48 |
+
if add_border:
|
49 |
+
textlines, raw_mask, mask = self._remove_border(image, img_w, img_h, textlines, raw_mask, mask)
|
50 |
+
if auto_rotate:
|
51 |
+
# Rotate if horizontal aspect ratios are prevalent to potentially improve detection
|
52 |
+
if len(textlines) > 0:
|
53 |
+
orientations = ['h' if txtln.aspect_ratio > 1 else 'v' for txtln in textlines]
|
54 |
+
majority_orientation = Counter(orientations).most_common(1)[0][0]
|
55 |
+
else:
|
56 |
+
majority_orientation = 'h'
|
57 |
+
if majority_orientation == 'h':
|
58 |
+
self.logger.info('Rerunning detection with 90° rotation')
|
59 |
+
return await self.detect(orig_image, detect_size, text_threshold, box_threshold, unclip_ratio, invert, gamma_correct,
|
60 |
+
rotate=(not rotate), auto_rotate=False, verbose=verbose)
|
61 |
+
if rotate:
|
62 |
+
textlines, raw_mask, mask = self._remove_rotation(textlines, raw_mask, mask, img_w, img_h)
|
63 |
+
|
64 |
+
return textlines, raw_mask, mask
|
65 |
+
|
66 |
+
@abstractmethod
|
67 |
+
async def _detect(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
|
68 |
+
unclip_ratio: float, verbose: bool = False) -> Tuple[List[Quadrilateral], np.ndarray, np.ndarray]:
|
69 |
+
pass
|
70 |
+
|
71 |
+
def _add_border(self, image: np.ndarray, target_side_length: int):
|
72 |
+
old_h, old_w = image.shape[:2]
|
73 |
+
new_w = new_h = max(old_w, old_h, target_side_length)
|
74 |
+
new_image = np.zeros([new_h, new_w, 3]).astype(np.uint8)
|
75 |
+
# new_image[:] = np.array([255, 255, 255], np.uint8)
|
76 |
+
x, y = 0, 0
|
77 |
+
# x, y = (new_h - old_h) // 2, (new_w - old_w) // 2
|
78 |
+
new_image[y:y+old_h, x:x+old_w] = image
|
79 |
+
return new_image
|
80 |
+
|
81 |
+
def _remove_border(self, image: np.ndarray, old_w: int, old_h: int, textlines: List[Quadrilateral], raw_mask, mask):
|
82 |
+
new_h, new_w = image.shape[:2]
|
83 |
+
raw_mask = cv2.resize(raw_mask, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
84 |
+
raw_mask = raw_mask[:old_h, :old_w]
|
85 |
+
if mask is not None:
|
86 |
+
mask = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
87 |
+
mask = mask[:old_h, :old_w]
|
88 |
+
|
89 |
+
# Filter out regions within the border and clamp the points of the remaining regions
|
90 |
+
new_textlines = []
|
91 |
+
for txtln in textlines:
|
92 |
+
if txtln.xyxy[0] >= old_w and txtln.xyxy[1] >= old_h:
|
93 |
+
continue
|
94 |
+
points = txtln.pts
|
95 |
+
points[:,0] = np.clip(points[:,0], 0, old_w)
|
96 |
+
points[:,1] = np.clip(points[:,1], 0, old_h)
|
97 |
+
new_txtln = Quadrilateral(points, txtln.text, txtln.prob)
|
98 |
+
new_textlines.append(new_txtln)
|
99 |
+
return new_textlines, raw_mask, mask
|
100 |
+
|
101 |
+
def _add_rotation(self, image: np.ndarray):
|
102 |
+
return np.rot90(image, k=-1)
|
103 |
+
|
104 |
+
def _remove_rotation(self, textlines, raw_mask, mask, img_w, img_h):
|
105 |
+
raw_mask = np.ascontiguousarray(np.rot90(raw_mask))
|
106 |
+
if mask is not None:
|
107 |
+
mask = np.ascontiguousarray(np.rot90(mask).astype(np.uint8))
|
108 |
+
|
109 |
+
for i, txtln in enumerate(textlines):
|
110 |
+
rotated_pts = txtln.pts[:,[1,0]]
|
111 |
+
rotated_pts[:,1] = -rotated_pts[:,1] + img_h
|
112 |
+
textlines[i] = Quadrilateral(rotated_pts, txtln.text, txtln.prob)
|
113 |
+
return textlines, raw_mask, mask
|
114 |
+
|
115 |
+
def _add_inversion(self, image: np.ndarray):
|
116 |
+
return cv2.bitwise_not(image)
|
117 |
+
|
118 |
+
def _add_gamma_correction(self, image: np.ndarray):
|
119 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
120 |
+
mid = 0.5
|
121 |
+
mean = np.mean(gray)
|
122 |
+
gamma = np.log(mid * 255) / np.log(mean)
|
123 |
+
img_gamma = np.power(image, gamma).clip(0,255).astype(np.uint8)
|
124 |
+
return img_gamma
|
125 |
+
|
126 |
+
def _add_histogram_equalization(self, image: np.ndarray):
|
127 |
+
img_yuv = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)
|
128 |
+
|
129 |
+
# equalize the histogram of the Y channel
|
130 |
+
img_yuv[:,:,0] = cv2.equalizeHist(img_yuv[:,:,0])
|
131 |
+
|
132 |
+
# convert the YUV image back to RGB format
|
133 |
+
img_output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR)
|
134 |
+
return img_output
|
135 |
+
|
136 |
+
|
137 |
+
class OfflineDetector(CommonDetector, ModelWrapper):
|
138 |
+
_MODEL_SUB_DIR = 'detection'
|
139 |
+
|
140 |
+
async def _detect(self, *args, **kwargs):
|
141 |
+
return await self.infer(*args, **kwargs)
|
142 |
+
|
143 |
+
@abstractmethod
|
144 |
+
async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
|
145 |
+
unclip_ratio: float, verbose: bool = False):
|
146 |
+
pass
|
manga_translator/detection/craft.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2019-present NAVER Corp.
|
3 |
+
MIT License
|
4 |
+
"""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import cv2
|
16 |
+
import einops
|
17 |
+
from typing import List, Tuple
|
18 |
+
|
19 |
+
from .default_utils.DBNet_resnet34 import TextDetection as TextDetectionDefault
|
20 |
+
from .default_utils import imgproc, dbnet_utils, craft_utils
|
21 |
+
from .common import OfflineDetector
|
22 |
+
from ..utils import TextBlock, Quadrilateral, det_rearrange_forward
|
23 |
+
from shapely.geometry import Polygon, MultiPoint
|
24 |
+
from shapely import affinity
|
25 |
+
|
26 |
+
from .craft_utils.vgg16_bn import vgg16_bn, init_weights
|
27 |
+
from .craft_utils.refiner import RefineNet
|
28 |
+
|
29 |
+
class double_conv(nn.Module):
|
30 |
+
def __init__(self, in_ch, mid_ch, out_ch):
|
31 |
+
super(double_conv, self).__init__()
|
32 |
+
self.conv = nn.Sequential(
|
33 |
+
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
|
34 |
+
nn.BatchNorm2d(mid_ch),
|
35 |
+
nn.ReLU(inplace=True),
|
36 |
+
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
|
37 |
+
nn.BatchNorm2d(out_ch),
|
38 |
+
nn.ReLU(inplace=True)
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.conv(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class CRAFT(nn.Module):
|
47 |
+
def __init__(self, pretrained=False, freeze=False):
|
48 |
+
super(CRAFT, self).__init__()
|
49 |
+
|
50 |
+
""" Base network """
|
51 |
+
self.basenet = vgg16_bn(pretrained, freeze)
|
52 |
+
|
53 |
+
""" U network """
|
54 |
+
self.upconv1 = double_conv(1024, 512, 256)
|
55 |
+
self.upconv2 = double_conv(512, 256, 128)
|
56 |
+
self.upconv3 = double_conv(256, 128, 64)
|
57 |
+
self.upconv4 = double_conv(128, 64, 32)
|
58 |
+
|
59 |
+
num_class = 2
|
60 |
+
self.conv_cls = nn.Sequential(
|
61 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
62 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
63 |
+
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
64 |
+
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
|
65 |
+
nn.Conv2d(16, num_class, kernel_size=1),
|
66 |
+
)
|
67 |
+
|
68 |
+
init_weights(self.upconv1.modules())
|
69 |
+
init_weights(self.upconv2.modules())
|
70 |
+
init_weights(self.upconv3.modules())
|
71 |
+
init_weights(self.upconv4.modules())
|
72 |
+
init_weights(self.conv_cls.modules())
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
""" Base network """
|
76 |
+
sources = self.basenet(x)
|
77 |
+
|
78 |
+
""" U network """
|
79 |
+
y = torch.cat([sources[0], sources[1]], dim=1)
|
80 |
+
y = self.upconv1(y)
|
81 |
+
|
82 |
+
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
|
83 |
+
y = torch.cat([y, sources[2]], dim=1)
|
84 |
+
y = self.upconv2(y)
|
85 |
+
|
86 |
+
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
|
87 |
+
y = torch.cat([y, sources[3]], dim=1)
|
88 |
+
y = self.upconv3(y)
|
89 |
+
|
90 |
+
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
|
91 |
+
y = torch.cat([y, sources[4]], dim=1)
|
92 |
+
feature = self.upconv4(y)
|
93 |
+
|
94 |
+
y = self.conv_cls(feature)
|
95 |
+
|
96 |
+
return y.permute(0,2,3,1), feature
|
97 |
+
|
98 |
+
|
99 |
+
from collections import OrderedDict
|
100 |
+
def copyStateDict(state_dict):
|
101 |
+
if list(state_dict.keys())[0].startswith("module"):
|
102 |
+
start_idx = 1
|
103 |
+
else:
|
104 |
+
start_idx = 0
|
105 |
+
new_state_dict = OrderedDict()
|
106 |
+
for k, v in state_dict.items():
|
107 |
+
name = ".".join(k.split(".")[start_idx:])
|
108 |
+
new_state_dict[name] = v
|
109 |
+
return new_state_dict
|
110 |
+
|
111 |
+
class CRAFTDetector(OfflineDetector):
|
112 |
+
_MODEL_MAPPING = {
|
113 |
+
'refiner': {
|
114 |
+
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/craft_refiner_CTW1500.pth',
|
115 |
+
'hash': 'f7000cd3e9c76f2231b62b32182212203f73c08dfaa12bb16ffb529948a01399',
|
116 |
+
'file': 'craft_refiner_CTW1500.pth',
|
117 |
+
},
|
118 |
+
'craft': {
|
119 |
+
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/craft_mlt_25k.pth',
|
120 |
+
'hash': '4a5efbfb48b4081100544e75e1e2b57f8de3d84f213004b14b85fd4b3748db17',
|
121 |
+
'file': 'craft_mlt_25k.pth',
|
122 |
+
}
|
123 |
+
}
|
124 |
+
|
125 |
+
def __init__(self, *args, **kwargs):
|
126 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
127 |
+
if os.path.exists('craft_mlt_25k.pth'):
|
128 |
+
shutil.move('craft_mlt_25k.pth', self._get_file_path('craft_mlt_25k.pth'))
|
129 |
+
if os.path.exists('craft_refiner_CTW1500.pth'):
|
130 |
+
shutil.move('craft_refiner_CTW1500.pth', self._get_file_path('craft_refiner_CTW1500.pth'))
|
131 |
+
super().__init__(*args, **kwargs)
|
132 |
+
|
133 |
+
async def _load(self, device: str):
|
134 |
+
self.model = CRAFT()
|
135 |
+
self.model.load_state_dict(copyStateDict(torch.load(self._get_file_path('craft_mlt_25k.pth'), map_location='cpu')))
|
136 |
+
self.model.eval()
|
137 |
+
self.model_refiner = RefineNet()
|
138 |
+
self.model_refiner.load_state_dict(copyStateDict(torch.load(self._get_file_path('craft_refiner_CTW1500.pth'), map_location='cpu')))
|
139 |
+
self.model_refiner.eval()
|
140 |
+
self.device = device
|
141 |
+
if device == 'cuda' or device == 'mps':
|
142 |
+
self.model = self.model.to(self.device)
|
143 |
+
self.model_refiner = self.model_refiner.to(self.device)
|
144 |
+
global MODEL
|
145 |
+
MODEL = self.model
|
146 |
+
|
147 |
+
async def _unload(self):
|
148 |
+
del self.model
|
149 |
+
|
150 |
+
async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
|
151 |
+
unclip_ratio: float, verbose: bool = False):
|
152 |
+
|
153 |
+
img_resized, target_ratio, size_heatmap, pad_w, pad_h = imgproc.resize_aspect_ratio(image, detect_size, interpolation = cv2.INTER_CUBIC, mag_ratio = 1)
|
154 |
+
ratio_h = ratio_w = 1 / target_ratio
|
155 |
+
|
156 |
+
# preprocessing
|
157 |
+
x = imgproc.normalizeMeanVariance(img_resized)
|
158 |
+
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
|
159 |
+
x = x.unsqueeze(0).to(self.device) # [c, h, w] to [b, c, h, w]
|
160 |
+
|
161 |
+
with torch.no_grad() :
|
162 |
+
y, feature = self.model(x)
|
163 |
+
|
164 |
+
# make score and link map
|
165 |
+
score_text = y[0,:,:,0].cpu().data.numpy()
|
166 |
+
score_link = y[0,:,:,1].cpu().data.numpy()
|
167 |
+
|
168 |
+
# refine link
|
169 |
+
y_refiner = self.model_refiner(y, feature)
|
170 |
+
score_link = y_refiner[0,:,:,0].cpu().data.numpy()
|
171 |
+
|
172 |
+
# Post-processing
|
173 |
+
boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, box_threshold, box_threshold, True)
|
174 |
+
|
175 |
+
# coordinate adjustment
|
176 |
+
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
177 |
+
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
|
178 |
+
for k in range(len(polys)):
|
179 |
+
if polys[k] is None: polys[k] = boxes[k]
|
180 |
+
|
181 |
+
mask = np.zeros(shape = (image.shape[0], image.shape[1]), dtype = np.uint8)
|
182 |
+
|
183 |
+
for poly in polys :
|
184 |
+
mask = cv2.fillPoly(mask, [poly.reshape((-1, 1, 2)).astype(np.int32)], color = 255)
|
185 |
+
|
186 |
+
polys_ret = []
|
187 |
+
for i in range(len(polys)) :
|
188 |
+
poly = MultiPoint(polys[i])
|
189 |
+
if poly.area > 10 :
|
190 |
+
rect = poly.minimum_rotated_rectangle
|
191 |
+
rect = affinity.scale(rect, xfact = 1.2, yfact = 1.2)
|
192 |
+
polys_ret.append(np.roll(np.asarray(list(rect.exterior.coords)[:4]), 2))
|
193 |
+
|
194 |
+
kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
|
195 |
+
mask = cv2.dilate(mask, kern)
|
196 |
+
|
197 |
+
textlines = [Quadrilateral(pts.astype(int), '', 1) for pts in polys_ret]
|
198 |
+
textlines = list(filter(lambda q: q.area > 16, textlines))
|
199 |
+
|
200 |
+
return textlines, mask, None
|
manga_translator/detection/craft_utils/refiner.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2019-present NAVER Corp.
|
3 |
+
MIT License
|
4 |
+
"""
|
5 |
+
|
6 |
+
# -*- coding: utf-8 -*-
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.autograd import Variable
|
11 |
+
from .vgg16_bn import init_weights
|
12 |
+
|
13 |
+
|
14 |
+
class RefineNet(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(RefineNet, self).__init__()
|
17 |
+
|
18 |
+
self.last_conv = nn.Sequential(
|
19 |
+
nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
20 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
|
22 |
+
)
|
23 |
+
|
24 |
+
self.aspp1 = nn.Sequential(
|
25 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
26 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
27 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
28 |
+
)
|
29 |
+
|
30 |
+
self.aspp2 = nn.Sequential(
|
31 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
32 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
33 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
34 |
+
)
|
35 |
+
|
36 |
+
self.aspp3 = nn.Sequential(
|
37 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
38 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
39 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
40 |
+
)
|
41 |
+
|
42 |
+
self.aspp4 = nn.Sequential(
|
43 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
44 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
45 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
46 |
+
)
|
47 |
+
|
48 |
+
init_weights(self.last_conv.modules())
|
49 |
+
init_weights(self.aspp1.modules())
|
50 |
+
init_weights(self.aspp2.modules())
|
51 |
+
init_weights(self.aspp3.modules())
|
52 |
+
init_weights(self.aspp4.modules())
|
53 |
+
|
54 |
+
def forward(self, y, upconv4):
|
55 |
+
refine = torch.cat([y.permute(0,3,1,2), upconv4], dim=1)
|
56 |
+
refine = self.last_conv(refine)
|
57 |
+
|
58 |
+
aspp1 = self.aspp1(refine)
|
59 |
+
aspp2 = self.aspp2(refine)
|
60 |
+
aspp3 = self.aspp3(refine)
|
61 |
+
aspp4 = self.aspp4(refine)
|
62 |
+
|
63 |
+
#out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
|
64 |
+
out = aspp1 + aspp2 + aspp3 + aspp4
|
65 |
+
return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1)
|
manga_translator/detection/craft_utils/vgg16_bn.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.init as init
|
6 |
+
from torchvision import models
|
7 |
+
|
8 |
+
def init_weights(modules):
|
9 |
+
for m in modules:
|
10 |
+
if isinstance(m, nn.Conv2d):
|
11 |
+
init.xavier_uniform_(m.weight.data)
|
12 |
+
if m.bias is not None:
|
13 |
+
m.bias.data.zero_()
|
14 |
+
elif isinstance(m, nn.BatchNorm2d):
|
15 |
+
m.weight.data.fill_(1)
|
16 |
+
m.bias.data.zero_()
|
17 |
+
elif isinstance(m, nn.Linear):
|
18 |
+
m.weight.data.normal_(0, 0.01)
|
19 |
+
m.bias.data.zero_()
|
20 |
+
|
21 |
+
class vgg16_bn(torch.nn.Module):
|
22 |
+
def __init__(self, pretrained=True, freeze=True):
|
23 |
+
super(vgg16_bn, self).__init__()
|
24 |
+
vgg_pretrained_features = models.vgg16_bn().features
|
25 |
+
self.slice1 = torch.nn.Sequential()
|
26 |
+
self.slice2 = torch.nn.Sequential()
|
27 |
+
self.slice3 = torch.nn.Sequential()
|
28 |
+
self.slice4 = torch.nn.Sequential()
|
29 |
+
self.slice5 = torch.nn.Sequential()
|
30 |
+
for x in range(12): # conv2_2
|
31 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
32 |
+
for x in range(12, 19): # conv3_3
|
33 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
34 |
+
for x in range(19, 29): # conv4_3
|
35 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
36 |
+
for x in range(29, 39): # conv5_3
|
37 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
38 |
+
|
39 |
+
# fc6, fc7 without atrous conv
|
40 |
+
self.slice5 = torch.nn.Sequential(
|
41 |
+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
42 |
+
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
|
43 |
+
nn.Conv2d(1024, 1024, kernel_size=1)
|
44 |
+
)
|
45 |
+
|
46 |
+
if not pretrained:
|
47 |
+
init_weights(self.slice1.modules())
|
48 |
+
init_weights(self.slice2.modules())
|
49 |
+
init_weights(self.slice3.modules())
|
50 |
+
init_weights(self.slice4.modules())
|
51 |
+
|
52 |
+
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
|
53 |
+
|
54 |
+
if freeze:
|
55 |
+
for param in self.slice1.parameters(): # only first conv
|
56 |
+
param.requires_grad= False
|
57 |
+
|
58 |
+
def forward(self, X):
|
59 |
+
h = self.slice1(X)
|
60 |
+
h_relu2_2 = h
|
61 |
+
h = self.slice2(h)
|
62 |
+
h_relu3_2 = h
|
63 |
+
h = self.slice3(h)
|
64 |
+
h_relu4_3 = h
|
65 |
+
h = self.slice4(h)
|
66 |
+
h_relu5_3 = h
|
67 |
+
h = self.slice5(h)
|
68 |
+
h_fc7 = h
|
69 |
+
vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
|
70 |
+
out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
|
71 |
+
return out
|
manga_translator/detection/ctd.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import numpy as np
|
4 |
+
import einops
|
5 |
+
from typing import Union, Tuple
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .ctd_utils.basemodel import TextDetBase, TextDetBaseDNN
|
10 |
+
from .ctd_utils.utils.yolov5_utils import non_max_suppression
|
11 |
+
from .ctd_utils.utils.db_utils import SegDetectorRepresenter
|
12 |
+
from .ctd_utils.utils.imgproc_utils import letterbox
|
13 |
+
from .ctd_utils.textmask import REFINEMASK_INPAINT, refine_mask
|
14 |
+
from .common import OfflineDetector
|
15 |
+
from ..utils import Quadrilateral, det_rearrange_forward
|
16 |
+
|
17 |
+
def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True):
|
18 |
+
if bgr2rgb:
|
19 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
20 |
+
img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64)
|
21 |
+
if to_tensor:
|
22 |
+
img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
23 |
+
img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255
|
24 |
+
if to_tensor:
|
25 |
+
img_in = torch.from_numpy(img_in).to(device)
|
26 |
+
if half:
|
27 |
+
img_in = img_in.half()
|
28 |
+
return img_in, ratio, int(dw), int(dh)
|
29 |
+
|
30 |
+
def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None):
|
31 |
+
# img = img.permute(1, 2, 0)
|
32 |
+
if isinstance(img, torch.Tensor):
|
33 |
+
img = img.squeeze_()
|
34 |
+
if img.device != 'cpu':
|
35 |
+
img = img.detach().cpu()
|
36 |
+
img = img.numpy()
|
37 |
+
else:
|
38 |
+
img = img.squeeze()
|
39 |
+
if thresh is not None:
|
40 |
+
img = img > thresh
|
41 |
+
img = img * 255
|
42 |
+
# if isinstance(img, torch.Tensor):
|
43 |
+
|
44 |
+
return img.astype(np.uint8)
|
45 |
+
|
46 |
+
def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None):
|
47 |
+
det = non_max_suppression(det, conf_thresh, nms_thresh)[0]
|
48 |
+
# bbox = det[..., 0:4]
|
49 |
+
if det.device != 'cpu':
|
50 |
+
det = det.detach_().cpu().numpy()
|
51 |
+
det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0]
|
52 |
+
det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1]
|
53 |
+
if sort_func is not None:
|
54 |
+
det = sort_func(det)
|
55 |
+
|
56 |
+
blines = det[..., 0:4].astype(np.int32)
|
57 |
+
confs = np.round(det[..., 4], 3)
|
58 |
+
cls = det[..., 5].astype(np.int32)
|
59 |
+
return blines, cls, confs
|
60 |
+
|
61 |
+
|
62 |
+
class ComicTextDetector(OfflineDetector):
|
63 |
+
_MODEL_MAPPING = {
|
64 |
+
'model-cuda': {
|
65 |
+
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt',
|
66 |
+
'hash': '1f90fa60aeeb1eb82e2ac1167a66bf139a8a61b8780acd351ead55268540cccb',
|
67 |
+
'file': '.',
|
68 |
+
},
|
69 |
+
'model-cpu': {
|
70 |
+
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt.onnx',
|
71 |
+
'hash': '1a86ace74961413cbd650002e7bb4dcec4980ffa21b2f19b86933372071d718f',
|
72 |
+
'file': '.',
|
73 |
+
},
|
74 |
+
}
|
75 |
+
|
76 |
+
def __init__(self, *args, **kwargs):
|
77 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
78 |
+
if os.path.exists('comictextdetector.pt'):
|
79 |
+
shutil.move('comictextdetector.pt', self._get_file_path('comictextdetector.pt'))
|
80 |
+
if os.path.exists('comictextdetector.pt.onnx'):
|
81 |
+
shutil.move('comictextdetector.pt.onnx', self._get_file_path('comictextdetector.pt.onnx'))
|
82 |
+
super().__init__(*args, **kwargs)
|
83 |
+
|
84 |
+
async def _load(self, device: str, input_size=1024, half=False, nms_thresh=0.35, conf_thresh=0.4):
|
85 |
+
self.device = device
|
86 |
+
if self.device == 'cuda' or self.device == 'mps':
|
87 |
+
self.model = TextDetBase(self._get_file_path('comictextdetector.pt'), device=self.device, act='leaky')
|
88 |
+
self.model.to(self.device)
|
89 |
+
self.backend = 'torch'
|
90 |
+
else:
|
91 |
+
model_path = self._get_file_path('comictextdetector.pt.onnx')
|
92 |
+
self.model = cv2.dnn.readNetFromONNX(model_path)
|
93 |
+
self.model = TextDetBaseDNN(input_size, model_path)
|
94 |
+
self.backend = 'opencv'
|
95 |
+
|
96 |
+
if isinstance(input_size, int):
|
97 |
+
input_size = (input_size, input_size)
|
98 |
+
self.input_size = input_size
|
99 |
+
self.half = half
|
100 |
+
self.conf_thresh = conf_thresh
|
101 |
+
self.nms_thresh = nms_thresh
|
102 |
+
self.seg_rep = SegDetectorRepresenter(thresh=0.3)
|
103 |
+
|
104 |
+
async def _unload(self):
|
105 |
+
del self.model
|
106 |
+
|
107 |
+
def det_batch_forward_ctd(self, batch: np.ndarray, device: str) -> Tuple[np.ndarray, np.ndarray]:
|
108 |
+
if isinstance(self.model, TextDetBase):
|
109 |
+
batch = einops.rearrange(batch.astype(np.float32) / 255., 'n h w c -> n c h w')
|
110 |
+
batch = torch.from_numpy(batch).to(device)
|
111 |
+
_, mask, lines = self.model(batch)
|
112 |
+
mask = mask.detach().cpu().numpy()
|
113 |
+
lines = lines.detach().cpu().numpy()
|
114 |
+
elif isinstance(self.model, TextDetBaseDNN):
|
115 |
+
mask_lst, line_lst = [], []
|
116 |
+
for b in batch:
|
117 |
+
_, mask, lines = self.model(b)
|
118 |
+
if mask.shape[1] == 2: # some version of opencv spit out reversed result
|
119 |
+
tmp = mask
|
120 |
+
mask = lines
|
121 |
+
lines = tmp
|
122 |
+
mask_lst.append(mask)
|
123 |
+
line_lst.append(lines)
|
124 |
+
lines, mask = np.concatenate(line_lst, 0), np.concatenate(mask_lst, 0)
|
125 |
+
else:
|
126 |
+
raise NotImplementedError
|
127 |
+
return lines, mask
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
|
131 |
+
unclip_ratio: float, verbose: bool = False):
|
132 |
+
|
133 |
+
# keep_undetected_mask = False
|
134 |
+
# refine_mode = REFINEMASK_INPAINT
|
135 |
+
|
136 |
+
im_h, im_w = image.shape[:2]
|
137 |
+
lines_map, mask = det_rearrange_forward(image, self.det_batch_forward_ctd, self.input_size[0], 4, self.device, verbose)
|
138 |
+
# blks = []
|
139 |
+
# resize_ratio = [1, 1]
|
140 |
+
if lines_map is None:
|
141 |
+
img_in, ratio, dw, dh = preprocess_img(image, input_size=self.input_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
|
142 |
+
blks, mask, lines_map = self.model(img_in)
|
143 |
+
|
144 |
+
if self.backend == 'opencv':
|
145 |
+
if mask.shape[1] == 2: # some version of opencv spit out reversed result
|
146 |
+
tmp = mask
|
147 |
+
mask = lines_map
|
148 |
+
lines_map = tmp
|
149 |
+
mask = mask.squeeze()
|
150 |
+
# resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh))
|
151 |
+
# blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio)
|
152 |
+
mask = mask[..., :mask.shape[0]-dh, :mask.shape[1]-dw]
|
153 |
+
lines_map = lines_map[..., :lines_map.shape[2]-dh, :lines_map.shape[3]-dw]
|
154 |
+
|
155 |
+
mask = postprocess_mask(mask)
|
156 |
+
lines, scores = self.seg_rep(None, lines_map, height=im_h, width=im_w)
|
157 |
+
box_thresh = 0.6
|
158 |
+
idx = np.where(scores[0] > box_thresh)
|
159 |
+
lines, scores = lines[0][idx], scores[0][idx]
|
160 |
+
|
161 |
+
# map output to input img
|
162 |
+
mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
|
163 |
+
|
164 |
+
# if lines.size == 0:
|
165 |
+
# lines = []
|
166 |
+
# else:
|
167 |
+
# lines = lines.astype(np.int32)
|
168 |
+
|
169 |
+
# YOLO was used for finding bboxes which to order the lines into. This is now solved
|
170 |
+
# through the textline merger, which seems to work more reliably.
|
171 |
+
# The YOLO language detection seems unnecessary as it could never be as good as
|
172 |
+
# using the OCR extracted string directly.
|
173 |
+
# Doing it for increasing the textline merge accuracy doesn't really work either,
|
174 |
+
# as the merge could be postponed until after the OCR finishes.
|
175 |
+
|
176 |
+
textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(lines, scores)]
|
177 |
+
mask_refined = refine_mask(image, mask, textlines, refine_mode=None)
|
178 |
+
|
179 |
+
return textlines, mask_refined, None
|
180 |
+
|
181 |
+
# blk_list = group_output(blks, lines, im_w, im_h, mask)
|
182 |
+
# mask_refined = refine_mask(image, mask, blk_list, refine_mode=refine_mode)
|
183 |
+
# if keep_undetected_mask:
|
184 |
+
# mask_refined = refine_undetected_mask(image, mask, mask_refined, blk_list, refine_mode=refine_mode)
|
185 |
+
|
186 |
+
# return blk_list, mask, mask_refined
|
manga_translator/detection/ctd_utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .basemodel import TextDetBase, TextDetBaseDNN
|
2 |
+
from .utils.yolov5_utils import non_max_suppression
|
3 |
+
from .utils.db_utils import SegDetectorRepresenter
|
4 |
+
from .utils.imgproc_utils import letterbox
|
5 |
+
from .textmask import refine_mask, refine_undetected_mask, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION
|
manga_translator/detection/ctd_utils/basemodel.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .utils.yolov5_utils import fuse_conv_and_bn
|
7 |
+
from .utils.weight_init import init_weights
|
8 |
+
from .yolov5.yolo import load_yolov5_ckpt
|
9 |
+
from .yolov5.common import C3, Conv
|
10 |
+
|
11 |
+
TEXTDET_MASK = 0
|
12 |
+
TEXTDET_DET = 1
|
13 |
+
TEXTDET_INFERENCE = 2
|
14 |
+
|
15 |
+
class double_conv_up_c3(nn.Module):
|
16 |
+
def __init__(self, in_ch, mid_ch, out_ch, act=True):
|
17 |
+
super(double_conv_up_c3, self).__init__()
|
18 |
+
self.conv = nn.Sequential(
|
19 |
+
C3(in_ch+mid_ch, mid_ch, act=act),
|
20 |
+
nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False),
|
21 |
+
nn.BatchNorm2d(out_ch),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.conv(x)
|
27 |
+
|
28 |
+
class double_conv_c3(nn.Module):
|
29 |
+
def __init__(self, in_ch, out_ch, stride=1, act=True):
|
30 |
+
super(double_conv_c3, self).__init__()
|
31 |
+
if stride > 1:
|
32 |
+
self.down = nn.AvgPool2d(2,stride=2) if stride > 1 else None
|
33 |
+
self.conv = C3(in_ch, out_ch, act=act)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
if self.down is not None:
|
37 |
+
x = self.down(x)
|
38 |
+
x = self.conv(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
class UnetHead(nn.Module):
|
42 |
+
def __init__(self, act=True) -> None:
|
43 |
+
|
44 |
+
super(UnetHead, self).__init__()
|
45 |
+
self.down_conv1 = double_conv_c3(512, 512, 2, act=act)
|
46 |
+
self.upconv0 = double_conv_up_c3(0, 512, 256, act=act)
|
47 |
+
self.upconv2 = double_conv_up_c3(256, 512, 256, act=act)
|
48 |
+
self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
|
49 |
+
self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
|
50 |
+
self.upconv5 = double_conv_up_c3(64, 128, 64, act=act)
|
51 |
+
self.upconv6 = nn.Sequential(
|
52 |
+
nn.ConvTranspose2d(64, 1, kernel_size=4, stride = 2, padding=1, bias=False),
|
53 |
+
nn.Sigmoid()
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, f160, f80, f40, f20, f3, forward_mode=TEXTDET_MASK):
|
57 |
+
# input: 640@3
|
58 |
+
d10 = self.down_conv1(f3) # 512@10
|
59 |
+
u20 = self.upconv0(d10) # 256@10
|
60 |
+
u40 = self.upconv2(torch.cat([f20, u20], dim = 1)) # 256@40
|
61 |
+
|
62 |
+
if forward_mode == TEXTDET_DET:
|
63 |
+
return f80, f40, u40
|
64 |
+
else:
|
65 |
+
u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
|
66 |
+
u160 = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
|
67 |
+
u320 = self.upconv5(torch.cat([f160, u160], dim = 1)) # 64@320
|
68 |
+
mask = self.upconv6(u320)
|
69 |
+
if forward_mode == TEXTDET_MASK:
|
70 |
+
return mask
|
71 |
+
else:
|
72 |
+
return mask, [f80, f40, u40]
|
73 |
+
|
74 |
+
def init_weight(self, init_func):
|
75 |
+
self.apply(init_func)
|
76 |
+
|
77 |
+
class DBHead(nn.Module):
|
78 |
+
def __init__(self, in_channels, k = 50, shrink_with_sigmoid=True, act=True):
|
79 |
+
super().__init__()
|
80 |
+
self.k = k
|
81 |
+
self.shrink_with_sigmoid = shrink_with_sigmoid
|
82 |
+
self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
|
83 |
+
self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
|
84 |
+
self.conv = nn.Sequential(
|
85 |
+
nn.Conv2d(128, in_channels, 1),
|
86 |
+
nn.BatchNorm2d(in_channels),
|
87 |
+
nn.ReLU(inplace=True)
|
88 |
+
)
|
89 |
+
self.binarize = nn.Sequential(
|
90 |
+
nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
|
91 |
+
nn.BatchNorm2d(in_channels // 4),
|
92 |
+
nn.ReLU(inplace=True),
|
93 |
+
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
|
94 |
+
nn.BatchNorm2d(in_channels // 4),
|
95 |
+
nn.ReLU(inplace=True),
|
96 |
+
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)
|
97 |
+
)
|
98 |
+
self.thresh = self._init_thresh(in_channels)
|
99 |
+
|
100 |
+
def forward(self, f80, f40, u40, shrink_with_sigmoid=True, step_eval=False):
|
101 |
+
shrink_with_sigmoid = self.shrink_with_sigmoid
|
102 |
+
u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
|
103 |
+
x = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
|
104 |
+
x = self.conv(x)
|
105 |
+
threshold_maps = self.thresh(x)
|
106 |
+
x = self.binarize(x)
|
107 |
+
shrink_maps = torch.sigmoid(x)
|
108 |
+
|
109 |
+
if self.training:
|
110 |
+
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
111 |
+
if shrink_with_sigmoid:
|
112 |
+
return torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
|
113 |
+
else:
|
114 |
+
return torch.cat((shrink_maps, threshold_maps, binary_maps, x), dim=1)
|
115 |
+
else:
|
116 |
+
if step_eval:
|
117 |
+
return self.step_function(shrink_maps, threshold_maps)
|
118 |
+
else:
|
119 |
+
return torch.cat((shrink_maps, threshold_maps), dim=1)
|
120 |
+
|
121 |
+
def init_weight(self, init_func):
|
122 |
+
self.apply(init_func)
|
123 |
+
|
124 |
+
def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
|
125 |
+
in_channels = inner_channels
|
126 |
+
if serial:
|
127 |
+
in_channels += 1
|
128 |
+
self.thresh = nn.Sequential(
|
129 |
+
nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
130 |
+
nn.BatchNorm2d(inner_channels // 4),
|
131 |
+
nn.ReLU(inplace=True),
|
132 |
+
self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
|
133 |
+
nn.BatchNorm2d(inner_channels // 4),
|
134 |
+
nn.ReLU(inplace=True),
|
135 |
+
self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
|
136 |
+
nn.Sigmoid())
|
137 |
+
return self.thresh
|
138 |
+
|
139 |
+
def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
|
140 |
+
if smooth:
|
141 |
+
inter_out_channels = out_channels
|
142 |
+
if out_channels == 1:
|
143 |
+
inter_out_channels = in_channels
|
144 |
+
module_list = [
|
145 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
146 |
+
nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
|
147 |
+
if out_channels == 1:
|
148 |
+
module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
|
149 |
+
return nn.Sequential(module_list)
|
150 |
+
else:
|
151 |
+
return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
|
152 |
+
|
153 |
+
def step_function(self, x, y):
|
154 |
+
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
|
155 |
+
|
156 |
+
class TextDetector(nn.Module):
|
157 |
+
def __init__(self, weights, map_location='cpu', forward_mode=TEXTDET_MASK, act=True):
|
158 |
+
super(TextDetector, self).__init__()
|
159 |
+
|
160 |
+
yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location=map_location)
|
161 |
+
yolov5s_backbone.eval()
|
162 |
+
out_indices = [1, 3, 5, 7, 9]
|
163 |
+
yolov5s_backbone.out_indices = out_indices
|
164 |
+
yolov5s_backbone.model = yolov5s_backbone.model[:max(out_indices)+1]
|
165 |
+
self.act = act
|
166 |
+
self.seg_net = UnetHead(act=act)
|
167 |
+
self.backbone = yolov5s_backbone
|
168 |
+
self.dbnet = None
|
169 |
+
self.forward_mode = forward_mode
|
170 |
+
|
171 |
+
def train_mask(self):
|
172 |
+
self.forward_mode = TEXTDET_MASK
|
173 |
+
self.backbone.eval()
|
174 |
+
self.seg_net.train()
|
175 |
+
|
176 |
+
def initialize_db(self, unet_weights):
|
177 |
+
self.dbnet = DBHead(64, act=self.act)
|
178 |
+
self.seg_net.load_state_dict(torch.load(unet_weights, map_location='cpu')['weights'])
|
179 |
+
self.dbnet.init_weight(init_weights)
|
180 |
+
self.dbnet.upconv3 = copy.deepcopy(self.seg_net.upconv3)
|
181 |
+
self.dbnet.upconv4 = copy.deepcopy(self.seg_net.upconv4)
|
182 |
+
del self.seg_net.upconv3
|
183 |
+
del self.seg_net.upconv4
|
184 |
+
del self.seg_net.upconv5
|
185 |
+
del self.seg_net.upconv6
|
186 |
+
# del self.seg_net.conv_mask
|
187 |
+
|
188 |
+
def train_db(self):
|
189 |
+
self.forward_mode = TEXTDET_DET
|
190 |
+
self.backbone.eval()
|
191 |
+
self.seg_net.eval()
|
192 |
+
self.dbnet.train()
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
forward_mode = self.forward_mode
|
196 |
+
with torch.no_grad():
|
197 |
+
outs = self.backbone(x)
|
198 |
+
if forward_mode == TEXTDET_MASK:
|
199 |
+
return self.seg_net(*outs, forward_mode=forward_mode)
|
200 |
+
elif forward_mode == TEXTDET_DET:
|
201 |
+
with torch.no_grad():
|
202 |
+
outs = self.seg_net(*outs, forward_mode=forward_mode)
|
203 |
+
return self.dbnet(*outs)
|
204 |
+
|
205 |
+
def get_base_det_models(model_path, device='cpu', half=False, act='leaky'):
|
206 |
+
textdetector_dict = torch.load(model_path, map_location=device)
|
207 |
+
blk_det = load_yolov5_ckpt(textdetector_dict['blk_det'], map_location=device)
|
208 |
+
text_seg = UnetHead(act=act)
|
209 |
+
text_seg.load_state_dict(textdetector_dict['text_seg'])
|
210 |
+
text_det = DBHead(64, act=act)
|
211 |
+
text_det.load_state_dict(textdetector_dict['text_det'])
|
212 |
+
if half:
|
213 |
+
return blk_det.eval().half(), text_seg.eval().half(), text_det.eval().half()
|
214 |
+
return blk_det.eval().to(device), text_seg.eval().to(device), text_det.eval().to(device)
|
215 |
+
|
216 |
+
class TextDetBase(nn.Module):
|
217 |
+
def __init__(self, model_path, device='cpu', half=False, fuse=False, act='leaky'):
|
218 |
+
super(TextDetBase, self).__init__()
|
219 |
+
self.blk_det, self.text_seg, self.text_det = get_base_det_models(model_path, device, half, act=act)
|
220 |
+
if fuse:
|
221 |
+
self.fuse()
|
222 |
+
|
223 |
+
def fuse(self):
|
224 |
+
def _fuse(model):
|
225 |
+
for m in model.modules():
|
226 |
+
if isinstance(m, (Conv)) and hasattr(m, 'bn'):
|
227 |
+
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
228 |
+
delattr(m, 'bn') # remove batchnorm
|
229 |
+
m.forward = m.forward_fuse # update forward
|
230 |
+
return model
|
231 |
+
self.text_seg = _fuse(self.text_seg)
|
232 |
+
self.text_det = _fuse(self.text_det)
|
233 |
+
|
234 |
+
def forward(self, features):
|
235 |
+
blks, features = self.blk_det(features, detect=True)
|
236 |
+
mask, features = self.text_seg(*features, forward_mode=TEXTDET_INFERENCE)
|
237 |
+
lines = self.text_det(*features, step_eval=False)
|
238 |
+
return blks[0], mask, lines
|
239 |
+
|
240 |
+
class TextDetBaseDNN:
|
241 |
+
def __init__(self, input_size, model_path):
|
242 |
+
self.input_size = input_size
|
243 |
+
self.model = cv2.dnn.readNetFromONNX(model_path)
|
244 |
+
self.uoln = self.model.getUnconnectedOutLayersNames()
|
245 |
+
|
246 |
+
def __call__(self, im_in):
|
247 |
+
blob = cv2.dnn.blobFromImage(im_in, scalefactor=1 / 255.0, size=(self.input_size, self.input_size))
|
248 |
+
self.model.setInput(blob)
|
249 |
+
blks, mask, lines_map = self.model.forward(self.uoln)
|
250 |
+
return blks, mask, lines_map
|
manga_translator/detection/ctd_utils/textmask.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from .utils.imgproc_utils import union_area, enlarge_window
|
6 |
+
from ...utils import TextBlock, Quadrilateral
|
7 |
+
|
8 |
+
WHITE = (255, 255, 255)
|
9 |
+
BLACK = (0, 0, 0)
|
10 |
+
LANG_ENG = 0
|
11 |
+
LANG_JPN = 1
|
12 |
+
|
13 |
+
REFINEMASK_INPAINT = 0
|
14 |
+
REFINEMASK_ANNOTATION = 1
|
15 |
+
|
16 |
+
def get_topk_color(color_list, bins, k=3, color_var=10, bin_tol=0.001):
|
17 |
+
idx = np.argsort(bins * -1)
|
18 |
+
color_list, bins = color_list[idx], bins[idx]
|
19 |
+
top_colors = [color_list[0]]
|
20 |
+
bin_tol = np.sum(bins) * bin_tol
|
21 |
+
if len(color_list) > 1:
|
22 |
+
for color, bin in zip(color_list[1:], bins[1:]):
|
23 |
+
if np.abs(np.array(top_colors) - color).min() > color_var:
|
24 |
+
top_colors.append(color)
|
25 |
+
if len(top_colors) >= k or bin < bin_tol:
|
26 |
+
break
|
27 |
+
return top_colors
|
28 |
+
|
29 |
+
def minxor_thresh(threshed, mask, dilate=False):
|
30 |
+
neg_threshed = 255 - threshed
|
31 |
+
e_size = 1
|
32 |
+
if dilate:
|
33 |
+
element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
|
34 |
+
neg_threshed = cv2.dilate(neg_threshed, element, iterations=1)
|
35 |
+
threshed = cv2.dilate(threshed, element, iterations=1)
|
36 |
+
neg_xor_sum = cv2.bitwise_xor(neg_threshed, mask).sum()
|
37 |
+
xor_sum = cv2.bitwise_xor(threshed, mask).sum()
|
38 |
+
if neg_xor_sum < xor_sum:
|
39 |
+
return neg_threshed, neg_xor_sum
|
40 |
+
else:
|
41 |
+
return threshed, xor_sum
|
42 |
+
|
43 |
+
def get_otsuthresh_masklist(img, pred_mask, per_channel=False) -> List[np.ndarray]:
|
44 |
+
channels = [img[..., 0], img[..., 1], img[..., 2]]
|
45 |
+
mask_list = []
|
46 |
+
for c in channels:
|
47 |
+
_, threshed = cv2.threshold(c, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
|
48 |
+
threshed, xor_sum = minxor_thresh(threshed, pred_mask, dilate=False)
|
49 |
+
mask_list.append([threshed, xor_sum])
|
50 |
+
mask_list.sort(key=lambda x: x[1])
|
51 |
+
if per_channel:
|
52 |
+
return mask_list
|
53 |
+
else:
|
54 |
+
return [mask_list[0]]
|
55 |
+
|
56 |
+
def get_topk_masklist(im_grey, pred_mask):
|
57 |
+
if len(im_grey.shape) == 3 and im_grey.shape[-1] == 3:
|
58 |
+
im_grey = cv2.cvtColor(im_grey, cv2.COLOR_BGR2GRAY)
|
59 |
+
msk = np.ascontiguousarray(pred_mask)
|
60 |
+
candidate_grey_px = im_grey[np.where(cv2.erode(msk, np.ones((3,3), np.uint8), iterations=1) > 127)]
|
61 |
+
bin, his = np.histogram(candidate_grey_px, bins=255)
|
62 |
+
topk_color = get_topk_color(his, bin, color_var=10, k=3)
|
63 |
+
color_range = 30
|
64 |
+
mask_list = list()
|
65 |
+
for ii, color in enumerate(topk_color):
|
66 |
+
c_top = min(color+color_range, 255)
|
67 |
+
c_bottom = c_top - 2 * color_range
|
68 |
+
threshed = cv2.inRange(im_grey, c_bottom, c_top)
|
69 |
+
threshed, xor_sum = minxor_thresh(threshed, msk)
|
70 |
+
mask_list.append([threshed, xor_sum])
|
71 |
+
return mask_list
|
72 |
+
|
73 |
+
def merge_mask_list(mask_list, pred_mask, blk: Quadrilateral = None, pred_thresh=30, text_window=None, filter_with_lines=False, refine_mode=REFINEMASK_INPAINT):
|
74 |
+
mask_list.sort(key=lambda x: x[1])
|
75 |
+
linemask = None
|
76 |
+
if blk is not None and filter_with_lines:
|
77 |
+
linemask = np.zeros_like(pred_mask)
|
78 |
+
lines = blk.pts.astype(np.int64)
|
79 |
+
for line in lines:
|
80 |
+
line[..., 0] -= text_window[0]
|
81 |
+
line[..., 1] -= text_window[1]
|
82 |
+
cv2.fillPoly(linemask, [line], 255)
|
83 |
+
linemask = cv2.dilate(linemask, np.ones((3, 3), np.uint8), iterations=3)
|
84 |
+
|
85 |
+
if pred_thresh > 0:
|
86 |
+
e_size = 1
|
87 |
+
element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
|
88 |
+
pred_mask = cv2.erode(pred_mask, element, iterations=1)
|
89 |
+
_, pred_mask = cv2.threshold(pred_mask, 60, 255, cv2.THRESH_BINARY)
|
90 |
+
connectivity = 8
|
91 |
+
mask_merged = np.zeros_like(pred_mask)
|
92 |
+
for ii, (candidate_mask, xor_sum) in enumerate(mask_list):
|
93 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(candidate_mask, connectivity, cv2.CV_16U)
|
94 |
+
for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
|
95 |
+
if label_index != 0: # skip background label
|
96 |
+
x, y, w, h, area = stat
|
97 |
+
if w * h < 3:
|
98 |
+
continue
|
99 |
+
x1, y1, x2, y2 = x, y, x+w, y+h
|
100 |
+
label_local = labels[y1: y2, x1: x2]
|
101 |
+
label_coordinates = np.where(label_local==label_index)
|
102 |
+
tmp_merged = np.zeros_like(label_local, np.uint8)
|
103 |
+
tmp_merged[label_coordinates] = 255
|
104 |
+
tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged)
|
105 |
+
xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum()
|
106 |
+
xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum()
|
107 |
+
if xor_merged < xor_origin:
|
108 |
+
mask_merged[y1: y2, x1: x2] = tmp_merged
|
109 |
+
|
110 |
+
if refine_mode == REFINEMASK_INPAINT:
|
111 |
+
mask_merged = cv2.dilate(mask_merged, np.ones((5, 5), np.uint8), iterations=1)
|
112 |
+
# fill holes
|
113 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(255-mask_merged, connectivity, cv2.CV_16U)
|
114 |
+
sorted_area = np.sort(stats[:, -1])
|
115 |
+
if len(sorted_area) > 1:
|
116 |
+
area_thresh = sorted_area[-2]
|
117 |
+
else:
|
118 |
+
area_thresh = sorted_area[-1]
|
119 |
+
for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
|
120 |
+
x, y, w, h, area = stat
|
121 |
+
if area < area_thresh:
|
122 |
+
x1, y1, x2, y2 = x, y, x+w, y+h
|
123 |
+
label_local = labels[y1: y2, x1: x2]
|
124 |
+
label_coordinates = np.where(label_local==label_index)
|
125 |
+
tmp_merged = np.zeros_like(label_local, np.uint8)
|
126 |
+
tmp_merged[label_coordinates] = 255
|
127 |
+
tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged)
|
128 |
+
xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum()
|
129 |
+
xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum()
|
130 |
+
if xor_merged < xor_origin:
|
131 |
+
mask_merged[y1: y2, x1: x2] = tmp_merged
|
132 |
+
return mask_merged
|
133 |
+
|
134 |
+
|
135 |
+
def refine_undetected_mask(img: np.ndarray, mask_pred: np.ndarray, mask_refined: np.ndarray, blk_list: List[TextBlock], refine_mode=REFINEMASK_INPAINT):
|
136 |
+
mask_pred[np.where(mask_refined > 30)] = 0
|
137 |
+
_, pred_mask_t = cv2.threshold(mask_pred, 30, 255, cv2.THRESH_BINARY)
|
138 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(pred_mask_t, 4, cv2.CV_16U)
|
139 |
+
valid_labels = np.where(stats[:, -1] > 50)[0]
|
140 |
+
seg_blk_list = []
|
141 |
+
if len(valid_labels) > 0:
|
142 |
+
for lab_index in valid_labels[1:]:
|
143 |
+
x, y, w, h, area = stats[lab_index]
|
144 |
+
bx1, by1 = x, y
|
145 |
+
bx2, by2 = x+w, y+h
|
146 |
+
bbox = [bx1, by1, bx2, by2]
|
147 |
+
bbox_score = -1
|
148 |
+
for blk in blk_list:
|
149 |
+
bbox_s = union_area(blk.xyxy, bbox)
|
150 |
+
if bbox_s > bbox_score:
|
151 |
+
bbox_score = bbox_s
|
152 |
+
if bbox_score / w / h < 0.5:
|
153 |
+
seg_blk_list.append(TextBlock(bbox))
|
154 |
+
if len(seg_blk_list) > 0:
|
155 |
+
mask_refined = cv2.bitwise_or(mask_refined, refine_mask(img, mask_pred, seg_blk_list, refine_mode=refine_mode))
|
156 |
+
return mask_refined
|
157 |
+
|
158 |
+
def refine_mask(img: np.ndarray, pred_mask: np.ndarray, blk_list: List[Quadrilateral], refine_mode: int = REFINEMASK_INPAINT) -> np.ndarray:
|
159 |
+
mask_refined = np.zeros_like(pred_mask)
|
160 |
+
for blk in blk_list:
|
161 |
+
bx1, by1, bx2, by2 = enlarge_window(blk.xyxy, img.shape[1], img.shape[0])
|
162 |
+
im = np.ascontiguousarray(img[by1: by2, bx1: bx2])
|
163 |
+
msk = np.ascontiguousarray(pred_mask[by1: by2, bx1: bx2])
|
164 |
+
|
165 |
+
mask_list = get_topk_masklist(im, msk)
|
166 |
+
mask_list += get_otsuthresh_masklist(im, msk, per_channel=False)
|
167 |
+
mask_merged = merge_mask_list(mask_list, msk, blk=blk, text_window=[bx1, by1, bx2, by2], refine_mode=refine_mode)
|
168 |
+
mask_refined[by1: by2, bx1: bx2] = cv2.bitwise_or(mask_refined[by1: by2, bx1: bx2], mask_merged)
|
169 |
+
# cv2.imshow('im', im)
|
170 |
+
# cv2.imshow('msk', msk)
|
171 |
+
# cv2.imshow('mask_refined', mask_refined[by1: by2, bx1: bx2])
|
172 |
+
# cv2.waitKey(0)
|
173 |
+
|
174 |
+
return mask_refined
|
manga_translator/detection/ctd_utils/utils/db_utils.py
ADDED
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import pyclipper
|
4 |
+
from shapely.geometry import Polygon
|
5 |
+
from collections import namedtuple
|
6 |
+
import warnings
|
7 |
+
import torch
|
8 |
+
warnings.filterwarnings('ignore')
|
9 |
+
|
10 |
+
|
11 |
+
def iou_rotate(box_a, box_b, method='union'):
|
12 |
+
rect_a = cv2.minAreaRect(box_a)
|
13 |
+
rect_b = cv2.minAreaRect(box_b)
|
14 |
+
r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b)
|
15 |
+
if r1[0] == 0:
|
16 |
+
return 0
|
17 |
+
else:
|
18 |
+
inter_area = cv2.contourArea(r1[1])
|
19 |
+
area_a = cv2.contourArea(box_a)
|
20 |
+
area_b = cv2.contourArea(box_b)
|
21 |
+
union_area = area_a + area_b - inter_area
|
22 |
+
if union_area == 0 or inter_area == 0:
|
23 |
+
return 0
|
24 |
+
if method == 'union':
|
25 |
+
iou = inter_area / union_area
|
26 |
+
elif method == 'intersection':
|
27 |
+
iou = inter_area / min(area_a, area_b)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
30 |
+
return iou
|
31 |
+
|
32 |
+
class SegDetectorRepresenter():
|
33 |
+
def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5):
|
34 |
+
self.min_size = 3
|
35 |
+
self.thresh = thresh
|
36 |
+
self.box_thresh = box_thresh
|
37 |
+
self.max_candidates = max_candidates
|
38 |
+
self.unclip_ratio = unclip_ratio
|
39 |
+
|
40 |
+
def __call__(self, batch, pred, is_output_polygon=False, height=None, width=None):
|
41 |
+
'''
|
42 |
+
batch: (image, polygons, ignore_tags
|
43 |
+
batch: a dict produced by dataloaders.
|
44 |
+
image: tensor of shape (N, C, H, W).
|
45 |
+
polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
|
46 |
+
ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
|
47 |
+
shape: the original shape of images.
|
48 |
+
filename: the original filenames of images.
|
49 |
+
pred:
|
50 |
+
binary: text region segmentation map, with shape (N, H, W)
|
51 |
+
thresh: [if exists] thresh hold prediction with shape (N, H, W)
|
52 |
+
thresh_binary: [if exists] binarized with threshold, (N, H, W)
|
53 |
+
'''
|
54 |
+
pred = pred[:, 0, :, :]
|
55 |
+
segmentation = self.binarize(pred)
|
56 |
+
boxes_batch = []
|
57 |
+
scores_batch = []
|
58 |
+
# print(pred.size())
|
59 |
+
batch_size = pred.size(0) if isinstance(pred, torch.Tensor) else pred.shape[0]
|
60 |
+
|
61 |
+
if height is None:
|
62 |
+
height = pred.shape[1]
|
63 |
+
if width is None:
|
64 |
+
width = pred.shape[2]
|
65 |
+
|
66 |
+
for batch_index in range(batch_size):
|
67 |
+
if is_output_polygon:
|
68 |
+
boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
|
69 |
+
else:
|
70 |
+
boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
|
71 |
+
boxes_batch.append(boxes)
|
72 |
+
scores_batch.append(scores)
|
73 |
+
return boxes_batch, scores_batch
|
74 |
+
|
75 |
+
def binarize(self, pred) -> np.ndarray:
|
76 |
+
return pred > self.thresh
|
77 |
+
|
78 |
+
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
79 |
+
'''
|
80 |
+
_bitmap: single map with shape (H, W),
|
81 |
+
whose values are binarized as {0, 1}
|
82 |
+
'''
|
83 |
+
|
84 |
+
assert len(_bitmap.shape) == 2
|
85 |
+
bitmap = _bitmap.cpu().numpy() # The first channel
|
86 |
+
pred = pred.cpu().detach().numpy()
|
87 |
+
height, width = bitmap.shape
|
88 |
+
boxes = []
|
89 |
+
scores = []
|
90 |
+
|
91 |
+
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
92 |
+
|
93 |
+
for contour in contours[:self.max_candidates]:
|
94 |
+
epsilon = 0.005 * cv2.arcLength(contour, True)
|
95 |
+
approx = cv2.approxPolyDP(contour, epsilon, True)
|
96 |
+
points = approx.reshape((-1, 2))
|
97 |
+
if points.shape[0] < 4:
|
98 |
+
continue
|
99 |
+
# _, sside = self.get_mini_boxes(contour)
|
100 |
+
# if sside < self.min_size:
|
101 |
+
# continue
|
102 |
+
score = self.box_score_fast(pred, contour.squeeze(1))
|
103 |
+
if self.box_thresh > score:
|
104 |
+
continue
|
105 |
+
|
106 |
+
if points.shape[0] > 2:
|
107 |
+
box = self.unclip(points, unclip_ratio=self.unclip_ratio)
|
108 |
+
if len(box) > 1:
|
109 |
+
continue
|
110 |
+
else:
|
111 |
+
continue
|
112 |
+
box = box.reshape(-1, 2)
|
113 |
+
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
114 |
+
if sside < self.min_size + 2:
|
115 |
+
continue
|
116 |
+
|
117 |
+
if not isinstance(dest_width, int):
|
118 |
+
dest_width = dest_width.item()
|
119 |
+
dest_height = dest_height.item()
|
120 |
+
|
121 |
+
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
122 |
+
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
123 |
+
boxes.append(box)
|
124 |
+
scores.append(score)
|
125 |
+
return boxes, scores
|
126 |
+
|
127 |
+
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
128 |
+
'''
|
129 |
+
_bitmap: single map with shape (H, W),
|
130 |
+
whose values are binarized as {0, 1}
|
131 |
+
'''
|
132 |
+
|
133 |
+
assert len(_bitmap.shape) == 2
|
134 |
+
if isinstance(pred, torch.Tensor):
|
135 |
+
bitmap = _bitmap.cpu().numpy() # The first channel
|
136 |
+
pred = pred.cpu().detach().numpy()
|
137 |
+
else:
|
138 |
+
bitmap = _bitmap
|
139 |
+
# cv2.imwrite('tmp.png', (bitmap*255).astype(np.uint8))
|
140 |
+
height, width = bitmap.shape
|
141 |
+
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
142 |
+
num_contours = min(len(contours), self.max_candidates)
|
143 |
+
boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
|
144 |
+
scores = np.zeros((num_contours,), dtype=np.float32)
|
145 |
+
|
146 |
+
for index in range(num_contours):
|
147 |
+
contour = contours[index].squeeze(1)
|
148 |
+
points, sside = self.get_mini_boxes(contour)
|
149 |
+
# if sside < self.min_size:
|
150 |
+
# continue
|
151 |
+
if sside < 2:
|
152 |
+
continue
|
153 |
+
points = np.array(points)
|
154 |
+
score = self.box_score_fast(pred, contour)
|
155 |
+
# if self.box_thresh > score:
|
156 |
+
# continue
|
157 |
+
|
158 |
+
box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
|
159 |
+
box, sside = self.get_mini_boxes(box)
|
160 |
+
# if sside < 5:
|
161 |
+
# continue
|
162 |
+
box = np.array(box)
|
163 |
+
if not isinstance(dest_width, int):
|
164 |
+
dest_width = dest_width.item()
|
165 |
+
dest_height = dest_height.item()
|
166 |
+
|
167 |
+
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
168 |
+
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
169 |
+
boxes[index, :, :] = box.astype(np.int16)
|
170 |
+
scores[index] = score
|
171 |
+
return boxes, scores
|
172 |
+
|
173 |
+
def unclip(self, box, unclip_ratio=1.5):
|
174 |
+
poly = Polygon(box)
|
175 |
+
distance = poly.area * unclip_ratio / poly.length
|
176 |
+
offset = pyclipper.PyclipperOffset()
|
177 |
+
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
178 |
+
expanded = np.array(offset.Execute(distance))
|
179 |
+
return expanded
|
180 |
+
|
181 |
+
def get_mini_boxes(self, contour):
|
182 |
+
bounding_box = cv2.minAreaRect(contour)
|
183 |
+
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
184 |
+
|
185 |
+
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
186 |
+
if points[1][1] > points[0][1]:
|
187 |
+
index_1 = 0
|
188 |
+
index_4 = 1
|
189 |
+
else:
|
190 |
+
index_1 = 1
|
191 |
+
index_4 = 0
|
192 |
+
if points[3][1] > points[2][1]:
|
193 |
+
index_2 = 2
|
194 |
+
index_3 = 3
|
195 |
+
else:
|
196 |
+
index_2 = 3
|
197 |
+
index_3 = 2
|
198 |
+
|
199 |
+
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
|
200 |
+
return box, min(bounding_box[1])
|
201 |
+
|
202 |
+
def box_score_fast(self, bitmap, _box):
|
203 |
+
h, w = bitmap.shape[:2]
|
204 |
+
box = _box.copy()
|
205 |
+
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
|
206 |
+
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
|
207 |
+
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
|
208 |
+
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
|
209 |
+
|
210 |
+
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
211 |
+
box[:, 0] = box[:, 0] - xmin
|
212 |
+
box[:, 1] = box[:, 1] - ymin
|
213 |
+
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
214 |
+
if bitmap.dtype == np.float16:
|
215 |
+
bitmap = bitmap.astype(np.float32)
|
216 |
+
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
217 |
+
|
218 |
+
class AverageMeter(object):
|
219 |
+
"""Computes and stores the average and current value"""
|
220 |
+
|
221 |
+
def __init__(self):
|
222 |
+
self.reset()
|
223 |
+
|
224 |
+
def reset(self):
|
225 |
+
self.val = 0
|
226 |
+
self.avg = 0
|
227 |
+
self.sum = 0
|
228 |
+
self.count = 0
|
229 |
+
|
230 |
+
def update(self, val, n=1):
|
231 |
+
self.val = val
|
232 |
+
self.sum += val * n
|
233 |
+
self.count += n
|
234 |
+
self.avg = self.sum / self.count
|
235 |
+
return self
|
236 |
+
|
237 |
+
|
238 |
+
class DetectionIoUEvaluator(object):
|
239 |
+
def __init__(self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5):
|
240 |
+
self.is_output_polygon = is_output_polygon
|
241 |
+
self.iou_constraint = iou_constraint
|
242 |
+
self.area_precision_constraint = area_precision_constraint
|
243 |
+
|
244 |
+
def evaluate_image(self, gt, pred):
|
245 |
+
|
246 |
+
def get_union(pD, pG):
|
247 |
+
return Polygon(pD).union(Polygon(pG)).area
|
248 |
+
|
249 |
+
def get_intersection_over_union(pD, pG):
|
250 |
+
return get_intersection(pD, pG) / get_union(pD, pG)
|
251 |
+
|
252 |
+
def get_intersection(pD, pG):
|
253 |
+
return Polygon(pD).intersection(Polygon(pG)).area
|
254 |
+
|
255 |
+
def compute_ap(confList, matchList, numGtCare):
|
256 |
+
correct = 0
|
257 |
+
AP = 0
|
258 |
+
if len(confList) > 0:
|
259 |
+
confList = np.array(confList)
|
260 |
+
matchList = np.array(matchList)
|
261 |
+
sorted_ind = np.argsort(-confList)
|
262 |
+
confList = confList[sorted_ind]
|
263 |
+
matchList = matchList[sorted_ind]
|
264 |
+
for n in range(len(confList)):
|
265 |
+
match = matchList[n]
|
266 |
+
if match:
|
267 |
+
correct += 1
|
268 |
+
AP += float(correct) / (n + 1)
|
269 |
+
|
270 |
+
if numGtCare > 0:
|
271 |
+
AP /= numGtCare
|
272 |
+
|
273 |
+
return AP
|
274 |
+
|
275 |
+
perSampleMetrics = {}
|
276 |
+
|
277 |
+
matchedSum = 0
|
278 |
+
|
279 |
+
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
280 |
+
|
281 |
+
numGlobalCareGt = 0
|
282 |
+
numGlobalCareDet = 0
|
283 |
+
|
284 |
+
arrGlobalConfidences = []
|
285 |
+
arrGlobalMatches = []
|
286 |
+
|
287 |
+
recall = 0
|
288 |
+
precision = 0
|
289 |
+
hmean = 0
|
290 |
+
|
291 |
+
detMatched = 0
|
292 |
+
|
293 |
+
iouMat = np.empty([1, 1])
|
294 |
+
|
295 |
+
gtPols = []
|
296 |
+
detPols = []
|
297 |
+
|
298 |
+
gtPolPoints = []
|
299 |
+
detPolPoints = []
|
300 |
+
|
301 |
+
# Array of Ground Truth Polygons' keys marked as don't Care
|
302 |
+
gtDontCarePolsNum = []
|
303 |
+
# Array of Detected Polygons' matched with a don't Care GT
|
304 |
+
detDontCarePolsNum = []
|
305 |
+
|
306 |
+
pairs = []
|
307 |
+
detMatchedNums = []
|
308 |
+
|
309 |
+
arrSampleConfidences = []
|
310 |
+
arrSampleMatch = []
|
311 |
+
|
312 |
+
evaluationLog = ""
|
313 |
+
|
314 |
+
for n in range(len(gt)):
|
315 |
+
points = gt[n]['points']
|
316 |
+
# transcription = gt[n]['text']
|
317 |
+
dontCare = gt[n]['ignore']
|
318 |
+
|
319 |
+
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
320 |
+
continue
|
321 |
+
|
322 |
+
gtPol = points
|
323 |
+
gtPols.append(gtPol)
|
324 |
+
gtPolPoints.append(points)
|
325 |
+
if dontCare:
|
326 |
+
gtDontCarePolsNum.append(len(gtPols) - 1)
|
327 |
+
|
328 |
+
evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(
|
329 |
+
gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
|
330 |
+
|
331 |
+
for n in range(len(pred)):
|
332 |
+
points = pred[n]['points']
|
333 |
+
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
334 |
+
continue
|
335 |
+
|
336 |
+
detPol = points
|
337 |
+
detPols.append(detPol)
|
338 |
+
detPolPoints.append(points)
|
339 |
+
if len(gtDontCarePolsNum) > 0:
|
340 |
+
for dontCarePol in gtDontCarePolsNum:
|
341 |
+
dontCarePol = gtPols[dontCarePol]
|
342 |
+
intersected_area = get_intersection(dontCarePol, detPol)
|
343 |
+
pdDimensions = Polygon(detPol).area
|
344 |
+
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
345 |
+
if (precision > self.area_precision_constraint):
|
346 |
+
detDontCarePolsNum.append(len(detPols) - 1)
|
347 |
+
break
|
348 |
+
|
349 |
+
evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(
|
350 |
+
detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
|
351 |
+
|
352 |
+
if len(gtPols) > 0 and len(detPols) > 0:
|
353 |
+
# Calculate IoU and precision matrixs
|
354 |
+
outputShape = [len(gtPols), len(detPols)]
|
355 |
+
iouMat = np.empty(outputShape)
|
356 |
+
gtRectMat = np.zeros(len(gtPols), np.int8)
|
357 |
+
detRectMat = np.zeros(len(detPols), np.int8)
|
358 |
+
if self.is_output_polygon:
|
359 |
+
for gtNum in range(len(gtPols)):
|
360 |
+
for detNum in range(len(detPols)):
|
361 |
+
pG = gtPols[gtNum]
|
362 |
+
pD = detPols[detNum]
|
363 |
+
iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
|
364 |
+
else:
|
365 |
+
# gtPols = np.float32(gtPols)
|
366 |
+
# detPols = np.float32(detPols)
|
367 |
+
for gtNum in range(len(gtPols)):
|
368 |
+
for detNum in range(len(detPols)):
|
369 |
+
pG = np.float32(gtPols[gtNum])
|
370 |
+
pD = np.float32(detPols[detNum])
|
371 |
+
iouMat[gtNum, detNum] = iou_rotate(pD, pG)
|
372 |
+
for gtNum in range(len(gtPols)):
|
373 |
+
for detNum in range(len(detPols)):
|
374 |
+
if gtRectMat[gtNum] == 0 and detRectMat[
|
375 |
+
detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
|
376 |
+
if iouMat[gtNum, detNum] > self.iou_constraint:
|
377 |
+
gtRectMat[gtNum] = 1
|
378 |
+
detRectMat[detNum] = 1
|
379 |
+
detMatched += 1
|
380 |
+
pairs.append({'gt': gtNum, 'det': detNum})
|
381 |
+
detMatchedNums.append(detNum)
|
382 |
+
evaluationLog += "Match GT #" + \
|
383 |
+
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
384 |
+
|
385 |
+
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
386 |
+
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
387 |
+
if numGtCare == 0:
|
388 |
+
recall = float(1)
|
389 |
+
precision = float(0) if numDetCare > 0 else float(1)
|
390 |
+
else:
|
391 |
+
recall = float(detMatched) / numGtCare
|
392 |
+
precision = 0 if numDetCare == 0 else float(
|
393 |
+
detMatched) / numDetCare
|
394 |
+
|
395 |
+
hmean = 0 if (precision + recall) == 0 else 2.0 * \
|
396 |
+
precision * recall / (precision + recall)
|
397 |
+
|
398 |
+
matchedSum += detMatched
|
399 |
+
numGlobalCareGt += numGtCare
|
400 |
+
numGlobalCareDet += numDetCare
|
401 |
+
|
402 |
+
perSampleMetrics = {
|
403 |
+
'precision': precision,
|
404 |
+
'recall': recall,
|
405 |
+
'hmean': hmean,
|
406 |
+
'pairs': pairs,
|
407 |
+
'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
|
408 |
+
'gtPolPoints': gtPolPoints,
|
409 |
+
'detPolPoints': detPolPoints,
|
410 |
+
'gtCare': numGtCare,
|
411 |
+
'detCare': numDetCare,
|
412 |
+
'gtDontCare': gtDontCarePolsNum,
|
413 |
+
'detDontCare': detDontCarePolsNum,
|
414 |
+
'detMatched': detMatched,
|
415 |
+
'evaluationLog': evaluationLog
|
416 |
+
}
|
417 |
+
|
418 |
+
return perSampleMetrics
|
419 |
+
|
420 |
+
def combine_results(self, results):
|
421 |
+
numGlobalCareGt = 0
|
422 |
+
numGlobalCareDet = 0
|
423 |
+
matchedSum = 0
|
424 |
+
for result in results:
|
425 |
+
numGlobalCareGt += result['gtCare']
|
426 |
+
numGlobalCareDet += result['detCare']
|
427 |
+
matchedSum += result['detMatched']
|
428 |
+
|
429 |
+
methodRecall = 0 if numGlobalCareGt == 0 else float(
|
430 |
+
matchedSum) / numGlobalCareGt
|
431 |
+
methodPrecision = 0 if numGlobalCareDet == 0 else float(
|
432 |
+
matchedSum) / numGlobalCareDet
|
433 |
+
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
|
434 |
+
methodRecall * methodPrecision / (
|
435 |
+
methodRecall + methodPrecision)
|
436 |
+
|
437 |
+
methodMetrics = {'precision': methodPrecision,
|
438 |
+
'recall': methodRecall, 'hmean': methodHmean}
|
439 |
+
|
440 |
+
return methodMetrics
|
441 |
+
|
442 |
+
class QuadMetric():
|
443 |
+
def __init__(self, is_output_polygon=False):
|
444 |
+
self.is_output_polygon = is_output_polygon
|
445 |
+
self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon)
|
446 |
+
|
447 |
+
def measure(self, batch, output, box_thresh=0.6):
|
448 |
+
'''
|
449 |
+
batch: (image, polygons, ignore_tags
|
450 |
+
batch: a dict produced by dataloaders.
|
451 |
+
image: tensor of shape (N, C, H, W).
|
452 |
+
polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
|
453 |
+
ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
|
454 |
+
shape: the original shape of images.
|
455 |
+
filename: the original filenames of images.
|
456 |
+
output: (polygons, ...)
|
457 |
+
'''
|
458 |
+
results = []
|
459 |
+
gt_polyons_batch = batch['text_polys']
|
460 |
+
ignore_tags_batch = batch['ignore_tags']
|
461 |
+
pred_polygons_batch = np.array(output[0])
|
462 |
+
pred_scores_batch = np.array(output[1])
|
463 |
+
for polygons, pred_polygons, pred_scores, ignore_tags in zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch):
|
464 |
+
gt = [dict(points=np.int64(polygons[i]), ignore=ignore_tags[i]) for i in range(len(polygons))]
|
465 |
+
if self.is_output_polygon:
|
466 |
+
pred = [dict(points=pred_polygons[i]) for i in range(len(pred_polygons))]
|
467 |
+
else:
|
468 |
+
pred = []
|
469 |
+
# print(pred_polygons.shape)
|
470 |
+
for i in range(pred_polygons.shape[0]):
|
471 |
+
if pred_scores[i] >= box_thresh:
|
472 |
+
# print(pred_polygons[i,:,:].tolist())
|
473 |
+
pred.append(dict(points=pred_polygons[i, :, :].astype(np.int32)))
|
474 |
+
# pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])]
|
475 |
+
results.append(self.evaluator.evaluate_image(gt, pred))
|
476 |
+
return results
|
477 |
+
|
478 |
+
def validate_measure(self, batch, output, box_thresh=0.6):
|
479 |
+
return self.measure(batch, output, box_thresh)
|
480 |
+
|
481 |
+
def evaluate_measure(self, batch, output):
|
482 |
+
return self.measure(batch, output), np.linspace(0, batch['image'].shape[0]).tolist()
|
483 |
+
|
484 |
+
def gather_measure(self, raw_metrics):
|
485 |
+
raw_metrics = [image_metrics
|
486 |
+
for batch_metrics in raw_metrics
|
487 |
+
for image_metrics in batch_metrics]
|
488 |
+
|
489 |
+
result = self.evaluator.combine_results(raw_metrics)
|
490 |
+
|
491 |
+
precision = AverageMeter()
|
492 |
+
recall = AverageMeter()
|
493 |
+
fmeasure = AverageMeter()
|
494 |
+
|
495 |
+
precision.update(result['precision'], n=len(raw_metrics))
|
496 |
+
recall.update(result['recall'], n=len(raw_metrics))
|
497 |
+
fmeasure_score = 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8)
|
498 |
+
fmeasure.update(fmeasure_score)
|
499 |
+
|
500 |
+
return {
|
501 |
+
'precision': precision,
|
502 |
+
'recall': recall,
|
503 |
+
'fmeasure': fmeasure
|
504 |
+
}
|
505 |
+
|
506 |
+
def shrink_polygon_py(polygon, shrink_ratio):
|
507 |
+
"""
|
508 |
+
对框进行缩放,返回去的比例为1/shrink_ratio 即可
|
509 |
+
"""
|
510 |
+
cx = polygon[:, 0].mean()
|
511 |
+
cy = polygon[:, 1].mean()
|
512 |
+
polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio
|
513 |
+
polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio
|
514 |
+
return polygon
|
515 |
+
|
516 |
+
|
517 |
+
def shrink_polygon_pyclipper(polygon, shrink_ratio):
|
518 |
+
from shapely.geometry import Polygon
|
519 |
+
import pyclipper
|
520 |
+
polygon_shape = Polygon(polygon)
|
521 |
+
distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
|
522 |
+
subject = [tuple(l) for l in polygon]
|
523 |
+
padding = pyclipper.PyclipperOffset()
|
524 |
+
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
525 |
+
shrunk = padding.Execute(-distance)
|
526 |
+
if shrunk == []:
|
527 |
+
shrunk = np.array(shrunk)
|
528 |
+
else:
|
529 |
+
shrunk = np.array(shrunk[0]).reshape(-1, 2)
|
530 |
+
return shrunk
|
531 |
+
|
532 |
+
class MakeShrinkMap():
|
533 |
+
r'''
|
534 |
+
Making binary mask from detection data with ICDAR format.
|
535 |
+
Typically following the process of class `MakeICDARData`.
|
536 |
+
'''
|
537 |
+
|
538 |
+
def __init__(self, min_text_size=4, shrink_ratio=0.4, shrink_type='pyclipper'):
|
539 |
+
shrink_func_dict = {'py': shrink_polygon_py, 'pyclipper': shrink_polygon_pyclipper}
|
540 |
+
self.shrink_func = shrink_func_dict[shrink_type]
|
541 |
+
self.min_text_size = min_text_size
|
542 |
+
self.shrink_ratio = shrink_ratio
|
543 |
+
|
544 |
+
def __call__(self, data: dict) -> dict:
|
545 |
+
"""
|
546 |
+
从scales中随机选择一个尺度,对图片和文本框进行缩放
|
547 |
+
:param data: {'imgs':,'text_polys':,'texts':,'ignore_tags':}
|
548 |
+
:return:
|
549 |
+
"""
|
550 |
+
image = data['imgs']
|
551 |
+
text_polys = data['text_polys']
|
552 |
+
ignore_tags = data['ignore_tags']
|
553 |
+
|
554 |
+
h, w = image.shape[:2]
|
555 |
+
text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
|
556 |
+
gt = np.zeros((h, w), dtype=np.float32)
|
557 |
+
mask = np.ones((h, w), dtype=np.float32)
|
558 |
+
for i in range(len(text_polys)):
|
559 |
+
polygon = text_polys[i]
|
560 |
+
height = max(polygon[:, 1]) - min(polygon[:, 1])
|
561 |
+
width = max(polygon[:, 0]) - min(polygon[:, 0])
|
562 |
+
if ignore_tags[i] or min(height, width) < self.min_text_size:
|
563 |
+
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
564 |
+
ignore_tags[i] = True
|
565 |
+
else:
|
566 |
+
shrunk = self.shrink_func(polygon, self.shrink_ratio)
|
567 |
+
if shrunk.size == 0:
|
568 |
+
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
569 |
+
ignore_tags[i] = True
|
570 |
+
continue
|
571 |
+
cv2.fillPoly(gt, [shrunk.astype(np.int32)], 1)
|
572 |
+
|
573 |
+
data['shrink_map'] = gt
|
574 |
+
data['shrink_mask'] = mask
|
575 |
+
return data
|
576 |
+
|
577 |
+
def validate_polygons(self, polygons, ignore_tags, h, w):
|
578 |
+
'''
|
579 |
+
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
|
580 |
+
'''
|
581 |
+
if len(polygons) == 0:
|
582 |
+
return polygons, ignore_tags
|
583 |
+
assert len(polygons) == len(ignore_tags)
|
584 |
+
for polygon in polygons:
|
585 |
+
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
|
586 |
+
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
|
587 |
+
|
588 |
+
for i in range(len(polygons)):
|
589 |
+
area = self.polygon_area(polygons[i])
|
590 |
+
if abs(area) < 1:
|
591 |
+
ignore_tags[i] = True
|
592 |
+
if area > 0:
|
593 |
+
polygons[i] = polygons[i][::-1, :]
|
594 |
+
return polygons, ignore_tags
|
595 |
+
|
596 |
+
def polygon_area(self, polygon):
|
597 |
+
return cv2.contourArea(polygon)
|
598 |
+
|
599 |
+
|
600 |
+
class MakeBorderMap():
|
601 |
+
def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7):
|
602 |
+
self.shrink_ratio = shrink_ratio
|
603 |
+
self.thresh_min = thresh_min
|
604 |
+
self.thresh_max = thresh_max
|
605 |
+
|
606 |
+
def __call__(self, data: dict) -> dict:
|
607 |
+
"""
|
608 |
+
从scales中随机选择一个尺度,对图片和文本框进行缩放
|
609 |
+
:param data: {'imgs':,'text_polys':,'texts':,'ignore_tags':}
|
610 |
+
:return:
|
611 |
+
"""
|
612 |
+
im = data['imgs']
|
613 |
+
text_polys = data['text_polys']
|
614 |
+
ignore_tags = data['ignore_tags']
|
615 |
+
|
616 |
+
canvas = np.zeros(im.shape[:2], dtype=np.float32)
|
617 |
+
mask = np.zeros(im.shape[:2], dtype=np.float32)
|
618 |
+
|
619 |
+
for i in range(len(text_polys)):
|
620 |
+
if ignore_tags[i]:
|
621 |
+
continue
|
622 |
+
self.draw_border_map(text_polys[i], canvas, mask=mask)
|
623 |
+
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
|
624 |
+
|
625 |
+
data['threshold_map'] = canvas
|
626 |
+
data['threshold_mask'] = mask
|
627 |
+
return data
|
628 |
+
|
629 |
+
def draw_border_map(self, polygon, canvas, mask):
|
630 |
+
polygon = np.array(polygon)
|
631 |
+
assert polygon.ndim == 2
|
632 |
+
assert polygon.shape[1] == 2
|
633 |
+
|
634 |
+
polygon_shape = Polygon(polygon)
|
635 |
+
if polygon_shape.area <= 0:
|
636 |
+
return
|
637 |
+
distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
|
638 |
+
subject = [tuple(l) for l in polygon]
|
639 |
+
padding = pyclipper.PyclipperOffset()
|
640 |
+
padding.AddPath(subject, pyclipper.JT_ROUND,
|
641 |
+
pyclipper.ET_CLOSEDPOLYGON)
|
642 |
+
|
643 |
+
padded_polygon = np.array(padding.Execute(distance)[0])
|
644 |
+
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
645 |
+
|
646 |
+
xmin = padded_polygon[:, 0].min()
|
647 |
+
xmax = padded_polygon[:, 0].max()
|
648 |
+
ymin = padded_polygon[:, 1].min()
|
649 |
+
ymax = padded_polygon[:, 1].max()
|
650 |
+
width = xmax - xmin + 1
|
651 |
+
height = ymax - ymin + 1
|
652 |
+
|
653 |
+
polygon[:, 0] = polygon[:, 0] - xmin
|
654 |
+
polygon[:, 1] = polygon[:, 1] - ymin
|
655 |
+
|
656 |
+
xs = np.broadcast_to(
|
657 |
+
np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
|
658 |
+
ys = np.broadcast_to(
|
659 |
+
np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
|
660 |
+
|
661 |
+
distance_map = np.zeros(
|
662 |
+
(polygon.shape[0], height, width), dtype=np.float32)
|
663 |
+
for i in range(polygon.shape[0]):
|
664 |
+
j = (i + 1) % polygon.shape[0]
|
665 |
+
absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
|
666 |
+
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
|
667 |
+
distance_map = distance_map.min(axis=0)
|
668 |
+
|
669 |
+
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
|
670 |
+
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
|
671 |
+
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
|
672 |
+
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
|
673 |
+
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
|
674 |
+
1 - distance_map[
|
675 |
+
ymin_valid - ymin:ymax_valid - ymax + height,
|
676 |
+
xmin_valid - xmin:xmax_valid - xmax + width],
|
677 |
+
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
|
678 |
+
|
679 |
+
def distance(self, xs, ys, point_1, point_2):
|
680 |
+
'''
|
681 |
+
compute the distance from point to a line
|
682 |
+
ys: coordinates in the first axis
|
683 |
+
xs: coordinates in the second axis
|
684 |
+
point_1, point_2: (x, y), the end of the line
|
685 |
+
'''
|
686 |
+
height, width = xs.shape[:2]
|
687 |
+
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
|
688 |
+
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
|
689 |
+
square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])
|
690 |
+
|
691 |
+
cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2))
|
692 |
+
square_sin = 1 - np.square(cosin)
|
693 |
+
square_sin = np.nan_to_num(square_sin)
|
694 |
+
|
695 |
+
result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance)
|
696 |
+
result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0]
|
697 |
+
return result
|
698 |
+
|
699 |
+
def extend_line(self, point_1, point_2, result):
|
700 |
+
ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))),
|
701 |
+
int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))))
|
702 |
+
cv2.line(result, tuple(ex_point_1), tuple(point_1), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
|
703 |
+
ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))),
|
704 |
+
int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))))
|
705 |
+
cv2.line(result, tuple(ex_point_2), tuple(point_2), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
|
706 |
+
return ex_point_1, ex_point_2
|
manga_translator/detection/ctd_utils/utils/imgproc_utils.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import random
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
def hex2bgr(hex):
|
7 |
+
gmask = 254 << 8
|
8 |
+
rmask = 254
|
9 |
+
b = hex >> 16
|
10 |
+
g = (hex & gmask) >> 8
|
11 |
+
r = hex & rmask
|
12 |
+
return np.stack([b, g, r]).transpose()
|
13 |
+
|
14 |
+
def union_area(bboxa, bboxb):
|
15 |
+
x1 = max(bboxa[0], bboxb[0])
|
16 |
+
y1 = max(bboxa[1], bboxb[1])
|
17 |
+
x2 = min(bboxa[2], bboxb[2])
|
18 |
+
y2 = min(bboxa[3], bboxb[3])
|
19 |
+
if y2 < y1 or x2 < x1:
|
20 |
+
return -1
|
21 |
+
return (y2 - y1) * (x2 - x1)
|
22 |
+
|
23 |
+
def get_yololabel_strings(clslist, labellist):
|
24 |
+
content = ''
|
25 |
+
for cls, xywh in zip(clslist, labellist):
|
26 |
+
content += str(int(cls)) + ' ' + ' '.join([str(e) for e in xywh]) + '\n'
|
27 |
+
if len(content) != 0:
|
28 |
+
content = content[:-1]
|
29 |
+
return content
|
30 |
+
|
31 |
+
# 4 points bbox to 8 points polygon
|
32 |
+
def xywh2xyxypoly(xywh, to_int=True):
|
33 |
+
xyxypoly = np.tile(xywh[:, [0, 1]], 4)
|
34 |
+
xyxypoly[:, [2, 4]] += xywh[:, [2]]
|
35 |
+
xyxypoly[:, [5, 7]] += xywh[:, [3]]
|
36 |
+
if to_int:
|
37 |
+
xyxypoly = xyxypoly.astype(np.int64)
|
38 |
+
return xyxypoly
|
39 |
+
|
40 |
+
def xyxy2yolo(xyxy, w: int, h: int):
|
41 |
+
if xyxy == [] or xyxy == np.array([]) or len(xyxy) == 0:
|
42 |
+
return None
|
43 |
+
if isinstance(xyxy, list):
|
44 |
+
xyxy = np.array(xyxy)
|
45 |
+
if len(xyxy.shape) == 1:
|
46 |
+
xyxy = np.array([xyxy])
|
47 |
+
yolo = np.copy(xyxy).astype(np.float64)
|
48 |
+
yolo[:, [0, 2]] = yolo[:, [0, 2]] / w
|
49 |
+
yolo[:, [1, 3]] = yolo[:, [1, 3]] / h
|
50 |
+
yolo[:, [2, 3]] -= yolo[:, [0, 1]]
|
51 |
+
yolo[:, [0, 1]] += yolo[:, [2, 3]] / 2
|
52 |
+
return yolo
|
53 |
+
|
54 |
+
def yolo_xywh2xyxy(xywh: np.array, w: int, h: int, to_int=True):
|
55 |
+
if xywh is None:
|
56 |
+
return None
|
57 |
+
if len(xywh) == 0:
|
58 |
+
return None
|
59 |
+
if len(xywh.shape) == 1:
|
60 |
+
xywh = np.array([xywh])
|
61 |
+
xywh[:, [0, 2]] *= w
|
62 |
+
xywh[:, [1, 3]] *= h
|
63 |
+
xywh[:, [0, 1]] -= xywh[:, [2, 3]] / 2
|
64 |
+
xywh[:, [2, 3]] += xywh[:, [0, 1]]
|
65 |
+
if to_int:
|
66 |
+
xywh = xywh.astype(np.int64)
|
67 |
+
return xywh
|
68 |
+
|
69 |
+
def letterbox(im, new_shape=(640, 640), color=(0, 0, 0), auto=False, scaleFill=False, scaleup=True, stride=128):
|
70 |
+
# Resize and pad image while meeting stride-multiple constraints
|
71 |
+
shape = im.shape[:2] # current shape [height, width]
|
72 |
+
if not isinstance(new_shape, tuple):
|
73 |
+
new_shape = (new_shape, new_shape)
|
74 |
+
|
75 |
+
# Scale ratio (new / old)
|
76 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
77 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
78 |
+
r = min(r, 1.0)
|
79 |
+
|
80 |
+
# Compute padding
|
81 |
+
ratio = r, r # width, height ratios
|
82 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
83 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
84 |
+
if auto: # minimum rectangle
|
85 |
+
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
86 |
+
elif scaleFill: # stretch
|
87 |
+
dw, dh = 0.0, 0.0
|
88 |
+
new_unpad = (new_shape[1], new_shape[0])
|
89 |
+
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
90 |
+
|
91 |
+
# dw /= 2 # divide padding into 2 sides
|
92 |
+
# dh /= 2
|
93 |
+
dh, dw = int(dh), int(dw)
|
94 |
+
|
95 |
+
if shape[::-1] != new_unpad: # resize
|
96 |
+
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
97 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
98 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
99 |
+
im = cv2.copyMakeBorder(im, 0, dh, 0, dw, cv2.BORDER_CONSTANT, value=color) # add border
|
100 |
+
return im, ratio, (dw, dh)
|
101 |
+
|
102 |
+
def resize_keepasp(im, new_shape=640, scaleup=True, interpolation=cv2.INTER_LINEAR, stride=None):
|
103 |
+
shape = im.shape[:2] # current shape [height, width]
|
104 |
+
|
105 |
+
if new_shape is not None:
|
106 |
+
if not isinstance(new_shape, tuple):
|
107 |
+
new_shape = (new_shape, new_shape)
|
108 |
+
else:
|
109 |
+
new_shape = shape
|
110 |
+
|
111 |
+
# Scale ratio (new / old)
|
112 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
113 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
114 |
+
r = min(r, 1.0)
|
115 |
+
|
116 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
117 |
+
|
118 |
+
if stride is not None:
|
119 |
+
h, w = new_unpad
|
120 |
+
if new_shape[0] % stride != 0:
|
121 |
+
new_h = (stride - (new_shape[0] % stride)) + h
|
122 |
+
else:
|
123 |
+
new_h = h
|
124 |
+
if w % stride != 0:
|
125 |
+
new_w = (stride - (w % stride)) + w
|
126 |
+
else:
|
127 |
+
new_w = w
|
128 |
+
new_unpad = (new_h, new_w)
|
129 |
+
|
130 |
+
if shape[::-1] != new_unpad: # resize
|
131 |
+
im = cv2.resize(im, new_unpad, interpolation=interpolation)
|
132 |
+
return im
|
133 |
+
|
134 |
+
def enlarge_window(rect, im_w, im_h, ratio=2.5, aspect_ratio=1.0) -> List:
|
135 |
+
assert ratio > 1.0
|
136 |
+
|
137 |
+
x1, y1, x2, y2 = rect
|
138 |
+
w = x2 - x1
|
139 |
+
h = y2 - y1
|
140 |
+
|
141 |
+
# https://numpy.org/doc/stable/reference/generated/numpy.roots.html
|
142 |
+
coeff = [aspect_ratio, w+h*aspect_ratio, (1-ratio)*w*h]
|
143 |
+
roots = np.roots(coeff)
|
144 |
+
roots.sort()
|
145 |
+
delta = int(round(roots[-1] / 2 ))
|
146 |
+
delta_w = int(delta * aspect_ratio)
|
147 |
+
delta_w = min(x1, im_w - x2, delta_w)
|
148 |
+
delta = min(y1, im_h - y2, delta)
|
149 |
+
rect = np.array([x1-delta_w, y1-delta, x2+delta_w, y2+delta], dtype=np.int64)
|
150 |
+
return rect.tolist()
|
151 |
+
|
152 |
+
def draw_connected_labels(num_labels, labels, stats, centroids, names="draw_connected_labels", skip_background=True):
|
153 |
+
labdraw = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8)
|
154 |
+
max_ind = 0
|
155 |
+
if isinstance(num_labels, int):
|
156 |
+
num_labels = range(num_labels)
|
157 |
+
|
158 |
+
# for ind, lab in enumerate((range(num_labels))):
|
159 |
+
for lab in num_labels:
|
160 |
+
if skip_background and lab == 0:
|
161 |
+
continue
|
162 |
+
randcolor = (random.randint(0,255), random.randint(0,255), random.randint(0,255))
|
163 |
+
labdraw[np.where(labels==lab)] = randcolor
|
164 |
+
maxr, minr = 0.5, 0.001
|
165 |
+
maxw, maxh = stats[max_ind][2] * maxr, stats[max_ind][3] * maxr
|
166 |
+
minarea = labdraw.shape[0] * labdraw.shape[1] * minr
|
167 |
+
|
168 |
+
stat = stats[lab]
|
169 |
+
bboxarea = stat[2] * stat[3]
|
170 |
+
if stat[2] < maxw and stat[3] < maxh and bboxarea > minarea:
|
171 |
+
pix = np.zeros((labels.shape[0], labels.shape[1]), dtype=np.uint8)
|
172 |
+
pix[np.where(labels==lab)] = 255
|
173 |
+
|
174 |
+
rect = cv2.minAreaRect(cv2.findNonZero(pix))
|
175 |
+
box = np.int0(cv2.boxPoints(rect))
|
176 |
+
labdraw = cv2.drawContours(labdraw, [box], 0, randcolor, 2)
|
177 |
+
labdraw = cv2.circle(labdraw, (int(centroids[lab][0]),int(centroids[lab][1])), radius=5, color=(random.randint(0,255), random.randint(0,255), random.randint(0,255)), thickness=-1)
|
178 |
+
|
179 |
+
cv2.imshow(names, labdraw)
|
180 |
+
return labdraw
|
manga_translator/detection/ctd_utils/utils/io_utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import glob
|
4 |
+
from pathlib import Path
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import json
|
8 |
+
|
9 |
+
IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg']
|
10 |
+
|
11 |
+
NP_BOOL_TYPES = (np.bool_, np.bool8)
|
12 |
+
NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
|
13 |
+
NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
|
14 |
+
|
15 |
+
# https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
|
16 |
+
class NumpyEncoder(json.JSONEncoder):
|
17 |
+
def default(self, obj):
|
18 |
+
if isinstance(obj, np.ndarray):
|
19 |
+
return obj.tolist()
|
20 |
+
elif isinstance(obj, np.ScalarType):
|
21 |
+
if isinstance(obj, NP_BOOL_TYPES):
|
22 |
+
return bool(obj)
|
23 |
+
elif isinstance(obj, NP_FLOAT_TYPES):
|
24 |
+
return float(obj)
|
25 |
+
elif isinstance(obj, NP_INT_TYPES):
|
26 |
+
return int(obj)
|
27 |
+
return json.JSONEncoder.default(self, obj)
|
28 |
+
|
29 |
+
def find_all_imgs(img_dir, abs_path=False):
|
30 |
+
imglist = list()
|
31 |
+
for filep in glob.glob(osp.join(img_dir, "*")):
|
32 |
+
filename = osp.basename(filep)
|
33 |
+
file_suffix = Path(filename).suffix
|
34 |
+
if file_suffix.lower() not in IMG_EXT:
|
35 |
+
continue
|
36 |
+
if abs_path:
|
37 |
+
imglist.append(filep)
|
38 |
+
else:
|
39 |
+
imglist.append(filename)
|
40 |
+
return imglist
|
41 |
+
|
42 |
+
def imread(imgpath, read_type=cv2.IMREAD_COLOR):
|
43 |
+
# img = cv2.imread(imgpath, read_type)
|
44 |
+
# if img is None:
|
45 |
+
img = cv2.imdecode(np.fromfile(imgpath, dtype=np.uint8), read_type)
|
46 |
+
return img
|
47 |
+
|
48 |
+
def imwrite(img_path, img, ext='.png'):
|
49 |
+
suffix = Path(img_path).suffix
|
50 |
+
if suffix != '':
|
51 |
+
img_path = img_path.replace(suffix, ext)
|
52 |
+
else:
|
53 |
+
img_path += ext
|
54 |
+
cv2.imencode(ext, img)[1].tofile(img_path)
|
manga_translator/detection/ctd_utils/utils/weight_init.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def constant_init(module, val, bias=0):
|
5 |
+
nn.init.constant_(module.weight, val)
|
6 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
7 |
+
nn.init.constant_(module.bias, bias)
|
8 |
+
|
9 |
+
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
10 |
+
assert distribution in ['uniform', 'normal']
|
11 |
+
if distribution == 'uniform':
|
12 |
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
13 |
+
else:
|
14 |
+
nn.init.xavier_normal_(module.weight, gain=gain)
|
15 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
16 |
+
nn.init.constant_(module.bias, bias)
|
17 |
+
|
18 |
+
|
19 |
+
def normal_init(module, mean=0, std=1, bias=0):
|
20 |
+
nn.init.normal_(module.weight, mean, std)
|
21 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
22 |
+
nn.init.constant_(module.bias, bias)
|
23 |
+
|
24 |
+
|
25 |
+
def uniform_init(module, a=0, b=1, bias=0):
|
26 |
+
nn.init.uniform_(module.weight, a, b)
|
27 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
28 |
+
nn.init.constant_(module.bias, bias)
|
29 |
+
|
30 |
+
|
31 |
+
def kaiming_init(module,
|
32 |
+
a=0,
|
33 |
+
is_rnn=False,
|
34 |
+
mode='fan_in',
|
35 |
+
nonlinearity='leaky_relu',
|
36 |
+
bias=0,
|
37 |
+
distribution='normal'):
|
38 |
+
assert distribution in ['uniform', 'normal']
|
39 |
+
if distribution == 'uniform':
|
40 |
+
if is_rnn:
|
41 |
+
for name, param in module.named_parameters():
|
42 |
+
if 'bias' in name:
|
43 |
+
nn.init.constant_(param, bias)
|
44 |
+
elif 'weight' in name:
|
45 |
+
nn.init.kaiming_uniform_(param,
|
46 |
+
a=a,
|
47 |
+
mode=mode,
|
48 |
+
nonlinearity=nonlinearity)
|
49 |
+
else:
|
50 |
+
nn.init.kaiming_uniform_(module.weight,
|
51 |
+
a=a,
|
52 |
+
mode=mode,
|
53 |
+
nonlinearity=nonlinearity)
|
54 |
+
|
55 |
+
else:
|
56 |
+
if is_rnn:
|
57 |
+
for name, param in module.named_parameters():
|
58 |
+
if 'bias' in name:
|
59 |
+
nn.init.constant_(param, bias)
|
60 |
+
elif 'weight' in name:
|
61 |
+
nn.init.kaiming_normal_(param,
|
62 |
+
a=a,
|
63 |
+
mode=mode,
|
64 |
+
nonlinearity=nonlinearity)
|
65 |
+
else:
|
66 |
+
nn.init.kaiming_normal_(module.weight,
|
67 |
+
a=a,
|
68 |
+
mode=mode,
|
69 |
+
nonlinearity=nonlinearity)
|
70 |
+
|
71 |
+
if not is_rnn and hasattr(module, 'bias') and module.bias is not None:
|
72 |
+
nn.init.constant_(module.bias, bias)
|
73 |
+
|
74 |
+
|
75 |
+
def bilinear_kernel(in_channels, out_channels, kernel_size):
|
76 |
+
factor = (kernel_size + 1) // 2
|
77 |
+
if kernel_size % 2 == 1:
|
78 |
+
center = factor - 1
|
79 |
+
else:
|
80 |
+
center = factor - 0.5
|
81 |
+
og = (torch.arange(kernel_size).reshape(-1, 1),
|
82 |
+
torch.arange(kernel_size).reshape(1, -1))
|
83 |
+
filt = (1 - torch.abs(og[0] - center) / factor) * \
|
84 |
+
(1 - torch.abs(og[1] - center) / factor)
|
85 |
+
weight = torch.zeros((in_channels, out_channels,
|
86 |
+
kernel_size, kernel_size))
|
87 |
+
weight[range(in_channels), range(out_channels), :, :] = filt
|
88 |
+
return weight
|
89 |
+
|
90 |
+
|
91 |
+
def init_weights(m):
|
92 |
+
# for m in modules:
|
93 |
+
|
94 |
+
if isinstance(m, nn.Conv2d):
|
95 |
+
kaiming_init(m)
|
96 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
97 |
+
constant_init(m, 1)
|
98 |
+
elif isinstance(m, nn.Linear):
|
99 |
+
xavier_init(m)
|
100 |
+
elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
|
101 |
+
kaiming_init(m, is_rnn=True)
|
102 |
+
# elif isinstance(m, nn.ConvTranspose2d):
|
103 |
+
# m.weight.data.copy_(bilinear_kernel(m.in_channels, m.out_channels, 4));
|
manga_translator/detection/ctd_utils/utils/yolov5_utils.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import time
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
11 |
+
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
12 |
+
if ratio == 1.0:
|
13 |
+
return img
|
14 |
+
else:
|
15 |
+
h, w = img.shape[2:]
|
16 |
+
s = (int(h * ratio), int(w * ratio)) # new size
|
17 |
+
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
18 |
+
if not same_shape: # pad/crop img
|
19 |
+
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
20 |
+
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
21 |
+
|
22 |
+
def fuse_conv_and_bn(conv, bn):
|
23 |
+
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
24 |
+
fusedconv = nn.Conv2d(conv.in_channels,
|
25 |
+
conv.out_channels,
|
26 |
+
kernel_size=conv.kernel_size,
|
27 |
+
stride=conv.stride,
|
28 |
+
padding=conv.padding,
|
29 |
+
groups=conv.groups,
|
30 |
+
bias=True).requires_grad_(False).to(conv.weight.device)
|
31 |
+
|
32 |
+
# prepare filters
|
33 |
+
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
34 |
+
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
35 |
+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
36 |
+
|
37 |
+
# prepare spatial bias
|
38 |
+
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
39 |
+
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
40 |
+
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
41 |
+
|
42 |
+
return fusedconv
|
43 |
+
|
44 |
+
def check_anchor_order(m):
|
45 |
+
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
|
46 |
+
a = m.anchors.prod(-1).view(-1) # anchor area
|
47 |
+
da = a[-1] - a[0] # delta a
|
48 |
+
ds = m.stride[-1] - m.stride[0] # delta s
|
49 |
+
if da.sign() != ds.sign(): # same order
|
50 |
+
m.anchors[:] = m.anchors.flip(0)
|
51 |
+
|
52 |
+
def initialize_weights(model):
|
53 |
+
for m in model.modules():
|
54 |
+
t = type(m)
|
55 |
+
if t is nn.Conv2d:
|
56 |
+
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
57 |
+
elif t is nn.BatchNorm2d:
|
58 |
+
m.eps = 1e-3
|
59 |
+
m.momentum = 0.03
|
60 |
+
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
61 |
+
m.inplace = True
|
62 |
+
|
63 |
+
def make_divisible(x, divisor):
|
64 |
+
# Returns nearest x divisible by divisor
|
65 |
+
if isinstance(divisor, torch.Tensor):
|
66 |
+
divisor = int(divisor.max()) # to int
|
67 |
+
return math.ceil(x / divisor) * divisor
|
68 |
+
|
69 |
+
def intersect_dicts(da, db, exclude=()):
|
70 |
+
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
71 |
+
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
|
72 |
+
|
73 |
+
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
|
74 |
+
# Check version vs. required version
|
75 |
+
from packaging import version
|
76 |
+
current, minimum = (version.parse(x) for x in (current, minimum))
|
77 |
+
result = (current == minimum) if pinned else (current >= minimum) # bool
|
78 |
+
if hard: # assert min requirements met
|
79 |
+
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
|
80 |
+
else:
|
81 |
+
return result
|
82 |
+
|
83 |
+
class Colors:
|
84 |
+
# Ultralytics color palette https://ultralytics.com/
|
85 |
+
def __init__(self):
|
86 |
+
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
87 |
+
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
88 |
+
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
89 |
+
self.palette = [self.hex2rgb('#' + c) for c in hex]
|
90 |
+
self.n = len(self.palette)
|
91 |
+
|
92 |
+
def __call__(self, i, bgr=False):
|
93 |
+
c = self.palette[int(i) % self.n]
|
94 |
+
return (c[2], c[1], c[0]) if bgr else c
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def hex2rgb(h): # rgb order (PIL)
|
98 |
+
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
99 |
+
|
100 |
+
def box_iou(box1, box2):
|
101 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
102 |
+
"""
|
103 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
104 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
105 |
+
Arguments:
|
106 |
+
box1 (Tensor[N, 4])
|
107 |
+
box2 (Tensor[M, 4])
|
108 |
+
Returns:
|
109 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
110 |
+
IoU values for every element in boxes1 and boxes2
|
111 |
+
"""
|
112 |
+
|
113 |
+
def box_area(box):
|
114 |
+
# box = 4xn
|
115 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
116 |
+
|
117 |
+
area1 = box_area(box1.T)
|
118 |
+
area2 = box_area(box2.T)
|
119 |
+
|
120 |
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
121 |
+
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
122 |
+
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
123 |
+
|
124 |
+
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
|
125 |
+
labels=(), max_det=300):
|
126 |
+
"""Runs Non-Maximum Suppression (NMS) on inference results
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
130 |
+
"""
|
131 |
+
|
132 |
+
if isinstance(prediction, np.ndarray):
|
133 |
+
prediction = torch.from_numpy(prediction)
|
134 |
+
|
135 |
+
nc = prediction.shape[2] - 5 # number of classes
|
136 |
+
xc = prediction[..., 4] > conf_thres # candidates
|
137 |
+
|
138 |
+
# Checks
|
139 |
+
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
140 |
+
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
141 |
+
|
142 |
+
# Settings
|
143 |
+
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
144 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
145 |
+
time_limit = 10.0 # seconds to quit after
|
146 |
+
redundant = True # require redundant detections
|
147 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
148 |
+
merge = False # use merge-NMS
|
149 |
+
|
150 |
+
t = time.time()
|
151 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
|
152 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
153 |
+
# Apply constraints
|
154 |
+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
155 |
+
x = x[xc[xi]] # confidence
|
156 |
+
|
157 |
+
# Cat apriori labels if autolabelling
|
158 |
+
if labels and len(labels[xi]):
|
159 |
+
l = labels[xi]
|
160 |
+
v = torch.zeros((len(l), nc + 5), device=x.device)
|
161 |
+
v[:, :4] = l[:, 1:5] # box
|
162 |
+
v[:, 4] = 1.0 # conf
|
163 |
+
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
|
164 |
+
x = torch.cat((x, v), 0)
|
165 |
+
|
166 |
+
# If none remain process next image
|
167 |
+
if not x.shape[0]:
|
168 |
+
continue
|
169 |
+
|
170 |
+
# Compute conf
|
171 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
172 |
+
|
173 |
+
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
174 |
+
box = xywh2xyxy(x[:, :4])
|
175 |
+
|
176 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
177 |
+
if multi_label:
|
178 |
+
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
179 |
+
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
|
180 |
+
else: # best class only
|
181 |
+
conf, j = x[:, 5:].max(1, keepdim=True)
|
182 |
+
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
183 |
+
|
184 |
+
# Filter by class
|
185 |
+
if classes is not None:
|
186 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
187 |
+
|
188 |
+
# Apply finite constraint
|
189 |
+
# if not torch.isfinite(x).all():
|
190 |
+
# x = x[torch.isfinite(x).all(1)]
|
191 |
+
|
192 |
+
# Check shape
|
193 |
+
n = x.shape[0] # number of boxes
|
194 |
+
if not n: # no boxes
|
195 |
+
continue
|
196 |
+
elif n > max_nms: # excess boxes
|
197 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
198 |
+
|
199 |
+
# Batched NMS
|
200 |
+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
201 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
202 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
203 |
+
if i.shape[0] > max_det: # limit detections
|
204 |
+
i = i[:max_det]
|
205 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
206 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
207 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
208 |
+
weights = iou * scores[None] # box weights
|
209 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
210 |
+
if redundant:
|
211 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
212 |
+
|
213 |
+
output[xi] = x[i]
|
214 |
+
if (time.time() - t) > time_limit:
|
215 |
+
print(f'WARNING: NMS time limit {time_limit}s exceeded')
|
216 |
+
break # time limit exceeded
|
217 |
+
|
218 |
+
return output
|
219 |
+
|
220 |
+
def xywh2xyxy(x):
|
221 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
222 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
223 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
224 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
225 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
226 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
227 |
+
return y
|
228 |
+
|
229 |
+
DEFAULT_LANG_LIST = ['eng', 'ja']
|
230 |
+
def draw_bbox(pred, img, lang_list=None):
|
231 |
+
if lang_list is None:
|
232 |
+
lang_list = DEFAULT_LANG_LIST
|
233 |
+
lw = max(round(sum(img.shape) / 2 * 0.003), 2) # line width
|
234 |
+
pred = pred.astype(np.int32)
|
235 |
+
colors = Colors()
|
236 |
+
img = np.copy(img)
|
237 |
+
for ii, obj in enumerate(pred):
|
238 |
+
p1, p2 = (obj[0], obj[1]), (obj[2], obj[3])
|
239 |
+
label = lang_list[obj[-1]] + str(ii+1)
|
240 |
+
cv2.rectangle(img, p1, p2, colors(obj[-1], bgr=True), lw, lineType=cv2.LINE_AA)
|
241 |
+
t_w, t_h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=lw)[0]
|
242 |
+
cv2.putText(img, label, (p1[0], p1[1] + t_h + 2), 0, lw / 3, colors(obj[-1], bgr=True), max(lw-1, 1), cv2.LINE_AA)
|
243 |
+
return img
|
manga_translator/detection/ctd_utils/yolov5/common.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
2 |
+
"""
|
3 |
+
Common modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import math
|
8 |
+
import platform
|
9 |
+
import warnings
|
10 |
+
from collections import OrderedDict, namedtuple
|
11 |
+
from copy import copy
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
import requests
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from PIL import Image
|
20 |
+
from torch.cuda import amp
|
21 |
+
|
22 |
+
from ..utils.yolov5_utils import make_divisible, initialize_weights, check_anchor_order, check_version, fuse_conv_and_bn
|
23 |
+
|
24 |
+
def autopad(k, p=None): # kernel, padding
|
25 |
+
# Pad to 'same'
|
26 |
+
if p is None:
|
27 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
28 |
+
return p
|
29 |
+
|
30 |
+
class Conv(nn.Module):
|
31 |
+
# Standard convolution
|
32 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
33 |
+
super().__init__()
|
34 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
35 |
+
self.bn = nn.BatchNorm2d(c2)
|
36 |
+
if isinstance(act, bool):
|
37 |
+
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
38 |
+
elif isinstance(act, str):
|
39 |
+
if act == 'leaky':
|
40 |
+
self.act = nn.LeakyReLU(0.1, inplace=True)
|
41 |
+
elif act == 'relu':
|
42 |
+
self.act = nn.ReLU(inplace=True)
|
43 |
+
else:
|
44 |
+
self.act = None
|
45 |
+
def forward(self, x):
|
46 |
+
return self.act(self.bn(self.conv(x)))
|
47 |
+
|
48 |
+
def forward_fuse(self, x):
|
49 |
+
return self.act(self.conv(x))
|
50 |
+
|
51 |
+
|
52 |
+
class DWConv(Conv):
|
53 |
+
# Depth-wise convolution class
|
54 |
+
def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
55 |
+
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
|
56 |
+
|
57 |
+
|
58 |
+
class TransformerLayer(nn.Module):
|
59 |
+
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
|
60 |
+
def __init__(self, c, num_heads):
|
61 |
+
super().__init__()
|
62 |
+
self.q = nn.Linear(c, c, bias=False)
|
63 |
+
self.k = nn.Linear(c, c, bias=False)
|
64 |
+
self.v = nn.Linear(c, c, bias=False)
|
65 |
+
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
|
66 |
+
self.fc1 = nn.Linear(c, c, bias=False)
|
67 |
+
self.fc2 = nn.Linear(c, c, bias=False)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
71 |
+
x = self.fc2(self.fc1(x)) + x
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class TransformerBlock(nn.Module):
|
76 |
+
# Vision Transformer https://arxiv.org/abs/2010.11929
|
77 |
+
def __init__(self, c1, c2, num_heads, num_layers):
|
78 |
+
super().__init__()
|
79 |
+
self.conv = None
|
80 |
+
if c1 != c2:
|
81 |
+
self.conv = Conv(c1, c2)
|
82 |
+
self.linear = nn.Linear(c2, c2) # learnable position embedding
|
83 |
+
self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
|
84 |
+
self.c2 = c2
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
if self.conv is not None:
|
88 |
+
x = self.conv(x)
|
89 |
+
b, _, w, h = x.shape
|
90 |
+
p = x.flatten(2).permute(2, 0, 1)
|
91 |
+
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
|
92 |
+
|
93 |
+
|
94 |
+
class Bottleneck(nn.Module):
|
95 |
+
# Standard bottleneck
|
96 |
+
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, act=True): # ch_in, ch_out, shortcut, groups, expansion
|
97 |
+
super().__init__()
|
98 |
+
c_ = int(c2 * e) # hidden channels
|
99 |
+
self.cv1 = Conv(c1, c_, 1, 1, act=act)
|
100 |
+
self.cv2 = Conv(c_, c2, 3, 1, g=g, act=act)
|
101 |
+
self.add = shortcut and c1 == c2
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
105 |
+
|
106 |
+
|
107 |
+
class BottleneckCSP(nn.Module):
|
108 |
+
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
109 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
110 |
+
super().__init__()
|
111 |
+
c_ = int(c2 * e) # hidden channels
|
112 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
113 |
+
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
114 |
+
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
115 |
+
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
116 |
+
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
117 |
+
self.act = nn.SiLU()
|
118 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
y1 = self.cv3(self.m(self.cv1(x)))
|
122 |
+
y2 = self.cv2(x)
|
123 |
+
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
|
124 |
+
|
125 |
+
|
126 |
+
class C3(nn.Module):
|
127 |
+
# CSP Bottleneck with 3 convolutions
|
128 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, act=True): # ch_in, ch_out, number, shortcut, groups, expansion
|
129 |
+
super().__init__()
|
130 |
+
c_ = int(c2 * e) # hidden channels
|
131 |
+
self.cv1 = Conv(c1, c_, 1, 1, act=act)
|
132 |
+
self.cv2 = Conv(c1, c_, 1, 1, act=act)
|
133 |
+
self.cv3 = Conv(2 * c_, c2, 1, act=act) # act=FReLU(c2)
|
134 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0, act=act) for _ in range(n)))
|
135 |
+
# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
|
139 |
+
|
140 |
+
|
141 |
+
class C3TR(C3):
|
142 |
+
# C3 module with TransformerBlock()
|
143 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
144 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
145 |
+
c_ = int(c2 * e)
|
146 |
+
self.m = TransformerBlock(c_, c_, 4, n)
|
147 |
+
|
148 |
+
|
149 |
+
class C3SPP(C3):
|
150 |
+
# C3 module with SPP()
|
151 |
+
def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
|
152 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
153 |
+
c_ = int(c2 * e)
|
154 |
+
self.m = SPP(c_, c_, k)
|
155 |
+
|
156 |
+
|
157 |
+
class C3Ghost(C3):
|
158 |
+
# C3 module with GhostBottleneck()
|
159 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
160 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
161 |
+
c_ = int(c2 * e) # hidden channels
|
162 |
+
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
|
163 |
+
|
164 |
+
|
165 |
+
class SPP(nn.Module):
|
166 |
+
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
|
167 |
+
def __init__(self, c1, c2, k=(5, 9, 13)):
|
168 |
+
super().__init__()
|
169 |
+
c_ = c1 // 2 # hidden channels
|
170 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
171 |
+
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
|
172 |
+
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
x = self.cv1(x)
|
176 |
+
with warnings.catch_warnings():
|
177 |
+
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
178 |
+
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
179 |
+
|
180 |
+
|
181 |
+
class SPPF(nn.Module):
|
182 |
+
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
|
183 |
+
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
184 |
+
super().__init__()
|
185 |
+
c_ = c1 // 2 # hidden channels
|
186 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
187 |
+
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
188 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
189 |
+
|
190 |
+
def forward(self, x):
|
191 |
+
x = self.cv1(x)
|
192 |
+
with warnings.catch_warnings():
|
193 |
+
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
194 |
+
y1 = self.m(x)
|
195 |
+
y2 = self.m(y1)
|
196 |
+
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
|
197 |
+
|
198 |
+
|
199 |
+
class Focus(nn.Module):
|
200 |
+
# Focus wh information into c-space
|
201 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
202 |
+
super().__init__()
|
203 |
+
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
|
204 |
+
# self.contract = Contract(gain=2)
|
205 |
+
|
206 |
+
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
207 |
+
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
|
208 |
+
# return self.conv(self.contract(x))
|
209 |
+
|
210 |
+
|
211 |
+
class GhostConv(nn.Module):
|
212 |
+
# Ghost Convolution https://github.com/huawei-noah/ghostnet
|
213 |
+
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
|
214 |
+
super().__init__()
|
215 |
+
c_ = c2 // 2 # hidden channels
|
216 |
+
self.cv1 = Conv(c1, c_, k, s, None, g, act)
|
217 |
+
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
y = self.cv1(x)
|
221 |
+
return torch.cat([y, self.cv2(y)], 1)
|
222 |
+
|
223 |
+
|
224 |
+
class GhostBottleneck(nn.Module):
|
225 |
+
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
|
226 |
+
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
|
227 |
+
super().__init__()
|
228 |
+
c_ = c2 // 2
|
229 |
+
self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
|
230 |
+
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
|
231 |
+
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
|
232 |
+
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
|
233 |
+
Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
|
234 |
+
|
235 |
+
def forward(self, x):
|
236 |
+
return self.conv(x) + self.shortcut(x)
|
237 |
+
|
238 |
+
|
239 |
+
class Contract(nn.Module):
|
240 |
+
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
|
241 |
+
def __init__(self, gain=2):
|
242 |
+
super().__init__()
|
243 |
+
self.gain = gain
|
244 |
+
|
245 |
+
def forward(self, x):
|
246 |
+
b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
|
247 |
+
s = self.gain
|
248 |
+
x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
|
249 |
+
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
|
250 |
+
return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
|
251 |
+
|
252 |
+
|
253 |
+
class Expand(nn.Module):
|
254 |
+
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
|
255 |
+
def __init__(self, gain=2):
|
256 |
+
super().__init__()
|
257 |
+
self.gain = gain
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
|
261 |
+
s = self.gain
|
262 |
+
x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
|
263 |
+
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
|
264 |
+
return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
|
265 |
+
|
266 |
+
|
267 |
+
class Concat(nn.Module):
|
268 |
+
# Concatenate a list of tensors along dimension
|
269 |
+
def __init__(self, dimension=1):
|
270 |
+
super().__init__()
|
271 |
+
self.d = dimension
|
272 |
+
|
273 |
+
def forward(self, x):
|
274 |
+
return torch.cat(x, self.d)
|
275 |
+
|
276 |
+
|
277 |
+
class Classify(nn.Module):
|
278 |
+
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
279 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
280 |
+
super().__init__()
|
281 |
+
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
|
282 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
|
283 |
+
self.flat = nn.Flatten()
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
|
287 |
+
return self.flat(self.conv(z)) # flatten to x(b,c2)
|
288 |
+
|
289 |
+
|
manga_translator/detection/ctd_utils/yolov5/yolo.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import mod
|
2 |
+
from cv2 import imshow
|
3 |
+
# from utils.yolov5_utils import scale_img
|
4 |
+
from copy import deepcopy
|
5 |
+
from .common import *
|
6 |
+
|
7 |
+
class Detect(nn.Module):
|
8 |
+
stride = None # strides computed during build
|
9 |
+
onnx_dynamic = False # ONNX export parameter
|
10 |
+
|
11 |
+
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
12 |
+
super().__init__()
|
13 |
+
self.nc = nc # number of classes
|
14 |
+
self.no = nc + 5 # number of outputs per anchor
|
15 |
+
self.nl = len(anchors) # number of detection layers
|
16 |
+
self.na = len(anchors[0]) // 2 # number of anchors
|
17 |
+
self.grid = [torch.zeros(1)] * self.nl # init grid
|
18 |
+
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
|
19 |
+
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
20 |
+
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
21 |
+
self.inplace = inplace # use in-place ops (e.g. slice assignment)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
z = [] # inference output
|
25 |
+
for i in range(self.nl):
|
26 |
+
x[i] = self.m[i](x[i]) # conv
|
27 |
+
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
28 |
+
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
29 |
+
|
30 |
+
if not self.training: # inference
|
31 |
+
if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
32 |
+
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
33 |
+
|
34 |
+
y = x[i].sigmoid()
|
35 |
+
if self.inplace:
|
36 |
+
y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
|
37 |
+
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
38 |
+
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
|
39 |
+
xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
|
40 |
+
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
41 |
+
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
42 |
+
z.append(y.view(bs, -1, self.no))
|
43 |
+
|
44 |
+
return x if self.training else (torch.cat(z, 1), x)
|
45 |
+
|
46 |
+
def _make_grid(self, nx=20, ny=20, i=0):
|
47 |
+
d = self.anchors[i].device
|
48 |
+
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
|
49 |
+
yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)], indexing='ij')
|
50 |
+
else:
|
51 |
+
yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)])
|
52 |
+
grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
|
53 |
+
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
|
54 |
+
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
|
55 |
+
return grid, anchor_grid
|
56 |
+
|
57 |
+
class Model(nn.Module):
|
58 |
+
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
59 |
+
super().__init__()
|
60 |
+
self.out_indices = None
|
61 |
+
if isinstance(cfg, dict):
|
62 |
+
self.yaml = cfg # model dict
|
63 |
+
else: # is *.yaml
|
64 |
+
import yaml # for torch hub
|
65 |
+
self.yaml_file = Path(cfg).name
|
66 |
+
with open(cfg, encoding='ascii', errors='ignore') as f:
|
67 |
+
self.yaml = yaml.safe_load(f) # model dict
|
68 |
+
|
69 |
+
# Define model
|
70 |
+
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
71 |
+
if nc and nc != self.yaml['nc']:
|
72 |
+
# LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
73 |
+
self.yaml['nc'] = nc # override yaml value
|
74 |
+
if anchors:
|
75 |
+
# LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
|
76 |
+
self.yaml['anchors'] = round(anchors) # override yaml value
|
77 |
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
78 |
+
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
79 |
+
self.inplace = self.yaml.get('inplace', True)
|
80 |
+
|
81 |
+
# Build strides, anchors
|
82 |
+
m = self.model[-1] # Detect()
|
83 |
+
# with torch.no_grad():
|
84 |
+
if isinstance(m, Detect):
|
85 |
+
s = 256 # 2x min stride
|
86 |
+
m.inplace = self.inplace
|
87 |
+
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
|
88 |
+
m.anchors /= m.stride.view(-1, 1, 1)
|
89 |
+
check_anchor_order(m)
|
90 |
+
self.stride = m.stride
|
91 |
+
self._initialize_biases() # only run once
|
92 |
+
|
93 |
+
# Init weights, biases
|
94 |
+
initialize_weights(self)
|
95 |
+
|
96 |
+
def forward(self, x, augment=False, profile=False, visualize=False, detect=False):
|
97 |
+
# if augment:
|
98 |
+
# return self._forward_augment(x) # augmented inference, None
|
99 |
+
return self._forward_once(x, profile, visualize, detect=detect) # single-scale inference, train
|
100 |
+
|
101 |
+
# def _forward_augment(self, x):
|
102 |
+
# img_size = x.shape[-2:] # height, width
|
103 |
+
# s = [1, 0.83, 0.67] # scales
|
104 |
+
# f = [None, 3, None] # flips (2-ud, 3-lr)
|
105 |
+
# y = [] # outputs
|
106 |
+
# for si, fi in zip(s, f):
|
107 |
+
# xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
108 |
+
# yi = self._forward_once(xi)[0] # forward
|
109 |
+
# # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
110 |
+
# yi = self._descale_pred(yi, fi, si, img_size)
|
111 |
+
# y.append(yi)
|
112 |
+
# y = self._clip_augmented(y) # clip augmented tails
|
113 |
+
# return torch.cat(y, 1), None # augmented inference, train
|
114 |
+
|
115 |
+
def _forward_once(self, x, profile=False, visualize=False, detect=False):
|
116 |
+
y, dt = [], [] # outputs
|
117 |
+
z = []
|
118 |
+
for ii, m in enumerate(self.model):
|
119 |
+
if m.f != -1: # if not from previous layer
|
120 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
121 |
+
if profile:
|
122 |
+
self._profile_one_layer(m, x, dt)
|
123 |
+
x = m(x) # run
|
124 |
+
y.append(x if m.i in self.save else None) # save output
|
125 |
+
if self.out_indices is not None:
|
126 |
+
if m.i in self.out_indices:
|
127 |
+
z.append(x)
|
128 |
+
if self.out_indices is not None:
|
129 |
+
if detect:
|
130 |
+
return x, z
|
131 |
+
else:
|
132 |
+
return z
|
133 |
+
else:
|
134 |
+
return x
|
135 |
+
|
136 |
+
def _descale_pred(self, p, flips, scale, img_size):
|
137 |
+
# de-scale predictions following augmented inference (inverse operation)
|
138 |
+
if self.inplace:
|
139 |
+
p[..., :4] /= scale # de-scale
|
140 |
+
if flips == 2:
|
141 |
+
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
|
142 |
+
elif flips == 3:
|
143 |
+
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
|
144 |
+
else:
|
145 |
+
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
|
146 |
+
if flips == 2:
|
147 |
+
y = img_size[0] - y # de-flip ud
|
148 |
+
elif flips == 3:
|
149 |
+
x = img_size[1] - x # de-flip lr
|
150 |
+
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
151 |
+
return p
|
152 |
+
|
153 |
+
def _clip_augmented(self, y):
|
154 |
+
# Clip YOLOv5 augmented inference tails
|
155 |
+
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
156 |
+
g = sum(4 ** x for x in range(nl)) # grid points
|
157 |
+
e = 1 # exclude layer count
|
158 |
+
i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
|
159 |
+
y[0] = y[0][:, :-i] # large
|
160 |
+
i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
161 |
+
y[-1] = y[-1][:, i:] # small
|
162 |
+
return y
|
163 |
+
|
164 |
+
def _profile_one_layer(self, m, x, dt):
|
165 |
+
c = isinstance(m, Detect) # is final layer, copy input as inplace fix
|
166 |
+
for _ in range(10):
|
167 |
+
m(x.copy() if c else x)
|
168 |
+
|
169 |
+
|
170 |
+
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
171 |
+
# https://arxiv.org/abs/1708.02002 section 3.3
|
172 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
173 |
+
m = self.model[-1] # Detect() module
|
174 |
+
for mi, s in zip(m.m, m.stride): # from
|
175 |
+
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
|
176 |
+
b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
|
177 |
+
b.data[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
178 |
+
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
179 |
+
|
180 |
+
def _print_biases(self):
|
181 |
+
m = self.model[-1] # Detect() module
|
182 |
+
for mi in m.m: # from
|
183 |
+
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
|
184 |
+
|
185 |
+
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
186 |
+
for m in self.model.modules():
|
187 |
+
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
188 |
+
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
189 |
+
delattr(m, 'bn') # remove batchnorm
|
190 |
+
m.forward = m.forward_fuse # update forward
|
191 |
+
# self.info()
|
192 |
+
return self
|
193 |
+
|
194 |
+
# def info(self, verbose=False, img_size=640): # print model information
|
195 |
+
# model_info(self, verbose, img_size)
|
196 |
+
|
197 |
+
def _apply(self, fn):
|
198 |
+
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
199 |
+
self = super()._apply(fn)
|
200 |
+
m = self.model[-1] # Detect()
|
201 |
+
if isinstance(m, Detect):
|
202 |
+
m.stride = fn(m.stride)
|
203 |
+
m.grid = list(map(fn, m.grid))
|
204 |
+
if isinstance(m.anchor_grid, list):
|
205 |
+
m.anchor_grid = list(map(fn, m.anchor_grid))
|
206 |
+
return self
|
207 |
+
|
208 |
+
def parse_model(d, ch): # model_dict, input_channels(3)
|
209 |
+
# LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
210 |
+
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
|
211 |
+
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
212 |
+
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
213 |
+
|
214 |
+
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
215 |
+
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
216 |
+
m = eval(m) if isinstance(m, str) else m # eval strings
|
217 |
+
for j, a in enumerate(args):
|
218 |
+
try:
|
219 |
+
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
220 |
+
except NameError:
|
221 |
+
pass
|
222 |
+
|
223 |
+
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
224 |
+
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
225 |
+
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
|
226 |
+
c1, c2 = ch[f], args[0]
|
227 |
+
if c2 != no: # if not output
|
228 |
+
c2 = make_divisible(c2 * gw, 8)
|
229 |
+
|
230 |
+
args = [c1, c2, *args[1:]]
|
231 |
+
if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
|
232 |
+
args.insert(2, n) # number of repeats
|
233 |
+
n = 1
|
234 |
+
elif m is nn.BatchNorm2d:
|
235 |
+
args = [ch[f]]
|
236 |
+
elif m is Concat:
|
237 |
+
c2 = sum(ch[x] for x in f)
|
238 |
+
elif m is Detect:
|
239 |
+
args.append([ch[x] for x in f])
|
240 |
+
if isinstance(args[1], int): # number of anchors
|
241 |
+
args[1] = [list(range(args[1] * 2))] * len(f)
|
242 |
+
elif m is Contract:
|
243 |
+
c2 = ch[f] * args[0] ** 2
|
244 |
+
elif m is Expand:
|
245 |
+
c2 = ch[f] // args[0] ** 2
|
246 |
+
else:
|
247 |
+
c2 = ch[f]
|
248 |
+
|
249 |
+
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
250 |
+
t = str(m)[8:-2].replace('__main__.', '') # module type
|
251 |
+
np = sum(x.numel() for x in m_.parameters()) # number params
|
252 |
+
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
|
253 |
+
# LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
|
254 |
+
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
255 |
+
layers.append(m_)
|
256 |
+
if i == 0:
|
257 |
+
ch = []
|
258 |
+
ch.append(c2)
|
259 |
+
return nn.Sequential(*layers), sorted(save)
|
260 |
+
|
261 |
+
def load_yolov5(weights, map_location='cuda', fuse=True, inplace=True, out_indices=[1, 3, 5, 7, 9]):
|
262 |
+
if isinstance(weights, str):
|
263 |
+
ckpt = torch.load(weights, map_location=map_location) # load
|
264 |
+
else:
|
265 |
+
ckpt = weights
|
266 |
+
|
267 |
+
if fuse:
|
268 |
+
model = ckpt['model'].float().fuse().eval() # FP32 model
|
269 |
+
else:
|
270 |
+
model = ckpt['model'].float().eval() # without layer fuse
|
271 |
+
|
272 |
+
# Compatibility updates
|
273 |
+
for m in model.modules():
|
274 |
+
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
|
275 |
+
m.inplace = inplace # pytorch 1.7.0 compatibility
|
276 |
+
if type(m) is Detect:
|
277 |
+
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
|
278 |
+
delattr(m, 'anchor_grid')
|
279 |
+
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
280 |
+
elif type(m) is Conv:
|
281 |
+
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
282 |
+
model.out_indices = out_indices
|
283 |
+
return model
|
284 |
+
|
285 |
+
@torch.no_grad()
|
286 |
+
def load_yolov5_ckpt(weights, map_location='cpu', fuse=True, inplace=True, out_indices=[1, 3, 5, 7, 9]):
|
287 |
+
if isinstance(weights, str):
|
288 |
+
ckpt = torch.load(weights, map_location=map_location) # load
|
289 |
+
else:
|
290 |
+
ckpt = weights
|
291 |
+
|
292 |
+
model = Model(ckpt['cfg'])
|
293 |
+
model.load_state_dict(ckpt['weights'], strict=True)
|
294 |
+
|
295 |
+
if fuse:
|
296 |
+
model = model.float().fuse().eval() # FP32 model
|
297 |
+
else:
|
298 |
+
model = model.float().eval() # without layer fuse
|
299 |
+
|
300 |
+
# Compatibility updates
|
301 |
+
for m in model.modules():
|
302 |
+
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
|
303 |
+
m.inplace = inplace # pytorch 1.7.0 compatibility
|
304 |
+
if type(m) is Detect:
|
305 |
+
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
|
306 |
+
delattr(m, 'anchor_grid')
|
307 |
+
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
308 |
+
elif type(m) is Conv:
|
309 |
+
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
310 |
+
model.out_indices = out_indices
|
311 |
+
return model
|
manga_translator/detection/dbnet_convnext.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from functools import partial
|
3 |
+
import shutil
|
4 |
+
from typing import Callable, Optional, Tuple, Union
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.nn.init as init
|
11 |
+
|
12 |
+
from torchvision.models import resnet34
|
13 |
+
|
14 |
+
import einops
|
15 |
+
import math
|
16 |
+
|
17 |
+
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
18 |
+
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
19 |
+
|
20 |
+
class Downsample(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
|
23 |
+
super().__init__()
|
24 |
+
avg_stride = stride if dilation == 1 else 1
|
25 |
+
if stride > 1 or dilation > 1:
|
26 |
+
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
27 |
+
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
28 |
+
else:
|
29 |
+
self.pool = nn.Identity()
|
30 |
+
|
31 |
+
if in_chs != out_chs:
|
32 |
+
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
33 |
+
else:
|
34 |
+
self.conv = nn.Identity()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.pool(x)
|
38 |
+
x = self.conv(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class ConvNeXtBlock(nn.Module):
|
43 |
+
""" ConvNeXt Block
|
44 |
+
There are two equivalent implementations:
|
45 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
46 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
47 |
+
|
48 |
+
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
49 |
+
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
50 |
+
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
in_chs: int,
|
56 |
+
out_chs: Optional[int] = None,
|
57 |
+
kernel_size: int = 7,
|
58 |
+
stride: int = 1,
|
59 |
+
dilation: Union[int, Tuple[int, int]] = (1, 1),
|
60 |
+
mlp_ratio: float = 4,
|
61 |
+
conv_mlp: bool = False,
|
62 |
+
conv_bias: bool = True,
|
63 |
+
use_grn: bool = False,
|
64 |
+
ls_init_value: Optional[float] = 1e-6,
|
65 |
+
act_layer: Union[str, Callable] = 'gelu',
|
66 |
+
norm_layer: Optional[Callable] = None,
|
67 |
+
drop_path: float = 0.,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
|
71 |
+
Args:
|
72 |
+
in_chs: Block input channels.
|
73 |
+
out_chs: Block output channels (same as in_chs if None).
|
74 |
+
kernel_size: Depthwise convolution kernel size.
|
75 |
+
stride: Stride of depthwise convolution.
|
76 |
+
dilation: Tuple specifying input and output dilation of block.
|
77 |
+
mlp_ratio: MLP expansion ratio.
|
78 |
+
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
|
79 |
+
conv_bias: Apply bias for all convolution (linear) layers.
|
80 |
+
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
|
81 |
+
ls_init_value: Layer-scale init values, layer-scale applied if not None.
|
82 |
+
act_layer: Activation layer.
|
83 |
+
norm_layer: Normalization layer (defaults to LN if not specified).
|
84 |
+
drop_path: Stochastic depth probability.
|
85 |
+
"""
|
86 |
+
super().__init__()
|
87 |
+
out_chs = out_chs or in_chs
|
88 |
+
dilation = to_ntuple(2)(dilation)
|
89 |
+
act_layer = get_act_layer(act_layer)
|
90 |
+
if not norm_layer:
|
91 |
+
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
92 |
+
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
|
93 |
+
self.use_conv_mlp = conv_mlp
|
94 |
+
self.conv_dw = create_conv2d(
|
95 |
+
in_chs,
|
96 |
+
out_chs,
|
97 |
+
kernel_size=kernel_size,
|
98 |
+
stride=stride,
|
99 |
+
dilation=dilation[0],
|
100 |
+
depthwise=True if out_chs >= in_chs else False,
|
101 |
+
bias=conv_bias,
|
102 |
+
)
|
103 |
+
self.norm = norm_layer(out_chs)
|
104 |
+
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
105 |
+
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
|
106 |
+
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
107 |
+
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
|
108 |
+
else:
|
109 |
+
self.shortcut = nn.Identity()
|
110 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
shortcut = x
|
114 |
+
x = self.conv_dw(x)
|
115 |
+
if self.use_conv_mlp:
|
116 |
+
x = self.norm(x)
|
117 |
+
x = self.mlp(x)
|
118 |
+
else:
|
119 |
+
x = x.permute(0, 2, 3, 1)
|
120 |
+
x = self.norm(x)
|
121 |
+
x = self.mlp(x)
|
122 |
+
x = x.permute(0, 3, 1, 2)
|
123 |
+
if self.gamma is not None:
|
124 |
+
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
125 |
+
|
126 |
+
x = self.drop_path(x) + self.shortcut(shortcut)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class ConvNeXtStage(nn.Module):
|
131 |
+
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
in_chs,
|
135 |
+
out_chs,
|
136 |
+
kernel_size=7,
|
137 |
+
stride=2,
|
138 |
+
depth=2,
|
139 |
+
dilation=(1, 1),
|
140 |
+
drop_path_rates=None,
|
141 |
+
ls_init_value=1.0,
|
142 |
+
conv_mlp=False,
|
143 |
+
conv_bias=True,
|
144 |
+
use_grn=False,
|
145 |
+
act_layer='gelu',
|
146 |
+
norm_layer=None,
|
147 |
+
norm_layer_cl=None
|
148 |
+
):
|
149 |
+
super().__init__()
|
150 |
+
self.grad_checkpointing = False
|
151 |
+
|
152 |
+
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
|
153 |
+
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
|
154 |
+
pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
|
155 |
+
self.downsample = nn.Sequential(
|
156 |
+
norm_layer(in_chs),
|
157 |
+
create_conv2d(
|
158 |
+
in_chs,
|
159 |
+
out_chs,
|
160 |
+
kernel_size=ds_ks,
|
161 |
+
stride=stride,
|
162 |
+
dilation=dilation[0],
|
163 |
+
padding=pad,
|
164 |
+
bias=conv_bias,
|
165 |
+
),
|
166 |
+
)
|
167 |
+
in_chs = out_chs
|
168 |
+
else:
|
169 |
+
self.downsample = nn.Identity()
|
170 |
+
|
171 |
+
drop_path_rates = drop_path_rates or [0.] * depth
|
172 |
+
stage_blocks = []
|
173 |
+
for i in range(depth):
|
174 |
+
stage_blocks.append(ConvNeXtBlock(
|
175 |
+
in_chs=in_chs,
|
176 |
+
out_chs=out_chs,
|
177 |
+
kernel_size=kernel_size,
|
178 |
+
dilation=dilation[1],
|
179 |
+
drop_path=drop_path_rates[i],
|
180 |
+
ls_init_value=ls_init_value,
|
181 |
+
conv_mlp=conv_mlp,
|
182 |
+
conv_bias=conv_bias,
|
183 |
+
use_grn=use_grn,
|
184 |
+
act_layer=act_layer,
|
185 |
+
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
|
186 |
+
))
|
187 |
+
in_chs = out_chs
|
188 |
+
self.blocks = nn.Sequential(*stage_blocks)
|
189 |
+
|
190 |
+
def forward(self, x):
|
191 |
+
x = self.downsample(x)
|
192 |
+
x = self.blocks(x)
|
193 |
+
return x
|
194 |
+
|
195 |
+
|
196 |
+
class ConvNeXt(nn.Module):
|
197 |
+
r""" ConvNeXt
|
198 |
+
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
in_chans: int = 3,
|
204 |
+
num_classes: int = 1000,
|
205 |
+
global_pool: str = 'avg',
|
206 |
+
output_stride: int = 32,
|
207 |
+
depths: Tuple[int, ...] = (3, 3, 9, 3),
|
208 |
+
dims: Tuple[int, ...] = (96, 192, 384, 768),
|
209 |
+
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
|
210 |
+
ls_init_value: Optional[float] = 1e-6,
|
211 |
+
stem_type: str = 'patch',
|
212 |
+
patch_size: int = 4,
|
213 |
+
head_init_scale: float = 1.,
|
214 |
+
head_norm_first: bool = False,
|
215 |
+
head_hidden_size: Optional[int] = None,
|
216 |
+
conv_mlp: bool = False,
|
217 |
+
conv_bias: bool = True,
|
218 |
+
use_grn: bool = False,
|
219 |
+
act_layer: Union[str, Callable] = 'gelu',
|
220 |
+
norm_layer: Optional[Union[str, Callable]] = None,
|
221 |
+
norm_eps: Optional[float] = None,
|
222 |
+
drop_rate: float = 0.,
|
223 |
+
drop_path_rate: float = 0.,
|
224 |
+
):
|
225 |
+
"""
|
226 |
+
Args:
|
227 |
+
in_chans: Number of input image channels.
|
228 |
+
num_classes: Number of classes for classification head.
|
229 |
+
global_pool: Global pooling type.
|
230 |
+
output_stride: Output stride of network, one of (8, 16, 32).
|
231 |
+
depths: Number of blocks at each stage.
|
232 |
+
dims: Feature dimension at each stage.
|
233 |
+
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
|
234 |
+
ls_init_value: Init value for Layer Scale, disabled if None.
|
235 |
+
stem_type: Type of stem.
|
236 |
+
patch_size: Stem patch size for patch stem.
|
237 |
+
head_init_scale: Init scaling value for classifier weights and biases.
|
238 |
+
head_norm_first: Apply normalization before global pool + head.
|
239 |
+
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
|
240 |
+
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
|
241 |
+
conv_bias: Use bias layers w/ all convolutions.
|
242 |
+
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
|
243 |
+
act_layer: Activation layer type.
|
244 |
+
norm_layer: Normalization layer type.
|
245 |
+
drop_rate: Head pre-classifier dropout rate.
|
246 |
+
drop_path_rate: Stochastic depth drop rate.
|
247 |
+
"""
|
248 |
+
super().__init__()
|
249 |
+
assert output_stride in (8, 16, 32)
|
250 |
+
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
251 |
+
if norm_layer is None:
|
252 |
+
norm_layer = LayerNorm2d
|
253 |
+
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
254 |
+
if norm_eps is not None:
|
255 |
+
norm_layer = partial(norm_layer, eps=norm_eps)
|
256 |
+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
257 |
+
else:
|
258 |
+
assert conv_mlp,\
|
259 |
+
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
260 |
+
norm_layer_cl = norm_layer
|
261 |
+
if norm_eps is not None:
|
262 |
+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
263 |
+
|
264 |
+
self.num_classes = num_classes
|
265 |
+
self.drop_rate = drop_rate
|
266 |
+
self.feature_info = []
|
267 |
+
|
268 |
+
assert stem_type in ('patch', 'overlap', 'overlap_tiered')
|
269 |
+
if stem_type == 'patch':
|
270 |
+
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
271 |
+
self.stem = nn.Sequential(
|
272 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
|
273 |
+
norm_layer(dims[0]),
|
274 |
+
)
|
275 |
+
stem_stride = patch_size
|
276 |
+
else:
|
277 |
+
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
|
278 |
+
self.stem = nn.Sequential(
|
279 |
+
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
280 |
+
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
281 |
+
norm_layer(dims[0]),
|
282 |
+
)
|
283 |
+
stem_stride = 4
|
284 |
+
|
285 |
+
self.stages = nn.Sequential()
|
286 |
+
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
287 |
+
stages = []
|
288 |
+
prev_chs = dims[0]
|
289 |
+
curr_stride = stem_stride
|
290 |
+
dilation = 1
|
291 |
+
# 4 feature resolution stages, each consisting of multiple residual blocks
|
292 |
+
for i in range(4):
|
293 |
+
stride = 2 if curr_stride == 2 or i > 0 else 1
|
294 |
+
if curr_stride >= output_stride and stride > 1:
|
295 |
+
dilation *= stride
|
296 |
+
stride = 1
|
297 |
+
curr_stride *= stride
|
298 |
+
first_dilation = 1 if dilation in (1, 2) else 2
|
299 |
+
out_chs = dims[i]
|
300 |
+
stages.append(ConvNeXtStage(
|
301 |
+
prev_chs,
|
302 |
+
out_chs,
|
303 |
+
kernel_size=kernel_sizes[i],
|
304 |
+
stride=stride,
|
305 |
+
dilation=(first_dilation, dilation),
|
306 |
+
depth=depths[i],
|
307 |
+
drop_path_rates=dp_rates[i],
|
308 |
+
ls_init_value=ls_init_value,
|
309 |
+
conv_mlp=conv_mlp,
|
310 |
+
conv_bias=conv_bias,
|
311 |
+
use_grn=use_grn,
|
312 |
+
act_layer=act_layer,
|
313 |
+
norm_layer=norm_layer,
|
314 |
+
norm_layer_cl=norm_layer_cl,
|
315 |
+
))
|
316 |
+
prev_chs = out_chs
|
317 |
+
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
318 |
+
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
|
319 |
+
self.stages = nn.Sequential(*stages)
|
320 |
+
self.num_features = prev_chs
|
321 |
+
|
322 |
+
@torch.jit.ignore
|
323 |
+
def group_matcher(self, coarse=False):
|
324 |
+
return dict(
|
325 |
+
stem=r'^stem',
|
326 |
+
blocks=r'^stages\.(\d+)' if coarse else [
|
327 |
+
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
|
328 |
+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
329 |
+
(r'^norm_pre', (99999,))
|
330 |
+
]
|
331 |
+
)
|
332 |
+
|
333 |
+
@torch.jit.ignore
|
334 |
+
def set_grad_checkpointing(self, enable=True):
|
335 |
+
for s in self.stages:
|
336 |
+
s.grad_checkpointing = enable
|
337 |
+
|
338 |
+
@torch.jit.ignore
|
339 |
+
def get_classifier(self):
|
340 |
+
return self.head.fc
|
341 |
+
|
342 |
+
def forward_features(self, x):
|
343 |
+
x = self.stem(x)
|
344 |
+
x = self.stages(x)
|
345 |
+
return x
|
346 |
+
|
347 |
+
def _init_weights(module, name=None, head_init_scale=1.0):
|
348 |
+
if isinstance(module, nn.Conv2d):
|
349 |
+
trunc_normal_(module.weight, std=.02)
|
350 |
+
if module.bias is not None:
|
351 |
+
nn.init.zeros_(module.bias)
|
352 |
+
elif isinstance(module, nn.Linear):
|
353 |
+
trunc_normal_(module.weight, std=.02)
|
354 |
+
nn.init.zeros_(module.bias)
|
355 |
+
if name and 'head.' in name:
|
356 |
+
module.weight.data.mul_(head_init_scale)
|
357 |
+
module.bias.data.mul_(head_init_scale)
|
358 |
+
|
359 |
+
class UpconvSkip(nn.Module) :
|
360 |
+
def __init__(self, ch1, ch2, out_ch) -> None:
|
361 |
+
super().__init__()
|
362 |
+
self.conv = ConvNeXtBlock(
|
363 |
+
in_chs=ch1 + ch2,
|
364 |
+
out_chs=out_ch,
|
365 |
+
kernel_size=7,
|
366 |
+
dilation=1,
|
367 |
+
drop_path=0,
|
368 |
+
ls_init_value=1.0,
|
369 |
+
conv_mlp=False,
|
370 |
+
conv_bias=True,
|
371 |
+
use_grn=False,
|
372 |
+
act_layer='gelu',
|
373 |
+
norm_layer=LayerNorm,
|
374 |
+
)
|
375 |
+
self.upconv = nn.ConvTranspose2d(out_ch, out_ch, 2, 2, 0, 0)
|
376 |
+
|
377 |
+
def forward(self, x) :
|
378 |
+
x = self.conv(x)
|
379 |
+
x = self.upconv(x)
|
380 |
+
return x
|
381 |
+
|
382 |
+
class DBHead(nn.Module):
|
383 |
+
def __init__(self, in_channels, k = 50):
|
384 |
+
super().__init__()
|
385 |
+
self.k = k
|
386 |
+
self.binarize = nn.Sequential(
|
387 |
+
nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
|
388 |
+
#nn.BatchNorm2d(in_channels // 4),
|
389 |
+
nn.SiLU(inplace=True),
|
390 |
+
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 4, 2, 1),
|
391 |
+
#nn.BatchNorm2d(in_channels // 4),
|
392 |
+
nn.SiLU(inplace=True),
|
393 |
+
nn.ConvTranspose2d(in_channels // 4, 1, 4, 2, 1),
|
394 |
+
)
|
395 |
+
self.binarize.apply(self.weights_init)
|
396 |
+
|
397 |
+
self.thresh = self._init_thresh(in_channels)
|
398 |
+
self.thresh.apply(self.weights_init)
|
399 |
+
|
400 |
+
def forward(self, x):
|
401 |
+
shrink_maps = self.binarize(x)
|
402 |
+
threshold_maps = self.thresh(x)
|
403 |
+
if self.training:
|
404 |
+
binary_maps = self.step_function(shrink_maps.sigmoid(), threshold_maps)
|
405 |
+
y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
|
406 |
+
else:
|
407 |
+
y = torch.cat((shrink_maps, threshold_maps), dim=1)
|
408 |
+
return y
|
409 |
+
|
410 |
+
def weights_init(self, m):
|
411 |
+
classname = m.__class__.__name__
|
412 |
+
if classname.find('Conv') != -1:
|
413 |
+
nn.init.kaiming_normal_(m.weight.data)
|
414 |
+
elif classname.find('BatchNorm') != -1:
|
415 |
+
m.weight.data.fill_(1.)
|
416 |
+
m.bias.data.fill_(1e-4)
|
417 |
+
|
418 |
+
def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
|
419 |
+
in_channels = inner_channels
|
420 |
+
if serial:
|
421 |
+
in_channels += 1
|
422 |
+
self.thresh = nn.Sequential(
|
423 |
+
nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
424 |
+
#nn.GroupNorm(inner_channels // 4),
|
425 |
+
nn.SiLU(inplace=True),
|
426 |
+
self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
|
427 |
+
#nn.GroupNorm(inner_channels // 4),
|
428 |
+
nn.SiLU(inplace=True),
|
429 |
+
self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
|
430 |
+
nn.Sigmoid())
|
431 |
+
return self.thresh
|
432 |
+
|
433 |
+
def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
|
434 |
+
if smooth:
|
435 |
+
inter_out_channels = out_channels
|
436 |
+
if out_channels == 1:
|
437 |
+
inter_out_channels = in_channels
|
438 |
+
module_list = [
|
439 |
+
nn.Upsample(scale_factor=2, mode='bilinear'),
|
440 |
+
nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
|
441 |
+
if out_channels == 1:
|
442 |
+
module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
|
443 |
+
return nn.Sequential(module_list)
|
444 |
+
else:
|
445 |
+
return nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)
|
446 |
+
|
447 |
+
def step_function(self, x, y):
|
448 |
+
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
|
449 |
+
|
450 |
+
class DBNetConvNext(nn.Module) :
|
451 |
+
def __init__(self) :
|
452 |
+
super(DBNetConvNext, self).__init__()
|
453 |
+
self.backbone = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
|
454 |
+
|
455 |
+
self.conv_mask = nn.Sequential(
|
456 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.SiLU(inplace=True),
|
457 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.SiLU(inplace=True),
|
458 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
459 |
+
nn.Sigmoid()
|
460 |
+
)
|
461 |
+
|
462 |
+
self.down_conv1 = ConvNeXtStage(1024, 1024, depth = 2, norm_layer = LayerNorm2d)
|
463 |
+
self.down_conv2 = ConvNeXtStage(1024, 1024, depth = 2, norm_layer = LayerNorm2d)
|
464 |
+
|
465 |
+
self.upconv1 = UpconvSkip(0, 1024, 128)
|
466 |
+
self.upconv2 = UpconvSkip(128, 1024, 128)
|
467 |
+
self.upconv3 = UpconvSkip(128, 1024, 128)
|
468 |
+
self.upconv4 = UpconvSkip(128, 512, 128)
|
469 |
+
self.upconv5 = UpconvSkip(128, 256, 128)
|
470 |
+
self.upconv6 = UpconvSkip(128, 128, 64)
|
471 |
+
|
472 |
+
self.conv_db = DBHead(128)
|
473 |
+
|
474 |
+
def forward(self, x) :
|
475 |
+
# in 3@1536
|
476 |
+
x = self.backbone.stem(x) # 128@384
|
477 |
+
h4 = self.backbone.stages[0](x) # 128@384
|
478 |
+
h8 = self.backbone.stages[1](h4) # 256@192
|
479 |
+
h16 = self.backbone.stages[2](h8) # 512@96
|
480 |
+
h32 = self.backbone.stages[3](h16) # 1024@48
|
481 |
+
h64 = self.down_conv1(h32) # 1024@24
|
482 |
+
h128 = self.down_conv2(h64) # 1024@12
|
483 |
+
|
484 |
+
up128 = self.upconv1(h128)
|
485 |
+
up64 = self.upconv2(torch.cat([up128, h64], dim = 1))
|
486 |
+
up32 = self.upconv3(torch.cat([up64, h32], dim = 1))
|
487 |
+
up16 = self.upconv4(torch.cat([up32, h16], dim = 1))
|
488 |
+
up8 = self.upconv5(torch.cat([up16, h8], dim = 1))
|
489 |
+
up4 = self.upconv6(torch.cat([up8, h4], dim = 1))
|
490 |
+
|
491 |
+
return self.conv_db(up8), self.conv_mask(up4)
|
492 |
+
|
493 |
+
import os
|
494 |
+
from .default_utils import imgproc, dbnet_utils, craft_utils
|
495 |
+
from .common import OfflineDetector
|
496 |
+
from ..utils import TextBlock, Quadrilateral, det_rearrange_forward
|
497 |
+
|
498 |
+
MODEL = None
|
499 |
+
def det_batch_forward_default(batch: np.ndarray, device: str):
|
500 |
+
global MODEL
|
501 |
+
if isinstance(batch, list):
|
502 |
+
batch = np.array(batch)
|
503 |
+
batch = einops.rearrange(batch.astype(np.float32) / 127.5 - 1.0, 'n h w c -> n c h w')
|
504 |
+
batch = torch.from_numpy(batch).to(device)
|
505 |
+
with torch.no_grad():
|
506 |
+
db, mask = MODEL(batch)
|
507 |
+
db = db.sigmoid().cpu().numpy()
|
508 |
+
mask = mask.cpu().numpy()
|
509 |
+
return db, mask
|
510 |
+
|
511 |
+
|
512 |
+
class DBConvNextDetector(OfflineDetector):
|
513 |
+
_MODEL_MAPPING = {
|
514 |
+
'model': {
|
515 |
+
'url': '',
|
516 |
+
'hash': '',
|
517 |
+
'file': '.',
|
518 |
+
}
|
519 |
+
}
|
520 |
+
|
521 |
+
def __init__(self, *args, **kwargs):
|
522 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
523 |
+
if os.path.exists('dbnet_convnext.ckpt'):
|
524 |
+
shutil.move('dbnet_convnext.ckpt', self._get_file_path('dbnet_convnext.ckpt'))
|
525 |
+
super().__init__(*args, **kwargs)
|
526 |
+
|
527 |
+
async def _load(self, device: str):
|
528 |
+
self.model = DBNetConvNext()
|
529 |
+
sd = torch.load(self._get_file_path('dbnet_convnext.ckpt'), map_location='cpu')
|
530 |
+
self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
|
531 |
+
self.model.eval()
|
532 |
+
self.device = device
|
533 |
+
if device == 'cuda' or device == 'mps':
|
534 |
+
self.model = self.model.to(self.device)
|
535 |
+
global MODEL
|
536 |
+
MODEL = self.model
|
537 |
+
|
538 |
+
async def _unload(self):
|
539 |
+
del self.model
|
540 |
+
|
541 |
+
async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
|
542 |
+
unclip_ratio: float, verbose: bool = False):
|
543 |
+
|
544 |
+
# TODO: Move det_rearrange_forward to common.py and refactor
|
545 |
+
db, mask = det_rearrange_forward(image, det_batch_forward_default, detect_size, 4, device=self.device, verbose=verbose)
|
546 |
+
|
547 |
+
if db is None:
|
548 |
+
# rearrangement is not required, fallback to default forward
|
549 |
+
img_resized, target_ratio, _, pad_w, pad_h = imgproc.resize_aspect_ratio(cv2.bilateralFilter(image, 17, 80, 80), detect_size, cv2.INTER_LINEAR, mag_ratio = 1)
|
550 |
+
img_resized_h, img_resized_w = img_resized.shape[:2]
|
551 |
+
ratio_h = ratio_w = 1 / target_ratio
|
552 |
+
db, mask = det_batch_forward_default([img_resized], self.device)
|
553 |
+
else:
|
554 |
+
img_resized_h, img_resized_w = image.shape[:2]
|
555 |
+
ratio_w = ratio_h = 1
|
556 |
+
pad_h = pad_w = 0
|
557 |
+
self.logger.info(f'Detection resolution: {img_resized_w}x{img_resized_h}')
|
558 |
+
|
559 |
+
mask = mask[0, 0, :, :]
|
560 |
+
det = dbnet_utils.SegDetectorRepresenter(text_threshold, box_threshold, unclip_ratio=unclip_ratio)
|
561 |
+
# boxes, scores = det({'shape': [(img_resized.shape[0], img_resized.shape[1])]}, db)
|
562 |
+
boxes, scores = det({'shape':[(img_resized_h, img_resized_w)]}, db)
|
563 |
+
boxes, scores = boxes[0], scores[0]
|
564 |
+
if boxes.size == 0:
|
565 |
+
polys = []
|
566 |
+
else:
|
567 |
+
idx = boxes.reshape(boxes.shape[0], -1).sum(axis=1) > 0
|
568 |
+
polys, _ = boxes[idx], scores[idx]
|
569 |
+
polys = polys.astype(np.float64)
|
570 |
+
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=1)
|
571 |
+
polys = polys.astype(np.int16)
|
572 |
+
|
573 |
+
textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(polys, scores)]
|
574 |
+
textlines = list(filter(lambda q: q.area > 16, textlines))
|
575 |
+
mask_resized = cv2.resize(mask, (mask.shape[1] * 2, mask.shape[0] * 2), interpolation=cv2.INTER_LINEAR)
|
576 |
+
if pad_h > 0:
|
577 |
+
mask_resized = mask_resized[:-pad_h, :]
|
578 |
+
elif pad_w > 0:
|
579 |
+
mask_resized = mask_resized[:, :-pad_w]
|
580 |
+
raw_mask = np.clip(mask_resized * 255, 0, 255).astype(np.uint8)
|
581 |
+
|
582 |
+
# if verbose:
|
583 |
+
# img_bbox_raw = np.copy(image)
|
584 |
+
# for txtln in textlines:
|
585 |
+
# cv2.polylines(img_bbox_raw, [txtln.pts], True, color=(255, 0, 0), thickness=2)
|
586 |
+
# cv2.imwrite(f'result/bboxes_unfiltered.png', cv2.cvtColor(img_bbox_raw, cv2.COLOR_RGB2BGR))
|
587 |
+
|
588 |
+
return textlines, raw_mask, None
|
589 |
+
|
590 |
+
|
591 |
+
if __name__ == '__main__' :
|
592 |
+
net = DBNetConvNext().cuda()
|
593 |
+
img = torch.randn(2, 3, 1536, 1536).cuda()
|
594 |
+
ret1, ret2 = net.forward(img)
|
595 |
+
print(ret1.shape)
|
596 |
+
print(ret2.shape)
|
manga_translator/detection/default.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import cv2
|
6 |
+
import einops
|
7 |
+
from typing import List, Tuple
|
8 |
+
|
9 |
+
from .default_utils.DBNet_resnet34 import TextDetection as TextDetectionDefault
|
10 |
+
from .default_utils import imgproc, dbnet_utils, craft_utils
|
11 |
+
from .common import OfflineDetector
|
12 |
+
from ..utils import TextBlock, Quadrilateral, det_rearrange_forward
|
13 |
+
|
14 |
+
MODEL = None
|
15 |
+
def det_batch_forward_default(batch: np.ndarray, device: str):
|
16 |
+
global MODEL
|
17 |
+
if isinstance(batch, list):
|
18 |
+
batch = np.array(batch)
|
19 |
+
batch = einops.rearrange(batch.astype(np.float32) / 127.5 - 1.0, 'n h w c -> n c h w')
|
20 |
+
batch = torch.from_numpy(batch).to(device)
|
21 |
+
with torch.no_grad():
|
22 |
+
db, mask = MODEL(batch)
|
23 |
+
db = db.sigmoid().cpu().numpy()
|
24 |
+
mask = mask.cpu().numpy()
|
25 |
+
return db, mask
|
26 |
+
|
27 |
+
class DefaultDetector(OfflineDetector):
|
28 |
+
_MODEL_MAPPING = {
|
29 |
+
'model': {
|
30 |
+
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/detect.ckpt',
|
31 |
+
'hash': '69080aea78de0803092bc6b751ae283ca463011de5f07e1d20e6491b05571a30',
|
32 |
+
'file': '.',
|
33 |
+
}
|
34 |
+
}
|
35 |
+
|
36 |
+
def __init__(self, *args, **kwargs):
|
37 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
38 |
+
if os.path.exists('detect.ckpt'):
|
39 |
+
shutil.move('detect.ckpt', self._get_file_path('detect.ckpt'))
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
|
42 |
+
async def _load(self, device: str):
|
43 |
+
self.model = TextDetectionDefault()
|
44 |
+
sd = torch.load(self._get_file_path('detect.ckpt'), map_location='cpu')
|
45 |
+
self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
|
46 |
+
self.model.eval()
|
47 |
+
self.device = device
|
48 |
+
if device == 'cuda' or device == 'mps':
|
49 |
+
self.model = self.model.to(self.device)
|
50 |
+
global MODEL
|
51 |
+
MODEL = self.model
|
52 |
+
|
53 |
+
async def _unload(self):
|
54 |
+
del self.model
|
55 |
+
|
56 |
+
async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
|
57 |
+
unclip_ratio: float, verbose: bool = False):
|
58 |
+
|
59 |
+
# TODO: Move det_rearrange_forward to common.py and refactor
|
60 |
+
db, mask = det_rearrange_forward(image, det_batch_forward_default, detect_size, 4, device=self.device, verbose=verbose)
|
61 |
+
|
62 |
+
if db is None:
|
63 |
+
# rearrangement is not required, fallback to default forward
|
64 |
+
img_resized, target_ratio, _, pad_w, pad_h = imgproc.resize_aspect_ratio(cv2.bilateralFilter(image, 17, 80, 80), detect_size, cv2.INTER_LINEAR, mag_ratio = 1)
|
65 |
+
img_resized_h, img_resized_w = img_resized.shape[:2]
|
66 |
+
ratio_h = ratio_w = 1 / target_ratio
|
67 |
+
db, mask = det_batch_forward_default([img_resized], self.device)
|
68 |
+
else:
|
69 |
+
img_resized_h, img_resized_w = image.shape[:2]
|
70 |
+
ratio_w = ratio_h = 1
|
71 |
+
pad_h = pad_w = 0
|
72 |
+
self.logger.info(f'Detection resolution: {img_resized_w}x{img_resized_h}')
|
73 |
+
|
74 |
+
mask = mask[0, 0, :, :]
|
75 |
+
det = dbnet_utils.SegDetectorRepresenter(text_threshold, box_threshold, unclip_ratio=unclip_ratio)
|
76 |
+
# boxes, scores = det({'shape': [(img_resized.shape[0], img_resized.shape[1])]}, db)
|
77 |
+
boxes, scores = det({'shape':[(img_resized_h, img_resized_w)]}, db)
|
78 |
+
boxes, scores = boxes[0], scores[0]
|
79 |
+
if boxes.size == 0:
|
80 |
+
polys = []
|
81 |
+
else:
|
82 |
+
idx = boxes.reshape(boxes.shape[0], -1).sum(axis=1) > 0
|
83 |
+
polys, _ = boxes[idx], scores[idx]
|
84 |
+
polys = polys.astype(np.float64)
|
85 |
+
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=1)
|
86 |
+
polys = polys.astype(np.int16)
|
87 |
+
|
88 |
+
textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(polys, scores)]
|
89 |
+
textlines = list(filter(lambda q: q.area > 16, textlines))
|
90 |
+
mask_resized = cv2.resize(mask, (mask.shape[1] * 2, mask.shape[0] * 2), interpolation=cv2.INTER_LINEAR)
|
91 |
+
if pad_h > 0:
|
92 |
+
mask_resized = mask_resized[:-pad_h, :]
|
93 |
+
elif pad_w > 0:
|
94 |
+
mask_resized = mask_resized[:, :-pad_w]
|
95 |
+
raw_mask = np.clip(mask_resized * 255, 0, 255).astype(np.uint8)
|
96 |
+
|
97 |
+
# if verbose:
|
98 |
+
# img_bbox_raw = np.copy(image)
|
99 |
+
# for txtln in textlines:
|
100 |
+
# cv2.polylines(img_bbox_raw, [txtln.pts], True, color=(255, 0, 0), thickness=2)
|
101 |
+
# cv2.imwrite(f'result/bboxes_unfiltered.png', cv2.cvtColor(img_bbox_raw, cv2.COLOR_RGB2BGR))
|
102 |
+
|
103 |
+
return textlines, raw_mask, None
|
manga_translator/detection/default_utils/CRAFT_resnet34.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn.init as init
|
6 |
+
|
7 |
+
from torchvision.models import resnet34
|
8 |
+
|
9 |
+
import einops
|
10 |
+
import math
|
11 |
+
|
12 |
+
class ImageMultiheadSelfAttention(nn.Module):
|
13 |
+
def __init__(self, planes):
|
14 |
+
super(ImageMultiheadSelfAttention, self).__init__()
|
15 |
+
self.attn = nn.MultiheadAttention(planes, 4)
|
16 |
+
def forward(self, x):
|
17 |
+
res = x
|
18 |
+
n, c, h, w = x.shape
|
19 |
+
x = einops.rearrange(x, 'n c h w -> (h w) n c')
|
20 |
+
x = self.attn(x, x, x)[0]
|
21 |
+
x = einops.rearrange(x, '(h w) n c -> n c h w', n = n, c = c, h = h, w = w)
|
22 |
+
return res + x
|
23 |
+
|
24 |
+
class double_conv(nn.Module):
|
25 |
+
def __init__(self, in_ch, mid_ch, out_ch, stride = 1, planes = 256):
|
26 |
+
super(double_conv, self).__init__()
|
27 |
+
self.planes = planes
|
28 |
+
# down = None
|
29 |
+
# if stride > 1:
|
30 |
+
# down = nn.Sequential(
|
31 |
+
# nn.AvgPool2d(2, 2),
|
32 |
+
# nn.Conv2d(in_ch + mid_ch, self.planes * Bottleneck.expansion, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(self.planes * Bottleneck.expansion)
|
33 |
+
# )
|
34 |
+
self.down = None
|
35 |
+
if stride > 1:
|
36 |
+
self.down = nn.AvgPool2d(2,stride=2)
|
37 |
+
self.conv = nn.Sequential(
|
38 |
+
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=3, padding=1, stride = 1, bias=False),
|
39 |
+
nn.BatchNorm2d(mid_ch),
|
40 |
+
nn.ReLU(inplace=True),
|
41 |
+
#Bottleneck(mid_ch, self.planes, stride, down, 2, 1, avd = True, norm_layer = nn.BatchNorm2d),
|
42 |
+
nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride = 1, padding=1, bias=False),
|
43 |
+
nn.BatchNorm2d(out_ch),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
if self.down is not None:
|
49 |
+
x = self.down(x)
|
50 |
+
x = self.conv(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
class CRAFT_net(nn.Module):
|
54 |
+
def __init__(self):
|
55 |
+
super(CRAFT_net, self).__init__()
|
56 |
+
self.backbone = resnet34()
|
57 |
+
|
58 |
+
self.conv_rs = nn.Sequential(
|
59 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
60 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
61 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
62 |
+
nn.Conv2d(32, 32, kernel_size=1), nn.ReLU(inplace=True),
|
63 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
64 |
+
nn.Sigmoid()
|
65 |
+
)
|
66 |
+
|
67 |
+
self.conv_as = nn.Sequential(
|
68 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
69 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
70 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
71 |
+
nn.Conv2d(32, 32, kernel_size=1), nn.ReLU(inplace=True),
|
72 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
73 |
+
nn.Sigmoid()
|
74 |
+
)
|
75 |
+
|
76 |
+
self.conv_mask = nn.Sequential(
|
77 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
78 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
79 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
80 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
81 |
+
nn.Sigmoid()
|
82 |
+
)
|
83 |
+
|
84 |
+
self.down_conv1 = double_conv(0, 512, 512, 2)
|
85 |
+
self.down_conv2 = double_conv(0, 512, 512, 2)
|
86 |
+
self.down_conv3 = double_conv(0, 512, 512, 2)
|
87 |
+
|
88 |
+
self.upconv1 = double_conv(0, 512, 256)
|
89 |
+
self.upconv2 = double_conv(256, 512, 256)
|
90 |
+
self.upconv3 = double_conv(256, 512, 256)
|
91 |
+
self.upconv4 = double_conv(256, 512, 256, planes = 128)
|
92 |
+
self.upconv5 = double_conv(256, 256, 128, planes = 64)
|
93 |
+
self.upconv6 = double_conv(128, 128, 64, planes = 32)
|
94 |
+
self.upconv7 = double_conv(64, 64, 64, planes = 16)
|
95 |
+
|
96 |
+
def forward_train(self, x):
|
97 |
+
x = self.backbone.conv1(x)
|
98 |
+
x = self.backbone.bn1(x)
|
99 |
+
x = self.backbone.relu(x)
|
100 |
+
x = self.backbone.maxpool(x) # 64@384
|
101 |
+
|
102 |
+
h4 = self.backbone.layer1(x) # 64@384
|
103 |
+
h8 = self.backbone.layer2(h4) # 128@192
|
104 |
+
h16 = self.backbone.layer3(h8) # 256@96
|
105 |
+
h32 = self.backbone.layer4(h16) # 512@48
|
106 |
+
h64 = self.down_conv1(h32) # 512@24
|
107 |
+
h128 = self.down_conv2(h64) # 512@12
|
108 |
+
h256 = self.down_conv3(h128) # 512@6
|
109 |
+
|
110 |
+
up256 = F.interpolate(self.upconv1(h256), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 512@12
|
111 |
+
up128 = F.interpolate(self.upconv2(torch.cat([up256, h128], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) #51264@24
|
112 |
+
up64 = F.interpolate(self.upconv3(torch.cat([up128, h64], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@48
|
113 |
+
up32 = F.interpolate(self.upconv4(torch.cat([up64, h32], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@96
|
114 |
+
up16 = F.interpolate(self.upconv5(torch.cat([up32, h16], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 128@192
|
115 |
+
up8 = F.interpolate(self.upconv6(torch.cat([up16, h8], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@384
|
116 |
+
up4 = F.interpolate(self.upconv7(torch.cat([up8, h4], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@768
|
117 |
+
|
118 |
+
ascore = self.conv_as(up4)
|
119 |
+
rscore = self.conv_rs(up4)
|
120 |
+
|
121 |
+
return torch.cat([rscore, ascore], dim = 1), self.conv_mask(up4)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
x = self.backbone.conv1(x)
|
125 |
+
x = self.backbone.bn1(x)
|
126 |
+
x = self.backbone.relu(x)
|
127 |
+
x = self.backbone.maxpool(x) # 64@384
|
128 |
+
|
129 |
+
h4 = self.backbone.layer1(x) # 64@384
|
130 |
+
h8 = self.backbone.layer2(h4) # 128@192
|
131 |
+
h16 = self.backbone.layer3(h8) # 256@96
|
132 |
+
h32 = self.backbone.layer4(h16) # 512@48
|
133 |
+
h64 = self.down_conv1(h32) # 512@24
|
134 |
+
h128 = self.down_conv2(h64) # 512@12
|
135 |
+
h256 = self.down_conv3(h128) # 512@6
|
136 |
+
|
137 |
+
up256 = F.interpolate(self.upconv1(h256), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 512@12
|
138 |
+
up128 = F.interpolate(self.upconv2(torch.cat([up256, h128], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) #51264@24
|
139 |
+
up64 = F.interpolate(self.upconv3(torch.cat([up128, h64], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@48
|
140 |
+
up32 = F.interpolate(self.upconv4(torch.cat([up64, h32], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@96
|
141 |
+
up16 = F.interpolate(self.upconv5(torch.cat([up32, h16], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 128@192
|
142 |
+
up8 = F.interpolate(self.upconv6(torch.cat([up16, h8], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@384
|
143 |
+
up4 = F.interpolate(self.upconv7(torch.cat([up8, h4], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@768
|
144 |
+
|
145 |
+
ascore = self.conv_as(up4)
|
146 |
+
rscore = self.conv_rs(up4)
|
147 |
+
|
148 |
+
return torch.cat([rscore, ascore], dim = 1), self.conv_mask(up4)
|
149 |
+
|
150 |
+
if __name__ == '__main__':
|
151 |
+
net = CRAFT_net().cuda()
|
152 |
+
img = torch.randn(2, 3, 1536, 1536).cuda()
|
153 |
+
print(net.forward_train(img)[0].shape)
|