Spaces:
Configuration error
Configuration error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .github/workflows/bot-autolint.yaml +50 -0
- .github/workflows/ci.yaml +54 -0
- .gitignore +178 -0
- .pre-commit-config.yaml +62 -0
- CIs/add_license_all.sh +2 -0
- Dockerfile +20 -0
- LICENSE +117 -0
- README.md +231 -12
- app/app_sana.py +488 -0
- app/app_sana_multithread.py +565 -0
- app/safety_check.py +72 -0
- app/sana_pipeline.py +324 -0
- asset/Sana.jpg +3 -0
- asset/docs/metrics_toolkit.md +118 -0
- asset/example_data/00000000.txt +1 -0
- asset/examples.py +69 -0
- asset/model-incremental.jpg +0 -0
- asset/model_paths.txt +2 -0
- asset/samples.txt +125 -0
- asset/samples_mini.txt +10 -0
- configs/sana_app_config/Sana_1600M_app.yaml +107 -0
- configs/sana_app_config/Sana_600M_app.yaml +105 -0
- configs/sana_base.yaml +140 -0
- configs/sana_config/1024ms/Sana_1600M_img1024.yaml +109 -0
- configs/sana_config/1024ms/Sana_600M_img1024.yaml +105 -0
- configs/sana_config/512ms/Sana_1600M_img512.yaml +108 -0
- configs/sana_config/512ms/Sana_600M_img512.yaml +107 -0
- configs/sana_config/512ms/ci_Sana_600M_img512.yaml +107 -0
- configs/sana_config/512ms/sample_dataset.yaml +107 -0
- diffusion/__init__.py +9 -0
- diffusion/data/__init__.py +2 -0
- diffusion/data/builder.py +76 -0
- diffusion/data/datasets/__init__.py +3 -0
- diffusion/data/datasets/sana_data.py +467 -0
- diffusion/data/datasets/sana_data_multi_scale.py +265 -0
- diffusion/data/datasets/utils.py +506 -0
- diffusion/data/transforms.py +46 -0
- diffusion/data/wids/__init__.py +16 -0
- diffusion/data/wids/wids.py +1051 -0
- diffusion/data/wids/wids_dl.py +174 -0
- diffusion/data/wids/wids_lru.py +81 -0
- diffusion/data/wids/wids_mmtar.py +168 -0
- diffusion/data/wids/wids_specs.py +192 -0
- diffusion/data/wids/wids_tar.py +98 -0
- diffusion/dpm_solver.py +69 -0
- diffusion/flow_euler_sampler.py +74 -0
- diffusion/iddpm.py +76 -0
- diffusion/lcm_scheduler.py +457 -0
- diffusion/model/__init__.py +1 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
asset/Sana.jpg filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/bot-autolint.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Auto Lint (triggered by "auto lint" label)
|
2 |
+
on:
|
3 |
+
pull_request:
|
4 |
+
types:
|
5 |
+
- opened
|
6 |
+
- edited
|
7 |
+
- closed
|
8 |
+
- reopened
|
9 |
+
- synchronize
|
10 |
+
- labeled
|
11 |
+
- unlabeled
|
12 |
+
# run only one unit test for a branch / tag.
|
13 |
+
concurrency:
|
14 |
+
group: ci-lint-${{ github.ref }}
|
15 |
+
cancel-in-progress: true
|
16 |
+
jobs:
|
17 |
+
lint-by-label:
|
18 |
+
if: contains(github.event.pull_request.labels.*.name, 'lint wanted')
|
19 |
+
runs-on: ubuntu-latest
|
20 |
+
steps:
|
21 |
+
- name: Check out Git repository
|
22 |
+
uses: actions/checkout@v4
|
23 |
+
with:
|
24 |
+
token: ${{ secrets.PAT }}
|
25 |
+
ref: ${{ github.event.pull_request.head.ref }}
|
26 |
+
- name: Set up Python
|
27 |
+
uses: actions/setup-python@v5
|
28 |
+
with:
|
29 |
+
python-version: '3.10'
|
30 |
+
- name: Test pre-commit hooks
|
31 |
+
continue-on-error: true
|
32 |
+
uses: pre-commit/[email protected] # sync with https://github.com/Efficient-Large-Model/VILA-Internal/blob/main/.github/workflows/pre-commit.yaml
|
33 |
+
with:
|
34 |
+
extra_args: --all-files
|
35 |
+
- name: Check if there are any changes
|
36 |
+
id: verify_diff
|
37 |
+
run: |
|
38 |
+
git diff --quiet . || echo "changed=true" >> $GITHUB_OUTPUT
|
39 |
+
- name: Commit files
|
40 |
+
if: steps.verify_diff.outputs.changed == 'true'
|
41 |
+
run: |
|
42 |
+
git config --local user.email "[email protected]"
|
43 |
+
git config --local user.name "GitHub Action"
|
44 |
+
git add .
|
45 |
+
git commit -m "[CI-Lint] Fix code style issues with pre-commit ${{ github.sha }}" -a
|
46 |
+
git push
|
47 |
+
- name: Remove label(s) after lint
|
48 |
+
uses: actions-ecosystem/action-remove-labels@v1
|
49 |
+
with:
|
50 |
+
labels: lint wanted
|
.github/workflows/ci.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: ci
|
2 |
+
on:
|
3 |
+
pull_request:
|
4 |
+
push:
|
5 |
+
branches: [main, feat/Sana-public, feat/Sana-public-for-NVLab]
|
6 |
+
concurrency:
|
7 |
+
group: ci-${{ github.workflow }}-${{ github.ref }}
|
8 |
+
cancel-in-progress: true
|
9 |
+
# if: ${{ github.repository == 'Efficient-Large-Model/Sana' }}
|
10 |
+
jobs:
|
11 |
+
pre-commit:
|
12 |
+
runs-on: ubuntu-latest
|
13 |
+
steps:
|
14 |
+
- name: Check out Git repository
|
15 |
+
uses: actions/checkout@v4
|
16 |
+
- name: Set up Python
|
17 |
+
uses: actions/setup-python@v5
|
18 |
+
with:
|
19 |
+
python-version: 3.10.10
|
20 |
+
- name: Test pre-commit hooks
|
21 |
+
uses: pre-commit/[email protected]
|
22 |
+
tests-bash:
|
23 |
+
# needs: pre-commit
|
24 |
+
runs-on: self-hosted
|
25 |
+
steps:
|
26 |
+
- name: Check out Git repository
|
27 |
+
uses: actions/checkout@v4
|
28 |
+
- name: Set up Python
|
29 |
+
uses: actions/setup-python@v5
|
30 |
+
with:
|
31 |
+
python-version: 3.10.10
|
32 |
+
- name: Set up the environment
|
33 |
+
run: |
|
34 |
+
bash environment_setup.sh
|
35 |
+
- name: Run tests with Slurm
|
36 |
+
run: |
|
37 |
+
sana-run --pty -m ci -J tests-bash bash tests/bash/entry.sh
|
38 |
+
|
39 |
+
# tests-python:
|
40 |
+
# needs: pre-commit
|
41 |
+
# runs-on: self-hosted
|
42 |
+
# steps:
|
43 |
+
# - name: Check out Git repository
|
44 |
+
# uses: actions/checkout@v4
|
45 |
+
# - name: Set up Python
|
46 |
+
# uses: actions/setup-python@v5
|
47 |
+
# with:
|
48 |
+
# python-version: 3.10.10
|
49 |
+
# - name: Set up the environment
|
50 |
+
# run: |
|
51 |
+
# ./environment_setup.sh
|
52 |
+
# - name: Run tests with Slurm
|
53 |
+
# run: |
|
54 |
+
# sana-run --pty -m ci -J tests-python pytest tests/python
|
.gitignore
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sana related files
|
2 |
+
.idea/
|
3 |
+
*.png
|
4 |
+
*.json
|
5 |
+
tmp*
|
6 |
+
output*
|
7 |
+
output/
|
8 |
+
outputs/
|
9 |
+
wandb/
|
10 |
+
.vscode/
|
11 |
+
private/
|
12 |
+
ldm_ae*
|
13 |
+
data/*
|
14 |
+
*.pth
|
15 |
+
.gradio/
|
16 |
+
|
17 |
+
# Byte-compiled / optimized / DLL files
|
18 |
+
__pycache__/
|
19 |
+
*.py[cod]
|
20 |
+
*$py.class
|
21 |
+
|
22 |
+
# C extensions
|
23 |
+
*.so
|
24 |
+
|
25 |
+
# Distribution / packaging
|
26 |
+
.Python
|
27 |
+
build/
|
28 |
+
develop-eggs/
|
29 |
+
dist/
|
30 |
+
downloads/
|
31 |
+
eggs/
|
32 |
+
.eggs/
|
33 |
+
lib/
|
34 |
+
lib64/
|
35 |
+
parts/
|
36 |
+
sdist/
|
37 |
+
var/
|
38 |
+
wheels/
|
39 |
+
share/python-wheels/
|
40 |
+
*.egg-info/
|
41 |
+
.installed.cfg
|
42 |
+
*.egg
|
43 |
+
MANIFEST
|
44 |
+
|
45 |
+
# PyInstaller
|
46 |
+
# Usually these files are written by a python script from a template
|
47 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
48 |
+
*.manifest
|
49 |
+
*.spec
|
50 |
+
|
51 |
+
# Installer logs
|
52 |
+
pip-log.txt
|
53 |
+
pip-delete-this-directory.txt
|
54 |
+
|
55 |
+
# Unit test / coverage reports
|
56 |
+
htmlcov/
|
57 |
+
.tox/
|
58 |
+
.nox/
|
59 |
+
.coverage
|
60 |
+
.coverage.*
|
61 |
+
.cache
|
62 |
+
nosetests.xml
|
63 |
+
coverage.xml
|
64 |
+
*.cover
|
65 |
+
*.py,cover
|
66 |
+
.hypothesis/
|
67 |
+
.pytest_cache/
|
68 |
+
cover/
|
69 |
+
|
70 |
+
# Translations
|
71 |
+
*.mo
|
72 |
+
*.pot
|
73 |
+
|
74 |
+
# Django stuff:
|
75 |
+
*.log
|
76 |
+
local_settings.py
|
77 |
+
db.sqlite3
|
78 |
+
db.sqlite3-journal
|
79 |
+
|
80 |
+
# Flask stuff:
|
81 |
+
instance/
|
82 |
+
.webassets-cache
|
83 |
+
|
84 |
+
# Scrapy stuff:
|
85 |
+
.scrapy
|
86 |
+
|
87 |
+
# Sphinx documentation
|
88 |
+
docs/_build/
|
89 |
+
|
90 |
+
# PyBuilder
|
91 |
+
.pybuilder/
|
92 |
+
target/
|
93 |
+
|
94 |
+
# Jupyter Notebook
|
95 |
+
.ipynb_checkpoints
|
96 |
+
|
97 |
+
# IPython
|
98 |
+
profile_default/
|
99 |
+
ipython_config.py
|
100 |
+
|
101 |
+
# pyenv
|
102 |
+
# For a library or package, you might want to ignore these files since the code is
|
103 |
+
# intended to run in multiple environments; otherwise, check them in:
|
104 |
+
# .python-version
|
105 |
+
|
106 |
+
# pipenv
|
107 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
108 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
109 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
110 |
+
# install all needed dependencies.
|
111 |
+
#Pipfile.lock
|
112 |
+
|
113 |
+
# poetry
|
114 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
115 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
116 |
+
# commonly ignored for libraries.
|
117 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
118 |
+
#poetry.lock
|
119 |
+
|
120 |
+
# pdm
|
121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
122 |
+
#pdm.lock
|
123 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
124 |
+
# in version control.
|
125 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
126 |
+
.pdm.toml
|
127 |
+
.pdm-python
|
128 |
+
.pdm-build/
|
129 |
+
|
130 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
131 |
+
__pypackages__/
|
132 |
+
|
133 |
+
# Celery stuff
|
134 |
+
celerybeat-schedule
|
135 |
+
celerybeat.pid
|
136 |
+
|
137 |
+
# SageMath parsed files
|
138 |
+
*.sage.py
|
139 |
+
|
140 |
+
# Environments
|
141 |
+
.env
|
142 |
+
.venv
|
143 |
+
env/
|
144 |
+
venv/
|
145 |
+
ENV/
|
146 |
+
env.bak/
|
147 |
+
venv.bak/
|
148 |
+
|
149 |
+
# Spyder project settings
|
150 |
+
.spyderproject
|
151 |
+
.spyproject
|
152 |
+
|
153 |
+
# Rope project settings
|
154 |
+
.ropeproject
|
155 |
+
|
156 |
+
# mkdocs documentation
|
157 |
+
/site
|
158 |
+
|
159 |
+
# mypy
|
160 |
+
.mypy_cache/
|
161 |
+
.dmypy.json
|
162 |
+
dmypy.json
|
163 |
+
|
164 |
+
# Pyre type checker
|
165 |
+
.pyre/
|
166 |
+
|
167 |
+
# pytype static type analyzer
|
168 |
+
.pytype/
|
169 |
+
|
170 |
+
# Cython debug symbols
|
171 |
+
cython_debug/
|
172 |
+
|
173 |
+
# PyCharm
|
174 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
175 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
176 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
177 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
178 |
+
#.idea/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
+
rev: v5.0.0
|
4 |
+
hooks:
|
5 |
+
- id: trailing-whitespace
|
6 |
+
name: (Common) Remove trailing whitespaces
|
7 |
+
- id: mixed-line-ending
|
8 |
+
name: (Common) Fix mixed line ending
|
9 |
+
args: [--fix=lf]
|
10 |
+
- id: end-of-file-fixer
|
11 |
+
name: (Common) Remove extra EOF newlines
|
12 |
+
- id: check-merge-conflict
|
13 |
+
name: (Common) Check for merge conflicts
|
14 |
+
- id: requirements-txt-fixer
|
15 |
+
name: (Common) Sort "requirements.txt"
|
16 |
+
- id: fix-encoding-pragma
|
17 |
+
name: (Python) Remove encoding pragmas
|
18 |
+
args: [--remove]
|
19 |
+
# - id: debug-statements
|
20 |
+
# name: (Python) Check for debugger imports
|
21 |
+
- id: check-json
|
22 |
+
name: (JSON) Check syntax
|
23 |
+
- id: check-yaml
|
24 |
+
name: (YAML) Check syntax
|
25 |
+
- id: check-toml
|
26 |
+
name: (TOML) Check syntax
|
27 |
+
# - repo: https://github.com/shellcheck-py/shellcheck-py
|
28 |
+
# rev: v0.10.0.1
|
29 |
+
# hooks:
|
30 |
+
# - id: shellcheck
|
31 |
+
- repo: https://github.com/google/yamlfmt
|
32 |
+
rev: v0.13.0
|
33 |
+
hooks:
|
34 |
+
- id: yamlfmt
|
35 |
+
- repo: https://github.com/executablebooks/mdformat
|
36 |
+
rev: 0.7.16
|
37 |
+
hooks:
|
38 |
+
- id: mdformat
|
39 |
+
name: (Markdown) Format docs with mdformat
|
40 |
+
- repo: https://github.com/asottile/pyupgrade
|
41 |
+
rev: v3.2.2
|
42 |
+
hooks:
|
43 |
+
- id: pyupgrade
|
44 |
+
name: (Python) Update syntax for newer versions
|
45 |
+
args: [--py37-plus]
|
46 |
+
- repo: https://github.com/psf/black
|
47 |
+
rev: 22.10.0
|
48 |
+
hooks:
|
49 |
+
- id: black
|
50 |
+
name: (Python) Format code with black
|
51 |
+
- repo: https://github.com/pycqa/isort
|
52 |
+
rev: 5.12.0
|
53 |
+
hooks:
|
54 |
+
- id: isort
|
55 |
+
name: (Python) Sort imports with isort
|
56 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
57 |
+
rev: v15.0.4
|
58 |
+
hooks:
|
59 |
+
- id: clang-format
|
60 |
+
name: (C/C++/CUDA) Format code with clang-format
|
61 |
+
args: [-style=google, -i]
|
62 |
+
types_or: [c, c++, cuda]
|
CIs/add_license_all.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#/bin/bash
|
2 |
+
addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' -ignore "**/*__init__.py" **/*.py
|
Dockerfile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/pytorch:24.06-py3
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \
|
6 |
+
&& sh ~/miniconda.sh -b -p /opt/conda \
|
7 |
+
&& rm ~/miniconda.sh
|
8 |
+
|
9 |
+
ENV PATH /opt/conda/bin:$PATH
|
10 |
+
COPY pyproject.toml pyproject.toml
|
11 |
+
COPY diffusion diffusion
|
12 |
+
COPY configs configs
|
13 |
+
COPY sana sana
|
14 |
+
COPY app app
|
15 |
+
|
16 |
+
COPY environment_setup.sh environment_setup.sh
|
17 |
+
RUN ./environment_setup.sh sana
|
18 |
+
|
19 |
+
# COPY server.py server.py
|
20 |
+
CMD ["conda", "run", "-n", "sana", "--no-capture-output", "python", "-u", "-W", "ignore", "app/app_sana.py", "--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml", "--model_path=hf://Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",]
|
LICENSE
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
Nvidia Source Code License-NC
|
5 |
+
|
6 |
+
=======================================================================
|
7 |
+
|
8 |
+
1. Definitions
|
9 |
+
|
10 |
+
“Licensor” means any person or entity that distributes its Work.
|
11 |
+
|
12 |
+
“Work” means (a) the original work of authorship made available under
|
13 |
+
this license, which may include software, documentation, or other
|
14 |
+
files, and (b) any additions to or derivative works thereof
|
15 |
+
that are made available under this license.
|
16 |
+
|
17 |
+
“NVIDIA Processors” means any central processing unit (CPU),
|
18 |
+
graphics processing unit (GPU), field-programmable gate array (FPGA),
|
19 |
+
application-specific integrated circuit (ASIC) or any combination
|
20 |
+
thereof designed, made, sold, or provided by NVIDIA or its affiliates.
|
21 |
+
|
22 |
+
The terms “reproduce,” “reproduction,” “derivative works,” and
|
23 |
+
“distribution” have the meaning as provided under U.S. copyright law;
|
24 |
+
provided, however, that for the purposes of this license, derivative
|
25 |
+
works shall not include works that remain separable from, or merely
|
26 |
+
link (or bind by name) to the interfaces of, the Work.
|
27 |
+
|
28 |
+
Works are “made available” under this license by including in or with
|
29 |
+
the Work either (a) a copyright notice referencing the applicability
|
30 |
+
of this license to the Work, or (b) a copy of this license.
|
31 |
+
|
32 |
+
"Safe Model" means ShieldGemma-2B, which is a series of safety
|
33 |
+
content moderation models designed to moderate four categories of
|
34 |
+
harmful content: sexually explicit material, dangerous content,
|
35 |
+
hate speech, and harassment, and which you separately obtain
|
36 |
+
from Google at https://huggingface.co/google/shieldgemma-2b.
|
37 |
+
|
38 |
+
|
39 |
+
2. License Grant
|
40 |
+
|
41 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
42 |
+
license, each Licensor grants to you a perpetual, worldwide,
|
43 |
+
non-exclusive, royalty-free, copyright license to use, reproduce,
|
44 |
+
prepare derivative works of, publicly display, publicly perform,
|
45 |
+
sublicense and distribute its Work and any resulting derivative
|
46 |
+
works in any form.
|
47 |
+
|
48 |
+
3. Limitations
|
49 |
+
|
50 |
+
3.1 Redistribution. You may reproduce or distribute the Work only if
|
51 |
+
(a) you do so under this license, (b) you include a complete copy of
|
52 |
+
this license with your distribution, and (c) you retain without
|
53 |
+
modification any copyright, patent, trademark, or attribution notices
|
54 |
+
that are present in the Work.
|
55 |
+
|
56 |
+
3.2 Derivative Works. You may specify that additional or different
|
57 |
+
terms apply to the use, reproduction, and distribution of your
|
58 |
+
derivative works of the Work (“Your Terms”) only if (a) Your Terms
|
59 |
+
provide that the use limitation in Section 3.3 applies to your
|
60 |
+
derivative works, and (b) you identify the specific derivative works
|
61 |
+
that are subject to Your Terms. Notwithstanding Your Terms, this
|
62 |
+
license (including the redistribution requirements in Section 3.1)
|
63 |
+
will continue to apply to the Work itself.
|
64 |
+
|
65 |
+
3.3 Use Limitation. The Work and any derivative works thereof only may
|
66 |
+
be used or intended for use non-commercially and with NVIDIA Processors,
|
67 |
+
in accordance with Section 3.4, below. Notwithstanding the foregoing,
|
68 |
+
NVIDIA Corporation and its affiliates may use the Work and any
|
69 |
+
derivative works commercially. As used herein, “non-commercially”
|
70 |
+
means for research or evaluation purposes only.
|
71 |
+
|
72 |
+
3.4 You shall filter your input content to the Work and any derivative
|
73 |
+
works thereof through the Safe Model to ensure that no content described
|
74 |
+
as Not Safe For Work (NSFW) is processed or generated. You shall not use
|
75 |
+
the Work to process or generate NSFW content. You are solely responsible
|
76 |
+
for any damages and liabilities arising from your failure to adequately
|
77 |
+
filter content in accordance with this section. As used herein,
|
78 |
+
“Not Safe For Work” or “NSFW” means content, videos or website pages
|
79 |
+
that contain potentially disturbing subject matter, including but not
|
80 |
+
limited to content that is sexually explicit, dangerous, hate,
|
81 |
+
or harassment.
|
82 |
+
|
83 |
+
3.5 Patent Claims. If you bring or threaten to bring a patent claim
|
84 |
+
against any Licensor (including any claim, cross-claim or counterclaim
|
85 |
+
in a lawsuit) to enforce any patents that you allege are infringed by
|
86 |
+
any Work, then your rights under this license from such Licensor
|
87 |
+
(including the grant in Section 2.1) will terminate immediately.
|
88 |
+
|
89 |
+
3.6 Trademarks. This license does not grant any rights to use any
|
90 |
+
Licensor’s or its affiliates’ names, logos, or trademarks, except as
|
91 |
+
necessary to reproduce the notices described in this license.
|
92 |
+
|
93 |
+
3.7 Termination. If you violate any term of this license, then your
|
94 |
+
rights under this license (including the grant in Section 2.1) will
|
95 |
+
terminate immediately.
|
96 |
+
|
97 |
+
4. Disclaimer of Warranty.
|
98 |
+
|
99 |
+
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
100 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
101 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
102 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES
|
103 |
+
UNDER THIS LICENSE.
|
104 |
+
|
105 |
+
5. Limitation of Liability.
|
106 |
+
|
107 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
108 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
109 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
110 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
111 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
112 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
113 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
114 |
+
DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE
|
115 |
+
POSSIBILITY OF SUCH DAMAGES.
|
116 |
+
|
117 |
+
=======================================================================
|
README.md
CHANGED
@@ -1,12 +1,231 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center" style="border-radius: 10px">
|
2 |
+
<img src="asset/logo.png" width="35%" alt="logo"/>
|
3 |
+
</p>
|
4 |
+
|
5 |
+
# ⚡️Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer
|
6 |
+
|
7 |
+
<div align="center">
|
8 |
+
<a href="https://nvlabs.github.io/Sana/"><img src="https://img.shields.io/static/v1?label=Project&message=Github&color=blue&logo=github-pages"></a>  
|
9 |
+
<a href="https://hanlab.mit.edu/projects/sana/"><img src="https://img.shields.io/static/v1?label=Page&message=MIT&color=darkred&logo=github-pages"></a>  
|
10 |
+
<a href="https://arxiv.org/abs/2410.10629"><img src="https://img.shields.io/static/v1?label=Arxiv&message=Sana&color=red&logo=arxiv"></a>  
|
11 |
+
<a href="https://nv-sana.mit.edu/"><img src="https://img.shields.io/static/v1?label=Demo&message=MIT&color=yellow"></a>  
|
12 |
+
<a href="https://discord.gg/rde6eaE5Ta"><img src="https://img.shields.io/static/v1?label=Discuss&message=Discord&color=purple&logo=discord"></a>  
|
13 |
+
</div>
|
14 |
+
|
15 |
+
<p align="center" border-raduis="10px">
|
16 |
+
<img src="asset/Sana.jpg" width="90%" alt="teaser_page1"/>
|
17 |
+
</p>
|
18 |
+
|
19 |
+
## 💡 Introduction
|
20 |
+
|
21 |
+
We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096 × 4096 resolution.
|
22 |
+
Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU.
|
23 |
+
Core designs include:
|
24 |
+
|
25 |
+
(1) [**DC-AE**](https://hanlab.mit.edu/projects/dc-ae): unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. \
|
26 |
+
(2) **Linear DiT**: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. \
|
27 |
+
(3) **Decoder-only text encoder**: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. \
|
28 |
+
(4) **Efficient training and sampling**: we propose **Flow-DPM-Solver** to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence.
|
29 |
+
|
30 |
+
As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024 × 1024 resolution image. Sana enables content creation at low cost.
|
31 |
+
|
32 |
+
<p align="center" border-raduis="10px">
|
33 |
+
<img src="asset/model-incremental.jpg" width="90%" alt="teaser_page2"/>
|
34 |
+
</p>
|
35 |
+
|
36 |
+
## 🔥🔥 News
|
37 |
+
|
38 |
+
- (🔥 New) \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released.
|
39 |
+
- (🔥 New) \[2024/11\] Training & Inference & Metrics code are released.
|
40 |
+
- (🔥 New) \[2024/11\] Working on [`diffusers`](https://github.com/huggingface/diffusers/pull/9982).
|
41 |
+
- \[2024/10\] [Demo](https://nv-sana.mit.edu/) is released.
|
42 |
+
- \[2024/10\] [DC-AE Code](https://github.com/mit-han-lab/efficientvit/blob/master/applications/dc_ae/README.md) and [weights](https://huggingface.co/collections/mit-han-lab/dc-ae-670085b9400ad7197bb1009b) are released!
|
43 |
+
- \[2024/10\] [Paper](https://arxiv.org/abs/2410.10629) is on Arxiv!
|
44 |
+
|
45 |
+
## Performance
|
46 |
+
|
47 |
+
| Methods (1024x1024) | Throughput (samples/s) | Latency (s) | Params (B) | Speedup | FID 👆 | CLIP 👆 | GenEval 👆 | DPG 👆 |
|
48 |
+
|------------------------------|------------------------|-------------|------------|-----------|-------------|--------------|-------------|-------------|
|
49 |
+
| FLUX-dev | 0.04 | 23.0 | 12.0 | 1.0× | 10.15 | 27.47 | _0.67_ | _84.0_ |
|
50 |
+
| **Sana-0.6B** | 1.7 | 0.9 | 0.6 | **39.5×** | <u>5.81</u> | 28.36 | 0.64 | 83.6 |
|
51 |
+
| **Sana-1.6B** | 1.0 | 1.2 | 1.6 | **23.3×** | **5.76** | <u>28.67</u> | <u>0.66</u> | **84.8** |
|
52 |
+
|
53 |
+
<details>
|
54 |
+
<summary><h3>Click to show all</h3></summary>
|
55 |
+
|
56 |
+
| Methods | Throughput (samples/s) | Latency (s) | Params (B) | Speedup | FID 👆 | CLIP 👆 | GenEval 👆 | DPG 👆 |
|
57 |
+
|------------------------------|------------------------|-------------|------------|-----------|-------------|--------------|-------------|-------------|
|
58 |
+
| _**512 × 512 resolution**_ | | | | | | | | |
|
59 |
+
| PixArt-α | 1.5 | 1.2 | 0.6 | 1.0× | 6.14 | 27.55 | 0.48 | 71.6 |
|
60 |
+
| PixArt-Σ | 1.5 | 1.2 | 0.6 | 1.0× | _6.34_ | _27.62_ | <u>0.52</u> | _79.5_ |
|
61 |
+
| **Sana-0.6B** | 6.7 | 0.8 | 0.6 | 5.0× | <u>5.67</u> | <u>27.92</u> | _0.64_ | <u>84.3</u> |
|
62 |
+
| **Sana-1.6B** | 3.8 | 0.6 | 1.6 | 2.5× | **5.16** | **28.19** | **0.66** | **85.5** |
|
63 |
+
| _**1024 × 1024 resolution**_ | | | | | | | | |
|
64 |
+
| LUMINA-Next | 0.12 | 9.1 | 2.0 | 2.8× | 7.58 | 26.84 | 0.46 | 74.6 |
|
65 |
+
| SDXL | 0.15 | 6.5 | 2.6 | 3.5× | 6.63 | _29.03_ | 0.55 | 74.7 |
|
66 |
+
| PlayGroundv2.5 | 0.21 | 5.3 | 2.6 | 4.9× | _6.09_ | **29.13** | 0.56 | 75.5 |
|
67 |
+
| Hunyuan-DiT | 0.05 | 18.2 | 1.5 | 1.2× | 6.54 | 28.19 | 0.63 | 78.9 |
|
68 |
+
| PixArt-Σ | 0.4 | 2.7 | 0.6 | 9.3× | 6.15 | 28.26 | 0.54 | 80.5 |
|
69 |
+
| DALLE3 | - | - | - | - | - | - | _0.67_ | 83.5 |
|
70 |
+
| SD3-medium | 0.28 | 4.4 | 2.0 | 6.5× | 11.92 | 27.83 | 0.62 | <u>84.1</u> |
|
71 |
+
| FLUX-dev | 0.04 | 23.0 | 12.0 | 1.0× | 10.15 | 27.47 | _0.67_ | _84.0_ |
|
72 |
+
| FLUX-schnell | 0.5 | 2.1 | 12.0 | 11.6× | 7.94 | 28.14 | **0.71** | **84.8** |
|
73 |
+
| **Sana-0.6B** | 1.7 | 0.9 | 0.6 | **39.5×** | <u>5.81</u> | 28.36 | 0.64 | 83.6 |
|
74 |
+
| **Sana-1.6B** | 1.0 | 1.2 | 1.6 | **23.3×** | **5.76** | <u>28.67</u> | <u>0.66</u> | **84.8** |
|
75 |
+
|
76 |
+
</details>
|
77 |
+
|
78 |
+
## Contents
|
79 |
+
|
80 |
+
- [Env](#-1-dependencies-and-installation)
|
81 |
+
- [Demo](#-3-how-to-inference)
|
82 |
+
- [Training](#-2-how-to-train)
|
83 |
+
- [Testing](#-4-how-to-inference--test-metrics-fid-clip-score-geneval-dpg-bench-etc)
|
84 |
+
- [TODO](#to-do-list)
|
85 |
+
- [Citation](#bibtex)
|
86 |
+
|
87 |
+
# 🔧 1. Dependencies and Installation
|
88 |
+
|
89 |
+
- Python >= 3.10.0 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
|
90 |
+
- [PyTorch >= 2.0.1+cu12.1](https://pytorch.org/)
|
91 |
+
|
92 |
+
```bash
|
93 |
+
git clone https://github.com/NVlabs/Sana.git
|
94 |
+
cd Sana
|
95 |
+
|
96 |
+
./environment_setup.sh sana
|
97 |
+
# or you can install each components step by step following environment_setup.sh
|
98 |
+
```
|
99 |
+
|
100 |
+
# 💻 2. How to Play with Sana (Inference)
|
101 |
+
|
102 |
+
## 💰Hardware requirement
|
103 |
+
|
104 |
+
- 9GB VRAM is required for 0.6B model and 12GB VRAM for 1.6B model. Our later quantization version will require less than 8GB for inference.
|
105 |
+
- All the tests are done on A100 GPUs. Different GPU version may be different.
|
106 |
+
|
107 |
+
## 🔛 Quick start with [Gradio](https://www.gradio.app/guides/quickstart)
|
108 |
+
|
109 |
+
```bash
|
110 |
+
# official online demo
|
111 |
+
DEMO_PORT=15432 \
|
112 |
+
python app/app_sana.py \
|
113 |
+
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
|
114 |
+
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
|
115 |
+
```
|
116 |
+
|
117 |
+
```python
|
118 |
+
import torch
|
119 |
+
from app.sana_pipeline import SanaPipeline
|
120 |
+
from torchvision.utils import save_image
|
121 |
+
|
122 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
123 |
+
generator = torch.Generator(device=device).manual_seed(42)
|
124 |
+
|
125 |
+
sana = SanaPipeline("configs/sana_config/1024ms/Sana_1600M_img1024.yaml")
|
126 |
+
sana.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth")
|
127 |
+
prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
|
128 |
+
|
129 |
+
image = sana(
|
130 |
+
prompt=prompt,
|
131 |
+
height=1024,
|
132 |
+
width=1024,
|
133 |
+
guidance_scale=5.0,
|
134 |
+
pag_guidance_scale=2.0,
|
135 |
+
num_inference_steps=18,
|
136 |
+
generator=generator,
|
137 |
+
)
|
138 |
+
save_image(image, 'output/sana.png', nrow=1, normalize=True, value_range=(-1, 1))
|
139 |
+
```
|
140 |
+
|
141 |
+
## 🔛 Run inference with TXT or JSON files
|
142 |
+
|
143 |
+
```bash
|
144 |
+
# Run samples in a txt file
|
145 |
+
python scripts/inference.py \
|
146 |
+
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
|
147 |
+
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
|
148 |
+
--txt_file=asset/samples_mini.txt
|
149 |
+
|
150 |
+
# Run samples in a json file
|
151 |
+
python scripts/inference.py \
|
152 |
+
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
|
153 |
+
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
|
154 |
+
--json_file=asset/samples_mini.json
|
155 |
+
```
|
156 |
+
|
157 |
+
where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a prompt to generate
|
158 |
+
|
159 |
+
# 🔥 3. How to Train Sana
|
160 |
+
|
161 |
+
## 💰Hardware requirement
|
162 |
+
|
163 |
+
- 32GB VRAM is required for both 0.6B and 1.6B model's training
|
164 |
+
|
165 |
+
We provide a training example here and you can also select your desired config file from [config files dir](configs/sana_config) based on your data structure.
|
166 |
+
|
167 |
+
To launch Sana training, you will first need to prepare data in the following formats
|
168 |
+
|
169 |
+
```bash
|
170 |
+
asset/example_data
|
171 |
+
├── AAA.txt
|
172 |
+
├── AAA.png
|
173 |
+
├── BCC.txt
|
174 |
+
├── BCC.png
|
175 |
+
├── ......
|
176 |
+
├── CCC.txt
|
177 |
+
└── CCC.png
|
178 |
+
```
|
179 |
+
|
180 |
+
Then Sana's training can be launched via
|
181 |
+
|
182 |
+
```bash
|
183 |
+
# Example of training Sana 0.6B with 512x512 resolution
|
184 |
+
bash train_scripts/train.sh \
|
185 |
+
configs/sana_config/512ms/Sana_600M_img512.yaml \
|
186 |
+
--data.data_dir="[asset/example_data]" \
|
187 |
+
--data.type=SanaImgDataset \
|
188 |
+
--model.multi_scale=false \
|
189 |
+
--train.train_batch_size=32
|
190 |
+
|
191 |
+
# Example of training Sana 1.6B with 1024x1024 resolution
|
192 |
+
bash train_scripts/train.sh \
|
193 |
+
configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
|
194 |
+
--data.data_dir="[asset/example_data]" \
|
195 |
+
--data.type=SanaImgDataset \
|
196 |
+
--model.multi_scale=false \
|
197 |
+
--train.train_batch_size=8
|
198 |
+
```
|
199 |
+
|
200 |
+
# 💻 4. Metric toolkit
|
201 |
+
|
202 |
+
Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md).
|
203 |
+
|
204 |
+
# 💪To-Do List
|
205 |
+
|
206 |
+
We will try our best to release
|
207 |
+
|
208 |
+
- \[x\] Training code
|
209 |
+
- \[x\] Inference code
|
210 |
+
- \[+\] Model zoo
|
211 |
+
- \[ \] working on Diffusers(https://github.com/huggingface/diffusers/pull/9982)
|
212 |
+
- \[ \] ComfyUI
|
213 |
+
- \[ \] Laptop development
|
214 |
+
|
215 |
+
# 🤗Acknowledgements
|
216 |
+
|
217 |
+
- Thanks to [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha), [PixArt-Σ](https://github.com/PixArt-alpha/PixArt-sigma) and [Efficient-ViT](https://github.com/mit-han-lab/efficientvit) for their wonderful work and codebase!
|
218 |
+
|
219 |
+
# 📖BibTeX
|
220 |
+
|
221 |
+
```
|
222 |
+
@misc{xie2024sana,
|
223 |
+
title={Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer},
|
224 |
+
author={Enze Xie and Junsong Chen and Junyu Chen and Han Cai and Haotian Tang and Yujun Lin and Zhekai Zhang and Muyang Li and Ligeng Zhu and Yao Lu and Song Han},
|
225 |
+
year={2024},
|
226 |
+
eprint={2410.10629},
|
227 |
+
archivePrefix={arXiv},
|
228 |
+
primaryClass={cs.CV},
|
229 |
+
url={https://arxiv.org/abs/2410.10629},
|
230 |
+
}
|
231 |
+
```
|
app/app_sana.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# SPDX-License-Identifier: Apache-2.0
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
import time
|
23 |
+
import uuid
|
24 |
+
from datetime import datetime
|
25 |
+
|
26 |
+
import gradio as gr
|
27 |
+
import numpy as np
|
28 |
+
import spaces
|
29 |
+
import torch
|
30 |
+
from PIL import Image
|
31 |
+
from torchvision.utils import make_grid, save_image
|
32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
33 |
+
|
34 |
+
from app import safety_check
|
35 |
+
from app.sana_pipeline import SanaPipeline
|
36 |
+
|
37 |
+
MAX_SEED = np.iinfo(np.int32).max
|
38 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
39 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
40 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
41 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
42 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
43 |
+
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
44 |
+
|
45 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
46 |
+
|
47 |
+
style_list = [
|
48 |
+
{
|
49 |
+
"name": "(No style)",
|
50 |
+
"prompt": "{prompt}",
|
51 |
+
"negative_prompt": "",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"name": "Cinematic",
|
55 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
56 |
+
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
57 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"name": "Photographic",
|
61 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
62 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"name": "Anime",
|
66 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
67 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"name": "Manga",
|
71 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
72 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"name": "Digital Art",
|
76 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
77 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"name": "Pixel art",
|
81 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
82 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"name": "Fantasy art",
|
86 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
87 |
+
"majestic, magical, fantasy art, cover art, dreamy",
|
88 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
89 |
+
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
90 |
+
"disfigured, sloppy, duplicate, mutated, black and white",
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"name": "Neonpunk",
|
94 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
95 |
+
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
96 |
+
"ultra detailed, intricate, professional",
|
97 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"name": "3D Model",
|
101 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
102 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
103 |
+
},
|
104 |
+
]
|
105 |
+
|
106 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
107 |
+
STYLE_NAMES = list(styles.keys())
|
108 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
109 |
+
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
110 |
+
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
111 |
+
NUM_IMAGES_PER_PROMPT = 1
|
112 |
+
TEST_TIMES = 0
|
113 |
+
INFER_SPEED = 0
|
114 |
+
FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
|
115 |
+
|
116 |
+
|
117 |
+
def read_inference_count():
|
118 |
+
global TEST_TIMES
|
119 |
+
try:
|
120 |
+
with open(FILENAME) as f:
|
121 |
+
count = int(f.read().strip())
|
122 |
+
except FileNotFoundError:
|
123 |
+
count = 0
|
124 |
+
TEST_TIMES = count
|
125 |
+
|
126 |
+
return count
|
127 |
+
|
128 |
+
|
129 |
+
def write_inference_count(count):
|
130 |
+
with open(FILENAME, "w") as f:
|
131 |
+
f.write(str(count))
|
132 |
+
|
133 |
+
|
134 |
+
def run_inference(num_imgs=1):
|
135 |
+
TEST_TIMES = read_inference_count()
|
136 |
+
TEST_TIMES += int(num_imgs)
|
137 |
+
write_inference_count(TEST_TIMES)
|
138 |
+
|
139 |
+
return (
|
140 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
141 |
+
f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
def update_inference_count():
|
146 |
+
count = read_inference_count()
|
147 |
+
return (
|
148 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
149 |
+
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
154 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
155 |
+
if not negative:
|
156 |
+
negative = ""
|
157 |
+
return p.replace("{prompt}", positive), n + negative
|
158 |
+
|
159 |
+
|
160 |
+
def get_args():
|
161 |
+
parser = argparse.ArgumentParser()
|
162 |
+
parser.add_argument("--config", type=str, help="config")
|
163 |
+
parser.add_argument(
|
164 |
+
"--model_path",
|
165 |
+
nargs="?",
|
166 |
+
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
167 |
+
type=str,
|
168 |
+
help="Path to the model file (positional)",
|
169 |
+
)
|
170 |
+
parser.add_argument("--output", default="./", type=str)
|
171 |
+
parser.add_argument("--bs", default=1, type=int)
|
172 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
173 |
+
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
174 |
+
parser.add_argument("--pag_scale", default=2.0, type=float)
|
175 |
+
parser.add_argument("--seed", default=42, type=int)
|
176 |
+
parser.add_argument("--step", default=-1, type=int)
|
177 |
+
parser.add_argument("--custom_image_size", default=None, type=int)
|
178 |
+
parser.add_argument(
|
179 |
+
"--shield_model_path",
|
180 |
+
type=str,
|
181 |
+
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
182 |
+
default="google/shieldgemma-2b",
|
183 |
+
)
|
184 |
+
|
185 |
+
return parser.parse_known_args()[0]
|
186 |
+
|
187 |
+
|
188 |
+
args = get_args()
|
189 |
+
|
190 |
+
if torch.cuda.is_available():
|
191 |
+
weight_dtype = torch.float16
|
192 |
+
model_path = args.model_path
|
193 |
+
pipe = SanaPipeline(args.config)
|
194 |
+
pipe.from_pretrained(model_path)
|
195 |
+
pipe.register_progress_bar(gr.Progress())
|
196 |
+
|
197 |
+
# safety checker
|
198 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
199 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
200 |
+
args.shield_model_path,
|
201 |
+
device_map="auto",
|
202 |
+
torch_dtype=torch.bfloat16,
|
203 |
+
).to(device)
|
204 |
+
|
205 |
+
|
206 |
+
def save_image_sana(img, seed="", save_img=False):
|
207 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
208 |
+
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
209 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
210 |
+
os.makedirs(save_path, exist_ok=True)
|
211 |
+
unique_name = os.path.join(save_path, unique_name)
|
212 |
+
if save_img:
|
213 |
+
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
214 |
+
|
215 |
+
return unique_name
|
216 |
+
|
217 |
+
|
218 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
219 |
+
if randomize_seed:
|
220 |
+
seed = random.randint(0, MAX_SEED)
|
221 |
+
return seed
|
222 |
+
|
223 |
+
|
224 |
+
@torch.no_grad()
|
225 |
+
@torch.inference_mode()
|
226 |
+
@spaces.GPU(enable_queue=True)
|
227 |
+
def generate(
|
228 |
+
prompt: str = None,
|
229 |
+
negative_prompt: str = "",
|
230 |
+
style: str = DEFAULT_STYLE_NAME,
|
231 |
+
use_negative_prompt: bool = False,
|
232 |
+
num_imgs: int = 1,
|
233 |
+
seed: int = 0,
|
234 |
+
height: int = 1024,
|
235 |
+
width: int = 1024,
|
236 |
+
flow_dpms_guidance_scale: float = 5.0,
|
237 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
238 |
+
flow_dpms_inference_steps: int = 20,
|
239 |
+
randomize_seed: bool = False,
|
240 |
+
):
|
241 |
+
global TEST_TIMES
|
242 |
+
global INFER_SPEED
|
243 |
+
# seed = 823753551
|
244 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
245 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
246 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
|
247 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
248 |
+
prompt = "A red heart."
|
249 |
+
|
250 |
+
print(prompt)
|
251 |
+
|
252 |
+
num_inference_steps = flow_dpms_inference_steps
|
253 |
+
guidance_scale = flow_dpms_guidance_scale
|
254 |
+
pag_guidance_scale = flow_dpms_pag_guidance_scale
|
255 |
+
|
256 |
+
if not use_negative_prompt:
|
257 |
+
negative_prompt = None # type: ignore
|
258 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
259 |
+
|
260 |
+
pipe.progress_fn(0, desc="Sana Start")
|
261 |
+
|
262 |
+
time_start = time.time()
|
263 |
+
images = pipe(
|
264 |
+
prompt=prompt,
|
265 |
+
height=height,
|
266 |
+
width=width,
|
267 |
+
negative_prompt=negative_prompt,
|
268 |
+
guidance_scale=guidance_scale,
|
269 |
+
pag_guidance_scale=pag_guidance_scale,
|
270 |
+
num_inference_steps=num_inference_steps,
|
271 |
+
num_images_per_prompt=num_imgs,
|
272 |
+
generator=generator,
|
273 |
+
)
|
274 |
+
|
275 |
+
pipe.progress_fn(1.0, desc="Sana End")
|
276 |
+
INFER_SPEED = (time.time() - time_start) / num_imgs
|
277 |
+
|
278 |
+
save_img = False
|
279 |
+
if save_img:
|
280 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
281 |
+
print(img)
|
282 |
+
else:
|
283 |
+
if num_imgs > 1:
|
284 |
+
nrow = 2
|
285 |
+
else:
|
286 |
+
nrow = 1
|
287 |
+
img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
|
288 |
+
img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
289 |
+
img = [Image.fromarray(img.astype(np.uint8))]
|
290 |
+
|
291 |
+
torch.cuda.empty_cache()
|
292 |
+
|
293 |
+
return (
|
294 |
+
img,
|
295 |
+
seed,
|
296 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
|
297 |
+
)
|
298 |
+
|
299 |
+
|
300 |
+
TEST_TIMES = read_inference_count()
|
301 |
+
model_size = "1.6" if "D20" in args.model_path else "0.6"
|
302 |
+
title = f"""
|
303 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
304 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
305 |
+
</div>
|
306 |
+
"""
|
307 |
+
DESCRIPTION = f"""
|
308 |
+
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
309 |
+
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
310 |
+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github(coming soon)]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
311 |
+
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>, running on A6000 node.
|
312 |
+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
313 |
+
"""
|
314 |
+
if model_size == "0.6":
|
315 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
316 |
+
if not torch.cuda.is_available():
|
317 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
318 |
+
|
319 |
+
examples = [
|
320 |
+
'a cyberpunk cat with a neon sign that says "Sana"',
|
321 |
+
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
322 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
323 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
324 |
+
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
325 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
326 |
+
"👧 with 🌹 in the ❄️",
|
327 |
+
"an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
328 |
+
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
329 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed",
|
330 |
+
"a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
331 |
+
]
|
332 |
+
|
333 |
+
css = """
|
334 |
+
.gradio-container{max-width: 640px !important}
|
335 |
+
h1{text-align:center}
|
336 |
+
"""
|
337 |
+
with gr.Blocks(css=css) as demo:
|
338 |
+
gr.Markdown(title)
|
339 |
+
gr.Markdown(DESCRIPTION)
|
340 |
+
gr.DuplicateButton(
|
341 |
+
value="Duplicate Space for private use",
|
342 |
+
elem_id="duplicate-button",
|
343 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
344 |
+
)
|
345 |
+
info_box = gr.Markdown(
|
346 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
|
347 |
+
)
|
348 |
+
demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
|
349 |
+
# with gr.Row(equal_height=False):
|
350 |
+
with gr.Group():
|
351 |
+
with gr.Row():
|
352 |
+
prompt = gr.Text(
|
353 |
+
label="Prompt",
|
354 |
+
show_label=False,
|
355 |
+
max_lines=1,
|
356 |
+
placeholder="Enter your prompt",
|
357 |
+
container=False,
|
358 |
+
)
|
359 |
+
run_button = gr.Button("Run", scale=0)
|
360 |
+
result = gr.Gallery(label="Result", show_label=False, columns=NUM_IMAGES_PER_PROMPT, format="png")
|
361 |
+
speed_box = gr.Markdown(
|
362 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
|
363 |
+
)
|
364 |
+
with gr.Accordion("Advanced options", open=False):
|
365 |
+
with gr.Group():
|
366 |
+
with gr.Row(visible=True):
|
367 |
+
height = gr.Slider(
|
368 |
+
label="Height",
|
369 |
+
minimum=256,
|
370 |
+
maximum=MAX_IMAGE_SIZE,
|
371 |
+
step=32,
|
372 |
+
value=1024,
|
373 |
+
)
|
374 |
+
width = gr.Slider(
|
375 |
+
label="Width",
|
376 |
+
minimum=256,
|
377 |
+
maximum=MAX_IMAGE_SIZE,
|
378 |
+
step=32,
|
379 |
+
value=1024,
|
380 |
+
)
|
381 |
+
with gr.Row():
|
382 |
+
flow_dpms_inference_steps = gr.Slider(
|
383 |
+
label="Sampling steps",
|
384 |
+
minimum=5,
|
385 |
+
maximum=40,
|
386 |
+
step=1,
|
387 |
+
value=18,
|
388 |
+
)
|
389 |
+
flow_dpms_guidance_scale = gr.Slider(
|
390 |
+
label="CFG Guidance scale",
|
391 |
+
minimum=1,
|
392 |
+
maximum=10,
|
393 |
+
step=0.1,
|
394 |
+
value=5.0,
|
395 |
+
)
|
396 |
+
flow_dpms_pag_guidance_scale = gr.Slider(
|
397 |
+
label="PAG Guidance scale",
|
398 |
+
minimum=1,
|
399 |
+
maximum=4,
|
400 |
+
step=0.5,
|
401 |
+
value=2.0,
|
402 |
+
)
|
403 |
+
with gr.Row():
|
404 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
405 |
+
negative_prompt = gr.Text(
|
406 |
+
label="Negative prompt",
|
407 |
+
max_lines=1,
|
408 |
+
placeholder="Enter a negative prompt",
|
409 |
+
visible=True,
|
410 |
+
)
|
411 |
+
style_selection = gr.Radio(
|
412 |
+
show_label=True,
|
413 |
+
container=True,
|
414 |
+
interactive=True,
|
415 |
+
choices=STYLE_NAMES,
|
416 |
+
value=DEFAULT_STYLE_NAME,
|
417 |
+
label="Image Style",
|
418 |
+
)
|
419 |
+
seed = gr.Slider(
|
420 |
+
label="Seed",
|
421 |
+
minimum=0,
|
422 |
+
maximum=MAX_SEED,
|
423 |
+
step=1,
|
424 |
+
value=0,
|
425 |
+
)
|
426 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
427 |
+
with gr.Row(visible=True):
|
428 |
+
schedule = gr.Radio(
|
429 |
+
show_label=True,
|
430 |
+
container=True,
|
431 |
+
interactive=True,
|
432 |
+
choices=SCHEDULE_NAME,
|
433 |
+
value=DEFAULT_SCHEDULE_NAME,
|
434 |
+
label="Sampler Schedule",
|
435 |
+
visible=True,
|
436 |
+
)
|
437 |
+
num_imgs = gr.Slider(
|
438 |
+
label="Num Images",
|
439 |
+
minimum=1,
|
440 |
+
maximum=6,
|
441 |
+
step=1,
|
442 |
+
value=1,
|
443 |
+
)
|
444 |
+
|
445 |
+
run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
|
446 |
+
|
447 |
+
gr.Examples(
|
448 |
+
examples=examples,
|
449 |
+
inputs=prompt,
|
450 |
+
outputs=[result, seed],
|
451 |
+
fn=generate,
|
452 |
+
cache_examples=CACHE_EXAMPLES,
|
453 |
+
)
|
454 |
+
|
455 |
+
use_negative_prompt.change(
|
456 |
+
fn=lambda x: gr.update(visible=x),
|
457 |
+
inputs=use_negative_prompt,
|
458 |
+
outputs=negative_prompt,
|
459 |
+
api_name=False,
|
460 |
+
)
|
461 |
+
|
462 |
+
gr.on(
|
463 |
+
triggers=[
|
464 |
+
prompt.submit,
|
465 |
+
negative_prompt.submit,
|
466 |
+
run_button.click,
|
467 |
+
],
|
468 |
+
fn=generate,
|
469 |
+
inputs=[
|
470 |
+
prompt,
|
471 |
+
negative_prompt,
|
472 |
+
style_selection,
|
473 |
+
use_negative_prompt,
|
474 |
+
num_imgs,
|
475 |
+
seed,
|
476 |
+
height,
|
477 |
+
width,
|
478 |
+
flow_dpms_guidance_scale,
|
479 |
+
flow_dpms_pag_guidance_scale,
|
480 |
+
flow_dpms_inference_steps,
|
481 |
+
randomize_seed,
|
482 |
+
],
|
483 |
+
outputs=[result, seed, speed_box],
|
484 |
+
api_name="run",
|
485 |
+
)
|
486 |
+
|
487 |
+
if __name__ == "__main__":
|
488 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
|
app/app_sana_multithread.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# SPDX-License-Identifier: Apache-2.0
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
import uuid
|
23 |
+
from datetime import datetime
|
24 |
+
|
25 |
+
import gradio as gr
|
26 |
+
import numpy as np
|
27 |
+
import spaces
|
28 |
+
import torch
|
29 |
+
from diffusers import FluxPipeline
|
30 |
+
from PIL import Image
|
31 |
+
from torchvision.utils import make_grid, save_image
|
32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
33 |
+
|
34 |
+
from app import safety_check
|
35 |
+
from app.sana_pipeline import SanaPipeline
|
36 |
+
|
37 |
+
MAX_SEED = np.iinfo(np.int32).max
|
38 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
39 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
40 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
41 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
42 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
43 |
+
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
44 |
+
|
45 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
46 |
+
|
47 |
+
style_list = [
|
48 |
+
{
|
49 |
+
"name": "(No style)",
|
50 |
+
"prompt": "{prompt}",
|
51 |
+
"negative_prompt": "",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"name": "Cinematic",
|
55 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
56 |
+
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
57 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"name": "Photographic",
|
61 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
62 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"name": "Anime",
|
66 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
67 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"name": "Manga",
|
71 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
72 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"name": "Digital Art",
|
76 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
77 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"name": "Pixel art",
|
81 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
82 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"name": "Fantasy art",
|
86 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
87 |
+
"majestic, magical, fantasy art, cover art, dreamy",
|
88 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
89 |
+
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
90 |
+
"disfigured, sloppy, duplicate, mutated, black and white",
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"name": "Neonpunk",
|
94 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
95 |
+
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
96 |
+
"ultra detailed, intricate, professional",
|
97 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"name": "3D Model",
|
101 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
102 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
103 |
+
},
|
104 |
+
]
|
105 |
+
|
106 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
107 |
+
STYLE_NAMES = list(styles.keys())
|
108 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
109 |
+
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
110 |
+
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
111 |
+
NUM_IMAGES_PER_PROMPT = 1
|
112 |
+
TEST_TIMES = 0
|
113 |
+
FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
|
114 |
+
|
115 |
+
|
116 |
+
def set_env(seed=0):
|
117 |
+
torch.manual_seed(seed)
|
118 |
+
torch.set_grad_enabled(False)
|
119 |
+
|
120 |
+
|
121 |
+
def read_inference_count():
|
122 |
+
global TEST_TIMES
|
123 |
+
try:
|
124 |
+
with open(FILENAME) as f:
|
125 |
+
count = int(f.read().strip())
|
126 |
+
except FileNotFoundError:
|
127 |
+
count = 0
|
128 |
+
TEST_TIMES = count
|
129 |
+
|
130 |
+
return count
|
131 |
+
|
132 |
+
|
133 |
+
def write_inference_count(count):
|
134 |
+
with open(FILENAME, "w") as f:
|
135 |
+
f.write(str(count))
|
136 |
+
|
137 |
+
|
138 |
+
def run_inference(num_imgs=1):
|
139 |
+
TEST_TIMES = read_inference_count()
|
140 |
+
TEST_TIMES += int(num_imgs)
|
141 |
+
write_inference_count(TEST_TIMES)
|
142 |
+
|
143 |
+
return (
|
144 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
145 |
+
f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
|
146 |
+
)
|
147 |
+
|
148 |
+
|
149 |
+
def update_inference_count():
|
150 |
+
count = read_inference_count()
|
151 |
+
return (
|
152 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
153 |
+
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
158 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
159 |
+
if not negative:
|
160 |
+
negative = ""
|
161 |
+
return p.replace("{prompt}", positive), n + negative
|
162 |
+
|
163 |
+
|
164 |
+
def get_args():
|
165 |
+
parser = argparse.ArgumentParser()
|
166 |
+
parser.add_argument("--config", type=str, help="config")
|
167 |
+
parser.add_argument(
|
168 |
+
"--model_path",
|
169 |
+
nargs="?",
|
170 |
+
default="output/Sana_D20/SANA.pth",
|
171 |
+
type=str,
|
172 |
+
help="Path to the model file (positional)",
|
173 |
+
)
|
174 |
+
parser.add_argument("--output", default="./", type=str)
|
175 |
+
parser.add_argument("--bs", default=1, type=int)
|
176 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
177 |
+
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
178 |
+
parser.add_argument("--pag_scale", default=2.0, type=float)
|
179 |
+
parser.add_argument("--seed", default=42, type=int)
|
180 |
+
parser.add_argument("--step", default=-1, type=int)
|
181 |
+
parser.add_argument("--custom_image_size", default=None, type=int)
|
182 |
+
parser.add_argument(
|
183 |
+
"--shield_model_path",
|
184 |
+
type=str,
|
185 |
+
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
186 |
+
default="google/shieldgemma-2b",
|
187 |
+
)
|
188 |
+
|
189 |
+
return parser.parse_args()
|
190 |
+
|
191 |
+
|
192 |
+
args = get_args()
|
193 |
+
|
194 |
+
if torch.cuda.is_available():
|
195 |
+
weight_dtype = torch.float16
|
196 |
+
model_path = args.model_path
|
197 |
+
pipe = SanaPipeline(args.config)
|
198 |
+
pipe.from_pretrained(model_path)
|
199 |
+
pipe.register_progress_bar(gr.Progress())
|
200 |
+
|
201 |
+
repo_name = "black-forest-labs/FLUX.1-dev"
|
202 |
+
pipe2 = FluxPipeline.from_pretrained(repo_name, torch_dtype=torch.float16).to("cuda")
|
203 |
+
|
204 |
+
# safety checker
|
205 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
206 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
207 |
+
args.shield_model_path,
|
208 |
+
device_map="auto",
|
209 |
+
torch_dtype=torch.bfloat16,
|
210 |
+
).to(device)
|
211 |
+
|
212 |
+
set_env(42)
|
213 |
+
|
214 |
+
|
215 |
+
def save_image_sana(img, seed="", save_img=False):
|
216 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
217 |
+
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
218 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
219 |
+
os.makedirs(save_path, exist_ok=True)
|
220 |
+
unique_name = os.path.join(save_path, unique_name)
|
221 |
+
if save_img:
|
222 |
+
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
223 |
+
|
224 |
+
return unique_name
|
225 |
+
|
226 |
+
|
227 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
228 |
+
if randomize_seed:
|
229 |
+
seed = random.randint(0, MAX_SEED)
|
230 |
+
return seed
|
231 |
+
|
232 |
+
|
233 |
+
@spaces.GPU(enable_queue=True)
|
234 |
+
async def generate_2(
|
235 |
+
prompt: str = None,
|
236 |
+
negative_prompt: str = "",
|
237 |
+
style: str = DEFAULT_STYLE_NAME,
|
238 |
+
use_negative_prompt: bool = False,
|
239 |
+
num_imgs: int = 1,
|
240 |
+
seed: int = 0,
|
241 |
+
height: int = 1024,
|
242 |
+
width: int = 1024,
|
243 |
+
flow_dpms_guidance_scale: float = 5.0,
|
244 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
245 |
+
flow_dpms_inference_steps: int = 20,
|
246 |
+
randomize_seed: bool = False,
|
247 |
+
):
|
248 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
249 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
250 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
251 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
|
252 |
+
prompt = "A red heart."
|
253 |
+
|
254 |
+
print(prompt)
|
255 |
+
|
256 |
+
if not use_negative_prompt:
|
257 |
+
negative_prompt = None # type: ignore
|
258 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
259 |
+
|
260 |
+
with torch.no_grad():
|
261 |
+
images = pipe2(
|
262 |
+
prompt=prompt,
|
263 |
+
height=height,
|
264 |
+
width=width,
|
265 |
+
guidance_scale=3.5,
|
266 |
+
num_inference_steps=50,
|
267 |
+
num_images_per_prompt=num_imgs,
|
268 |
+
max_sequence_length=256,
|
269 |
+
generator=generator,
|
270 |
+
).images
|
271 |
+
|
272 |
+
save_img = False
|
273 |
+
img = images
|
274 |
+
if save_img:
|
275 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
276 |
+
print(img)
|
277 |
+
torch.cuda.empty_cache()
|
278 |
+
|
279 |
+
return img
|
280 |
+
|
281 |
+
|
282 |
+
@spaces.GPU(enable_queue=True)
|
283 |
+
async def generate(
|
284 |
+
prompt: str = None,
|
285 |
+
negative_prompt: str = "",
|
286 |
+
style: str = DEFAULT_STYLE_NAME,
|
287 |
+
use_negative_prompt: bool = False,
|
288 |
+
num_imgs: int = 1,
|
289 |
+
seed: int = 0,
|
290 |
+
height: int = 1024,
|
291 |
+
width: int = 1024,
|
292 |
+
flow_dpms_guidance_scale: float = 5.0,
|
293 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
294 |
+
flow_dpms_inference_steps: int = 20,
|
295 |
+
randomize_seed: bool = False,
|
296 |
+
):
|
297 |
+
global TEST_TIMES
|
298 |
+
# seed = 823753551
|
299 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
300 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
301 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
|
302 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
|
303 |
+
prompt = "A red heart."
|
304 |
+
|
305 |
+
print(prompt)
|
306 |
+
|
307 |
+
num_inference_steps = flow_dpms_inference_steps
|
308 |
+
guidance_scale = flow_dpms_guidance_scale
|
309 |
+
pag_guidance_scale = flow_dpms_pag_guidance_scale
|
310 |
+
|
311 |
+
if not use_negative_prompt:
|
312 |
+
negative_prompt = None # type: ignore
|
313 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
314 |
+
|
315 |
+
pipe.progress_fn(0, desc="Sana Start")
|
316 |
+
|
317 |
+
with torch.no_grad():
|
318 |
+
images = pipe(
|
319 |
+
prompt=prompt,
|
320 |
+
height=height,
|
321 |
+
width=width,
|
322 |
+
negative_prompt=negative_prompt,
|
323 |
+
guidance_scale=guidance_scale,
|
324 |
+
pag_guidance_scale=pag_guidance_scale,
|
325 |
+
num_inference_steps=num_inference_steps,
|
326 |
+
num_images_per_prompt=num_imgs,
|
327 |
+
generator=generator,
|
328 |
+
)
|
329 |
+
|
330 |
+
pipe.progress_fn(1.0, desc="Sana End")
|
331 |
+
|
332 |
+
save_img = False
|
333 |
+
if save_img:
|
334 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
335 |
+
print(img)
|
336 |
+
else:
|
337 |
+
if num_imgs > 1:
|
338 |
+
nrow = 2
|
339 |
+
else:
|
340 |
+
nrow = 1
|
341 |
+
img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
|
342 |
+
img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
343 |
+
img = [Image.fromarray(img.astype(np.uint8))]
|
344 |
+
|
345 |
+
torch.cuda.empty_cache()
|
346 |
+
|
347 |
+
return img
|
348 |
+
|
349 |
+
|
350 |
+
TEST_TIMES = read_inference_count()
|
351 |
+
model_size = "1.6" if "D20" in args.model_path else "0.6"
|
352 |
+
title = f"""
|
353 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
354 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
355 |
+
</div>
|
356 |
+
"""
|
357 |
+
DESCRIPTION = f"""
|
358 |
+
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
359 |
+
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
360 |
+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github(coming soon)]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
361 |
+
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>
|
362 |
+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
363 |
+
"""
|
364 |
+
if model_size == "0.6":
|
365 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
366 |
+
if not torch.cuda.is_available():
|
367 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
368 |
+
|
369 |
+
examples = [
|
370 |
+
'a cyberpunk cat with a neon sign that says "Sana"',
|
371 |
+
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
372 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
373 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
374 |
+
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
375 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
376 |
+
# "👧 with 🌹 in the ❄️",
|
377 |
+
# "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
378 |
+
# "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
379 |
+
# "Astronaut in a jungle, cold color palette, muted colors, detailed",
|
380 |
+
# "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
381 |
+
]
|
382 |
+
|
383 |
+
css = """
|
384 |
+
.gradio-container{max-width: 1024px !important}
|
385 |
+
h1{text-align:center}
|
386 |
+
"""
|
387 |
+
with gr.Blocks(css=css) as demo:
|
388 |
+
gr.Markdown(title)
|
389 |
+
gr.Markdown(DESCRIPTION)
|
390 |
+
gr.DuplicateButton(
|
391 |
+
value="Duplicate Space for private use",
|
392 |
+
elem_id="duplicate-button",
|
393 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
394 |
+
)
|
395 |
+
info_box = gr.Markdown(
|
396 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
|
397 |
+
)
|
398 |
+
demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
|
399 |
+
# with gr.Row(equal_height=False):
|
400 |
+
with gr.Group():
|
401 |
+
with gr.Row():
|
402 |
+
prompt = gr.Text(
|
403 |
+
label="Prompt",
|
404 |
+
show_label=False,
|
405 |
+
max_lines=1,
|
406 |
+
placeholder="Enter your prompt",
|
407 |
+
container=False,
|
408 |
+
)
|
409 |
+
run_button = gr.Button("Run-sana", scale=0)
|
410 |
+
run_button2 = gr.Button("Run-flux", scale=0)
|
411 |
+
|
412 |
+
with gr.Row():
|
413 |
+
result = gr.Gallery(label="Result from Sana", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp")
|
414 |
+
result_2 = gr.Gallery(
|
415 |
+
label="Result from FLUX", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp"
|
416 |
+
)
|
417 |
+
|
418 |
+
with gr.Accordion("Advanced options", open=False):
|
419 |
+
with gr.Group():
|
420 |
+
with gr.Row(visible=True):
|
421 |
+
height = gr.Slider(
|
422 |
+
label="Height",
|
423 |
+
minimum=256,
|
424 |
+
maximum=MAX_IMAGE_SIZE,
|
425 |
+
step=32,
|
426 |
+
value=1024,
|
427 |
+
)
|
428 |
+
width = gr.Slider(
|
429 |
+
label="Width",
|
430 |
+
minimum=256,
|
431 |
+
maximum=MAX_IMAGE_SIZE,
|
432 |
+
step=32,
|
433 |
+
value=1024,
|
434 |
+
)
|
435 |
+
with gr.Row():
|
436 |
+
flow_dpms_inference_steps = gr.Slider(
|
437 |
+
label="Sampling steps",
|
438 |
+
minimum=5,
|
439 |
+
maximum=40,
|
440 |
+
step=1,
|
441 |
+
value=18,
|
442 |
+
)
|
443 |
+
flow_dpms_guidance_scale = gr.Slider(
|
444 |
+
label="CFG Guidance scale",
|
445 |
+
minimum=1,
|
446 |
+
maximum=10,
|
447 |
+
step=0.1,
|
448 |
+
value=5.0,
|
449 |
+
)
|
450 |
+
flow_dpms_pag_guidance_scale = gr.Slider(
|
451 |
+
label="PAG Guidance scale",
|
452 |
+
minimum=1,
|
453 |
+
maximum=4,
|
454 |
+
step=0.5,
|
455 |
+
value=2.0,
|
456 |
+
)
|
457 |
+
with gr.Row():
|
458 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
459 |
+
negative_prompt = gr.Text(
|
460 |
+
label="Negative prompt",
|
461 |
+
max_lines=1,
|
462 |
+
placeholder="Enter a negative prompt",
|
463 |
+
visible=True,
|
464 |
+
)
|
465 |
+
style_selection = gr.Radio(
|
466 |
+
show_label=True,
|
467 |
+
container=True,
|
468 |
+
interactive=True,
|
469 |
+
choices=STYLE_NAMES,
|
470 |
+
value=DEFAULT_STYLE_NAME,
|
471 |
+
label="Image Style",
|
472 |
+
)
|
473 |
+
seed = gr.Slider(
|
474 |
+
label="Seed",
|
475 |
+
minimum=0,
|
476 |
+
maximum=MAX_SEED,
|
477 |
+
step=1,
|
478 |
+
value=0,
|
479 |
+
)
|
480 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
481 |
+
with gr.Row(visible=True):
|
482 |
+
schedule = gr.Radio(
|
483 |
+
show_label=True,
|
484 |
+
container=True,
|
485 |
+
interactive=True,
|
486 |
+
choices=SCHEDULE_NAME,
|
487 |
+
value=DEFAULT_SCHEDULE_NAME,
|
488 |
+
label="Sampler Schedule",
|
489 |
+
visible=True,
|
490 |
+
)
|
491 |
+
num_imgs = gr.Slider(
|
492 |
+
label="Num Images",
|
493 |
+
minimum=1,
|
494 |
+
maximum=6,
|
495 |
+
step=1,
|
496 |
+
value=1,
|
497 |
+
)
|
498 |
+
|
499 |
+
run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
|
500 |
+
|
501 |
+
gr.Examples(
|
502 |
+
examples=examples,
|
503 |
+
inputs=prompt,
|
504 |
+
outputs=[result],
|
505 |
+
fn=generate,
|
506 |
+
cache_examples=CACHE_EXAMPLES,
|
507 |
+
)
|
508 |
+
gr.Examples(
|
509 |
+
examples=examples,
|
510 |
+
inputs=prompt,
|
511 |
+
outputs=[result_2],
|
512 |
+
fn=generate_2,
|
513 |
+
cache_examples=CACHE_EXAMPLES,
|
514 |
+
)
|
515 |
+
|
516 |
+
use_negative_prompt.change(
|
517 |
+
fn=lambda x: gr.update(visible=x),
|
518 |
+
inputs=use_negative_prompt,
|
519 |
+
outputs=negative_prompt,
|
520 |
+
api_name=False,
|
521 |
+
)
|
522 |
+
|
523 |
+
run_button.click(
|
524 |
+
fn=generate,
|
525 |
+
inputs=[
|
526 |
+
prompt,
|
527 |
+
negative_prompt,
|
528 |
+
style_selection,
|
529 |
+
use_negative_prompt,
|
530 |
+
num_imgs,
|
531 |
+
seed,
|
532 |
+
height,
|
533 |
+
width,
|
534 |
+
flow_dpms_guidance_scale,
|
535 |
+
flow_dpms_pag_guidance_scale,
|
536 |
+
flow_dpms_inference_steps,
|
537 |
+
randomize_seed,
|
538 |
+
],
|
539 |
+
outputs=[result],
|
540 |
+
queue=True,
|
541 |
+
)
|
542 |
+
|
543 |
+
run_button2.click(
|
544 |
+
fn=generate_2,
|
545 |
+
inputs=[
|
546 |
+
prompt,
|
547 |
+
negative_prompt,
|
548 |
+
style_selection,
|
549 |
+
use_negative_prompt,
|
550 |
+
num_imgs,
|
551 |
+
seed,
|
552 |
+
height,
|
553 |
+
width,
|
554 |
+
flow_dpms_guidance_scale,
|
555 |
+
flow_dpms_pag_guidance_scale,
|
556 |
+
flow_dpms_inference_steps,
|
557 |
+
randomize_seed,
|
558 |
+
],
|
559 |
+
outputs=[result_2],
|
560 |
+
queue=True,
|
561 |
+
)
|
562 |
+
|
563 |
+
|
564 |
+
if __name__ == "__main__":
|
565 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
|
app/safety_check.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
# This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
|
17 |
+
import json
|
18 |
+
from functools import lru_cache
|
19 |
+
from typing import Dict, Sequence
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import transformers
|
23 |
+
from torch.nn.functional import softmax
|
24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
25 |
+
|
26 |
+
__all__ = ["is_dangerous"]
|
27 |
+
|
28 |
+
__rank, __local_rank, __world_size, __device = (
|
29 |
+
0,
|
30 |
+
0,
|
31 |
+
1,
|
32 |
+
"cuda" if torch.cuda.is_available() else "cpu",
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
safety_policy = """
|
37 |
+
The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
|
42 |
+
prompt = f"""You are a policy expert trying to help determine whether a user
|
43 |
+
prompt is in violation of the defined safety policies.
|
44 |
+
|
45 |
+
<start_of_turn>
|
46 |
+
Human Question: {user_prompt.strip()}
|
47 |
+
<end_of_turn>
|
48 |
+
|
49 |
+
Our safety principle is defined in the below:
|
50 |
+
|
51 |
+
{safety_policy.strip()}
|
52 |
+
|
53 |
+
Does the human question violate the above principle? Your answer must start
|
54 |
+
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
|
55 |
+
correctly.
|
56 |
+
"""
|
57 |
+
|
58 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
59 |
+
with torch.no_grad():
|
60 |
+
logits = model(**inputs).logits
|
61 |
+
|
62 |
+
# Extract the logits for the Yes and No tokens
|
63 |
+
vocab = tokenizer.get_vocab()
|
64 |
+
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
|
65 |
+
|
66 |
+
# Convert these logits to a probability with softmax
|
67 |
+
probabilities = softmax(selected_logits, dim=0)
|
68 |
+
|
69 |
+
# Return probability of 'Yes'
|
70 |
+
score = probabilities[0].item()
|
71 |
+
|
72 |
+
return score > threshold
|
app/sana_pipeline.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
import argparse
|
17 |
+
import warnings
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from typing import Optional, Tuple
|
20 |
+
|
21 |
+
import pyrallis
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
warnings.filterwarnings("ignore") # ignore warning
|
26 |
+
|
27 |
+
|
28 |
+
from diffusion import DPMS, FlowEuler
|
29 |
+
from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST
|
30 |
+
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
|
31 |
+
from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor
|
32 |
+
from diffusion.utils.config import SanaConfig
|
33 |
+
from diffusion.utils.logger import get_root_logger
|
34 |
+
|
35 |
+
# from diffusion.utils.misc import read_config
|
36 |
+
from tools.download import find_model
|
37 |
+
|
38 |
+
|
39 |
+
def guidance_type_select(default_guidance_type, pag_scale, attn_type):
|
40 |
+
guidance_type = default_guidance_type
|
41 |
+
if not (pag_scale > 1.0 and attn_type == "linear"):
|
42 |
+
guidance_type = "classifier-free"
|
43 |
+
return guidance_type
|
44 |
+
|
45 |
+
|
46 |
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
47 |
+
"""Returns binned height and width."""
|
48 |
+
ar = float(height / width)
|
49 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
50 |
+
default_hw = ratios[closest_ratio]
|
51 |
+
return int(default_hw[0]), int(default_hw[1])
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class SanaInference(SanaConfig):
|
56 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
|
57 |
+
model_path: str = field(
|
58 |
+
default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
|
59 |
+
)
|
60 |
+
output: str = "./output"
|
61 |
+
bs: int = 1
|
62 |
+
image_size: int = 1024
|
63 |
+
cfg_scale: float = 5.0
|
64 |
+
pag_scale: float = 2.0
|
65 |
+
seed: int = 42
|
66 |
+
step: int = -1
|
67 |
+
custom_image_size: Optional[int] = None
|
68 |
+
shield_model_path: str = field(
|
69 |
+
default="google/shieldgemma-2b",
|
70 |
+
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
class SanaPipeline(nn.Module):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
config = pyrallis.load(SanaInference, open(config))
|
81 |
+
self.args = self.config = config
|
82 |
+
|
83 |
+
# set some hyper-parameters
|
84 |
+
self.image_size = self.config.model.image_size
|
85 |
+
|
86 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
87 |
+
logger = get_root_logger()
|
88 |
+
self.logger = logger
|
89 |
+
self.progress_fn = lambda progress, desc: None
|
90 |
+
|
91 |
+
self.latent_size = self.image_size // config.vae.vae_downsample_rate
|
92 |
+
self.max_sequence_length = config.text_encoder.model_max_length
|
93 |
+
self.flow_shift = config.scheduler.flow_shift
|
94 |
+
guidance_type = "classifier-free_PAG"
|
95 |
+
|
96 |
+
if config.model.mixed_precision == "fp16":
|
97 |
+
weight_dtype = torch.float16
|
98 |
+
elif config.model.mixed_precision == "bf16":
|
99 |
+
weight_dtype = torch.bfloat16
|
100 |
+
elif config.model.mixed_precision == "fp32":
|
101 |
+
weight_dtype = torch.float32
|
102 |
+
else:
|
103 |
+
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
|
104 |
+
self.weight_dtype = weight_dtype
|
105 |
+
|
106 |
+
self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
|
107 |
+
self.vis_sampler = self.config.scheduler.vis_sampler
|
108 |
+
logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
|
109 |
+
self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
|
110 |
+
logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
|
111 |
+
|
112 |
+
# 1. build vae and text encoder
|
113 |
+
self.vae = self.build_vae(config.vae)
|
114 |
+
self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
|
115 |
+
|
116 |
+
# 2. build Sana model
|
117 |
+
self.model = self.build_sana_model(config).to(self.device)
|
118 |
+
|
119 |
+
# 3. pre-compute null embedding
|
120 |
+
with torch.no_grad():
|
121 |
+
null_caption_token = self.tokenizer(
|
122 |
+
"", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
|
123 |
+
).to(self.device)
|
124 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
125 |
+
0
|
126 |
+
]
|
127 |
+
|
128 |
+
def build_vae(self, config):
|
129 |
+
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.weight_dtype)
|
130 |
+
return vae
|
131 |
+
|
132 |
+
def build_text_encoder(self, config):
|
133 |
+
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
|
134 |
+
return tokenizer, text_encoder
|
135 |
+
|
136 |
+
def build_sana_model(self, config):
|
137 |
+
# model setting
|
138 |
+
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
|
139 |
+
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
|
140 |
+
model_kwargs = {
|
141 |
+
"input_size": self.latent_size,
|
142 |
+
"pe_interpolation": config.model.pe_interpolation,
|
143 |
+
"config": config,
|
144 |
+
"model_max_length": config.text_encoder.model_max_length,
|
145 |
+
"qk_norm": config.model.qk_norm,
|
146 |
+
"micro_condition": config.model.micro_condition,
|
147 |
+
"caption_channels": self.text_encoder.config.hidden_size,
|
148 |
+
"y_norm": config.text_encoder.y_norm,
|
149 |
+
"attn_type": config.model.attn_type,
|
150 |
+
"ffn_type": config.model.ffn_type,
|
151 |
+
"mlp_ratio": config.model.mlp_ratio,
|
152 |
+
"mlp_acts": list(config.model.mlp_acts),
|
153 |
+
"in_channels": config.vae.vae_latent_dim,
|
154 |
+
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
|
155 |
+
"use_pe": config.model.use_pe,
|
156 |
+
"pred_sigma": pred_sigma,
|
157 |
+
"learn_sigma": learn_sigma,
|
158 |
+
"use_fp32_attention": config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
|
159 |
+
}
|
160 |
+
model = build_model(config.model.model, **model_kwargs)
|
161 |
+
model = model.to(self.weight_dtype)
|
162 |
+
|
163 |
+
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
|
164 |
+
self.logger.info(
|
165 |
+
f"{model.__class__.__name__}:{config.model.model},"
|
166 |
+
f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
|
167 |
+
)
|
168 |
+
return model
|
169 |
+
|
170 |
+
def from_pretrained(self, model_path):
|
171 |
+
state_dict = find_model(model_path)
|
172 |
+
state_dict = state_dict.get("state_dict", state_dict)
|
173 |
+
if "pos_embed" in state_dict:
|
174 |
+
del state_dict["pos_embed"]
|
175 |
+
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
|
176 |
+
self.model.eval().to(self.weight_dtype)
|
177 |
+
|
178 |
+
self.logger.info("Generating sample from ckpt: %s" % model_path)
|
179 |
+
self.logger.warning(f"Missing keys: {missing}")
|
180 |
+
self.logger.warning(f"Unexpected keys: {unexpected}")
|
181 |
+
|
182 |
+
def register_progress_bar(self, progress_fn=None):
|
183 |
+
self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
|
184 |
+
|
185 |
+
@torch.inference_mode()
|
186 |
+
def forward(
|
187 |
+
self,
|
188 |
+
prompt=None,
|
189 |
+
height=1024,
|
190 |
+
width=1024,
|
191 |
+
negative_prompt="",
|
192 |
+
num_inference_steps=20,
|
193 |
+
guidance_scale=5,
|
194 |
+
pag_guidance_scale=2.5,
|
195 |
+
num_images_per_prompt=1,
|
196 |
+
generator=torch.Generator().manual_seed(42),
|
197 |
+
latents=None,
|
198 |
+
):
|
199 |
+
self.ori_height, self.ori_width = height, width
|
200 |
+
self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
|
201 |
+
self.latent_size_h, self.latent_size_w = (
|
202 |
+
self.height // self.config.vae.vae_downsample_rate,
|
203 |
+
self.width // self.config.vae.vae_downsample_rate,
|
204 |
+
)
|
205 |
+
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
|
206 |
+
|
207 |
+
# 1. pre-compute negative embedding
|
208 |
+
if negative_prompt != "":
|
209 |
+
null_caption_token = self.tokenizer(
|
210 |
+
negative_prompt,
|
211 |
+
max_length=self.max_sequence_length,
|
212 |
+
padding="max_length",
|
213 |
+
truncation=True,
|
214 |
+
return_tensors="pt",
|
215 |
+
).to(self.device)
|
216 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
217 |
+
0
|
218 |
+
]
|
219 |
+
|
220 |
+
if prompt is None:
|
221 |
+
prompt = [""]
|
222 |
+
prompts = prompt if isinstance(prompt, list) else [prompt]
|
223 |
+
samples = []
|
224 |
+
|
225 |
+
for prompt in prompts:
|
226 |
+
# data prepare
|
227 |
+
prompts, hw, ar = (
|
228 |
+
[],
|
229 |
+
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
|
230 |
+
num_images_per_prompt, 1
|
231 |
+
),
|
232 |
+
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
|
233 |
+
)
|
234 |
+
for _ in range(num_images_per_prompt):
|
235 |
+
with torch.no_grad():
|
236 |
+
prompts.append(
|
237 |
+
prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()
|
238 |
+
)
|
239 |
+
|
240 |
+
# prepare text feature
|
241 |
+
if not self.config.text_encoder.chi_prompt:
|
242 |
+
max_length_all = self.config.text_encoder.model_max_length
|
243 |
+
prompts_all = prompts
|
244 |
+
else:
|
245 |
+
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
|
246 |
+
prompts_all = [chi_prompt + prompt for prompt in prompts]
|
247 |
+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
248 |
+
max_length_all = (
|
249 |
+
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
|
250 |
+
) # magic number 2: [bos], [_]
|
251 |
+
|
252 |
+
caption_token = self.tokenizer(
|
253 |
+
prompts_all,
|
254 |
+
max_length=max_length_all,
|
255 |
+
padding="max_length",
|
256 |
+
truncation=True,
|
257 |
+
return_tensors="pt",
|
258 |
+
).to(device=self.device)
|
259 |
+
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
|
260 |
+
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
|
261 |
+
:, :, select_index
|
262 |
+
].to(self.weight_dtype)
|
263 |
+
emb_masks = caption_token.attention_mask[:, select_index]
|
264 |
+
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
|
265 |
+
|
266 |
+
n = len(prompts)
|
267 |
+
if latents is None:
|
268 |
+
z = torch.randn(
|
269 |
+
n,
|
270 |
+
self.config.vae.vae_latent_dim,
|
271 |
+
self.latent_size_h,
|
272 |
+
self.latent_size_w,
|
273 |
+
generator=generator,
|
274 |
+
device=self.device,
|
275 |
+
dtype=self.weight_dtype,
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
z = latents.to(self.weight_dtype).to(self.device)
|
279 |
+
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
|
280 |
+
if self.vis_sampler == "flow_euler":
|
281 |
+
flow_solver = FlowEuler(
|
282 |
+
self.model,
|
283 |
+
condition=caption_embs,
|
284 |
+
uncondition=null_y,
|
285 |
+
cfg_scale=guidance_scale,
|
286 |
+
model_kwargs=model_kwargs,
|
287 |
+
)
|
288 |
+
sample = flow_solver.sample(
|
289 |
+
z,
|
290 |
+
steps=num_inference_steps,
|
291 |
+
)
|
292 |
+
elif self.vis_sampler == "flow_dpm-solver":
|
293 |
+
scheduler = DPMS(
|
294 |
+
self.model,
|
295 |
+
condition=caption_embs,
|
296 |
+
uncondition=null_y,
|
297 |
+
guidance_type=self.guidance_type,
|
298 |
+
cfg_scale=guidance_scale,
|
299 |
+
pag_scale=pag_guidance_scale,
|
300 |
+
pag_applied_layers=self.config.model.pag_applied_layers,
|
301 |
+
model_type="flow",
|
302 |
+
model_kwargs=model_kwargs,
|
303 |
+
schedule="FLOW",
|
304 |
+
)
|
305 |
+
scheduler.register_progress_bar(self.progress_fn)
|
306 |
+
sample = scheduler.sample(
|
307 |
+
z,
|
308 |
+
steps=num_inference_steps,
|
309 |
+
order=2,
|
310 |
+
skip_type="time_uniform_flow",
|
311 |
+
method="multistep",
|
312 |
+
flow_shift=self.flow_shift,
|
313 |
+
)
|
314 |
+
|
315 |
+
sample = sample.to(self.weight_dtype)
|
316 |
+
with torch.no_grad():
|
317 |
+
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
|
318 |
+
|
319 |
+
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
|
320 |
+
samples.append(sample)
|
321 |
+
|
322 |
+
return sample
|
323 |
+
|
324 |
+
return samples
|
asset/Sana.jpg
ADDED
Git LFS Details
|
asset/docs/metrics_toolkit.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 💻 How to Inference & Test Metrics (FID, CLIP Score, GenEval, DPG-Bench, etc...)
|
2 |
+
|
3 |
+
This ToolKit will automatically inference your model and log the metrics results onto wandb as chart for better illustration. We curerntly support:
|
4 |
+
|
5 |
+
- \[x\] [FID](https://github.com/mseitzer/pytorch-fid) & [CLIP-Score](https://github.com/openai/CLIP)
|
6 |
+
- \[x\] [GenEval](https://github.com/djghosh13/geneval)
|
7 |
+
- \[x\] [DPG-Bench](https://github.com/TencentQQGYLab/ELLA)
|
8 |
+
- \[x\] [ImageReward](https://github.com/THUDM/ImageReward/tree/main)
|
9 |
+
|
10 |
+
### 0. Install corresponding env for GenEval and DPG-Bench
|
11 |
+
|
12 |
+
Make sure you can activate the following envs:
|
13 |
+
|
14 |
+
- `conda activate geneval`([GenEval](https://github.com/djghosh13/geneval))
|
15 |
+
- `conda activate dpg`([DGB-Bench](https://github.com/TencentQQGYLab/ELLA))
|
16 |
+
|
17 |
+
### 0.1 Prepare data.
|
18 |
+
|
19 |
+
Metirc FID & CLIP-Score on [MJHQ-30K](https://huggingface.co/datasets/playgroundai/MJHQ-30K)
|
20 |
+
|
21 |
+
```python
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
|
24 |
+
hf_hub_download(
|
25 |
+
repo_id="playgroundai/MJHQ-30K",
|
26 |
+
filename="mjhq30k_imgs.zip",
|
27 |
+
local_dir="data/test/PG-eval-data/MJHQ-30K/",
|
28 |
+
repo_type="dataset"
|
29 |
+
)
|
30 |
+
```
|
31 |
+
|
32 |
+
Unzip mjhq30k_imgs.zip into its per-category folder structure.
|
33 |
+
|
34 |
+
```
|
35 |
+
data/test/PG-eval-data/MJHQ-30K/imgs/
|
36 |
+
├── animals
|
37 |
+
├── art
|
38 |
+
├── fashion
|
39 |
+
├── food
|
40 |
+
├── indoor
|
41 |
+
├── landscape
|
42 |
+
├── logo
|
43 |
+
├── people
|
44 |
+
├── plants
|
45 |
+
└── vehicles
|
46 |
+
```
|
47 |
+
|
48 |
+
### 0.2 Prepare checkpoints
|
49 |
+
|
50 |
+
```bash
|
51 |
+
huggingface-cli download Efficient-Large-Model/Sana_1600M_1024px --repo-type model --local-dir ./output/Sana_1600M_1024px --local-dir-use-symlinks False
|
52 |
+
```
|
53 |
+
|
54 |
+
### 1. directly \[Inference and Metric\] a .pth file
|
55 |
+
|
56 |
+
```bash
|
57 |
+
# We provide four scripts for evaluating metrics:
|
58 |
+
fid_clipscore_launch=scripts/bash_run_inference_metric.sh
|
59 |
+
geneval_launch=scripts/bash_run_inference_metric_geneval.sh
|
60 |
+
dpg_launch=scripts/bash_run_inference_metric_dpg.sh
|
61 |
+
image_reward_launch=scripts/bash_run_inference_metric_imagereward.sh
|
62 |
+
|
63 |
+
# Use following format to metric your models:
|
64 |
+
# bash $correspoinding_metric_launch $your_config_file_path $your_relative_pth_file_path
|
65 |
+
|
66 |
+
# example
|
67 |
+
bash $geneval_launch \
|
68 |
+
configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
|
69 |
+
output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
|
70 |
+
```
|
71 |
+
|
72 |
+
### 2. \[Inference and Metric\] a list of .pth files using a txt file
|
73 |
+
|
74 |
+
You can also write all your pth files of a job in one txt file, eg. [model_paths.txt](../model_paths.txt)
|
75 |
+
|
76 |
+
```bash
|
77 |
+
# Use following format to metric your models, gathering in a txt file:
|
78 |
+
# bash $correspoinding_metric_launch $your_config_file_path $your_txt_file_path_containing_pth_path
|
79 |
+
|
80 |
+
# We suggest follow the file tree structure in our project for robust experiment
|
81 |
+
# example
|
82 |
+
bash scripts/bash_run_inference_metric.sh \
|
83 |
+
configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
|
84 |
+
asset/model_paths.txt
|
85 |
+
```
|
86 |
+
|
87 |
+
### 3. You will get the following data tree.
|
88 |
+
|
89 |
+
```
|
90 |
+
output
|
91 |
+
├──your_job_name/ (everything will be saved here)
|
92 |
+
│ ├──config.yaml
|
93 |
+
│ ├──train_log.log
|
94 |
+
|
95 |
+
│ ├──checkpoints (all checkpoints)
|
96 |
+
│ │ ├──epoch_1_step_6666.pth
|
97 |
+
│ │ ├──epoch_1_step_8888.pth
|
98 |
+
│ │ ├──......
|
99 |
+
|
100 |
+
│ ├──vis (all visualization result dirs)
|
101 |
+
│ │ ├──visualization_file_name
|
102 |
+
│ │ │ ├──xxxxxxx.jpg
|
103 |
+
│ │ │ ├──......
|
104 |
+
│ │ ├──visualization_file_name2
|
105 |
+
│ │ │ ├──xxxxxxx.jpg
|
106 |
+
│ │ │ ├──......
|
107 |
+
│ ├──......
|
108 |
+
|
109 |
+
│ ├──metrics (all metrics testing related files)
|
110 |
+
│ │ ├──model_paths.txt Optional(👈)(relative path of testing ckpts)
|
111 |
+
│ │ │ ├──output/your_job_name/checkpoings/epoch_1_step_6666.pth
|
112 |
+
│ │ │ ├──output/your_job_name/checkpoings/epoch_1_step_8888.pth
|
113 |
+
│ │ ├──fid_img_paths.txt Optional(👈)(name of testing img_dir in vis)
|
114 |
+
│ │ │ ├──visualization_file_name
|
115 |
+
│ │ │ ├──visualization_file_name2
|
116 |
+
│ │ ├──cached_img_paths.txt Optional(👈)
|
117 |
+
│ │ ├──......
|
118 |
+
```
|
asset/example_data/00000000.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
a cyberpunk cat with a neon sign that says "Sana".
|
asset/examples.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
examples = [
|
18 |
+
[
|
19 |
+
"A small cactus with a happy face in the Sahara desert.",
|
20 |
+
"flow_dpm-solver",
|
21 |
+
20,
|
22 |
+
5.0,
|
23 |
+
2.5,
|
24 |
+
],
|
25 |
+
[
|
26 |
+
"An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history"
|
27 |
+
"of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits "
|
28 |
+
"mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret "
|
29 |
+
"and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile "
|
30 |
+
"as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and "
|
31 |
+
"the Parisian streets and city in the background, depth of field, cinematic 35mm film.",
|
32 |
+
"flow_dpm-solver",
|
33 |
+
20,
|
34 |
+
5.0,
|
35 |
+
2.5,
|
36 |
+
],
|
37 |
+
[
|
38 |
+
"An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. "
|
39 |
+
"Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. "
|
40 |
+
"The quote 'Find the universe within you' is etched in bold letters across the horizon."
|
41 |
+
"blue and pink, brilliantly illuminated in the background.",
|
42 |
+
"flow_dpm-solver",
|
43 |
+
20,
|
44 |
+
5.0,
|
45 |
+
2.5,
|
46 |
+
],
|
47 |
+
[
|
48 |
+
"A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
|
49 |
+
"flow_dpm-solver",
|
50 |
+
20,
|
51 |
+
5.0,
|
52 |
+
2.5,
|
53 |
+
],
|
54 |
+
[
|
55 |
+
"A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
|
56 |
+
"flow_dpm-solver",
|
57 |
+
20,
|
58 |
+
5.0,
|
59 |
+
2.5,
|
60 |
+
],
|
61 |
+
[
|
62 |
+
"a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, "
|
63 |
+
"national geographic photo, 8k resolution, crayon art, interactive artwork",
|
64 |
+
"flow_dpm-solver",
|
65 |
+
20,
|
66 |
+
5.0,
|
67 |
+
2.5,
|
68 |
+
],
|
69 |
+
]
|
asset/model-incremental.jpg
ADDED
asset/model_paths.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
|
2 |
+
output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
|
asset/samples.txt
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A small cactus with a happy face in the Sahara desert.
|
2 |
+
Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.
|
3 |
+
beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background
|
4 |
+
stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.
|
5 |
+
nature vs human nature, surreal, UHD, 8k, hyper details, rich colors, photograph.
|
6 |
+
Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism
|
7 |
+
anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur
|
8 |
+
The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
|
9 |
+
Bright scene, aerial view, ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens.
|
10 |
+
8k uhd A man looks up at the starry sky, lonely and ethereal, Minimalism, Chaotic composition Op Art
|
11 |
+
A middle-aged woman of Asian descent, her dark hair streaked with silver, appears fractured and splintered, intricately embedded within a sea of broken porcelain. The porcelain glistens with splatter paint patterns in a harmonious blend of glossy and matte blues, greens, oranges, and reds, capturing her dance in a surreal juxtaposition of movement and stillness. Her skin tone, a light hue like the porcelain, adds an almost mystical quality to her form.
|
12 |
+
A 4k dslr image of a lemur wearing a red magician hat and a blue coat performing magic tricks with cards in a garden.
|
13 |
+
A alpaca made of colorful building blocks, cyberpunk
|
14 |
+
A baby painter trying to draw very simple picture, white background
|
15 |
+
A boy and a girl fall in love
|
16 |
+
A dog that has been meditating all the time
|
17 |
+
A man is sitting in a chair with his chin resting on his hand. The chair, along with the man's feet, are submerged in the sea. Strikingly, the man's back is on fire.
|
18 |
+
A painter study hard to learn how to draw with many concepts in the air, white background
|
19 |
+
A painter with low quality, white background, pixel art
|
20 |
+
A person standing on the desert, desert waves, gossip illustration, half red, half blue, abstract image of sand, clear style, trendy illustration, outdoor, top view, clear style, precision art, ultra high definition image
|
21 |
+
A silhouette of a grand piano overlooking a dusky cityscape viewed from a top-floor penthouse, rendered in the bold and vivid sytle of a vintage travel poster.
|
22 |
+
A sureal parallel world where mankind avoid extinction by preserving nature, epic trees, water streams, various flowers, intricate details, rich colors, rich vegetation, cinematic, symmetrical, beautiful lighting, V-Ray render, sun rays, magical lights, photography
|
23 |
+
A woman is shopping for fresh produce at the farmer's market.
|
24 |
+
A worker that looks like a mixture of cow and horse is working hard to type code
|
25 |
+
A young man dressed in ancient Chinese clothing, Asian people, White robe, Handsome, Hand gestures forming a spell, Martial arts and fairy-like vibe, Carrying a legendary-level giant sword on the back, Game character, Surrounded by runes, Cyberpunk style, neon lights, best quality, masterpiece, cg, hdr, high-definition, extremely detailed, photorealistic, epic, character design, detailed face, superhero, hero, detailed UHD, real-time, vfx, 3D rendering, 8k
|
26 |
+
An alien octopus floats through a protal reading a newspaper
|
27 |
+
An epressive oil painting of a basketbal player dunking, depicted as an explosion of a nebula
|
28 |
+
art collection style and fashion shoot, in the style of made of glass, dark blue and light pink, paul rand, solarpunk, camille vivier, beth didonato hair, barbiecore, hyper-realistic
|
29 |
+
artistic
|
30 |
+
beautiful secen
|
31 |
+
Crocodile in a sweater
|
32 |
+
Design a letter A, 3D stereoscopic Ice material Interior light blue Conceptual product design Futuristic Blind box toy Handcrafted Exquisite 3D effect Full body display Ultra-high precision Ultra-detailed Perfect lighting OC Renderer Blender 8k Ultra-sharp Ultra-noise reduction
|
33 |
+
Floating,colossal,futuristic statue in the sky, awe-inspiring and serenein the style of Stuart Lippincott:2with detailed composition and subtle geometric elements.This sanctuary-ike atmosphere features crisp clarity and soft amber tones.In contrasttiny human figures surround the statueThe pieceincorporates flowing draperiesreminiscent of Shwedoff and Philip McKay's stylesemphasizing thejuxtaposition between the powerful presence of the statue and thevulnerability of the minuscule human figuresshwedoff
|
34 |
+
knolling of a drawing tools for painter
|
35 |
+
Leonardo da Vinci's Last Supper content, Van Goph's Starry Night Style
|
36 |
+
Luffy from ONEPIECE, handsome face, fantasy
|
37 |
+
photography shot through an outdoor window of a coffee shop with neon sign lighting, window glares and reflections, depth of field, {little girl with red hair sitting at a table, portrait, kodak portra 800,105 mm f1.8
|
38 |
+
poster of a mechanical cat, techical Schematics viewed from front and side view on light white blueprint paper, illustartion drafting style, illustation, typography, conceptual art, dark fantasy steampunk, cinematic, dark fantasy
|
39 |
+
The girl in the car is filled with goldfish and flowers, goldfish can fly, Kawaguchi Renko's art, natural posture, holiday dadcore, youthful energy and pressure, body stretching, goldfish simulation movies in the sky, super details, and dreamy high photography. Colorful. Covered by water and goldfish, indoor scene, close-up shot in XT4 movie
|
40 |
+
The image features a woman wearing a red shirt with an icon. She appears to be posing for the camera, and her outfit includes a pair of jeans. The woman seems to be in a good mood, as she is smiling. The background of the image is blurry, focusing more on the woman and her attire.
|
41 |
+
The towel was on top of the hard counter.
|
42 |
+
A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
|
43 |
+
I want to supplement vitamin c, please help me paint related food.
|
44 |
+
A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the window.
|
45 |
+
A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
|
46 |
+
A blue jay standing on a large basket of rainbow macarons.
|
47 |
+
A bucket bag made of blue suede. The bag is decorated with intricate golden paisley patterns. The handle of the bag is made of rubies and pearls.
|
48 |
+
An alien octopus floats through a portal reading a newspaper.
|
49 |
+
bird's eye view of a city.
|
50 |
+
beautiful scene
|
51 |
+
A 2D animation of a folk music band composed of anthropomorphic autumn leaves, each playing traditional bluegrass instruments, amidst a rustic forest setting dappled with the soft light of a harvest moon.
|
52 |
+
In front of a deep black backdrop, a figure of middle years, her Tongan skin rich and glowing, is captured mid-twirl, her curly hair flowing like a storm behind her. Her attire resembles a whirlwind of marble and porcelain fragments. Illuminated by the gleam of scattered porcelain shards, creating a dreamlike atmosphere, the dancer manages to appear fragmented, yet maintains a harmonious and fluid form.
|
53 |
+
Digital illustration of a beach scene crafted from yarn. The sandy beach is depicted with beige yarn, waves are made of blue and white yarn crashing onto the shore. A yarn sun sets on the horizon, casting a warm glow. Yarn palm trees sway gently, and little yarn seashells dot the shoreline.
|
54 |
+
Illustration of a chic chair with a design reminiscent of a pumpkin’s form, with deep orange cushioning, in a stylish loft setting.
|
55 |
+
A detailed oil painting of an old sea captain, steering his ship through a storm. Saltwater is splashing against his weathered face, determination in his eyes. Twirling malevolent clouds are seen above and stern waves threaten to submerge the ship while seagulls dive and twirl through the chaotic landscape. Thunder and lights embark in the distance, illuminating the scene with an eerie green glow.
|
56 |
+
An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. The quote 'Find the universe within you' is etched in bold letters across the horizon.
|
57 |
+
A modern architectural building with large glass windows, situated on a cliff overlooking a serene ocean at sunset
|
58 |
+
photo of an ancient shipwreck nestled on the ocean floor. Marine plants have claimed the wooden structure, and fish swim in and out of its hollow spaces. Sunken treasures and old cannons are scattered around, providing a glimpse into the past
|
59 |
+
A 3D render of a coffee mug placed on a window sill during a stormy day. The storm outside the window is reflected in the coffee, with miniature lightning bolts and turbulent waves seen inside the mug. The room is dimly lit, adding to the dramatic atmosphere.A minimap diorama of a cafe adorned with indoor plants. Wooden beams crisscross above, and a cold brew station stands out with tiny bottles and glasses.
|
60 |
+
An antique botanical illustration drawn with fine lines and a touch of watercolour whimsy, depicting a strange lily crossed with a Venus flytrap, its petals poised as if ready to snap shut on any unsuspecting insects.An illustration inspired by old-world botanical sketches blends a cactus with lilac blooms into a Möbius strip, using detailed lines and subtle watercolor touches to capture nature's diverse beauty and mathematical intrigue.
|
61 |
+
An ink sketch style illustration of a small hedgehog holding a piece of watermelon with its tiny paws, taking little bites with its eyes closed in delight.Photo of a lychee-inspired spherical chair, with a bumpy white exterior and plush interior, set against a tropical wallpaper.
|
62 |
+
3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background
|
63 |
+
professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
|
64 |
+
an astronaut sitting in a diner, eating fries, cinematic, analog film
|
65 |
+
Chinese architecture, ancient style,mountain, bird, lotus, pond, big tree, 4K Unity, octane rendering.
|
66 |
+
Ethereal fantasy concept art of thunder god with hammer. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy.
|
67 |
+
A Japanese girl walking along a path, surrounding by blooming oriental cherry, pink petal slowly falling down to the ground
|
68 |
+
A Ukiyoe style painting, an astronaut riding a unicorn, In the background there is an ancient Japanese architecture
|
69 |
+
Steampunk makeup, in the style of vray tracing, colorful impasto, uhd image, indonesian art, fine feather details with bright red and yellow and green and pink and orange colours, intricate patterns and details, dark cyan and amber makeup. Rich colourful plumes. Victorian style.
|
70 |
+
A cute teddy bear in front of a plain white wall, warm and brown fur, soft and fluffy
|
71 |
+
The beautiful scenery of Seattle, painting by Al Capp.
|
72 |
+
Photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang.
|
73 |
+
An astronaut riding a horse on the moon, oil painting by Van Gogh.
|
74 |
+
A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky
|
75 |
+
Realistic oil painting of a stunning model merged in multicolor splash made of finely torn paper, eye contact, walking with class in a street.
|
76 |
+
a chinese model is sitting on a train, magazine cover, clothes made of plastic, photorealistic,futuristic style, gray and green light, movie lighting, 32K HD
|
77 |
+
a handsome 24 years old boy in the middle with sky color background wearing eye glasses, it's super detailed with anime style, it's a portrait with delicated eyes and nice looking face
|
78 |
+
a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, national geographic photo, 8k resolution, crayon art, interactive artwork
|
79 |
+
3D rendering miniature scene design, Many tall buildings, A winding urban road runs through the middle,a lot of cars on the road, transparent material pipeline transports Materials, ,there are many people around, in thestyle of light orange and yellow, graphic design- inspired illustrations, classic still-life, beeple, josan gon-zalez, manga-influenced, miniature dioramas, in thestyle of playful and whimsical designs, graphic de-sign-inspired illustrations, minimalism, hyperrealismlomo lca, e-commerce C4D style, e-commerce posterUl, UX, octane render, blender
|
80 |
+
Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works
|
81 |
+
A cute orange kitten sliding down an aqua slide. happy excited. 16mm lens in front. we see his excitement and scared in the eye. vibrant colors. water splashing on the lens
|
82 |
+
Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.
|
83 |
+
A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures.
|
84 |
+
An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.
|
85 |
+
A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.
|
86 |
+
A New Zealand female business owner stands and is happy that his business is growing by having good VoIP and broadband supplied by Voyager Internet. This business owner is dressed semi casual and is standing with a funky office space in the background. The image is light and bright and is well lit. This image needs to be shot like a professional photo shoot using a Canon R6 with high quality 25mm lens. This image has a shallow depth of field
|
87 |
+
The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
|
88 |
+
Editorial photoshoot of a old woman, high fashion 2000s fashion
|
89 |
+
Mural Painted of Prince in Purple Rain on side of 5 story brick building next to zen garden vacant lot in the urban center district, rgb
|
90 |
+
Cozy Scandinavian living room, there is a cat sleeping on the couch, depth of field
|
91 |
+
Street style centered straight shot photo shot on Afga Vista 400, lense 50mm, of a two women,skin to skin touch face, emotion, hughing, natural blond hair, natural features, ultra detailed, skin texture, Rembrandt light, soft shadows
|
92 |
+
Frog, in forest, colorful, no watermark, no signature, in forest, 8k
|
93 |
+
selfie of a woman and her lion cub on the plains
|
94 |
+
A fisherman fixing his net sitting on a beautiful tropical beach at sunset with bending palm trees fishing gear and a small boat on shore
|
95 |
+
Coast, decorative painting, horizon, modern, fashionable, full of abstract feeling, full of imagination, the picture reveals the sense of time passing, there is a feeling of the end of the world
|
96 |
+
A close up of a branch of a tree and a golden bug on the top a leaf, shutterstock contest winner,ecological art, depth of field, shallow depth of field, macro photography
|
97 |
+
Outdoor style fashion photo, full – body shot of a man with short brown hair, happy and smiling, he is standing on his hipster bicycle wearing a light blue long sleeved blouse with closed buttons and dark blue jeans trousers, in the background the exterior of an Aldi store, fully lit background, natural afternoon lighting
|
98 |
+
beautiful woman sniper, wearing soviet army uniform, one eye on sniper lens, in snow ground
|
99 |
+
A very attractive and natural woman, sitting on a yoka mat, breathing, eye closed, no make up, intense satisfaction, she looks like she is intensely relaxed, yoga class, sunrise, 35mm
|
100 |
+
a close up of a helmet on a person, digital art, inspired by Han Gan, cloisonnism, female, victorian armor, ultramarine, best of behance, anton fadeev 8 k, fined detail, sci-fi character, elegant armor, fantasy art behance
|
101 |
+
a melting apple
|
102 |
+
yellow FIAT 500 Cinquecento 1957 driving through liechtenstein castle with a lot of banknotes scattered behind ,filled with wads of cash , car color yellow, license plate R-33
|
103 |
+
tented resort in the desert, rocky and sandy terrain, 5 star hotel, beautiful landscape, landscape photography, depth of view, Fujifilm GFX 100 –uplight
|
104 |
+
Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm.
|
105 |
+
Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots.
|
106 |
+
Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light.
|
107 |
+
Color photo of a corgi made of transparent glass, standing on the riverside in Yosemite National Park.
|
108 |
+
Happy dreamy owl monster sitting on a tree branch, colorful glittering particles, forest background, detailed feathers.
|
109 |
+
Game-Art - An island with different geographical properties and multiple small cities floating in space
|
110 |
+
Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.
|
111 |
+
A car made out of vegetables.
|
112 |
+
A serene lakeside during autumn with trees displaying a palette of fiery colors.
|
113 |
+
A realistic landscape shot of the Northern Lights dancing over a snowy mountain range in Iceland.
|
114 |
+
A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky.
|
115 |
+
Drone view of waves crashing against the rugged cliffs along Big Sur’s Garay Point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore.
|
116 |
+
A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed.
|
117 |
+
Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture.
|
118 |
+
Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works.
|
119 |
+
smiling cartoon dog sits at a table, coffee mug on hand, as a room goes up in flames. "Help" the dog is yelling.
|
120 |
+
A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
|
121 |
+
A close-up photo of a person. The subject is a woman. She wore a blue coat with a gray dress underneath. She has blue eyes and blond hair and wears a pair of earrings. Behind are blurred city buildings and streets.
|
122 |
+
👧 with 🌹 in the ❄️
|
123 |
+
🐶 Wearing 🕶 flying on the 🌈
|
124 |
+
a cyberpunk cat with a neon sign that says "MIT"
|
125 |
+
a black and white picture of a woman looking through the window, in the style of Duffy Sheridan, Anna Razumovskaya, smooth and shiny, wavy, Patrick Demarchelier, album covers, lush and detailed.
|
asset/samples_mini.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A cyberpunk cat with a neon sign that says 'Sana'.
|
2 |
+
A small cactus with a happy face in the Sahara desert.
|
3 |
+
The towel was on top of the hard counter.
|
4 |
+
A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
|
5 |
+
I want to supplement vitamin c, please help me paint related food.
|
6 |
+
A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
|
7 |
+
an old rusted robot wearing pants and a jacket riding skis in a supermarket.
|
8 |
+
professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
|
9 |
+
Astronaut in a jungle, cold color palette, muted colors, detailed
|
10 |
+
a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests.
|
configs/sana_app_config/Sana_1600M_app.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
data_dir: []
|
3 |
+
image_size: 1024
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: []
|
7 |
+
external_clipscore_suffixes: []
|
8 |
+
clip_thr_temperature: 0.1
|
9 |
+
clip_thr: 25.0
|
10 |
+
load_text_feat: false
|
11 |
+
load_vae_feat: false
|
12 |
+
transform: default_train
|
13 |
+
type: SanaWebDatasetMS
|
14 |
+
data:
|
15 |
+
sort_dataset: false
|
16 |
+
# model config
|
17 |
+
model:
|
18 |
+
model: SanaMS_1600M_P1_D20
|
19 |
+
image_size: 1024
|
20 |
+
mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
|
21 |
+
fp32_attention: true
|
22 |
+
load_from:
|
23 |
+
resume_from:
|
24 |
+
aspect_ratio_type: ASPECT_RATIO_1024
|
25 |
+
multi_scale: true
|
26 |
+
#pe_interpolation: 1.
|
27 |
+
attn_type: linear
|
28 |
+
ffn_type: glumbconv
|
29 |
+
mlp_acts:
|
30 |
+
- silu
|
31 |
+
- silu
|
32 |
+
-
|
33 |
+
mlp_ratio: 2.5
|
34 |
+
use_pe: false
|
35 |
+
qk_norm: false
|
36 |
+
class_dropout_prob: 0.1
|
37 |
+
# CFG & PAG settings
|
38 |
+
pag_applied_layers:
|
39 |
+
- 8
|
40 |
+
# VAE setting
|
41 |
+
vae:
|
42 |
+
vae_type: dc-ae
|
43 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
44 |
+
scale_factor: 0.41407
|
45 |
+
vae_latent_dim: 32
|
46 |
+
vae_downsample_rate: 32
|
47 |
+
sample_posterior: true
|
48 |
+
# text encoder
|
49 |
+
text_encoder:
|
50 |
+
text_encoder_name: gemma-2-2b-it
|
51 |
+
y_norm: true
|
52 |
+
y_norm_scale_factor: 0.01
|
53 |
+
model_max_length: 300
|
54 |
+
# CHI
|
55 |
+
chi_prompt:
|
56 |
+
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
|
57 |
+
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
|
58 |
+
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
|
59 |
+
- 'Here are examples of how to transform or refine prompts:'
|
60 |
+
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
|
61 |
+
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
|
62 |
+
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
|
63 |
+
- 'User Prompt: '
|
64 |
+
# Sana schedule Flow
|
65 |
+
scheduler:
|
66 |
+
predict_v: true
|
67 |
+
noise_schedule: linear_flow
|
68 |
+
pred_sigma: false
|
69 |
+
flow_shift: 3.0
|
70 |
+
# logit-normal timestep
|
71 |
+
weighting_scheme: logit_normal
|
72 |
+
logit_mean: 0.0
|
73 |
+
logit_std: 1.0
|
74 |
+
vis_sampler: flow_dpm-solver
|
75 |
+
# training setting
|
76 |
+
train:
|
77 |
+
num_workers: 10
|
78 |
+
seed: 1
|
79 |
+
train_batch_size: 64
|
80 |
+
num_epochs: 100
|
81 |
+
gradient_accumulation_steps: 1
|
82 |
+
grad_checkpointing: true
|
83 |
+
gradient_clip: 0.1
|
84 |
+
optimizer:
|
85 |
+
betas:
|
86 |
+
- 0.9
|
87 |
+
- 0.999
|
88 |
+
- 0.9999
|
89 |
+
eps:
|
90 |
+
- 1.0e-30
|
91 |
+
- 1.0e-16
|
92 |
+
lr: 0.0001
|
93 |
+
type: CAMEWrapper
|
94 |
+
weight_decay: 0.0
|
95 |
+
lr_schedule: constant
|
96 |
+
lr_schedule_args:
|
97 |
+
num_warmup_steps: 2000
|
98 |
+
local_save_vis: true # if save log image locally
|
99 |
+
visualize: true
|
100 |
+
eval_sampling_steps: 500
|
101 |
+
log_interval: 20
|
102 |
+
save_model_epochs: 5
|
103 |
+
save_model_steps: 500
|
104 |
+
work_dir: output/debug
|
105 |
+
online_metric: false
|
106 |
+
eval_metric_step: 2000
|
107 |
+
online_metric_dir: metric_helper
|
configs/sana_app_config/Sana_600M_app.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
data_dir: []
|
3 |
+
image_size: 1024
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: []
|
7 |
+
external_clipscore_suffixes: []
|
8 |
+
clip_thr_temperature: 0.1
|
9 |
+
clip_thr: 25.0
|
10 |
+
load_text_feat: false
|
11 |
+
load_vae_feat: true
|
12 |
+
transform: default_train
|
13 |
+
type: SanaWebDatasetMS
|
14 |
+
sort_dataset: false
|
15 |
+
# model config
|
16 |
+
model:
|
17 |
+
model: SanaMS_600M_P1_D28
|
18 |
+
image_size: 1024
|
19 |
+
mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
|
20 |
+
fp32_attention: true
|
21 |
+
load_from:
|
22 |
+
resume_from:
|
23 |
+
aspect_ratio_type: ASPECT_RATIO_1024
|
24 |
+
multi_scale: true
|
25 |
+
attn_type: linear
|
26 |
+
ffn_type: glumbconv
|
27 |
+
mlp_acts:
|
28 |
+
- silu
|
29 |
+
- silu
|
30 |
+
-
|
31 |
+
mlp_ratio: 2.5
|
32 |
+
use_pe: false
|
33 |
+
qk_norm: false
|
34 |
+
class_dropout_prob: 0.1
|
35 |
+
# CFG & PAG settings
|
36 |
+
pag_applied_layers:
|
37 |
+
- 14
|
38 |
+
# VAE setting
|
39 |
+
vae:
|
40 |
+
vae_type: dc-ae
|
41 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
42 |
+
scale_factor: 0.41407
|
43 |
+
vae_latent_dim: 32
|
44 |
+
vae_downsample_rate: 32
|
45 |
+
sample_posterior: true
|
46 |
+
# text encoder
|
47 |
+
text_encoder:
|
48 |
+
text_encoder_name: gemma-2-2b-it
|
49 |
+
y_norm: true
|
50 |
+
y_norm_scale_factor: 0.01
|
51 |
+
model_max_length: 300
|
52 |
+
# CHI
|
53 |
+
chi_prompt:
|
54 |
+
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
|
55 |
+
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
|
56 |
+
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
|
57 |
+
- 'Here are examples of how to transform or refine prompts:'
|
58 |
+
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
|
59 |
+
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
|
60 |
+
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
|
61 |
+
- 'User Prompt: '
|
62 |
+
# Sana schedule Flow
|
63 |
+
scheduler:
|
64 |
+
predict_v: true
|
65 |
+
noise_schedule: linear_flow
|
66 |
+
pred_sigma: false
|
67 |
+
flow_shift: 4.0
|
68 |
+
# logit-normal timestep
|
69 |
+
weighting_scheme: logit_normal
|
70 |
+
logit_mean: 0.0
|
71 |
+
logit_std: 1.0
|
72 |
+
vis_sampler: flow_dpm-solver
|
73 |
+
# training setting
|
74 |
+
train:
|
75 |
+
num_workers: 10
|
76 |
+
seed: 1
|
77 |
+
train_batch_size: 64
|
78 |
+
num_epochs: 100
|
79 |
+
gradient_accumulation_steps: 1
|
80 |
+
grad_checkpointing: true
|
81 |
+
gradient_clip: 0.1
|
82 |
+
optimizer:
|
83 |
+
betas:
|
84 |
+
- 0.9
|
85 |
+
- 0.999
|
86 |
+
- 0.9999
|
87 |
+
eps:
|
88 |
+
- 1.0e-30
|
89 |
+
- 1.0e-16
|
90 |
+
lr: 0.0001
|
91 |
+
type: CAMEWrapper
|
92 |
+
weight_decay: 0.0
|
93 |
+
lr_schedule: constant
|
94 |
+
lr_schedule_args:
|
95 |
+
num_warmup_steps: 2000
|
96 |
+
local_save_vis: true # if save log image locally
|
97 |
+
visualize: true
|
98 |
+
eval_sampling_steps: 500
|
99 |
+
log_interval: 20
|
100 |
+
save_model_epochs: 5
|
101 |
+
save_model_steps: 500
|
102 |
+
work_dir: output/debug
|
103 |
+
online_metric: false
|
104 |
+
eval_metric_step: 2000
|
105 |
+
online_metric_dir: metric_helper
|
configs/sana_base.yaml
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# data settings
|
2 |
+
data:
|
3 |
+
data_dir: []
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: []
|
7 |
+
external_clipscore_suffixes: []
|
8 |
+
clip_thr_temperature: 1.0
|
9 |
+
clip_thr: 0.0
|
10 |
+
sort_dataset: false
|
11 |
+
load_text_feat: false
|
12 |
+
load_vae_feat: false
|
13 |
+
transform: default_train
|
14 |
+
type: SanaWebDatasetMS
|
15 |
+
image_size: 512
|
16 |
+
hq_only: false
|
17 |
+
valid_num: 0
|
18 |
+
# model settings
|
19 |
+
model:
|
20 |
+
model: SanaMS_600M_P1_D28
|
21 |
+
image_size: 512
|
22 |
+
mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
|
23 |
+
fp32_attention: true
|
24 |
+
load_from:
|
25 |
+
resume_from:
|
26 |
+
checkpoint:
|
27 |
+
load_ema: false
|
28 |
+
resume_lr_scheduler: true
|
29 |
+
resume_optimizer: true
|
30 |
+
aspect_ratio_type: ASPECT_RATIO_1024
|
31 |
+
multi_scale: true
|
32 |
+
pe_interpolation: 1.0
|
33 |
+
micro_condition: false
|
34 |
+
attn_type: linear # 'flash', 'linear', 'vanilla', 'triton_linear'
|
35 |
+
cross_norm: false
|
36 |
+
autocast_linear_attn: false
|
37 |
+
ffn_type: glumbconv
|
38 |
+
mlp_acts:
|
39 |
+
- silu
|
40 |
+
- silu
|
41 |
+
-
|
42 |
+
mlp_ratio: 2.5
|
43 |
+
use_pe: false
|
44 |
+
qk_norm: false
|
45 |
+
class_dropout_prob: 0.0
|
46 |
+
linear_head_dim: 32
|
47 |
+
# CFG & PAG settings
|
48 |
+
cfg_scale: 4
|
49 |
+
guidance_type: classifier-free
|
50 |
+
pag_applied_layers: [14]
|
51 |
+
# text encoder settings
|
52 |
+
text_encoder:
|
53 |
+
text_encoder_name: gemma-2-2b-it
|
54 |
+
caption_channels: 2304
|
55 |
+
y_norm: false
|
56 |
+
y_norm_scale_factor: 1.0
|
57 |
+
model_max_length: 300
|
58 |
+
chi_prompt: []
|
59 |
+
# VAE settings
|
60 |
+
vae:
|
61 |
+
vae_type: dc-ae
|
62 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
63 |
+
scale_factor: 0.41407
|
64 |
+
vae_latent_dim: 32
|
65 |
+
vae_downsample_rate: 32
|
66 |
+
sample_posterior: true
|
67 |
+
# Scheduler settings
|
68 |
+
scheduler:
|
69 |
+
train_sampling_steps: 1000
|
70 |
+
predict_v: True
|
71 |
+
noise_schedule: linear_flow
|
72 |
+
pred_sigma: false
|
73 |
+
flow_shift: 1.0
|
74 |
+
weighting_scheme: logit_normal
|
75 |
+
logit_mean: 0.0
|
76 |
+
logit_std: 1.0
|
77 |
+
vis_sampler: flow_dpm-solver
|
78 |
+
# training settings
|
79 |
+
train:
|
80 |
+
num_workers: 4
|
81 |
+
seed: 43
|
82 |
+
train_batch_size: 32
|
83 |
+
num_epochs: 100
|
84 |
+
gradient_accumulation_steps: 1
|
85 |
+
grad_checkpointing: false
|
86 |
+
gradient_clip: 1.0
|
87 |
+
gc_step: 1
|
88 |
+
# optimizer settings
|
89 |
+
optimizer:
|
90 |
+
eps: 1.0e-10
|
91 |
+
lr: 0.0001
|
92 |
+
type: AdamW
|
93 |
+
weight_decay: 0.03
|
94 |
+
lr_schedule: constant
|
95 |
+
lr_schedule_args:
|
96 |
+
num_warmup_steps: 500
|
97 |
+
auto_lr:
|
98 |
+
rule: sqrt
|
99 |
+
ema_rate: 0.9999
|
100 |
+
eval_batch_size: 16
|
101 |
+
use_fsdp: false
|
102 |
+
use_flash_attn: false
|
103 |
+
eval_sampling_steps: 250
|
104 |
+
lora_rank: 4
|
105 |
+
log_interval: 50
|
106 |
+
mask_type: 'null'
|
107 |
+
mask_loss_coef: 0.0
|
108 |
+
load_mask_index: false
|
109 |
+
snr_loss: false
|
110 |
+
real_prompt_ratio: 1.0
|
111 |
+
debug_nan: false
|
112 |
+
# checkpoint settings
|
113 |
+
save_image_epochs: 1
|
114 |
+
save_model_epochs: 1
|
115 |
+
save_model_steps: 1000000
|
116 |
+
# visualization settings
|
117 |
+
visualize: false
|
118 |
+
null_embed_root: output/pretrained_models/
|
119 |
+
valid_prompt_embed_root: output/tmp_embed/
|
120 |
+
validation_prompts:
|
121 |
+
- dog
|
122 |
+
- portrait photo of a girl, photograph, highly detailed face, depth of field
|
123 |
+
- Self-portrait oil painting, a beautiful cyborg with golden hair, 8k
|
124 |
+
- Astronaut in a jungle, cold color palette, muted colors, detailed, 8k
|
125 |
+
- A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece
|
126 |
+
local_save_vis: false
|
127 |
+
deterministic_validation: true
|
128 |
+
online_metric: false
|
129 |
+
eval_metric_step: 5000
|
130 |
+
online_metric_dir: metric_helper
|
131 |
+
# work dir settings
|
132 |
+
work_dir: /cache/exps/
|
133 |
+
skip_step: 0
|
134 |
+
# LCM settings
|
135 |
+
loss_type: huber
|
136 |
+
huber_c: 0.001
|
137 |
+
num_ddim_timesteps: 50
|
138 |
+
w_max: 15.0
|
139 |
+
w_min: 3.0
|
140 |
+
ema_decay: 0.95
|
configs/sana_config/1024ms/Sana_1600M_img1024.yaml
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
data_dir: [data/data_public/dir1]
|
3 |
+
image_size: 1024
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
|
7 |
+
external_clipscore_suffixes:
|
8 |
+
- _InternVL2-26B_clip_score
|
9 |
+
- _VILA1-5-13B_clip_score
|
10 |
+
- _prompt_clip_score
|
11 |
+
clip_thr_temperature: 0.1
|
12 |
+
clip_thr: 25.0
|
13 |
+
load_text_feat: false
|
14 |
+
load_vae_feat: false
|
15 |
+
transform: default_train
|
16 |
+
type: SanaWebDatasetMS
|
17 |
+
sort_dataset: false
|
18 |
+
# model config
|
19 |
+
model:
|
20 |
+
model: SanaMS_1600M_P1_D20
|
21 |
+
image_size: 1024
|
22 |
+
mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
|
23 |
+
fp32_attention: true
|
24 |
+
load_from:
|
25 |
+
resume_from:
|
26 |
+
aspect_ratio_type: ASPECT_RATIO_1024
|
27 |
+
multi_scale: true
|
28 |
+
#pe_interpolation: 1.
|
29 |
+
attn_type: linear
|
30 |
+
ffn_type: glumbconv
|
31 |
+
mlp_acts:
|
32 |
+
- silu
|
33 |
+
- silu
|
34 |
+
-
|
35 |
+
mlp_ratio: 2.5
|
36 |
+
use_pe: false
|
37 |
+
qk_norm: false
|
38 |
+
class_dropout_prob: 0.1
|
39 |
+
# PAG
|
40 |
+
pag_applied_layers:
|
41 |
+
- 8
|
42 |
+
# VAE setting
|
43 |
+
vae:
|
44 |
+
vae_type: dc-ae
|
45 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
46 |
+
scale_factor: 0.41407
|
47 |
+
vae_latent_dim: 32
|
48 |
+
vae_downsample_rate: 32
|
49 |
+
sample_posterior: true
|
50 |
+
# text encoder
|
51 |
+
text_encoder:
|
52 |
+
text_encoder_name: gemma-2-2b-it
|
53 |
+
y_norm: true
|
54 |
+
y_norm_scale_factor: 0.01
|
55 |
+
model_max_length: 300
|
56 |
+
# CHI
|
57 |
+
chi_prompt:
|
58 |
+
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
|
59 |
+
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
|
60 |
+
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
|
61 |
+
- 'Here are examples of how to transform or refine prompts:'
|
62 |
+
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
|
63 |
+
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
|
64 |
+
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
|
65 |
+
- 'User Prompt: '
|
66 |
+
# Sana schedule Flow
|
67 |
+
scheduler:
|
68 |
+
predict_v: true
|
69 |
+
noise_schedule: linear_flow
|
70 |
+
pred_sigma: false
|
71 |
+
flow_shift: 3.0
|
72 |
+
# logit-normal timestep
|
73 |
+
weighting_scheme: logit_normal
|
74 |
+
logit_mean: 0.0
|
75 |
+
logit_std: 1.0
|
76 |
+
vis_sampler: flow_dpm-solver
|
77 |
+
# training setting
|
78 |
+
train:
|
79 |
+
num_workers: 10
|
80 |
+
seed: 1
|
81 |
+
train_batch_size: 64
|
82 |
+
num_epochs: 100
|
83 |
+
gradient_accumulation_steps: 1
|
84 |
+
grad_checkpointing: true
|
85 |
+
gradient_clip: 0.1
|
86 |
+
optimizer:
|
87 |
+
betas:
|
88 |
+
- 0.9
|
89 |
+
- 0.999
|
90 |
+
- 0.9999
|
91 |
+
eps:
|
92 |
+
- 1.0e-30
|
93 |
+
- 1.0e-16
|
94 |
+
lr: 0.0001
|
95 |
+
type: CAMEWrapper
|
96 |
+
weight_decay: 0.0
|
97 |
+
lr_schedule: constant
|
98 |
+
lr_schedule_args:
|
99 |
+
num_warmup_steps: 2000
|
100 |
+
local_save_vis: true # if save log image locally
|
101 |
+
visualize: true
|
102 |
+
eval_sampling_steps: 500
|
103 |
+
log_interval: 20
|
104 |
+
save_model_epochs: 5
|
105 |
+
save_model_steps: 500
|
106 |
+
work_dir: output/debug
|
107 |
+
online_metric: false
|
108 |
+
eval_metric_step: 2000
|
109 |
+
online_metric_dir: metric_helper
|
configs/sana_config/1024ms/Sana_600M_img1024.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
data_dir: [data/data_public/dir1]
|
3 |
+
image_size: 1024
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
|
7 |
+
external_clipscore_suffixes:
|
8 |
+
- _InternVL2-26B_clip_score
|
9 |
+
- _VILA1-5-13B_clip_score
|
10 |
+
- _prompt_clip_score
|
11 |
+
clip_thr_temperature: 0.1
|
12 |
+
clip_thr: 25.0
|
13 |
+
load_text_feat: false
|
14 |
+
load_vae_feat: false
|
15 |
+
transform: default_train
|
16 |
+
type: SanaWebDatasetMS
|
17 |
+
sort_dataset: false
|
18 |
+
# model config
|
19 |
+
model:
|
20 |
+
model: SanaMS_600M_P1_D28
|
21 |
+
image_size: 1024
|
22 |
+
mixed_precision: fp16
|
23 |
+
fp32_attention: true
|
24 |
+
load_from:
|
25 |
+
resume_from:
|
26 |
+
aspect_ratio_type: ASPECT_RATIO_1024
|
27 |
+
multi_scale: true
|
28 |
+
attn_type: linear
|
29 |
+
ffn_type: glumbconv
|
30 |
+
mlp_acts:
|
31 |
+
- silu
|
32 |
+
- silu
|
33 |
+
-
|
34 |
+
mlp_ratio: 2.5
|
35 |
+
use_pe: false
|
36 |
+
qk_norm: false
|
37 |
+
class_dropout_prob: 0.1
|
38 |
+
# VAE setting
|
39 |
+
vae:
|
40 |
+
vae_type: dc-ae
|
41 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
42 |
+
scale_factor: 0.41407
|
43 |
+
vae_latent_dim: 32
|
44 |
+
vae_downsample_rate: 32
|
45 |
+
sample_posterior: true
|
46 |
+
# text encoder
|
47 |
+
text_encoder:
|
48 |
+
text_encoder_name: gemma-2-2b-it
|
49 |
+
y_norm: true
|
50 |
+
y_norm_scale_factor: 0.01
|
51 |
+
model_max_length: 300
|
52 |
+
# CHI
|
53 |
+
chi_prompt:
|
54 |
+
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
|
55 |
+
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
|
56 |
+
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
|
57 |
+
- 'Here are examples of how to transform or refine prompts:'
|
58 |
+
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
|
59 |
+
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
|
60 |
+
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
|
61 |
+
- 'User Prompt: '
|
62 |
+
# Sana schedule Flow
|
63 |
+
scheduler:
|
64 |
+
predict_v: true
|
65 |
+
noise_schedule: linear_flow
|
66 |
+
pred_sigma: false
|
67 |
+
flow_shift: 4.0
|
68 |
+
# logit-normal timestep
|
69 |
+
weighting_scheme: logit_normal
|
70 |
+
logit_mean: 0.0
|
71 |
+
logit_std: 1.0
|
72 |
+
vis_sampler: flow_dpm-solver
|
73 |
+
# training setting
|
74 |
+
train:
|
75 |
+
num_workers: 10
|
76 |
+
seed: 1
|
77 |
+
train_batch_size: 64
|
78 |
+
num_epochs: 100
|
79 |
+
gradient_accumulation_steps: 1
|
80 |
+
grad_checkpointing: true
|
81 |
+
gradient_clip: 0.1
|
82 |
+
optimizer:
|
83 |
+
betas:
|
84 |
+
- 0.9
|
85 |
+
- 0.999
|
86 |
+
- 0.9999
|
87 |
+
eps:
|
88 |
+
- 1.0e-30
|
89 |
+
- 1.0e-16
|
90 |
+
lr: 0.0001
|
91 |
+
type: CAMEWrapper
|
92 |
+
weight_decay: 0.0
|
93 |
+
lr_schedule: constant
|
94 |
+
lr_schedule_args:
|
95 |
+
num_warmup_steps: 2000
|
96 |
+
local_save_vis: true # if save log image locally
|
97 |
+
visualize: true
|
98 |
+
eval_sampling_steps: 500
|
99 |
+
log_interval: 20
|
100 |
+
save_model_epochs: 5
|
101 |
+
save_model_steps: 500
|
102 |
+
work_dir: output/debug
|
103 |
+
online_metric: false
|
104 |
+
eval_metric_step: 2000
|
105 |
+
online_metric_dir: metric_helper
|
configs/sana_config/512ms/Sana_1600M_img512.yaml
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
data_dir: [data/data_public/dir1]
|
3 |
+
image_size: 512
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
|
7 |
+
external_clipscore_suffixes:
|
8 |
+
- _InternVL2-26B_clip_score
|
9 |
+
- _VILA1-5-13B_clip_score
|
10 |
+
- _prompt_clip_score
|
11 |
+
clip_thr_temperature: 0.1
|
12 |
+
clip_thr: 25.0
|
13 |
+
load_text_feat: false
|
14 |
+
load_vae_feat: false
|
15 |
+
transform: default_train
|
16 |
+
type: SanaWebDatasetMS
|
17 |
+
sort_dataset: false
|
18 |
+
# model config
|
19 |
+
model:
|
20 |
+
model: SanaMS_1600M_P1_D20
|
21 |
+
image_size: 512
|
22 |
+
mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
|
23 |
+
fp32_attention: true
|
24 |
+
load_from:
|
25 |
+
resume_from:
|
26 |
+
aspect_ratio_type: ASPECT_RATIO_512
|
27 |
+
multi_scale: true
|
28 |
+
attn_type: linear
|
29 |
+
ffn_type: glumbconv
|
30 |
+
mlp_acts:
|
31 |
+
- silu
|
32 |
+
- silu
|
33 |
+
-
|
34 |
+
mlp_ratio: 2.5
|
35 |
+
use_pe: false
|
36 |
+
qk_norm: false
|
37 |
+
class_dropout_prob: 0.1
|
38 |
+
# PAG
|
39 |
+
pag_applied_layers:
|
40 |
+
- 8
|
41 |
+
# VAE setting
|
42 |
+
vae:
|
43 |
+
vae_type: dc-ae
|
44 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
45 |
+
scale_factor: 0.41407
|
46 |
+
vae_latent_dim: 32
|
47 |
+
vae_downsample_rate: 32
|
48 |
+
sample_posterior: true
|
49 |
+
# text encoder
|
50 |
+
text_encoder:
|
51 |
+
text_encoder_name: gemma-2-2b-it
|
52 |
+
y_norm: true
|
53 |
+
y_norm_scale_factor: 0.01
|
54 |
+
model_max_length: 300
|
55 |
+
# CHI
|
56 |
+
chi_prompt:
|
57 |
+
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
|
58 |
+
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
|
59 |
+
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
|
60 |
+
- 'Here are examples of how to transform or refine prompts:'
|
61 |
+
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
|
62 |
+
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
|
63 |
+
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
|
64 |
+
- 'User Prompt: '
|
65 |
+
# Sana schedule Flow
|
66 |
+
scheduler:
|
67 |
+
predict_v: true
|
68 |
+
noise_schedule: linear_flow
|
69 |
+
pred_sigma: false
|
70 |
+
flow_shift: 1.0
|
71 |
+
# logit-normal timestep
|
72 |
+
weighting_scheme: logit_normal
|
73 |
+
logit_mean: 0.0
|
74 |
+
logit_std: 1.0
|
75 |
+
vis_sampler: flow_dpm-solver
|
76 |
+
# training setting
|
77 |
+
train:
|
78 |
+
num_workers: 10
|
79 |
+
seed: 1
|
80 |
+
train_batch_size: 64
|
81 |
+
num_epochs: 100
|
82 |
+
gradient_accumulation_steps: 1
|
83 |
+
grad_checkpointing: true
|
84 |
+
gradient_clip: 0.1
|
85 |
+
optimizer:
|
86 |
+
betas:
|
87 |
+
- 0.9
|
88 |
+
- 0.999
|
89 |
+
- 0.9999
|
90 |
+
eps:
|
91 |
+
- 1.0e-30
|
92 |
+
- 1.0e-16
|
93 |
+
lr: 0.0001
|
94 |
+
type: CAMEWrapper
|
95 |
+
weight_decay: 0.0
|
96 |
+
lr_schedule: constant
|
97 |
+
lr_schedule_args:
|
98 |
+
num_warmup_steps: 2000
|
99 |
+
local_save_vis: true # if save log image locally
|
100 |
+
visualize: true
|
101 |
+
eval_sampling_steps: 500
|
102 |
+
log_interval: 20
|
103 |
+
save_model_epochs: 5
|
104 |
+
save_model_steps: 500
|
105 |
+
work_dir: output/debug
|
106 |
+
online_metric: false
|
107 |
+
eval_metric_step: 2000
|
108 |
+
online_metric_dir: metric_helper
|
configs/sana_config/512ms/Sana_600M_img512.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
data_dir: [data/data_public/dir1]
|
3 |
+
image_size: 512
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
|
7 |
+
external_clipscore_suffixes:
|
8 |
+
- _InternVL2-26B_clip_score
|
9 |
+
- _VILA1-5-13B_clip_score
|
10 |
+
- _prompt_clip_score
|
11 |
+
clip_thr_temperature: 0.1
|
12 |
+
clip_thr: 25.0
|
13 |
+
load_text_feat: false
|
14 |
+
load_vae_feat: false
|
15 |
+
transform: default_train
|
16 |
+
type: SanaWebDatasetMS
|
17 |
+
sort_dataset: false
|
18 |
+
# model config
|
19 |
+
model:
|
20 |
+
model: SanaMS_600M_P1_D28
|
21 |
+
image_size: 512
|
22 |
+
mixed_precision: fp16
|
23 |
+
fp32_attention: true
|
24 |
+
load_from:
|
25 |
+
resume_from:
|
26 |
+
aspect_ratio_type: ASPECT_RATIO_512
|
27 |
+
multi_scale: true
|
28 |
+
#pe_interpolation: 1.
|
29 |
+
attn_type: linear
|
30 |
+
linear_head_dim: 32
|
31 |
+
ffn_type: glumbconv
|
32 |
+
mlp_acts:
|
33 |
+
- silu
|
34 |
+
- silu
|
35 |
+
- null
|
36 |
+
mlp_ratio: 2.5
|
37 |
+
use_pe: false
|
38 |
+
qk_norm: false
|
39 |
+
class_dropout_prob: 0.1
|
40 |
+
# VAE setting
|
41 |
+
vae:
|
42 |
+
vae_type: dc-ae
|
43 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
44 |
+
scale_factor: 0.41407
|
45 |
+
vae_latent_dim: 32
|
46 |
+
vae_downsample_rate: 32
|
47 |
+
sample_posterior: true
|
48 |
+
# text encoder
|
49 |
+
text_encoder:
|
50 |
+
text_encoder_name: gemma-2-2b-it
|
51 |
+
y_norm: true
|
52 |
+
y_norm_scale_factor: 0.01
|
53 |
+
model_max_length: 300
|
54 |
+
# CHI
|
55 |
+
chi_prompt:
|
56 |
+
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
|
57 |
+
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
|
58 |
+
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
|
59 |
+
- 'Here are examples of how to transform or refine prompts:'
|
60 |
+
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
|
61 |
+
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
|
62 |
+
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
|
63 |
+
- 'User Prompt: '
|
64 |
+
# Sana schedule Flow
|
65 |
+
scheduler:
|
66 |
+
predict_v: true
|
67 |
+
noise_schedule: linear_flow
|
68 |
+
pred_sigma: false
|
69 |
+
flow_shift: 1.0
|
70 |
+
# logit-normal timestep
|
71 |
+
weighting_scheme: logit_normal
|
72 |
+
logit_mean: 0.0
|
73 |
+
logit_std: 1.0
|
74 |
+
vis_sampler: flow_dpm-solver
|
75 |
+
# training setting
|
76 |
+
train:
|
77 |
+
num_workers: 10
|
78 |
+
seed: 1
|
79 |
+
train_batch_size: 128
|
80 |
+
num_epochs: 100
|
81 |
+
gradient_accumulation_steps: 1
|
82 |
+
grad_checkpointing: true
|
83 |
+
gradient_clip: 0.1
|
84 |
+
optimizer:
|
85 |
+
betas:
|
86 |
+
- 0.9
|
87 |
+
- 0.999
|
88 |
+
- 0.9999
|
89 |
+
eps:
|
90 |
+
- 1.0e-30
|
91 |
+
- 1.0e-16
|
92 |
+
lr: 0.0001
|
93 |
+
type: CAMEWrapper
|
94 |
+
weight_decay: 0.0
|
95 |
+
lr_schedule: constant
|
96 |
+
lr_schedule_args:
|
97 |
+
num_warmup_steps: 2000
|
98 |
+
local_save_vis: true # if save log image locally
|
99 |
+
visualize: true
|
100 |
+
eval_sampling_steps: 500
|
101 |
+
log_interval: 20
|
102 |
+
save_model_epochs: 5
|
103 |
+
save_model_steps: 500
|
104 |
+
work_dir: output/debug
|
105 |
+
online_metric: false
|
106 |
+
eval_metric_step: 2000
|
107 |
+
online_metric_dir: metric_helper
|
configs/sana_config/512ms/ci_Sana_600M_img512.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
data_dir: [data/data_public/vaef32c32_v2_512/dir1]
|
3 |
+
image_size: 512
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
|
7 |
+
external_clipscore_suffixes:
|
8 |
+
- _InternVL2-26B_clip_score
|
9 |
+
- _VILA1-5-13B_clip_score
|
10 |
+
- _prompt_clip_score
|
11 |
+
clip_thr_temperature: 0.1
|
12 |
+
clip_thr: 25.0
|
13 |
+
load_text_feat: false
|
14 |
+
load_vae_feat: false
|
15 |
+
transform: default_train
|
16 |
+
type: SanaWebDatasetMS
|
17 |
+
sort_dataset: false
|
18 |
+
# model config
|
19 |
+
model:
|
20 |
+
model: SanaMS_600M_P1_D28
|
21 |
+
image_size: 512
|
22 |
+
mixed_precision: fp16
|
23 |
+
fp32_attention: true
|
24 |
+
load_from:
|
25 |
+
resume_from:
|
26 |
+
aspect_ratio_type: ASPECT_RATIO_512
|
27 |
+
multi_scale: true
|
28 |
+
#pe_interpolation: 1.
|
29 |
+
attn_type: linear
|
30 |
+
linear_head_dim: 32
|
31 |
+
ffn_type: glumbconv
|
32 |
+
mlp_acts:
|
33 |
+
- silu
|
34 |
+
- silu
|
35 |
+
- null
|
36 |
+
mlp_ratio: 2.5
|
37 |
+
use_pe: false
|
38 |
+
qk_norm: false
|
39 |
+
class_dropout_prob: 0.1
|
40 |
+
# VAE setting
|
41 |
+
vae:
|
42 |
+
vae_type: dc-ae
|
43 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
44 |
+
scale_factor: 0.41407
|
45 |
+
vae_latent_dim: 32
|
46 |
+
vae_downsample_rate: 32
|
47 |
+
sample_posterior: true
|
48 |
+
# text encoder
|
49 |
+
text_encoder:
|
50 |
+
text_encoder_name: gemma-2-2b-it
|
51 |
+
y_norm: true
|
52 |
+
y_norm_scale_factor: 0.01
|
53 |
+
model_max_length: 300
|
54 |
+
# CHI
|
55 |
+
chi_prompt:
|
56 |
+
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
|
57 |
+
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
|
58 |
+
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
|
59 |
+
- 'Here are examples of how to transform or refine prompts:'
|
60 |
+
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
|
61 |
+
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
|
62 |
+
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
|
63 |
+
- 'User Prompt: '
|
64 |
+
# Sana schedule Flow
|
65 |
+
scheduler:
|
66 |
+
predict_v: true
|
67 |
+
noise_schedule: linear_flow
|
68 |
+
pred_sigma: false
|
69 |
+
flow_shift: 1.0
|
70 |
+
# logit-normal timestep
|
71 |
+
weighting_scheme: logit_normal
|
72 |
+
logit_mean: 0.0
|
73 |
+
logit_std: 1.0
|
74 |
+
vis_sampler: flow_dpm-solver
|
75 |
+
# training setting
|
76 |
+
train:
|
77 |
+
num_workers: 10
|
78 |
+
seed: 1
|
79 |
+
train_batch_size: 64
|
80 |
+
num_epochs: 1
|
81 |
+
gradient_accumulation_steps: 1
|
82 |
+
grad_checkpointing: true
|
83 |
+
gradient_clip: 0.1
|
84 |
+
optimizer:
|
85 |
+
betas:
|
86 |
+
- 0.9
|
87 |
+
- 0.999
|
88 |
+
- 0.9999
|
89 |
+
eps:
|
90 |
+
- 1.0e-30
|
91 |
+
- 1.0e-16
|
92 |
+
lr: 0.0001
|
93 |
+
type: CAMEWrapper
|
94 |
+
weight_decay: 0.0
|
95 |
+
lr_schedule: constant
|
96 |
+
lr_schedule_args:
|
97 |
+
num_warmup_steps: 2000
|
98 |
+
local_save_vis: true # if save log image locally
|
99 |
+
visualize: true
|
100 |
+
eval_sampling_steps: 500
|
101 |
+
log_interval: 20
|
102 |
+
save_model_epochs: 5
|
103 |
+
save_model_steps: 500
|
104 |
+
work_dir: output/debug
|
105 |
+
online_metric: false
|
106 |
+
eval_metric_step: 2000
|
107 |
+
online_metric_dir: metric_helper
|
configs/sana_config/512ms/sample_dataset.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
data_dir: [asset/example_data]
|
3 |
+
image_size: 512
|
4 |
+
caption_proportion:
|
5 |
+
prompt: 1
|
6 |
+
external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] # json fils
|
7 |
+
external_clipscore_suffixes: # json files
|
8 |
+
- _InternVL2-26B_clip_score
|
9 |
+
- _VILA1-5-13B_clip_score
|
10 |
+
- _prompt_clip_score
|
11 |
+
clip_thr_temperature: 0.1
|
12 |
+
clip_thr: 25.0
|
13 |
+
load_text_feat: false
|
14 |
+
load_vae_feat: false
|
15 |
+
transform: default_train
|
16 |
+
type: SanaImgDataset
|
17 |
+
sort_dataset: false
|
18 |
+
# model config
|
19 |
+
model:
|
20 |
+
model: SanaMS_600M_P1_D28
|
21 |
+
image_size: 512
|
22 |
+
mixed_precision: fp16
|
23 |
+
fp32_attention: true
|
24 |
+
load_from:
|
25 |
+
resume_from:
|
26 |
+
aspect_ratio_type: ASPECT_RATIO_512
|
27 |
+
multi_scale: false
|
28 |
+
#pe_interpolation: 1.
|
29 |
+
attn_type: linear
|
30 |
+
linear_head_dim: 32
|
31 |
+
ffn_type: glumbconv
|
32 |
+
mlp_acts:
|
33 |
+
- silu
|
34 |
+
- silu
|
35 |
+
- null
|
36 |
+
mlp_ratio: 2.5
|
37 |
+
use_pe: false
|
38 |
+
qk_norm: false
|
39 |
+
class_dropout_prob: 0.1
|
40 |
+
# VAE setting
|
41 |
+
vae:
|
42 |
+
vae_type: dc-ae
|
43 |
+
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
|
44 |
+
scale_factor: 0.41407
|
45 |
+
vae_latent_dim: 32
|
46 |
+
vae_downsample_rate: 32
|
47 |
+
sample_posterior: true
|
48 |
+
# text encoder
|
49 |
+
text_encoder:
|
50 |
+
text_encoder_name: gemma-2-2b-it
|
51 |
+
y_norm: true
|
52 |
+
y_norm_scale_factor: 0.01
|
53 |
+
model_max_length: 300
|
54 |
+
# CHI
|
55 |
+
chi_prompt:
|
56 |
+
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
|
57 |
+
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
|
58 |
+
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
|
59 |
+
- 'Here are examples of how to transform or refine prompts:'
|
60 |
+
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
|
61 |
+
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
|
62 |
+
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
|
63 |
+
- 'User Prompt: '
|
64 |
+
# Sana schedule Flow
|
65 |
+
scheduler:
|
66 |
+
predict_v: true
|
67 |
+
noise_schedule: linear_flow
|
68 |
+
pred_sigma: false
|
69 |
+
flow_shift: 1.0
|
70 |
+
# logit-normal timestep
|
71 |
+
weighting_scheme: logit_normal
|
72 |
+
logit_mean: 0.0
|
73 |
+
logit_std: 1.0
|
74 |
+
vis_sampler: flow_dpm-solver
|
75 |
+
# training setting
|
76 |
+
train:
|
77 |
+
num_workers: 10
|
78 |
+
seed: 1
|
79 |
+
train_batch_size: 128
|
80 |
+
num_epochs: 100
|
81 |
+
gradient_accumulation_steps: 1
|
82 |
+
grad_checkpointing: true
|
83 |
+
gradient_clip: 0.1
|
84 |
+
optimizer:
|
85 |
+
betas:
|
86 |
+
- 0.9
|
87 |
+
- 0.999
|
88 |
+
- 0.9999
|
89 |
+
eps:
|
90 |
+
- 1.0e-30
|
91 |
+
- 1.0e-16
|
92 |
+
lr: 0.0001
|
93 |
+
type: CAMEWrapper
|
94 |
+
weight_decay: 0.0
|
95 |
+
lr_schedule: constant
|
96 |
+
lr_schedule_args:
|
97 |
+
num_warmup_steps: 2000
|
98 |
+
local_save_vis: true # if save log image locally
|
99 |
+
visualize: true
|
100 |
+
eval_sampling_steps: 500
|
101 |
+
log_interval: 20
|
102 |
+
save_model_epochs: 5
|
103 |
+
save_model_steps: 500
|
104 |
+
work_dir: output/debug
|
105 |
+
online_metric: false
|
106 |
+
eval_metric_step: 2000
|
107 |
+
online_metric_dir: metric_helper
|
diffusion/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from .dpm_solver import DPMS
|
7 |
+
from .flow_euler_sampler import FlowEuler
|
8 |
+
from .iddpm import Scheduler
|
9 |
+
from .sa_sampler import SASolverSampler
|
diffusion/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .datasets import *
|
2 |
+
from .transforms import get_transform
|
diffusion/data/builder.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
|
20 |
+
from mmcv import Registry, build_from_cfg
|
21 |
+
from termcolor import colored
|
22 |
+
from torch.utils.data import DataLoader
|
23 |
+
|
24 |
+
from diffusion.data.transforms import get_transform
|
25 |
+
from diffusion.utils.logger import get_root_logger
|
26 |
+
|
27 |
+
DATASETS = Registry("datasets")
|
28 |
+
|
29 |
+
DATA_ROOT = "data"
|
30 |
+
|
31 |
+
|
32 |
+
def set_data_root(data_root):
|
33 |
+
global DATA_ROOT
|
34 |
+
DATA_ROOT = data_root
|
35 |
+
|
36 |
+
|
37 |
+
def get_data_path(data_dir):
|
38 |
+
if os.path.isabs(data_dir):
|
39 |
+
return data_dir
|
40 |
+
global DATA_ROOT
|
41 |
+
return os.path.join(DATA_ROOT, data_dir)
|
42 |
+
|
43 |
+
|
44 |
+
def get_data_root_and_path(data_dir):
|
45 |
+
if os.path.isabs(data_dir):
|
46 |
+
return data_dir
|
47 |
+
global DATA_ROOT
|
48 |
+
return DATA_ROOT, os.path.join(DATA_ROOT, data_dir)
|
49 |
+
|
50 |
+
|
51 |
+
def build_dataset(cfg, resolution=224, **kwargs):
|
52 |
+
logger = get_root_logger()
|
53 |
+
|
54 |
+
dataset_type = cfg.get("type")
|
55 |
+
logger.info(f"Constructing dataset {dataset_type}...")
|
56 |
+
t = time.time()
|
57 |
+
transform = cfg.pop("transform", "default_train")
|
58 |
+
transform = get_transform(transform, resolution)
|
59 |
+
dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs))
|
60 |
+
logger.info(
|
61 |
+
f"{colored(f'Dataset {dataset_type} constructed: ', 'green', attrs=['bold'])}"
|
62 |
+
f"time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}"
|
63 |
+
)
|
64 |
+
return dataset
|
65 |
+
|
66 |
+
|
67 |
+
def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs):
|
68 |
+
if "batch_sampler" in kwargs:
|
69 |
+
dataloader = DataLoader(
|
70 |
+
dataset, batch_sampler=kwargs["batch_sampler"], num_workers=num_workers, pin_memory=True
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
dataloader = DataLoader(
|
74 |
+
dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, **kwargs
|
75 |
+
)
|
76 |
+
return dataloader
|
diffusion/data/datasets/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .sana_data import SanaImgDataset, SanaWebDataset
|
2 |
+
from .sana_data_multi_scale import DummyDatasetMS, SanaWebDatasetMS
|
3 |
+
from .utils import *
|
diffusion/data/datasets/sana_data.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
|
18 |
+
import getpass
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import os.path as osp
|
22 |
+
import random
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import torch.distributed as dist
|
27 |
+
from PIL import Image
|
28 |
+
from termcolor import colored
|
29 |
+
from torch.utils.data import Dataset
|
30 |
+
|
31 |
+
from diffusion.data.builder import DATASETS, get_data_path
|
32 |
+
from diffusion.data.wids import ShardListDataset, ShardListDatasetMulti, lru_json_load
|
33 |
+
from diffusion.utils.logger import get_root_logger
|
34 |
+
|
35 |
+
|
36 |
+
@DATASETS.register_module()
|
37 |
+
class SanaImgDataset(torch.utils.data.Dataset):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
data_dir="",
|
41 |
+
transform=None,
|
42 |
+
resolution=256,
|
43 |
+
load_vae_feat=False,
|
44 |
+
load_text_feat=False,
|
45 |
+
max_length=300,
|
46 |
+
config=None,
|
47 |
+
caption_proportion=None,
|
48 |
+
external_caption_suffixes=None,
|
49 |
+
external_clipscore_suffixes=None,
|
50 |
+
clip_thr=0.0,
|
51 |
+
clip_thr_temperature=1.0,
|
52 |
+
img_extension=".png",
|
53 |
+
**kwargs,
|
54 |
+
):
|
55 |
+
if external_caption_suffixes is None:
|
56 |
+
external_caption_suffixes = []
|
57 |
+
if external_clipscore_suffixes is None:
|
58 |
+
external_clipscore_suffixes = []
|
59 |
+
|
60 |
+
self.logger = (
|
61 |
+
get_root_logger() if config is None else get_root_logger(osp.join(config.work_dir, "train_log.log"))
|
62 |
+
)
|
63 |
+
self.transform = transform if not load_vae_feat else None
|
64 |
+
self.load_vae_feat = load_vae_feat
|
65 |
+
self.load_text_feat = load_text_feat
|
66 |
+
self.resolution = resolution
|
67 |
+
self.max_length = max_length
|
68 |
+
self.caption_proportion = caption_proportion if caption_proportion is not None else {"prompt": 1.0}
|
69 |
+
self.external_caption_suffixes = external_caption_suffixes
|
70 |
+
self.external_clipscore_suffixes = external_clipscore_suffixes
|
71 |
+
self.clip_thr = clip_thr
|
72 |
+
self.clip_thr_temperature = clip_thr_temperature
|
73 |
+
self.default_prompt = "prompt"
|
74 |
+
self.img_extension = img_extension
|
75 |
+
|
76 |
+
self.data_dirs = data_dir if isinstance(data_dir, list) else [data_dir]
|
77 |
+
# self.meta_datas = [osp.join(data_dir, "meta_data.json") for data_dir in self.data_dirs]
|
78 |
+
self.dataset = []
|
79 |
+
for data_dir in self.data_dirs:
|
80 |
+
meta_data = json.load(open(osp.join(data_dir, "meta_data.json")))
|
81 |
+
self.dataset.extend([osp.join(data_dir, i) for i in meta_data["img_names"]])
|
82 |
+
|
83 |
+
self.dataset = self.dataset * 2000
|
84 |
+
self.logger.info(colored("Dataset is repeat 2000 times for toy dataset", "red", attrs=["bold"]))
|
85 |
+
self.ori_imgs_nums = len(self)
|
86 |
+
self.logger.info(f"Dataset samples: {len(self.dataset)}")
|
87 |
+
|
88 |
+
self.logger.info(f"Loading external caption json from: original_filename{external_caption_suffixes}.json")
|
89 |
+
self.logger.info(f"Loading external clipscore json from: original_filename{external_clipscore_suffixes}.json")
|
90 |
+
self.logger.info(f"external caption clipscore threshold: {clip_thr}, temperature: {clip_thr_temperature}")
|
91 |
+
self.logger.info(f"T5 max token length: {self.max_length}")
|
92 |
+
|
93 |
+
def getdata(self, idx):
|
94 |
+
data = self.dataset[idx]
|
95 |
+
self.key = data.split("/")[-1]
|
96 |
+
# info = json.load(open(f"{data}.json"))[self.key]
|
97 |
+
info = {}
|
98 |
+
with open(f"{data}.txt") as f:
|
99 |
+
info[self.default_prompt] = f.readlines()[0].strip()
|
100 |
+
|
101 |
+
# external json file
|
102 |
+
for suffix in self.external_caption_suffixes:
|
103 |
+
caption_json_path = f"{data}{suffix}.json"
|
104 |
+
if os.path.exists(caption_json_path):
|
105 |
+
try:
|
106 |
+
caption_json = lru_json_load(caption_json_path)
|
107 |
+
except:
|
108 |
+
caption_json = {}
|
109 |
+
if self.key in caption_json:
|
110 |
+
info.update(caption_json[self.key])
|
111 |
+
|
112 |
+
caption_type, caption_clipscore = self.weighted_sample_clipscore(data, info)
|
113 |
+
caption_type = caption_type if caption_type in info else self.default_prompt
|
114 |
+
txt_fea = "" if info[caption_type] is None else info[caption_type]
|
115 |
+
|
116 |
+
data_info = {
|
117 |
+
"img_hw": torch.tensor([self.resolution, self.resolution], dtype=torch.float32),
|
118 |
+
"aspect_ratio": torch.tensor(1.0),
|
119 |
+
}
|
120 |
+
|
121 |
+
if self.load_vae_feat:
|
122 |
+
assert ValueError("Load VAE is not supported now")
|
123 |
+
else:
|
124 |
+
img = f"{data}{self.img_extension}"
|
125 |
+
img = Image.open(img)
|
126 |
+
if self.transform:
|
127 |
+
img = self.transform(img)
|
128 |
+
|
129 |
+
attention_mask = torch.ones(1, 1, self.max_length, dtype=torch.int16) # 1x1xT
|
130 |
+
if self.load_text_feat:
|
131 |
+
npz_path = f"{self.key}.npz"
|
132 |
+
txt_info = np.load(npz_path)
|
133 |
+
txt_fea = torch.from_numpy(txt_info["caption_feature"]) # 1xTx4096
|
134 |
+
if "attention_mask" in txt_info:
|
135 |
+
attention_mask = torch.from_numpy(txt_info["attention_mask"])[None]
|
136 |
+
# make sure the feature length are the same
|
137 |
+
if txt_fea.shape[1] != self.max_length:
|
138 |
+
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_length - txt_fea.shape[1], 1)], dim=1)
|
139 |
+
attention_mask = torch.cat(
|
140 |
+
[attention_mask, torch.zeros(1, 1, self.max_length - attention_mask.shape[-1])], dim=-1
|
141 |
+
)
|
142 |
+
|
143 |
+
return (
|
144 |
+
img,
|
145 |
+
txt_fea,
|
146 |
+
attention_mask.to(torch.int16),
|
147 |
+
data_info,
|
148 |
+
idx,
|
149 |
+
caption_type,
|
150 |
+
"",
|
151 |
+
str(caption_clipscore),
|
152 |
+
)
|
153 |
+
|
154 |
+
def __getitem__(self, idx):
|
155 |
+
for _ in range(10):
|
156 |
+
try:
|
157 |
+
data = self.getdata(idx)
|
158 |
+
return data
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Error details: {str(e)}")
|
161 |
+
idx = idx + 1
|
162 |
+
raise RuntimeError("Too many bad data.")
|
163 |
+
|
164 |
+
def __len__(self):
|
165 |
+
return len(self.dataset)
|
166 |
+
|
167 |
+
def weighted_sample_fix_prob(self):
|
168 |
+
labels = list(self.caption_proportion.keys())
|
169 |
+
weights = list(self.caption_proportion.values())
|
170 |
+
sampled_label = random.choices(labels, weights=weights, k=1)[0]
|
171 |
+
return sampled_label
|
172 |
+
|
173 |
+
def weighted_sample_clipscore(self, data, info):
|
174 |
+
labels = []
|
175 |
+
weights = []
|
176 |
+
fallback_label = None
|
177 |
+
max_clip_score = float("-inf")
|
178 |
+
|
179 |
+
for suffix in self.external_clipscore_suffixes:
|
180 |
+
clipscore_json_path = f"{data}{suffix}.json"
|
181 |
+
|
182 |
+
if os.path.exists(clipscore_json_path):
|
183 |
+
try:
|
184 |
+
clipscore_json = lru_json_load(clipscore_json_path)
|
185 |
+
except:
|
186 |
+
clipscore_json = {}
|
187 |
+
if self.key in clipscore_json:
|
188 |
+
clip_scores = clipscore_json[self.key]
|
189 |
+
|
190 |
+
for caption_type, clip_score in clip_scores.items():
|
191 |
+
clip_score = float(clip_score)
|
192 |
+
if caption_type in info:
|
193 |
+
if clip_score >= self.clip_thr:
|
194 |
+
labels.append(caption_type)
|
195 |
+
weights.append(clip_score)
|
196 |
+
|
197 |
+
if clip_score > max_clip_score:
|
198 |
+
max_clip_score = clip_score
|
199 |
+
fallback_label = caption_type
|
200 |
+
|
201 |
+
if not labels and fallback_label:
|
202 |
+
return fallback_label, max_clip_score
|
203 |
+
|
204 |
+
if not labels:
|
205 |
+
return self.default_prompt, 0.0
|
206 |
+
|
207 |
+
adjusted_weights = np.array(weights) ** (1.0 / max(self.clip_thr_temperature, 0.01))
|
208 |
+
normalized_weights = adjusted_weights / np.sum(adjusted_weights)
|
209 |
+
sampled_label = random.choices(labels, weights=normalized_weights, k=1)[0]
|
210 |
+
# sampled_label = random.choices(labels, weights=[1]*len(weights), k=1)[0]
|
211 |
+
index = labels.index(sampled_label)
|
212 |
+
original_weight = weights[index]
|
213 |
+
|
214 |
+
return sampled_label, original_weight
|
215 |
+
|
216 |
+
|
217 |
+
@DATASETS.register_module()
|
218 |
+
class SanaWebDataset(torch.utils.data.Dataset):
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
data_dir="",
|
222 |
+
meta_path=None,
|
223 |
+
cache_dir="/cache/data/sana-webds-meta",
|
224 |
+
max_shards_to_load=None,
|
225 |
+
transform=None,
|
226 |
+
resolution=256,
|
227 |
+
load_vae_feat=False,
|
228 |
+
load_text_feat=False,
|
229 |
+
max_length=300,
|
230 |
+
config=None,
|
231 |
+
caption_proportion=None,
|
232 |
+
sort_dataset=False,
|
233 |
+
num_replicas=None,
|
234 |
+
external_caption_suffixes=None,
|
235 |
+
external_clipscore_suffixes=None,
|
236 |
+
clip_thr=0.0,
|
237 |
+
clip_thr_temperature=1.0,
|
238 |
+
**kwargs,
|
239 |
+
):
|
240 |
+
if external_caption_suffixes is None:
|
241 |
+
external_caption_suffixes = []
|
242 |
+
if external_clipscore_suffixes is None:
|
243 |
+
external_clipscore_suffixes = []
|
244 |
+
|
245 |
+
self.logger = (
|
246 |
+
get_root_logger() if config is None else get_root_logger(osp.join(config.work_dir, "train_log.log"))
|
247 |
+
)
|
248 |
+
self.transform = transform if not load_vae_feat else None
|
249 |
+
self.load_vae_feat = load_vae_feat
|
250 |
+
self.load_text_feat = load_text_feat
|
251 |
+
self.resolution = resolution
|
252 |
+
self.max_length = max_length
|
253 |
+
self.caption_proportion = caption_proportion if caption_proportion is not None else {"prompt": 1.0}
|
254 |
+
self.external_caption_suffixes = external_caption_suffixes
|
255 |
+
self.external_clipscore_suffixes = external_clipscore_suffixes
|
256 |
+
self.clip_thr = clip_thr
|
257 |
+
self.clip_thr_temperature = clip_thr_temperature
|
258 |
+
self.default_prompt = "prompt"
|
259 |
+
|
260 |
+
data_dirs = data_dir if isinstance(data_dir, list) else [data_dir]
|
261 |
+
meta_paths = meta_path if isinstance(meta_path, list) else [meta_path] * len(data_dirs)
|
262 |
+
self.meta_paths = []
|
263 |
+
for data_path, meta_path in zip(data_dirs, meta_paths):
|
264 |
+
self.data_path = osp.expanduser(data_path)
|
265 |
+
self.meta_path = osp.expanduser(meta_path) if meta_path is not None else None
|
266 |
+
|
267 |
+
_local_meta_path = osp.join(self.data_path, "wids-meta.json")
|
268 |
+
if meta_path is None and osp.exists(_local_meta_path):
|
269 |
+
self.logger.info(f"loading from {_local_meta_path}")
|
270 |
+
self.meta_path = meta_path = _local_meta_path
|
271 |
+
|
272 |
+
if meta_path is None:
|
273 |
+
self.meta_path = osp.join(
|
274 |
+
osp.expanduser(cache_dir),
|
275 |
+
self.data_path.replace("/", "--") + f".max_shards:{max_shards_to_load}" + ".wdsmeta.json",
|
276 |
+
)
|
277 |
+
|
278 |
+
assert osp.exists(self.meta_path), f"meta path not found in [{self.meta_path}] or [{_local_meta_path}]"
|
279 |
+
self.logger.info(f"[SimplyInternal] Loading meta information {self.meta_path}")
|
280 |
+
self.meta_paths.append(self.meta_path)
|
281 |
+
|
282 |
+
self._initialize_dataset(num_replicas, sort_dataset)
|
283 |
+
|
284 |
+
self.logger.info(f"Loading external caption json from: original_filename{external_caption_suffixes}.json")
|
285 |
+
self.logger.info(f"Loading external clipscore json from: original_filename{external_clipscore_suffixes}.json")
|
286 |
+
self.logger.info(f"external caption clipscore threshold: {clip_thr}, temperature: {clip_thr_temperature}")
|
287 |
+
self.logger.info(f"T5 max token length: {self.max_length}")
|
288 |
+
self.logger.warning(f"Sort the dataset: {sort_dataset}")
|
289 |
+
|
290 |
+
def _initialize_dataset(self, num_replicas, sort_dataset):
|
291 |
+
# uuid = abs(hash(self.meta_path)) % (10 ** 8)
|
292 |
+
import hashlib
|
293 |
+
|
294 |
+
uuid = hashlib.sha256(self.meta_path.encode()).hexdigest()[:8]
|
295 |
+
if len(self.meta_paths) > 0:
|
296 |
+
self.dataset = ShardListDatasetMulti(
|
297 |
+
self.meta_paths,
|
298 |
+
cache_dir=osp.expanduser(f"~/.cache/_wids_cache/{getpass.getuser()}-{uuid}"),
|
299 |
+
sort_data_inseq=sort_dataset,
|
300 |
+
num_replicas=num_replicas or dist.get_world_size(),
|
301 |
+
)
|
302 |
+
else:
|
303 |
+
# TODO: tmp to ensure there is no bug
|
304 |
+
self.dataset = ShardListDataset(
|
305 |
+
self.meta_path,
|
306 |
+
cache_dir=osp.expanduser(f"~/.cache/_wids_cache/{getpass.getuser()}-{uuid}"),
|
307 |
+
)
|
308 |
+
self.ori_imgs_nums = len(self)
|
309 |
+
self.logger.info(f"{self.dataset.data_info}")
|
310 |
+
|
311 |
+
def getdata(self, idx):
|
312 |
+
data = self.dataset[idx]
|
313 |
+
info = data[".json"]
|
314 |
+
self.key = data["__key__"]
|
315 |
+
dataindex_info = {
|
316 |
+
"index": data["__index__"],
|
317 |
+
"shard": "/".join(data["__shard__"].rsplit("/", 2)[-2:]),
|
318 |
+
"shardindex": data["__shardindex__"],
|
319 |
+
}
|
320 |
+
|
321 |
+
# external json file
|
322 |
+
for suffix in self.external_caption_suffixes:
|
323 |
+
caption_json_path = data["__shard__"].replace(".tar", f"{suffix}.json")
|
324 |
+
if os.path.exists(caption_json_path):
|
325 |
+
try:
|
326 |
+
caption_json = lru_json_load(caption_json_path)
|
327 |
+
except:
|
328 |
+
caption_json = {}
|
329 |
+
if self.key in caption_json:
|
330 |
+
info.update(caption_json[self.key])
|
331 |
+
|
332 |
+
caption_type, caption_clipscore = self.weighted_sample_clipscore(data, info)
|
333 |
+
caption_type = caption_type if caption_type in info else self.default_prompt
|
334 |
+
txt_fea = "" if info[caption_type] is None else info[caption_type]
|
335 |
+
|
336 |
+
data_info = {
|
337 |
+
"img_hw": torch.tensor([self.resolution, self.resolution], dtype=torch.float32),
|
338 |
+
"aspect_ratio": torch.tensor(1.0),
|
339 |
+
}
|
340 |
+
|
341 |
+
if self.load_vae_feat:
|
342 |
+
img = data[".npy"]
|
343 |
+
else:
|
344 |
+
img = data[".png"] if ".png" in data else data[".jpg"]
|
345 |
+
if self.transform:
|
346 |
+
img = self.transform(img)
|
347 |
+
|
348 |
+
attention_mask = torch.ones(1, 1, self.max_length, dtype=torch.int16) # 1x1xT
|
349 |
+
if self.load_text_feat:
|
350 |
+
npz_path = f"{self.key}.npz"
|
351 |
+
txt_info = np.load(npz_path)
|
352 |
+
txt_fea = torch.from_numpy(txt_info["caption_feature"]) # 1xTx4096
|
353 |
+
if "attention_mask" in txt_info:
|
354 |
+
attention_mask = torch.from_numpy(txt_info["attention_mask"])[None]
|
355 |
+
# make sure the feature length are the same
|
356 |
+
if txt_fea.shape[1] != self.max_length:
|
357 |
+
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_length - txt_fea.shape[1], 1)], dim=1)
|
358 |
+
attention_mask = torch.cat(
|
359 |
+
[attention_mask, torch.zeros(1, 1, self.max_length - attention_mask.shape[-1])], dim=-1
|
360 |
+
)
|
361 |
+
|
362 |
+
return (
|
363 |
+
img,
|
364 |
+
txt_fea,
|
365 |
+
attention_mask.to(torch.int16),
|
366 |
+
data_info,
|
367 |
+
idx,
|
368 |
+
caption_type,
|
369 |
+
dataindex_info,
|
370 |
+
str(caption_clipscore),
|
371 |
+
)
|
372 |
+
|
373 |
+
def __getitem__(self, idx):
|
374 |
+
for _ in range(10):
|
375 |
+
try:
|
376 |
+
data = self.getdata(idx)
|
377 |
+
return data
|
378 |
+
except Exception as e:
|
379 |
+
print(f"Error details: {str(e)}")
|
380 |
+
idx = idx + 1
|
381 |
+
raise RuntimeError("Too many bad data.")
|
382 |
+
|
383 |
+
def __len__(self):
|
384 |
+
return len(self.dataset)
|
385 |
+
|
386 |
+
def weighted_sample_fix_prob(self):
|
387 |
+
labels = list(self.caption_proportion.keys())
|
388 |
+
weights = list(self.caption_proportion.values())
|
389 |
+
sampled_label = random.choices(labels, weights=weights, k=1)[0]
|
390 |
+
return sampled_label
|
391 |
+
|
392 |
+
def weighted_sample_clipscore(self, data, info):
|
393 |
+
labels = []
|
394 |
+
weights = []
|
395 |
+
fallback_label = None
|
396 |
+
max_clip_score = float("-inf")
|
397 |
+
|
398 |
+
for suffix in self.external_clipscore_suffixes:
|
399 |
+
clipscore_json_path = data["__shard__"].replace(".tar", f"{suffix}.json")
|
400 |
+
|
401 |
+
if os.path.exists(clipscore_json_path):
|
402 |
+
try:
|
403 |
+
clipscore_json = lru_json_load(clipscore_json_path)
|
404 |
+
except:
|
405 |
+
clipscore_json = {}
|
406 |
+
if self.key in clipscore_json:
|
407 |
+
clip_scores = clipscore_json[self.key]
|
408 |
+
|
409 |
+
for caption_type, clip_score in clip_scores.items():
|
410 |
+
clip_score = float(clip_score)
|
411 |
+
if caption_type in info:
|
412 |
+
if clip_score >= self.clip_thr:
|
413 |
+
labels.append(caption_type)
|
414 |
+
weights.append(clip_score)
|
415 |
+
|
416 |
+
if clip_score > max_clip_score:
|
417 |
+
max_clip_score = clip_score
|
418 |
+
fallback_label = caption_type
|
419 |
+
|
420 |
+
if not labels and fallback_label:
|
421 |
+
return fallback_label, max_clip_score
|
422 |
+
|
423 |
+
if not labels:
|
424 |
+
return self.default_prompt, 0.0
|
425 |
+
|
426 |
+
adjusted_weights = np.array(weights) ** (1.0 / max(self.clip_thr_temperature, 0.01))
|
427 |
+
normalized_weights = adjusted_weights / np.sum(adjusted_weights)
|
428 |
+
sampled_label = random.choices(labels, weights=normalized_weights, k=1)[0]
|
429 |
+
# sampled_label = random.choices(labels, weights=[1]*len(weights), k=1)[0]
|
430 |
+
index = labels.index(sampled_label)
|
431 |
+
original_weight = weights[index]
|
432 |
+
|
433 |
+
return sampled_label, original_weight
|
434 |
+
|
435 |
+
def get_data_info(self, idx):
|
436 |
+
try:
|
437 |
+
data = self.dataset[idx]
|
438 |
+
info = data[".json"]
|
439 |
+
key = data["__key__"]
|
440 |
+
version = info.get("version", "others")
|
441 |
+
return {"height": info["height"], "width": info["width"], "version": version, "key": key}
|
442 |
+
except Exception as e:
|
443 |
+
print(f"Error details: {str(e)}")
|
444 |
+
return None
|
445 |
+
|
446 |
+
|
447 |
+
if __name__ == "__main__":
|
448 |
+
from torch.utils.data import DataLoader
|
449 |
+
|
450 |
+
from diffusion.data.transforms import get_transform
|
451 |
+
|
452 |
+
image_size = 1024 # 256
|
453 |
+
transform = get_transform("default_train", image_size)
|
454 |
+
train_dataset = SanaWebDataset(
|
455 |
+
data_dir="debug_data_train/vaef32c32/debug_data",
|
456 |
+
resolution=image_size,
|
457 |
+
transform=transform,
|
458 |
+
max_length=300,
|
459 |
+
load_vae_feat=True,
|
460 |
+
num_replicas=1,
|
461 |
+
)
|
462 |
+
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=4)
|
463 |
+
|
464 |
+
for data in dataloader:
|
465 |
+
img, txt_fea, attention_mask, data_info = data
|
466 |
+
print(txt_fea)
|
467 |
+
break
|
diffusion/data/datasets/sana_data_multi_scale.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
|
18 |
+
import os
|
19 |
+
import random
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from torchvision import transforms as T
|
24 |
+
from torchvision.transforms.functional import InterpolationMode
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
from diffusion.data.builder import DATASETS
|
28 |
+
from diffusion.data.datasets.sana_data import SanaWebDataset
|
29 |
+
from diffusion.data.datasets.utils import *
|
30 |
+
from diffusion.data.wids import lru_json_load
|
31 |
+
|
32 |
+
|
33 |
+
def get_closest_ratio(height: float, width: float, ratios: dict):
|
34 |
+
aspect_ratio = height / width
|
35 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
36 |
+
return ratios[closest_ratio], float(closest_ratio)
|
37 |
+
|
38 |
+
|
39 |
+
@DATASETS.register_module()
|
40 |
+
class SanaWebDatasetMS(SanaWebDataset):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
data_dir="",
|
44 |
+
meta_path=None,
|
45 |
+
cache_dir="/cache/data/sana-webds-meta",
|
46 |
+
max_shards_to_load=None,
|
47 |
+
transform=None,
|
48 |
+
resolution=256,
|
49 |
+
sample_subset=None,
|
50 |
+
load_vae_feat=False,
|
51 |
+
load_text_feat=False,
|
52 |
+
input_size=32,
|
53 |
+
patch_size=2,
|
54 |
+
max_length=300,
|
55 |
+
config=None,
|
56 |
+
caption_proportion=None,
|
57 |
+
sort_dataset=False,
|
58 |
+
num_replicas=None,
|
59 |
+
external_caption_suffixes=None,
|
60 |
+
external_clipscore_suffixes=None,
|
61 |
+
clip_thr=0.0,
|
62 |
+
clip_thr_temperature=1.0,
|
63 |
+
vae_downsample_rate=32,
|
64 |
+
**kwargs,
|
65 |
+
):
|
66 |
+
super().__init__(
|
67 |
+
data_dir=data_dir,
|
68 |
+
meta_path=meta_path,
|
69 |
+
cache_dir=cache_dir,
|
70 |
+
max_shards_to_load=max_shards_to_load,
|
71 |
+
transform=transform,
|
72 |
+
resolution=resolution,
|
73 |
+
sample_subset=sample_subset,
|
74 |
+
load_vae_feat=load_vae_feat,
|
75 |
+
load_text_feat=load_text_feat,
|
76 |
+
input_size=input_size,
|
77 |
+
patch_size=patch_size,
|
78 |
+
max_length=max_length,
|
79 |
+
config=config,
|
80 |
+
caption_proportion=caption_proportion,
|
81 |
+
sort_dataset=sort_dataset,
|
82 |
+
num_replicas=num_replicas,
|
83 |
+
external_caption_suffixes=external_caption_suffixes,
|
84 |
+
external_clipscore_suffixes=external_clipscore_suffixes,
|
85 |
+
clip_thr=clip_thr,
|
86 |
+
clip_thr_temperature=clip_thr_temperature,
|
87 |
+
vae_downsample_rate=32,
|
88 |
+
**kwargs,
|
89 |
+
)
|
90 |
+
self.base_size = int(kwargs["aspect_ratio_type"].split("_")[-1])
|
91 |
+
self.aspect_ratio = eval(kwargs.pop("aspect_ratio_type")) # base aspect ratio
|
92 |
+
self.ratio_index = {}
|
93 |
+
self.ratio_nums = {}
|
94 |
+
self.interpolate_model = InterpolationMode.BICUBIC
|
95 |
+
self.interpolate_model = (
|
96 |
+
InterpolationMode.BICUBIC
|
97 |
+
if self.aspect_ratio not in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]
|
98 |
+
else InterpolationMode.LANCZOS
|
99 |
+
)
|
100 |
+
|
101 |
+
for k, v in self.aspect_ratio.items():
|
102 |
+
self.ratio_index[float(k)] = []
|
103 |
+
self.ratio_nums[float(k)] = 0
|
104 |
+
|
105 |
+
self.vae_downsample_rate = vae_downsample_rate
|
106 |
+
|
107 |
+
def __getitem__(self, idx):
|
108 |
+
for _ in range(10):
|
109 |
+
try:
|
110 |
+
data = self.getdata(idx)
|
111 |
+
return data
|
112 |
+
except Exception as e:
|
113 |
+
print(f"Error details: {str(e)}")
|
114 |
+
idx = random.choice(self.ratio_index[self.closest_ratio])
|
115 |
+
raise RuntimeError("Too many bad data.")
|
116 |
+
|
117 |
+
def getdata(self, idx):
|
118 |
+
data = self.dataset[idx]
|
119 |
+
info = data[".json"]
|
120 |
+
self.key = data["__key__"]
|
121 |
+
dataindex_info = {
|
122 |
+
"index": data["__index__"],
|
123 |
+
"shard": "/".join(data["__shard__"].rsplit("/", 2)[-2:]),
|
124 |
+
"shardindex": data["__shardindex__"],
|
125 |
+
}
|
126 |
+
|
127 |
+
# external json file
|
128 |
+
for suffix in self.external_caption_suffixes:
|
129 |
+
caption_json_path = data["__shard__"].replace(".tar", f"{suffix}.json")
|
130 |
+
if os.path.exists(caption_json_path):
|
131 |
+
try:
|
132 |
+
caption_json = lru_json_load(caption_json_path)
|
133 |
+
except:
|
134 |
+
caption_json = {}
|
135 |
+
if self.key in caption_json:
|
136 |
+
info.update(caption_json[self.key])
|
137 |
+
|
138 |
+
data_info = {}
|
139 |
+
ori_h, ori_w = info["height"], info["width"]
|
140 |
+
|
141 |
+
# Calculate the closest aspect ratio and resize & crop image[w, h]
|
142 |
+
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
|
143 |
+
closest_size = list(map(lambda x: int(x), closest_size))
|
144 |
+
self.closest_ratio = closest_ratio
|
145 |
+
|
146 |
+
data_info["img_hw"] = torch.tensor([ori_h, ori_w], dtype=torch.float32)
|
147 |
+
data_info["aspect_ratio"] = closest_ratio
|
148 |
+
|
149 |
+
caption_type, caption_clipscore = self.weighted_sample_clipscore(data, info)
|
150 |
+
caption_type = caption_type if caption_type in info else self.default_prompt
|
151 |
+
txt_fea = "" if info[caption_type] is None else info[caption_type]
|
152 |
+
|
153 |
+
if self.load_vae_feat:
|
154 |
+
img = data[".npy"]
|
155 |
+
if len(img.shape) == 4 and img.shape[0] == 1:
|
156 |
+
img = img[0]
|
157 |
+
h, w = (img.shape[1], img.shape[2])
|
158 |
+
assert h == int(closest_size[0] // self.vae_downsample_rate) and w == int(
|
159 |
+
closest_size[1] // self.vae_downsample_rate
|
160 |
+
), f"h: {h}, w: {w}, ori_hw: {closest_size}, data_info: {dataindex_info}"
|
161 |
+
else:
|
162 |
+
img = data[".png"] if ".png" in data else data[".jpg"]
|
163 |
+
if closest_size[0] / ori_h > closest_size[1] / ori_w:
|
164 |
+
resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
|
165 |
+
else:
|
166 |
+
resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
|
167 |
+
self.transform = T.Compose(
|
168 |
+
[
|
169 |
+
T.Lambda(lambda img: img.convert("RGB")),
|
170 |
+
T.Resize(resize_size, interpolation=self.interpolate_model), # Image.BICUBIC
|
171 |
+
T.CenterCrop(closest_size),
|
172 |
+
T.ToTensor(),
|
173 |
+
T.Normalize([0.5], [0.5]),
|
174 |
+
]
|
175 |
+
)
|
176 |
+
if idx not in self.ratio_index[closest_ratio]:
|
177 |
+
self.ratio_index[closest_ratio].append(idx)
|
178 |
+
|
179 |
+
if self.transform:
|
180 |
+
img = self.transform(img)
|
181 |
+
|
182 |
+
attention_mask = torch.ones(1, 1, self.max_length, dtype=torch.int16) # 1x1xT
|
183 |
+
if self.load_text_feat:
|
184 |
+
npz_path = f"{self.key}.npz"
|
185 |
+
txt_info = np.load(npz_path)
|
186 |
+
txt_fea = torch.from_numpy(txt_info["caption_feature"]) # 1xTx4096
|
187 |
+
if "attention_mask" in txt_info:
|
188 |
+
attention_mask = torch.from_numpy(txt_info["attention_mask"])[None]
|
189 |
+
# make sure the feature length are the same
|
190 |
+
if txt_fea.shape[1] != self.max_length:
|
191 |
+
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_length - txt_fea.shape[1], 1)], dim=1)
|
192 |
+
attention_mask = torch.cat(
|
193 |
+
[attention_mask, torch.zeros(1, 1, self.max_length - attention_mask.shape[-1])], dim=-1
|
194 |
+
)
|
195 |
+
|
196 |
+
return (
|
197 |
+
img,
|
198 |
+
txt_fea,
|
199 |
+
attention_mask.to(torch.int16),
|
200 |
+
data_info,
|
201 |
+
idx,
|
202 |
+
caption_type,
|
203 |
+
dataindex_info,
|
204 |
+
str(caption_clipscore),
|
205 |
+
)
|
206 |
+
|
207 |
+
def __len__(self):
|
208 |
+
return len(self.dataset)
|
209 |
+
|
210 |
+
|
211 |
+
@DATASETS.register_module()
|
212 |
+
class DummyDatasetMS(SanaWebDatasetMS):
|
213 |
+
def __init__(self, **kwargs):
|
214 |
+
self.base_size = int(kwargs["aspect_ratio_type"].split("_")[-1])
|
215 |
+
self.aspect_ratio = eval(kwargs.pop("aspect_ratio_type")) # base aspect ratio
|
216 |
+
self.ratio_index = {}
|
217 |
+
self.ratio_nums = {}
|
218 |
+
self.interpolate_model = InterpolationMode.BICUBIC
|
219 |
+
self.interpolate_model = (
|
220 |
+
InterpolationMode.BICUBIC
|
221 |
+
if self.aspect_ratio not in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]
|
222 |
+
else InterpolationMode.LANCZOS
|
223 |
+
)
|
224 |
+
|
225 |
+
for k, v in self.aspect_ratio.items():
|
226 |
+
self.ratio_index[float(k)] = []
|
227 |
+
self.ratio_nums[float(k)] = 0
|
228 |
+
|
229 |
+
self.ori_imgs_nums = 1_000_000
|
230 |
+
self.height = 384
|
231 |
+
self.width = 672
|
232 |
+
|
233 |
+
def __getitem__(self, idx):
|
234 |
+
img = torch.randn((3, self.height, self.width))
|
235 |
+
txt_fea = "The image depicts a young woman standing in the middle of a street, leaning against a silver car. She is dressed in a stylish outfit consisting of a blue blouse and black pants. Her hair is long and dark, and she is looking directly at the camera with a confident expression. The street is lined with colorful buildings, and the trees have autumn leaves, suggesting the season is fall. The lighting is warm, with sunlight casting long shadows on the street. There are a few people in the background, and the overall atmosphere is vibrant and lively."
|
236 |
+
attention_mask = torch.ones(1, 1, 300, dtype=torch.int16) # 1x1xT
|
237 |
+
data_info = {"img_hw": torch.tensor([816.0, 1456.0]), "aspect_ratio": 0.57}
|
238 |
+
idx = 2500
|
239 |
+
caption_type = self.default_prompt
|
240 |
+
dataindex_info = {"index": 2500, "shard": "data_for_test_after_change/00000000.tar", "shardindex": 2500}
|
241 |
+
return img, txt_fea, attention_mask, data_info, idx, caption_type, dataindex_info
|
242 |
+
|
243 |
+
def __len__(self):
|
244 |
+
return self.ori_imgs_nums
|
245 |
+
|
246 |
+
def get_data_info(self, idx):
|
247 |
+
return {"height": self.height, "width": self.width, "version": "1.0", "key": "dummpy_key"}
|
248 |
+
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
from torch.utils.data import DataLoader
|
252 |
+
|
253 |
+
from diffusion.data.datasets.utils import ASPECT_RATIO_1024
|
254 |
+
from diffusion.data.transforms import get_transform
|
255 |
+
|
256 |
+
image_size = 256
|
257 |
+
transform = get_transform("default_train", image_size)
|
258 |
+
data_dir = ["data/debug_data_train/debug_data"]
|
259 |
+
for data_path in data_dir:
|
260 |
+
train_dataset = SanaWebDatasetMS(data_dir=data_path, resolution=image_size, transform=transform, max_length=300)
|
261 |
+
dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=4)
|
262 |
+
|
263 |
+
for data in tqdm(dataloader):
|
264 |
+
break
|
265 |
+
print(dataloader.dataset.index_info)
|
diffusion/data/datasets/utils.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
|
18 |
+
ASPECT_RATIO_4096 = {
|
19 |
+
"0.25": [2048.0, 8192.0],
|
20 |
+
"0.26": [2048.0, 7936.0],
|
21 |
+
"0.27": [2048.0, 7680.0],
|
22 |
+
"0.28": [2048.0, 7424.0],
|
23 |
+
"0.32": [2304.0, 7168.0],
|
24 |
+
"0.33": [2304.0, 6912.0],
|
25 |
+
"0.35": [2304.0, 6656.0],
|
26 |
+
"0.4": [2560.0, 6400.0],
|
27 |
+
"0.42": [2560.0, 6144.0],
|
28 |
+
"0.48": [2816.0, 5888.0],
|
29 |
+
"0.5": [2816.0, 5632.0],
|
30 |
+
"0.52": [2816.0, 5376.0],
|
31 |
+
"0.57": [3072.0, 5376.0],
|
32 |
+
"0.6": [3072.0, 5120.0],
|
33 |
+
"0.68": [3328.0, 4864.0],
|
34 |
+
"0.72": [3328.0, 4608.0],
|
35 |
+
"0.78": [3584.0, 4608.0],
|
36 |
+
"0.82": [3584.0, 4352.0],
|
37 |
+
"0.88": [3840.0, 4352.0],
|
38 |
+
"0.94": [3840.0, 4096.0],
|
39 |
+
"1.0": [4096.0, 4096.0],
|
40 |
+
"1.07": [4096.0, 3840.0],
|
41 |
+
"1.13": [4352.0, 3840.0],
|
42 |
+
"1.21": [4352.0, 3584.0],
|
43 |
+
"1.29": [4608.0, 3584.0],
|
44 |
+
"1.38": [4608.0, 3328.0],
|
45 |
+
"1.46": [4864.0, 3328.0],
|
46 |
+
"1.67": [5120.0, 3072.0],
|
47 |
+
"1.75": [5376.0, 3072.0],
|
48 |
+
"2.0": [5632.0, 2816.0],
|
49 |
+
"2.09": [5888.0, 2816.0],
|
50 |
+
"2.4": [6144.0, 2560.0],
|
51 |
+
"2.5": [6400.0, 2560.0],
|
52 |
+
"2.89": [6656.0, 2304.0],
|
53 |
+
"3.0": [6912.0, 2304.0],
|
54 |
+
"3.11": [7168.0, 2304.0],
|
55 |
+
"3.62": [7424.0, 2048.0],
|
56 |
+
"3.75": [7680.0, 2048.0],
|
57 |
+
"3.88": [7936.0, 2048.0],
|
58 |
+
"4.0": [8192.0, 2048.0],
|
59 |
+
}
|
60 |
+
|
61 |
+
ASPECT_RATIO_2880 = {
|
62 |
+
"0.25": [1408.0, 5760.0],
|
63 |
+
"0.26": [1408.0, 5568.0],
|
64 |
+
"0.27": [1408.0, 5376.0],
|
65 |
+
"0.28": [1408.0, 5184.0],
|
66 |
+
"0.32": [1600.0, 4992.0],
|
67 |
+
"0.33": [1600.0, 4800.0],
|
68 |
+
"0.34": [1600.0, 4672.0],
|
69 |
+
"0.4": [1792.0, 4480.0],
|
70 |
+
"0.42": [1792.0, 4288.0],
|
71 |
+
"0.47": [1920.0, 4096.0],
|
72 |
+
"0.49": [1920.0, 3904.0],
|
73 |
+
"0.51": [1920.0, 3776.0],
|
74 |
+
"0.55": [2112.0, 3840.0],
|
75 |
+
"0.59": [2112.0, 3584.0],
|
76 |
+
"0.68": [2304.0, 3392.0],
|
77 |
+
"0.72": [2304.0, 3200.0],
|
78 |
+
"0.78": [2496.0, 3200.0],
|
79 |
+
"0.83": [2496.0, 3008.0],
|
80 |
+
"0.89": [2688.0, 3008.0],
|
81 |
+
"0.93": [2688.0, 2880.0],
|
82 |
+
"1.0": [2880.0, 2880.0],
|
83 |
+
"1.07": [2880.0, 2688.0],
|
84 |
+
"1.12": [3008.0, 2688.0],
|
85 |
+
"1.21": [3008.0, 2496.0],
|
86 |
+
"1.28": [3200.0, 2496.0],
|
87 |
+
"1.39": [3200.0, 2304.0],
|
88 |
+
"1.47": [3392.0, 2304.0],
|
89 |
+
"1.7": [3584.0, 2112.0],
|
90 |
+
"1.82": [3840.0, 2112.0],
|
91 |
+
"2.03": [3904.0, 1920.0],
|
92 |
+
"2.13": [4096.0, 1920.0],
|
93 |
+
"2.39": [4288.0, 1792.0],
|
94 |
+
"2.5": [4480.0, 1792.0],
|
95 |
+
"2.92": [4672.0, 1600.0],
|
96 |
+
"3.0": [4800.0, 1600.0],
|
97 |
+
"3.12": [4992.0, 1600.0],
|
98 |
+
"3.68": [5184.0, 1408.0],
|
99 |
+
"3.82": [5376.0, 1408.0],
|
100 |
+
"3.95": [5568.0, 1408.0],
|
101 |
+
"4.0": [5760.0, 1408.0],
|
102 |
+
}
|
103 |
+
|
104 |
+
ASPECT_RATIO_2048 = {
|
105 |
+
"0.25": [1024.0, 4096.0],
|
106 |
+
"0.26": [1024.0, 3968.0],
|
107 |
+
"0.27": [1024.0, 3840.0],
|
108 |
+
"0.28": [1024.0, 3712.0],
|
109 |
+
"0.32": [1152.0, 3584.0],
|
110 |
+
"0.33": [1152.0, 3456.0],
|
111 |
+
"0.35": [1152.0, 3328.0],
|
112 |
+
"0.4": [1280.0, 3200.0],
|
113 |
+
"0.42": [1280.0, 3072.0],
|
114 |
+
"0.48": [1408.0, 2944.0],
|
115 |
+
"0.5": [1408.0, 2816.0],
|
116 |
+
"0.52": [1408.0, 2688.0],
|
117 |
+
"0.57": [1536.0, 2688.0],
|
118 |
+
"0.6": [1536.0, 2560.0],
|
119 |
+
"0.68": [1664.0, 2432.0],
|
120 |
+
"0.72": [1664.0, 2304.0],
|
121 |
+
"0.78": [1792.0, 2304.0],
|
122 |
+
"0.82": [1792.0, 2176.0],
|
123 |
+
"0.88": [1920.0, 2176.0],
|
124 |
+
"0.94": [1920.0, 2048.0],
|
125 |
+
"1.0": [2048.0, 2048.0],
|
126 |
+
"1.07": [2048.0, 1920.0],
|
127 |
+
"1.13": [2176.0, 1920.0],
|
128 |
+
"1.21": [2176.0, 1792.0],
|
129 |
+
"1.29": [2304.0, 1792.0],
|
130 |
+
"1.38": [2304.0, 1664.0],
|
131 |
+
"1.46": [2432.0, 1664.0],
|
132 |
+
"1.67": [2560.0, 1536.0],
|
133 |
+
"1.75": [2688.0, 1536.0],
|
134 |
+
"2.0": [2816.0, 1408.0],
|
135 |
+
"2.09": [2944.0, 1408.0],
|
136 |
+
"2.4": [3072.0, 1280.0],
|
137 |
+
"2.5": [3200.0, 1280.0],
|
138 |
+
"2.89": [3328.0, 1152.0],
|
139 |
+
"3.0": [3456.0, 1152.0],
|
140 |
+
"3.11": [3584.0, 1152.0],
|
141 |
+
"3.62": [3712.0, 1024.0],
|
142 |
+
"3.75": [3840.0, 1024.0],
|
143 |
+
"3.88": [3968.0, 1024.0],
|
144 |
+
"4.0": [4096.0, 1024.0],
|
145 |
+
}
|
146 |
+
|
147 |
+
ASPECT_RATIO_1024 = {
|
148 |
+
"0.25": [512.0, 2048.0],
|
149 |
+
"0.26": [512.0, 1984.0],
|
150 |
+
"0.27": [512.0, 1920.0],
|
151 |
+
"0.28": [512.0, 1856.0],
|
152 |
+
"0.32": [576.0, 1792.0],
|
153 |
+
"0.33": [576.0, 1728.0],
|
154 |
+
"0.35": [576.0, 1664.0],
|
155 |
+
"0.4": [640.0, 1600.0],
|
156 |
+
"0.42": [640.0, 1536.0],
|
157 |
+
"0.48": [704.0, 1472.0],
|
158 |
+
"0.5": [704.0, 1408.0],
|
159 |
+
"0.52": [704.0, 1344.0],
|
160 |
+
"0.57": [768.0, 1344.0],
|
161 |
+
"0.6": [768.0, 1280.0],
|
162 |
+
"0.68": [832.0, 1216.0],
|
163 |
+
"0.72": [832.0, 1152.0],
|
164 |
+
"0.78": [896.0, 1152.0],
|
165 |
+
"0.82": [896.0, 1088.0],
|
166 |
+
"0.88": [960.0, 1088.0],
|
167 |
+
"0.94": [960.0, 1024.0],
|
168 |
+
"1.0": [1024.0, 1024.0],
|
169 |
+
"1.07": [1024.0, 960.0],
|
170 |
+
"1.13": [1088.0, 960.0],
|
171 |
+
"1.21": [1088.0, 896.0],
|
172 |
+
"1.29": [1152.0, 896.0],
|
173 |
+
"1.38": [1152.0, 832.0],
|
174 |
+
"1.46": [1216.0, 832.0],
|
175 |
+
"1.67": [1280.0, 768.0],
|
176 |
+
"1.75": [1344.0, 768.0],
|
177 |
+
"2.0": [1408.0, 704.0],
|
178 |
+
"2.09": [1472.0, 704.0],
|
179 |
+
"2.4": [1536.0, 640.0],
|
180 |
+
"2.5": [1600.0, 640.0],
|
181 |
+
"2.89": [1664.0, 576.0],
|
182 |
+
"3.0": [1728.0, 576.0],
|
183 |
+
"3.11": [1792.0, 576.0],
|
184 |
+
"3.62": [1856.0, 512.0],
|
185 |
+
"3.75": [1920.0, 512.0],
|
186 |
+
"3.88": [1984.0, 512.0],
|
187 |
+
"4.0": [2048.0, 512.0],
|
188 |
+
}
|
189 |
+
|
190 |
+
ASPECT_RATIO_512 = {
|
191 |
+
"0.25": [256.0, 1024.0],
|
192 |
+
"0.26": [256.0, 992.0],
|
193 |
+
"0.27": [256.0, 960.0],
|
194 |
+
"0.28": [256.0, 928.0],
|
195 |
+
"0.32": [288.0, 896.0],
|
196 |
+
"0.33": [288.0, 864.0],
|
197 |
+
"0.35": [288.0, 832.0],
|
198 |
+
"0.4": [320.0, 800.0],
|
199 |
+
"0.42": [320.0, 768.0],
|
200 |
+
"0.48": [352.0, 736.0],
|
201 |
+
"0.5": [352.0, 704.0],
|
202 |
+
"0.52": [352.0, 672.0],
|
203 |
+
"0.57": [384.0, 672.0],
|
204 |
+
"0.6": [384.0, 640.0],
|
205 |
+
"0.68": [416.0, 608.0],
|
206 |
+
"0.72": [416.0, 576.0],
|
207 |
+
"0.78": [448.0, 576.0],
|
208 |
+
"0.82": [448.0, 544.0],
|
209 |
+
"0.88": [480.0, 544.0],
|
210 |
+
"0.94": [480.0, 512.0],
|
211 |
+
"1.0": [512.0, 512.0],
|
212 |
+
"1.07": [512.0, 480.0],
|
213 |
+
"1.13": [544.0, 480.0],
|
214 |
+
"1.21": [544.0, 448.0],
|
215 |
+
"1.29": [576.0, 448.0],
|
216 |
+
"1.38": [576.0, 416.0],
|
217 |
+
"1.46": [608.0, 416.0],
|
218 |
+
"1.67": [640.0, 384.0],
|
219 |
+
"1.75": [672.0, 384.0],
|
220 |
+
"2.0": [704.0, 352.0],
|
221 |
+
"2.09": [736.0, 352.0],
|
222 |
+
"2.4": [768.0, 320.0],
|
223 |
+
"2.5": [800.0, 320.0],
|
224 |
+
"2.89": [832.0, 288.0],
|
225 |
+
"3.0": [864.0, 288.0],
|
226 |
+
"3.11": [896.0, 288.0],
|
227 |
+
"3.62": [928.0, 256.0],
|
228 |
+
"3.75": [960.0, 256.0],
|
229 |
+
"3.88": [992.0, 256.0],
|
230 |
+
"4.0": [1024.0, 256.0],
|
231 |
+
}
|
232 |
+
|
233 |
+
ASPECT_RATIO_256 = {
|
234 |
+
"0.25": [128.0, 512.0],
|
235 |
+
"0.26": [128.0, 496.0],
|
236 |
+
"0.27": [128.0, 480.0],
|
237 |
+
"0.28": [128.0, 464.0],
|
238 |
+
"0.32": [144.0, 448.0],
|
239 |
+
"0.33": [144.0, 432.0],
|
240 |
+
"0.35": [144.0, 416.0],
|
241 |
+
"0.4": [160.0, 400.0],
|
242 |
+
"0.42": [160.0, 384.0],
|
243 |
+
"0.48": [176.0, 368.0],
|
244 |
+
"0.5": [176.0, 352.0],
|
245 |
+
"0.52": [176.0, 336.0],
|
246 |
+
"0.57": [192.0, 336.0],
|
247 |
+
"0.6": [192.0, 320.0],
|
248 |
+
"0.68": [208.0, 304.0],
|
249 |
+
"0.72": [208.0, 288.0],
|
250 |
+
"0.78": [224.0, 288.0],
|
251 |
+
"0.82": [224.0, 272.0],
|
252 |
+
"0.88": [240.0, 272.0],
|
253 |
+
"0.94": [240.0, 256.0],
|
254 |
+
"1.0": [256.0, 256.0],
|
255 |
+
"1.07": [256.0, 240.0],
|
256 |
+
"1.13": [272.0, 240.0],
|
257 |
+
"1.21": [272.0, 224.0],
|
258 |
+
"1.29": [288.0, 224.0],
|
259 |
+
"1.38": [288.0, 208.0],
|
260 |
+
"1.46": [304.0, 208.0],
|
261 |
+
"1.67": [320.0, 192.0],
|
262 |
+
"1.75": [336.0, 192.0],
|
263 |
+
"2.0": [352.0, 176.0],
|
264 |
+
"2.09": [368.0, 176.0],
|
265 |
+
"2.4": [384.0, 160.0],
|
266 |
+
"2.5": [400.0, 160.0],
|
267 |
+
"2.89": [416.0, 144.0],
|
268 |
+
"3.0": [432.0, 144.0],
|
269 |
+
"3.11": [448.0, 144.0],
|
270 |
+
"3.62": [464.0, 128.0],
|
271 |
+
"3.75": [480.0, 128.0],
|
272 |
+
"3.88": [496.0, 128.0],
|
273 |
+
"4.0": [512.0, 128.0],
|
274 |
+
}
|
275 |
+
|
276 |
+
ASPECT_RATIO_256_TEST = {
|
277 |
+
"0.25": [128.0, 512.0],
|
278 |
+
"0.28": [128.0, 464.0],
|
279 |
+
"0.32": [144.0, 448.0],
|
280 |
+
"0.33": [144.0, 432.0],
|
281 |
+
"0.35": [144.0, 416.0],
|
282 |
+
"0.4": [160.0, 400.0],
|
283 |
+
"0.42": [160.0, 384.0],
|
284 |
+
"0.48": [176.0, 368.0],
|
285 |
+
"0.5": [176.0, 352.0],
|
286 |
+
"0.52": [176.0, 336.0],
|
287 |
+
"0.57": [192.0, 336.0],
|
288 |
+
"0.6": [192.0, 320.0],
|
289 |
+
"0.68": [208.0, 304.0],
|
290 |
+
"0.72": [208.0, 288.0],
|
291 |
+
"0.78": [224.0, 288.0],
|
292 |
+
"0.82": [224.0, 272.0],
|
293 |
+
"0.88": [240.0, 272.0],
|
294 |
+
"0.94": [240.0, 256.0],
|
295 |
+
"1.0": [256.0, 256.0],
|
296 |
+
"1.07": [256.0, 240.0],
|
297 |
+
"1.13": [272.0, 240.0],
|
298 |
+
"1.21": [272.0, 224.0],
|
299 |
+
"1.29": [288.0, 224.0],
|
300 |
+
"1.38": [288.0, 208.0],
|
301 |
+
"1.46": [304.0, 208.0],
|
302 |
+
"1.67": [320.0, 192.0],
|
303 |
+
"1.75": [336.0, 192.0],
|
304 |
+
"2.0": [352.0, 176.0],
|
305 |
+
"2.09": [368.0, 176.0],
|
306 |
+
"2.4": [384.0, 160.0],
|
307 |
+
"2.5": [400.0, 160.0],
|
308 |
+
"3.0": [432.0, 144.0],
|
309 |
+
"4.0": [512.0, 128.0],
|
310 |
+
}
|
311 |
+
|
312 |
+
ASPECT_RATIO_512_TEST = {
|
313 |
+
"0.25": [256.0, 1024.0],
|
314 |
+
"0.28": [256.0, 928.0],
|
315 |
+
"0.32": [288.0, 896.0],
|
316 |
+
"0.33": [288.0, 864.0],
|
317 |
+
"0.35": [288.0, 832.0],
|
318 |
+
"0.4": [320.0, 800.0],
|
319 |
+
"0.42": [320.0, 768.0],
|
320 |
+
"0.48": [352.0, 736.0],
|
321 |
+
"0.5": [352.0, 704.0],
|
322 |
+
"0.52": [352.0, 672.0],
|
323 |
+
"0.57": [384.0, 672.0],
|
324 |
+
"0.6": [384.0, 640.0],
|
325 |
+
"0.68": [416.0, 608.0],
|
326 |
+
"0.72": [416.0, 576.0],
|
327 |
+
"0.78": [448.0, 576.0],
|
328 |
+
"0.82": [448.0, 544.0],
|
329 |
+
"0.88": [480.0, 544.0],
|
330 |
+
"0.94": [480.0, 512.0],
|
331 |
+
"1.0": [512.0, 512.0],
|
332 |
+
"1.07": [512.0, 480.0],
|
333 |
+
"1.13": [544.0, 480.0],
|
334 |
+
"1.21": [544.0, 448.0],
|
335 |
+
"1.29": [576.0, 448.0],
|
336 |
+
"1.38": [576.0, 416.0],
|
337 |
+
"1.46": [608.0, 416.0],
|
338 |
+
"1.67": [640.0, 384.0],
|
339 |
+
"1.75": [672.0, 384.0],
|
340 |
+
"2.0": [704.0, 352.0],
|
341 |
+
"2.09": [736.0, 352.0],
|
342 |
+
"2.4": [768.0, 320.0],
|
343 |
+
"2.5": [800.0, 320.0],
|
344 |
+
"3.0": [864.0, 288.0],
|
345 |
+
"4.0": [1024.0, 256.0],
|
346 |
+
}
|
347 |
+
|
348 |
+
ASPECT_RATIO_1024_TEST = {
|
349 |
+
"0.25": [512.0, 2048.0],
|
350 |
+
"0.28": [512.0, 1856.0],
|
351 |
+
"0.32": [576.0, 1792.0],
|
352 |
+
"0.33": [576.0, 1728.0],
|
353 |
+
"0.35": [576.0, 1664.0],
|
354 |
+
"0.4": [640.0, 1600.0],
|
355 |
+
"0.42": [640.0, 1536.0],
|
356 |
+
"0.48": [704.0, 1472.0],
|
357 |
+
"0.5": [704.0, 1408.0],
|
358 |
+
"0.52": [704.0, 1344.0],
|
359 |
+
"0.57": [768.0, 1344.0],
|
360 |
+
"0.6": [768.0, 1280.0],
|
361 |
+
"0.68": [832.0, 1216.0],
|
362 |
+
"0.72": [832.0, 1152.0],
|
363 |
+
"0.78": [896.0, 1152.0],
|
364 |
+
"0.82": [896.0, 1088.0],
|
365 |
+
"0.88": [960.0, 1088.0],
|
366 |
+
"0.94": [960.0, 1024.0],
|
367 |
+
"1.0": [1024.0, 1024.0],
|
368 |
+
"1.07": [1024.0, 960.0],
|
369 |
+
"1.13": [1088.0, 960.0],
|
370 |
+
"1.21": [1088.0, 896.0],
|
371 |
+
"1.29": [1152.0, 896.0],
|
372 |
+
"1.38": [1152.0, 832.0],
|
373 |
+
"1.46": [1216.0, 832.0],
|
374 |
+
"1.67": [1280.0, 768.0],
|
375 |
+
"1.75": [1344.0, 768.0],
|
376 |
+
"2.0": [1408.0, 704.0],
|
377 |
+
"2.09": [1472.0, 704.0],
|
378 |
+
"2.4": [1536.0, 640.0],
|
379 |
+
"2.5": [1600.0, 640.0],
|
380 |
+
"3.0": [1728.0, 576.0],
|
381 |
+
"4.0": [2048.0, 512.0],
|
382 |
+
}
|
383 |
+
|
384 |
+
ASPECT_RATIO_2048_TEST = {
|
385 |
+
"0.25": [1024.0, 4096.0],
|
386 |
+
"0.26": [1024.0, 3968.0],
|
387 |
+
"0.32": [1152.0, 3584.0],
|
388 |
+
"0.33": [1152.0, 3456.0],
|
389 |
+
"0.35": [1152.0, 3328.0],
|
390 |
+
"0.4": [1280.0, 3200.0],
|
391 |
+
"0.42": [1280.0, 3072.0],
|
392 |
+
"0.48": [1408.0, 2944.0],
|
393 |
+
"0.5": [1408.0, 2816.0],
|
394 |
+
"0.52": [1408.0, 2688.0],
|
395 |
+
"0.57": [1536.0, 2688.0],
|
396 |
+
"0.6": [1536.0, 2560.0],
|
397 |
+
"0.68": [1664.0, 2432.0],
|
398 |
+
"0.72": [1664.0, 2304.0],
|
399 |
+
"0.78": [1792.0, 2304.0],
|
400 |
+
"0.82": [1792.0, 2176.0],
|
401 |
+
"0.88": [1920.0, 2176.0],
|
402 |
+
"0.94": [1920.0, 2048.0],
|
403 |
+
"1.0": [2048.0, 2048.0],
|
404 |
+
"1.07": [2048.0, 1920.0],
|
405 |
+
"1.13": [2176.0, 1920.0],
|
406 |
+
"1.21": [2176.0, 1792.0],
|
407 |
+
"1.29": [2304.0, 1792.0],
|
408 |
+
"1.38": [2304.0, 1664.0],
|
409 |
+
"1.46": [2432.0, 1664.0],
|
410 |
+
"1.67": [2560.0, 1536.0],
|
411 |
+
"1.75": [2688.0, 1536.0],
|
412 |
+
"2.0": [2816.0, 1408.0],
|
413 |
+
"2.09": [2944.0, 1408.0],
|
414 |
+
"2.4": [3072.0, 1280.0],
|
415 |
+
"2.5": [3200.0, 1280.0],
|
416 |
+
"3.0": [3456.0, 1152.0],
|
417 |
+
"4.0": [4096.0, 1024.0],
|
418 |
+
}
|
419 |
+
|
420 |
+
ASPECT_RATIO_2880_TEST = {
|
421 |
+
"0.25": [2048.0, 8192.0],
|
422 |
+
"0.26": [2048.0, 7936.0],
|
423 |
+
"0.32": [2304.0, 7168.0],
|
424 |
+
"0.33": [2304.0, 6912.0],
|
425 |
+
"0.35": [2304.0, 6656.0],
|
426 |
+
"0.4": [2560.0, 6400.0],
|
427 |
+
"0.42": [2560.0, 6144.0],
|
428 |
+
"0.48": [2816.0, 5888.0],
|
429 |
+
"0.5": [2816.0, 5632.0],
|
430 |
+
"0.52": [2816.0, 5376.0],
|
431 |
+
"0.57": [3072.0, 5376.0],
|
432 |
+
"0.6": [3072.0, 5120.0],
|
433 |
+
"0.68": [3328.0, 4864.0],
|
434 |
+
"0.72": [3328.0, 4608.0],
|
435 |
+
"0.78": [3584.0, 4608.0],
|
436 |
+
"0.82": [3584.0, 4352.0],
|
437 |
+
"0.88": [3840.0, 4352.0],
|
438 |
+
"0.94": [3840.0, 4096.0],
|
439 |
+
"1.0": [4096.0, 4096.0],
|
440 |
+
"1.07": [4096.0, 3840.0],
|
441 |
+
"1.13": [4352.0, 3840.0],
|
442 |
+
"1.21": [4352.0, 3584.0],
|
443 |
+
"1.29": [4608.0, 3584.0],
|
444 |
+
"1.38": [4608.0, 3328.0],
|
445 |
+
"1.46": [4864.0, 3328.0],
|
446 |
+
"1.67": [5120.0, 3072.0],
|
447 |
+
"1.75": [5376.0, 3072.0],
|
448 |
+
"2.0": [5632.0, 2816.0],
|
449 |
+
"2.09": [5888.0, 2816.0],
|
450 |
+
"2.4": [6144.0, 2560.0],
|
451 |
+
"2.5": [6400.0, 2560.0],
|
452 |
+
"3.0": [6912.0, 2304.0],
|
453 |
+
"4.0": [8192.0, 2048.0],
|
454 |
+
}
|
455 |
+
|
456 |
+
ASPECT_RATIO_4096_TEST = {
|
457 |
+
"0.25": [2048.0, 8192.0],
|
458 |
+
"0.26": [2048.0, 7936.0],
|
459 |
+
"0.27": [2048.0, 7680.0],
|
460 |
+
"0.28": [2048.0, 7424.0],
|
461 |
+
"0.32": [2304.0, 7168.0],
|
462 |
+
"0.33": [2304.0, 6912.0],
|
463 |
+
"0.35": [2304.0, 6656.0],
|
464 |
+
"0.4": [2560.0, 6400.0],
|
465 |
+
"0.42": [2560.0, 6144.0],
|
466 |
+
"0.48": [2816.0, 5888.0],
|
467 |
+
"0.5": [2816.0, 5632.0],
|
468 |
+
"0.52": [2816.0, 5376.0],
|
469 |
+
"0.57": [3072.0, 5376.0],
|
470 |
+
"0.6": [3072.0, 5120.0],
|
471 |
+
"0.68": [3328.0, 4864.0],
|
472 |
+
"0.72": [3328.0, 4608.0],
|
473 |
+
"0.78": [3584.0, 4608.0],
|
474 |
+
"0.82": [3584.0, 4352.0],
|
475 |
+
"0.88": [3840.0, 4352.0],
|
476 |
+
"0.94": [3840.0, 4096.0],
|
477 |
+
"1.0": [4096.0, 4096.0],
|
478 |
+
"1.07": [4096.0, 3840.0],
|
479 |
+
"1.13": [4352.0, 3840.0],
|
480 |
+
"1.21": [4352.0, 3584.0],
|
481 |
+
"1.29": [4608.0, 3584.0],
|
482 |
+
"1.38": [4608.0, 3328.0],
|
483 |
+
"1.46": [4864.0, 3328.0],
|
484 |
+
"1.67": [5120.0, 3072.0],
|
485 |
+
"1.75": [5376.0, 3072.0],
|
486 |
+
"2.0": [5632.0, 2816.0],
|
487 |
+
"2.09": [5888.0, 2816.0],
|
488 |
+
"2.4": [6144.0, 2560.0],
|
489 |
+
"2.5": [6400.0, 2560.0],
|
490 |
+
"2.89": [6656.0, 2304.0],
|
491 |
+
"3.0": [6912.0, 2304.0],
|
492 |
+
"3.11": [7168.0, 2304.0],
|
493 |
+
"3.62": [7424.0, 2048.0],
|
494 |
+
"3.75": [7680.0, 2048.0],
|
495 |
+
"3.88": [7936.0, 2048.0],
|
496 |
+
"4.0": [8192.0, 2048.0],
|
497 |
+
}
|
498 |
+
|
499 |
+
ASPECT_RATIO_1280_TEST = {"1.0": [1280.0, 1280.0]}
|
500 |
+
ASPECT_RATIO_1536_TEST = {"1.0": [1536.0, 1536.0]}
|
501 |
+
ASPECT_RATIO_768_TEST = {"1.0": [768.0, 768.0]}
|
502 |
+
|
503 |
+
|
504 |
+
def get_chunks(lst, n):
|
505 |
+
for i in range(0, len(lst), n):
|
506 |
+
yield lst[i : i + n]
|
diffusion/data/transforms.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
import torchvision.transforms as T
|
18 |
+
|
19 |
+
TRANSFORMS = dict()
|
20 |
+
|
21 |
+
|
22 |
+
def register_transform(transform):
|
23 |
+
name = transform.__name__
|
24 |
+
if name in TRANSFORMS:
|
25 |
+
raise RuntimeError(f"Transform {name} has already registered.")
|
26 |
+
TRANSFORMS.update({name: transform})
|
27 |
+
|
28 |
+
|
29 |
+
def get_transform(type, resolution):
|
30 |
+
transform = TRANSFORMS[type](resolution)
|
31 |
+
transform = T.Compose(transform)
|
32 |
+
transform.image_size = resolution
|
33 |
+
return transform
|
34 |
+
|
35 |
+
|
36 |
+
@register_transform
|
37 |
+
def default_train(n_px):
|
38 |
+
transform = [
|
39 |
+
T.Lambda(lambda img: img.convert("RGB")),
|
40 |
+
T.Resize(n_px), # Image.BICUBIC
|
41 |
+
T.CenterCrop(n_px),
|
42 |
+
# T.RandomHorizontalFlip(),
|
43 |
+
T.ToTensor(),
|
44 |
+
T.Normalize([0.5], [0.5]),
|
45 |
+
]
|
46 |
+
return transform
|
diffusion/data/wids/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
2 |
+
# This file is part of the WebDataset library.
|
3 |
+
# See the LICENSE file for licensing terms (BSD-style).
|
4 |
+
#
|
5 |
+
# flake8: noqa
|
6 |
+
|
7 |
+
from .wids import (
|
8 |
+
ChunkedSampler,
|
9 |
+
DistributedChunkedSampler,
|
10 |
+
DistributedLocalSampler,
|
11 |
+
DistributedRangedSampler,
|
12 |
+
ShardedSampler,
|
13 |
+
ShardListDataset,
|
14 |
+
ShardListDatasetMulti,
|
15 |
+
lru_json_load,
|
16 |
+
)
|
diffusion/data/wids/wids.py
ADDED
@@ -0,0 +1,1051 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is modified from https://github.com/NVlabs/VILA/tree/main/llava/wids
|
18 |
+
import base64
|
19 |
+
import gzip
|
20 |
+
import hashlib
|
21 |
+
import io
|
22 |
+
import json
|
23 |
+
import math
|
24 |
+
import os
|
25 |
+
import os.path as osp
|
26 |
+
import random
|
27 |
+
import re
|
28 |
+
import sqlite3
|
29 |
+
import sys
|
30 |
+
import tempfile
|
31 |
+
import uuid
|
32 |
+
import warnings
|
33 |
+
from functools import lru_cache, partial
|
34 |
+
from typing import Any, BinaryIO, Dict, Optional, TypeVar, Union
|
35 |
+
from urllib.parse import quote, urlparse
|
36 |
+
|
37 |
+
import numpy as np
|
38 |
+
import torch
|
39 |
+
import torch.distributed as dist
|
40 |
+
from torch.utils.data.distributed import DistributedSampler
|
41 |
+
|
42 |
+
from .wids_dl import download_and_open
|
43 |
+
from .wids_lru import LRUCache
|
44 |
+
from .wids_mmtar import MMIndexedTar
|
45 |
+
from .wids_specs import load_dsdesc_and_resolve, urldir
|
46 |
+
from .wids_tar import TarFileReader, find_index_file
|
47 |
+
|
48 |
+
try:
|
49 |
+
from torch.utils.data import Dataset, Sampler
|
50 |
+
except ImportError:
|
51 |
+
|
52 |
+
class Dataset:
|
53 |
+
pass
|
54 |
+
|
55 |
+
class Sampler:
|
56 |
+
pass
|
57 |
+
|
58 |
+
|
59 |
+
T = TypeVar("T")
|
60 |
+
|
61 |
+
T_co = TypeVar("T_co", covariant=True)
|
62 |
+
|
63 |
+
|
64 |
+
def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str:
|
65 |
+
"""Compute the md5sum of a file in chunks.
|
66 |
+
|
67 |
+
Parameters
|
68 |
+
----------
|
69 |
+
fname : Union[str, BinaryIO]
|
70 |
+
Filename or file object
|
71 |
+
chunksize : int, optional
|
72 |
+
Chunk size in bytes, by default 1000000
|
73 |
+
|
74 |
+
Returns
|
75 |
+
-------
|
76 |
+
str
|
77 |
+
MD5 sum of the file
|
78 |
+
|
79 |
+
Examples
|
80 |
+
--------
|
81 |
+
>>> compute_file_md5sum("test.txt")
|
82 |
+
'd41d8cd98f00b204e9800998ecf8427e'
|
83 |
+
"""
|
84 |
+
md5 = hashlib.md5()
|
85 |
+
if isinstance(fname, str):
|
86 |
+
with open(fname, "rb") as f:
|
87 |
+
for chunk in iter(lambda: f.read(chunksize), b""):
|
88 |
+
md5.update(chunk)
|
89 |
+
else:
|
90 |
+
fname.seek(0)
|
91 |
+
for chunk in iter(lambda: fname.read(chunksize), b""):
|
92 |
+
md5.update(chunk)
|
93 |
+
return md5.hexdigest()
|
94 |
+
|
95 |
+
|
96 |
+
def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str:
|
97 |
+
"""Compute the md5sum of a file in chunks."""
|
98 |
+
md5 = hashlib.md5()
|
99 |
+
if isinstance(fname, str):
|
100 |
+
with open(fname, "rb") as f:
|
101 |
+
for chunk in iter(lambda: f.read(chunksize), b""):
|
102 |
+
md5.update(chunk)
|
103 |
+
else:
|
104 |
+
fname.seek(0)
|
105 |
+
for chunk in iter(lambda: fname.read(chunksize), b""):
|
106 |
+
md5.update(chunk)
|
107 |
+
return md5.hexdigest()
|
108 |
+
|
109 |
+
|
110 |
+
def compute_num_samples(fname):
|
111 |
+
ds = IndexedTarSamples(fname)
|
112 |
+
return len(ds)
|
113 |
+
|
114 |
+
|
115 |
+
def splitname(fname):
|
116 |
+
"""Returns the basename and extension of a filename"""
|
117 |
+
assert "." in fname, "Filename must have an extension"
|
118 |
+
# basename, extension = re.match(r"^((?:.*/)?.*?)(\..*)$", fname).groups()
|
119 |
+
basename, extension = os.path.splitext(fname)
|
120 |
+
return basename, extension
|
121 |
+
|
122 |
+
|
123 |
+
# NOTE(ligeng): change to ordered mapping to more flexbile dict
|
124 |
+
# TODO(ligeng): submit a PR to fix the mapping issue.
|
125 |
+
def group_by_key(names):
|
126 |
+
"""Group the file names by key.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
names: A list of file names.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
A list of lists of indices, where each sublist contains indices of files
|
133 |
+
with the same key.
|
134 |
+
"""
|
135 |
+
groups = []
|
136 |
+
kmaps = {}
|
137 |
+
for i, fname in enumerate(names):
|
138 |
+
# Ignore files that are not in a subdirectory.
|
139 |
+
if "." not in fname:
|
140 |
+
print(f"Warning: Ignoring file {fname} (no '.')")
|
141 |
+
continue
|
142 |
+
if fname == ".":
|
143 |
+
print(f"Warning: Ignoring the '.' file.")
|
144 |
+
continue
|
145 |
+
key, ext = splitname(fname)
|
146 |
+
if key not in kmaps:
|
147 |
+
kmaps[key] = []
|
148 |
+
kmaps[key].append(i)
|
149 |
+
for k, v in kmaps.items():
|
150 |
+
groups.append(v)
|
151 |
+
return groups
|
152 |
+
|
153 |
+
|
154 |
+
def default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True):
|
155 |
+
"""A default decoder for webdataset.
|
156 |
+
|
157 |
+
This handles common file extensions: .txt, .cls, .cls2,
|
158 |
+
.jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl.
|
159 |
+
These are the most common extensions used in webdataset.
|
160 |
+
For other extensions, users can provide their own decoder.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
sample: sample, modified in place
|
164 |
+
"""
|
165 |
+
sample = dict(sample)
|
166 |
+
for key, stream in sample.items():
|
167 |
+
extensions = key.split(".")
|
168 |
+
if len(extensions) < 1:
|
169 |
+
continue
|
170 |
+
extension = extensions[-1]
|
171 |
+
if extension in ["gz"]:
|
172 |
+
decompressed = gzip.decompress(stream.read())
|
173 |
+
stream = io.BytesIO(decompressed)
|
174 |
+
if len(extensions) < 2:
|
175 |
+
sample[key] = stream
|
176 |
+
continue
|
177 |
+
extension = extensions[-2]
|
178 |
+
if key.startswith("__"):
|
179 |
+
continue
|
180 |
+
elif extension in ["txt", "text"]:
|
181 |
+
value = stream.read()
|
182 |
+
sample[key] = value.decode("utf-8")
|
183 |
+
elif extension in ["cls", "cls2"]:
|
184 |
+
value = stream.read()
|
185 |
+
sample[key] = int(value.decode("utf-8"))
|
186 |
+
elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]:
|
187 |
+
if format == "PIL":
|
188 |
+
import PIL.Image
|
189 |
+
|
190 |
+
sample[key] = PIL.Image.open(stream)
|
191 |
+
elif format == "numpy":
|
192 |
+
import numpy as np
|
193 |
+
|
194 |
+
sample[key] = np.asarray(PIL.Image.open(stream))
|
195 |
+
else:
|
196 |
+
raise ValueError(f"Unknown format: {format}")
|
197 |
+
elif extension == "json":
|
198 |
+
import json
|
199 |
+
|
200 |
+
value = stream.read()
|
201 |
+
sample[key] = json.loads(value)
|
202 |
+
elif extension == "npy":
|
203 |
+
import numpy as np
|
204 |
+
|
205 |
+
sample[key] = np.load(stream)
|
206 |
+
elif extension == "mp":
|
207 |
+
import msgpack
|
208 |
+
|
209 |
+
value = stream.read()
|
210 |
+
sample[key] = msgpack.unpackb(value, raw=False)
|
211 |
+
elif extension in ["pt", "pth"]:
|
212 |
+
import torch
|
213 |
+
|
214 |
+
sample[key] = torch.load(stream)
|
215 |
+
elif extension in ["pickle", "pkl"]:
|
216 |
+
import pickle
|
217 |
+
|
218 |
+
sample[key] = pickle.load(stream)
|
219 |
+
elif extension == "mp4":
|
220 |
+
# Write stream to a temporary file
|
221 |
+
# with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmpfile:
|
222 |
+
# tmpfile.write(stream.read())
|
223 |
+
# tmpfile_path = tmpfile.name
|
224 |
+
|
225 |
+
# sample[key] = tmpfile_path
|
226 |
+
sample[key] = io.BytesIO(stream.read())
|
227 |
+
return sample
|
228 |
+
|
229 |
+
|
230 |
+
def update_dict_with_extend(original_dict, update_dict):
|
231 |
+
for key, value in update_dict.items():
|
232 |
+
if key in original_dict and isinstance(original_dict[key], list) and isinstance(value, list):
|
233 |
+
original_dict[key].extend(value)
|
234 |
+
else:
|
235 |
+
original_dict[key] = value
|
236 |
+
|
237 |
+
|
238 |
+
open_itfs = {}
|
239 |
+
|
240 |
+
|
241 |
+
class IndexedTarSamples:
|
242 |
+
"""A class that accesses samples in a tar file. The tar file must follow
|
243 |
+
WebDataset conventions. The tar file is indexed when the IndexedTarSamples
|
244 |
+
object is created. The samples are accessed by index using the __getitem__
|
245 |
+
method. The __getitem__ method returns a dictionary containing the files
|
246 |
+
for the sample. The key for each file is the extension of the file name.
|
247 |
+
The key "__key__" is reserved for the key of the sample (the basename of
|
248 |
+
each file without the extension). For example, if the tar file contains
|
249 |
+
the files "sample1.jpg" and "sample1.txt", then the sample with key
|
250 |
+
"sample1" will be returned as the dictionary {"jpg": ..., "txt": ...}.
|
251 |
+
"""
|
252 |
+
|
253 |
+
def __init__(
|
254 |
+
self,
|
255 |
+
*,
|
256 |
+
path=None,
|
257 |
+
stream=None,
|
258 |
+
md5sum=None,
|
259 |
+
expected_size=None,
|
260 |
+
use_mmap=True,
|
261 |
+
index_file=find_index_file,
|
262 |
+
):
|
263 |
+
assert path is not None or stream is not None
|
264 |
+
|
265 |
+
# Create TarFileReader object to read from tar_file
|
266 |
+
self.path = path
|
267 |
+
stream = self.stream = stream or open(path, "rb")
|
268 |
+
|
269 |
+
# verify the MD5 sum
|
270 |
+
if md5sum is not None:
|
271 |
+
stream.seek(0)
|
272 |
+
got = compute_file_md5sum(stream)
|
273 |
+
assert got == md5sum, f"MD5 sum mismatch: expected {md5sum}, got {got}"
|
274 |
+
stream.seek(0)
|
275 |
+
|
276 |
+
# use either the mmap or the stream based implementation
|
277 |
+
# NOTE(ligeng): https://stackoverflow.com/questions/11072705/twitter-trends-api-unicodedecodeerror-utf8-codec-cant-decode-byte-0x8b-in-po
|
278 |
+
# import gzip
|
279 |
+
# print("convert to gzip IO stream")
|
280 |
+
# stream = gzip.GzipFile(fileobj=stream)
|
281 |
+
|
282 |
+
if use_mmap:
|
283 |
+
self.reader = MMIndexedTar(stream)
|
284 |
+
else:
|
285 |
+
self.reader = TarFileReader(stream, index_file=index_file)
|
286 |
+
|
287 |
+
# Get list of all files in stream
|
288 |
+
all_files = self.reader.names()
|
289 |
+
|
290 |
+
# Group files by key into samples
|
291 |
+
self.samples = group_by_key(all_files)
|
292 |
+
# print("DEBUG:", list(all_files)[:20])
|
293 |
+
# print("DEBUG:", self.samples[:20])
|
294 |
+
|
295 |
+
# check that the number of samples is correct
|
296 |
+
if expected_size is not None:
|
297 |
+
assert len(self) == expected_size, f"Expected {expected_size} samples, got {len(self)}"
|
298 |
+
|
299 |
+
self.uuid = str(uuid.uuid4())
|
300 |
+
|
301 |
+
def close(self):
|
302 |
+
self.reader.close()
|
303 |
+
if not self.stream.closed:
|
304 |
+
self.stream.close()
|
305 |
+
|
306 |
+
def __len__(self):
|
307 |
+
return len(self.samples)
|
308 |
+
|
309 |
+
def __getitem__(self, idx):
|
310 |
+
# Get indexes of files for the sample at index idx
|
311 |
+
try:
|
312 |
+
indexes = self.samples[idx]
|
313 |
+
except IndexError as e:
|
314 |
+
print(f"[wids-debug] curr idx: {idx}, total sample length: {len(self.samples)} {e}")
|
315 |
+
raise e
|
316 |
+
sample = {}
|
317 |
+
key = None
|
318 |
+
for i in indexes:
|
319 |
+
# Get filename and data for the file at index i
|
320 |
+
fname, data = self.reader.get_file(i)
|
321 |
+
# Split filename into key and extension
|
322 |
+
k, ext = splitname(fname)
|
323 |
+
# Make sure all files in sample have same key
|
324 |
+
key = key or k
|
325 |
+
assert key == k
|
326 |
+
sample[ext] = data
|
327 |
+
# Add key to sample
|
328 |
+
sample["__key__"] = key
|
329 |
+
return sample
|
330 |
+
|
331 |
+
def __str__(self):
|
332 |
+
return f"<IndexedTarSamples-{id(self)} {self.path}>"
|
333 |
+
|
334 |
+
def __repr__(self):
|
335 |
+
return str(self)
|
336 |
+
|
337 |
+
|
338 |
+
def hash_localname(dldir="/tmp/_wids_cache"):
|
339 |
+
os.makedirs(dldir, exist_ok=True)
|
340 |
+
|
341 |
+
connection = sqlite3.connect(os.path.join(dldir, "cache.db"))
|
342 |
+
cursor = connection.cursor()
|
343 |
+
cursor.execute("CREATE TABLE IF NOT EXISTS cache (url TEXT PRIMARY KEY, path TEXT, checksum TEXT)")
|
344 |
+
connection.commit()
|
345 |
+
|
346 |
+
def f(shard):
|
347 |
+
"""Given a URL, return a local name for the shard."""
|
348 |
+
if shard.startswith("pipe:"):
|
349 |
+
# uuencode the entire URL string
|
350 |
+
hex32 = base64.urlsafe_b64encode(hashlib.sha256(shard.encode()).digest())[:32].decode()
|
351 |
+
return os.path.join(dldir, "pipe__" + hex32)
|
352 |
+
else:
|
353 |
+
# we hash the host and directory components into a 16 character string
|
354 |
+
dirname = urldir(shard)
|
355 |
+
hex16 = base64.urlsafe_b64encode(hashlib.sha256(dirname.encode()).digest())[:16].decode()
|
356 |
+
# the cache name is the concatenation of the hex16 string and the file name component of the URL
|
357 |
+
cachename = "data__" + hex16 + "__" + os.path.basename(urlparse(shard).path)
|
358 |
+
checksum = None
|
359 |
+
cursor.execute(
|
360 |
+
"INSERT OR REPLACE INTO cache VALUES (?, ?, ?)",
|
361 |
+
(shard, cachename, checksum),
|
362 |
+
)
|
363 |
+
connection.commit()
|
364 |
+
return os.path.join(dldir, cachename)
|
365 |
+
|
366 |
+
return f
|
367 |
+
|
368 |
+
|
369 |
+
def cache_localname(cachedir):
|
370 |
+
os.makedirs(cachedir, exist_ok=True)
|
371 |
+
|
372 |
+
def f(shard):
|
373 |
+
"""Given a URL, return a local name for the shard."""
|
374 |
+
path = urlparse(shard).path
|
375 |
+
fname = os.path.basename(path)
|
376 |
+
return os.path.join(cachedir, fname)
|
377 |
+
|
378 |
+
return f
|
379 |
+
|
380 |
+
|
381 |
+
def default_localname(dldir="/tmp/_wids_cache"):
|
382 |
+
os.makedirs(dldir, exist_ok=True)
|
383 |
+
|
384 |
+
def f(shard):
|
385 |
+
"""Given a URL, return a local name for the shard."""
|
386 |
+
cachename = quote(shard, safe="+-")
|
387 |
+
return os.path.join(dldir, cachename)
|
388 |
+
|
389 |
+
return f
|
390 |
+
|
391 |
+
|
392 |
+
class LRUShards:
|
393 |
+
"""A class that manages a cache of shards. The cache is a LRU cache that
|
394 |
+
stores the local names of the shards as keys and the downloaded paths as
|
395 |
+
values. The shards are downloaded to a directory specified by dldir.
|
396 |
+
The local name of a shard is computed by the localname function, which
|
397 |
+
takes the shard URL as an argument. If keep is True, the downloaded files
|
398 |
+
are not deleted when they are no longer needed.
|
399 |
+
"""
|
400 |
+
|
401 |
+
def __init__(self, lru_size, keep=False, localname=default_localname()):
|
402 |
+
self.localname = localname
|
403 |
+
# the cache contains the local name as the key and the downloaded path as the value
|
404 |
+
self.lru = LRUCache(lru_size, release_handler=self.release_handler)
|
405 |
+
# keep statistics
|
406 |
+
self.reset_stats()
|
407 |
+
|
408 |
+
def reset_stats(self):
|
409 |
+
self.accesses = 0
|
410 |
+
self.misses = 0
|
411 |
+
|
412 |
+
def __len__(self):
|
413 |
+
return len(self.lru)
|
414 |
+
|
415 |
+
def release_handler(self, key, value):
|
416 |
+
value.close()
|
417 |
+
|
418 |
+
def clear(self):
|
419 |
+
self.lru.clear()
|
420 |
+
|
421 |
+
def get_shard(self, url):
|
422 |
+
assert isinstance(url, str)
|
423 |
+
self.accesses += 1
|
424 |
+
if url not in self.lru:
|
425 |
+
local = self.localname(url)
|
426 |
+
with download_and_open(url, local) as stream:
|
427 |
+
itf = IndexedTarSamples(path=local, stream=stream)
|
428 |
+
self.lru[url] = itf
|
429 |
+
self.misses += 1
|
430 |
+
self.last_missed = True
|
431 |
+
else:
|
432 |
+
self.last_missed = False
|
433 |
+
return self.lru[url]
|
434 |
+
|
435 |
+
|
436 |
+
def interpret_transformations(transformations):
|
437 |
+
"""Interpret the transformations argument.
|
438 |
+
|
439 |
+
This takes care of transformations specified as string shortcuts
|
440 |
+
and returns a list of callables.
|
441 |
+
"""
|
442 |
+
if not isinstance(transformations, list):
|
443 |
+
transformations = [transformations]
|
444 |
+
|
445 |
+
result = []
|
446 |
+
|
447 |
+
for transformation in transformations:
|
448 |
+
if transformation == "PIL":
|
449 |
+
transformation = partial(default_decoder, format="PIL")
|
450 |
+
elif transformation == "numpy":
|
451 |
+
transformation = partial(default_decoder, format="numpy")
|
452 |
+
else:
|
453 |
+
assert callable(transformation)
|
454 |
+
result.append(transformation)
|
455 |
+
|
456 |
+
return result
|
457 |
+
|
458 |
+
|
459 |
+
def hash_dataset_name(input_string):
|
460 |
+
"""Compute a hash of the input string and return the first 16 characters of the hash."""
|
461 |
+
# Compute SHA256 hash of the input string
|
462 |
+
hash_object = hashlib.sha256(input_string.encode())
|
463 |
+
hash_digest = hash_object.digest()
|
464 |
+
|
465 |
+
# Encode the hash in base64
|
466 |
+
base64_encoded_hash = base64.urlsafe_b64encode(hash_digest)
|
467 |
+
|
468 |
+
# Return the first 16 characters of the base64-encoded hash
|
469 |
+
return base64_encoded_hash[:16].decode("ascii")
|
470 |
+
|
471 |
+
|
472 |
+
@lru_cache(maxsize=16)
|
473 |
+
def lru_json_load(fpath):
|
474 |
+
with open(fpath) as fp:
|
475 |
+
return json.load(fp)
|
476 |
+
|
477 |
+
|
478 |
+
class ShardListDataset(Dataset[T]):
|
479 |
+
"""An indexable dataset based on a list of shards.
|
480 |
+
|
481 |
+
The dataset is either given as a list of shards with optional options and name,
|
482 |
+
or as a URL pointing to a JSON descriptor file.
|
483 |
+
|
484 |
+
Datasets can reference other datasets via `source_url`.
|
485 |
+
|
486 |
+
Shard references within a dataset are resolve relative to an explicitly
|
487 |
+
given `base` property, or relative to the URL from which the dataset
|
488 |
+
descriptor was loaded.
|
489 |
+
"""
|
490 |
+
|
491 |
+
def __init__(
|
492 |
+
self,
|
493 |
+
shards,
|
494 |
+
*,
|
495 |
+
cache_size=int(1e12),
|
496 |
+
cache_dir=None,
|
497 |
+
lru_size=10,
|
498 |
+
dataset_name=None,
|
499 |
+
localname=None,
|
500 |
+
transformations="PIL",
|
501 |
+
keep=False,
|
502 |
+
base=None,
|
503 |
+
options=None,
|
504 |
+
):
|
505 |
+
"""Create a ShardListDataset.
|
506 |
+
|
507 |
+
Args:
|
508 |
+
shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file
|
509 |
+
cache_size: the number of shards to keep in the cache
|
510 |
+
lru_size: the number of shards to keep in the LRU cache
|
511 |
+
localname: a function that maps URLs to local filenames
|
512 |
+
|
513 |
+
Note that there are two caches: an on-disk directory, and an in-memory LRU cache.
|
514 |
+
"""
|
515 |
+
if options is None:
|
516 |
+
options = {}
|
517 |
+
super().__init__()
|
518 |
+
# shards is a list of (filename, length) pairs. We'll need to
|
519 |
+
# keep track of the lengths and cumulative lengths to know how
|
520 |
+
# to map indices to shards and indices within shards.
|
521 |
+
if isinstance(shards, (str, io.IOBase)):
|
522 |
+
if base is None and isinstance(shards, str):
|
523 |
+
shards = osp.expanduser(shards)
|
524 |
+
base = urldir(shards)
|
525 |
+
self.base = base
|
526 |
+
self.spec = load_dsdesc_and_resolve(shards, options=options, base=base)
|
527 |
+
self.shards = self.spec.get("shardlist", [])
|
528 |
+
self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards))
|
529 |
+
else:
|
530 |
+
raise NotImplementedError("Only support taking path/url to JSON descriptor file.")
|
531 |
+
self.base = None
|
532 |
+
self.spec = options
|
533 |
+
self.shards = shards
|
534 |
+
self.dataset_name = dataset_name or hash_dataset_name(str(shards))
|
535 |
+
|
536 |
+
self.lengths = [shard["nsamples"] for shard in self.shards]
|
537 |
+
self.cum_lengths = np.cumsum(self.lengths)
|
538 |
+
self.total_length = self.cum_lengths[-1]
|
539 |
+
|
540 |
+
if cache_dir is not None:
|
541 |
+
# when a cache dir is explicitly given, we download files into
|
542 |
+
# that directory without any changes
|
543 |
+
self.cache_dir = cache_dir
|
544 |
+
self.localname = cache_localname(cache_dir)
|
545 |
+
elif localname is not None:
|
546 |
+
# when a localname function is given, we use that
|
547 |
+
self.cache_dir = None
|
548 |
+
self.localname = localname
|
549 |
+
else:
|
550 |
+
import getpass
|
551 |
+
|
552 |
+
# when no cache dir or localname are given, use the cache from the environment
|
553 |
+
self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache")
|
554 |
+
self.cache_dir = osp.expanduser(self.cache_dir)
|
555 |
+
self.localname = default_localname(self.cache_dir)
|
556 |
+
|
557 |
+
self.data_info = (
|
558 |
+
f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, "
|
559 |
+
f"nfiles: {str(len(self.shards))}"
|
560 |
+
)
|
561 |
+
if True or int(os.environ.get("WIDS_VERBOSE", 0)):
|
562 |
+
nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
|
563 |
+
nsamples = sum(shard["nsamples"] for shard in self.shards)
|
564 |
+
self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} "
|
565 |
+
# print(
|
566 |
+
# "[WebShardedList]",
|
567 |
+
# str(shards),
|
568 |
+
# "base:",
|
569 |
+
# self.base,
|
570 |
+
# "name:",
|
571 |
+
# self.spec.get("name"),
|
572 |
+
# "nfiles:",
|
573 |
+
# len(self.shards),
|
574 |
+
# "nbytes:",
|
575 |
+
# nbytes,
|
576 |
+
# "samples:",
|
577 |
+
# nsamples,
|
578 |
+
# "cache:",
|
579 |
+
# self.cache_dir,
|
580 |
+
# file=sys.stderr,
|
581 |
+
# )
|
582 |
+
self.transformations = interpret_transformations(transformations)
|
583 |
+
|
584 |
+
if lru_size > 200:
|
585 |
+
warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors")
|
586 |
+
self.cache = LRUShards(lru_size, localname=self.localname, keep=keep)
|
587 |
+
|
588 |
+
def add_transform(self, transform):
|
589 |
+
"""Add a transformation to the dataset."""
|
590 |
+
self.transformations.append(transform)
|
591 |
+
return self
|
592 |
+
|
593 |
+
def __len__(self):
|
594 |
+
"""Return the total number of samples in the dataset."""
|
595 |
+
return self.total_length
|
596 |
+
|
597 |
+
def get_stats(self):
|
598 |
+
"""Return the number of cache accesses and misses."""
|
599 |
+
return self.cache.accesses, self.cache.misses
|
600 |
+
|
601 |
+
def check_cache_misses(self):
|
602 |
+
"""Check if the cache miss rate is too high."""
|
603 |
+
accesses, misses = self.get_stats()
|
604 |
+
if accesses > 100 and misses / accesses > 0.3:
|
605 |
+
# output a warning only once
|
606 |
+
self.check_cache_misses = lambda: None
|
607 |
+
print(f"Warning: ShardListDataset has a cache miss rate of {misses * 100.0 / accesses:.1%}%")
|
608 |
+
|
609 |
+
def get_shard(self, index):
|
610 |
+
"""Get the shard and index within the shard corresponding to the given index."""
|
611 |
+
# Find the shard corresponding to the given index.
|
612 |
+
shard_idx = np.searchsorted(self.cum_lengths, index, side="right")
|
613 |
+
|
614 |
+
# Figure out which index within the shard corresponds to the
|
615 |
+
# given index.
|
616 |
+
if shard_idx == 0:
|
617 |
+
inner_idx = index
|
618 |
+
else:
|
619 |
+
inner_idx = index - self.cum_lengths[shard_idx - 1]
|
620 |
+
|
621 |
+
# Get the shard and return the corresponding element.
|
622 |
+
desc = self.shards[shard_idx]
|
623 |
+
url = desc["url"]
|
624 |
+
if url.startswith(("https://", "http://", "gs://", "/", "~")):
|
625 |
+
# absolute path or url path
|
626 |
+
url = url
|
627 |
+
else:
|
628 |
+
# concat relative path
|
629 |
+
if self.base is None and "base_path" not in self.spec:
|
630 |
+
raise FileNotFoundError("passing a relative path in shardlist but no base found.")
|
631 |
+
base_path = self.spec["base_path"] if "base_path" in self.spec else self.base
|
632 |
+
url = osp.abspath(osp.join(osp.expanduser(base_path), url))
|
633 |
+
|
634 |
+
desc["url"] = url
|
635 |
+
try:
|
636 |
+
shard = self.cache.get_shard(url)
|
637 |
+
except UnicodeDecodeError as e:
|
638 |
+
print("UnicodeDecodeError:", desc)
|
639 |
+
raise e
|
640 |
+
return shard, inner_idx, desc
|
641 |
+
|
642 |
+
def __getitem__(self, index):
|
643 |
+
"""Return the sample corresponding to the given index."""
|
644 |
+
shard, inner_idx, desc = self.get_shard(index)
|
645 |
+
sample = shard[inner_idx]
|
646 |
+
|
647 |
+
# Check if we're missing the cache too often.
|
648 |
+
self.check_cache_misses()
|
649 |
+
|
650 |
+
sample["__dataset__"] = desc.get("dataset")
|
651 |
+
sample["__index__"] = index
|
652 |
+
sample["__shard__"] = desc["url"]
|
653 |
+
sample["__shardindex__"] = inner_idx
|
654 |
+
|
655 |
+
# Apply transformations
|
656 |
+
for transform in self.transformations:
|
657 |
+
sample = transform(sample)
|
658 |
+
|
659 |
+
return sample
|
660 |
+
|
661 |
+
def close(self):
|
662 |
+
"""Close the dataset."""
|
663 |
+
self.cache.clear()
|
664 |
+
|
665 |
+
|
666 |
+
class ShardListDatasetMulti(ShardListDataset):
|
667 |
+
"""An indexable dataset based on a list of shards.
|
668 |
+
|
669 |
+
The dataset is either given as a list of shards with optional options and name,
|
670 |
+
or as a URL pointing to a JSON descriptor file.
|
671 |
+
|
672 |
+
Datasets can reference other datasets via `source_url`.
|
673 |
+
|
674 |
+
Shard references within a dataset are resolve relative to an explicitly
|
675 |
+
given `base` property, or relative to the URL from which the dataset
|
676 |
+
descriptor was loaded.
|
677 |
+
"""
|
678 |
+
|
679 |
+
def __init__(
|
680 |
+
self,
|
681 |
+
shards,
|
682 |
+
*,
|
683 |
+
cache_size=int(1e12),
|
684 |
+
cache_dir=None,
|
685 |
+
lru_size=10,
|
686 |
+
dataset_name=None,
|
687 |
+
localname=None,
|
688 |
+
transformations="PIL",
|
689 |
+
keep=False,
|
690 |
+
base=None,
|
691 |
+
options=None,
|
692 |
+
sort_data_inseq=False,
|
693 |
+
num_replicas=None,
|
694 |
+
):
|
695 |
+
"""Create a ShardListDataset.
|
696 |
+
|
697 |
+
Args:
|
698 |
+
shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file
|
699 |
+
cache_size: the number of shards to keep in the cache
|
700 |
+
lru_size: the number of shards to keep in the LRU cache
|
701 |
+
localname: a function that maps URLs to local filenames
|
702 |
+
|
703 |
+
Note that there are two caches: an on-disk directory, and an in-memory LRU cache.
|
704 |
+
"""
|
705 |
+
if options is None:
|
706 |
+
options = {}
|
707 |
+
# shards is a list of (filename, length) pairs. We'll need to
|
708 |
+
# keep track of the lengths and cumulative lengths to know how
|
709 |
+
# to map indices to shards and indices within shards.
|
710 |
+
shards_lists = shards if isinstance(shards, list) else [shards]
|
711 |
+
bases = base if isinstance(base, list) else [base] * len(shards_lists)
|
712 |
+
self.spec = {}
|
713 |
+
self.shards = []
|
714 |
+
self.num_per_dir = {}
|
715 |
+
for base, shards in zip(bases, shards_lists):
|
716 |
+
if isinstance(shards, (str, io.IOBase)):
|
717 |
+
if base is None and isinstance(shards, str):
|
718 |
+
shards = osp.expanduser(shards)
|
719 |
+
base = urldir(shards)
|
720 |
+
self.base = base
|
721 |
+
_spec = load_dsdesc_and_resolve(shards, options=options, base=base)
|
722 |
+
update_dict_with_extend(self.spec, _spec)
|
723 |
+
self.num_per_dir[os.path.basename(os.path.dirname(shards))] = sum(
|
724 |
+
[shard["nsamples"] for shard in _spec.get("shardlist", [])]
|
725 |
+
)
|
726 |
+
else:
|
727 |
+
raise NotImplementedError("Only support taking path/url to JSON descriptor file.")
|
728 |
+
self.base = None
|
729 |
+
self.spec = options
|
730 |
+
self.shards = shards
|
731 |
+
self.dataset_name = dataset_name or hash_dataset_name(str(shards))
|
732 |
+
|
733 |
+
if sort_data_inseq and len(self.spec.get("shardlist", [])) > 0:
|
734 |
+
num_replicas = num_replicas or dist.get_world_size()
|
735 |
+
self.spec["shardlist"] = split_and_recombine(self.spec["shardlist"], num_replicas)
|
736 |
+
|
737 |
+
self.shards.extend(self.spec.get("shardlist", []))
|
738 |
+
self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards))
|
739 |
+
|
740 |
+
self.lengths = [shard["nsamples"] for shard in self.shards]
|
741 |
+
self.cum_lengths = np.cumsum(self.lengths)
|
742 |
+
self.total_length = self.cum_lengths[-1]
|
743 |
+
|
744 |
+
if cache_dir is not None:
|
745 |
+
# when a cache dir is explicitly given, we download files into
|
746 |
+
# that directory without any changes
|
747 |
+
self.cache_dir = cache_dir
|
748 |
+
self.localname = cache_localname(cache_dir)
|
749 |
+
elif localname is not None:
|
750 |
+
# when a localname function is given, we use that
|
751 |
+
self.cache_dir = None
|
752 |
+
self.localname = localname
|
753 |
+
else:
|
754 |
+
import getpass
|
755 |
+
|
756 |
+
# when no cache dir or localname are given, use the cache from the environment
|
757 |
+
self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache")
|
758 |
+
self.cache_dir = osp.expanduser(self.cache_dir)
|
759 |
+
self.localname = default_localname(self.cache_dir)
|
760 |
+
|
761 |
+
self.data_info = (
|
762 |
+
f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, "
|
763 |
+
f"nfiles: {str(len(self.shards))}"
|
764 |
+
)
|
765 |
+
if True or int(os.environ.get("WIDS_VERBOSE", 0)):
|
766 |
+
nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
|
767 |
+
nsamples = sum(shard["nsamples"] for shard in self.shards)
|
768 |
+
self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} "
|
769 |
+
self.transformations = interpret_transformations(transformations)
|
770 |
+
|
771 |
+
if lru_size > 200:
|
772 |
+
warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors")
|
773 |
+
self.cache = LRUShards(lru_size, localname=self.localname, keep=keep)
|
774 |
+
|
775 |
+
|
776 |
+
def split_and_recombine(lst, n):
|
777 |
+
from collections import OrderedDict
|
778 |
+
|
779 |
+
def extract_prefix(i):
|
780 |
+
return i["url"].split("/")[-2]
|
781 |
+
|
782 |
+
unique_parts = list(OrderedDict((extract_prefix(item), None) for item in lst).keys())
|
783 |
+
split_dict = {part: [] for part in unique_parts}
|
784 |
+
|
785 |
+
for part in unique_parts:
|
786 |
+
part_list = [item for item in lst if extract_prefix(item) == part]
|
787 |
+
chunk_size = max(1, len(part_list) // n) # 确保 chunk_size 至少为 1
|
788 |
+
chunks = [part_list[i * chunk_size : (i + 1) * chunk_size] for i in range(n)]
|
789 |
+
|
790 |
+
# 处理最后一个 chunk,如果数量不均匀,将剩余的元素添加到最后一个 chunk
|
791 |
+
if len(part_list) % n != 0:
|
792 |
+
chunks[-1].extend(part_list[n * chunk_size :])
|
793 |
+
|
794 |
+
split_dict[part] = chunks
|
795 |
+
|
796 |
+
recombined_list = []
|
797 |
+
for i in range(n):
|
798 |
+
for part in unique_parts:
|
799 |
+
recombined_list.extend(split_dict[part][i])
|
800 |
+
|
801 |
+
return recombined_list
|
802 |
+
|
803 |
+
|
804 |
+
def lengths_to_ranges(lengths):
|
805 |
+
"""Convert a list of lengths to a list of ranges."""
|
806 |
+
ranges = []
|
807 |
+
start = 0
|
808 |
+
for length in lengths:
|
809 |
+
ranges.append((start, start + length))
|
810 |
+
start += length
|
811 |
+
return ranges
|
812 |
+
|
813 |
+
|
814 |
+
def intersect_range(a, b):
|
815 |
+
"""Return the intersection of the two half-open integer intervals."""
|
816 |
+
result = max(a[0], b[0]), min(a[1], b[1])
|
817 |
+
if result[0] >= result[1]:
|
818 |
+
return None
|
819 |
+
return result
|
820 |
+
|
821 |
+
|
822 |
+
def intersect_ranges(rangelist, r):
|
823 |
+
"""Return the intersection of the half-open integer interval r with the list of half-open integer intervals."""
|
824 |
+
result = []
|
825 |
+
for a in rangelist:
|
826 |
+
x = intersect_range(a, r)
|
827 |
+
if x is not None:
|
828 |
+
result.append(x)
|
829 |
+
return result
|
830 |
+
|
831 |
+
|
832 |
+
def iterate_ranges(ranges, rng, indexshuffle=True, shardshuffle=True):
|
833 |
+
"""Iterate over the ranges in a random order."""
|
834 |
+
shard_indexes = list(range(len(ranges)))
|
835 |
+
if shardshuffle:
|
836 |
+
rng.shuffle(shard_indexes)
|
837 |
+
for i in shard_indexes:
|
838 |
+
lo, hi = ranges[i]
|
839 |
+
sample_indexes = list(range(lo, hi))
|
840 |
+
if indexshuffle:
|
841 |
+
rng.shuffle(sample_indexes)
|
842 |
+
yield from sample_indexes
|
843 |
+
|
844 |
+
|
845 |
+
class ShardListSampler(Sampler):
|
846 |
+
"""A sampler that samples consistent with a ShardListDataset.
|
847 |
+
|
848 |
+
This sampler is used to sample from a ShardListDataset in a way that
|
849 |
+
preserves locality.
|
850 |
+
|
851 |
+
This returns a permutation of the indexes by shard, then a permutation of
|
852 |
+
indexes within each shard. This ensures that the data is accessed in a
|
853 |
+
way that preserves locality.
|
854 |
+
|
855 |
+
Note that how this ends up splitting data between multiple workers ends up
|
856 |
+
on the details of the DataLoader. Generally, it will likely load samples from the
|
857 |
+
same shard in each worker.
|
858 |
+
|
859 |
+
Other more sophisticated shard-aware samplers are possible and will likely
|
860 |
+
be added.
|
861 |
+
"""
|
862 |
+
|
863 |
+
def __init__(self, dataset, *, lengths=None, seed=0, shufflefirst=False):
|
864 |
+
if lengths is None:
|
865 |
+
lengths = list(dataset.lengths)
|
866 |
+
self.ranges = lengths_to_ranges(lengths)
|
867 |
+
self.seed = seed
|
868 |
+
self.shufflefirst = shufflefirst
|
869 |
+
self.epoch = 0
|
870 |
+
|
871 |
+
def __iter__(self):
|
872 |
+
self.rng = random.Random(self.seed + 1289738273 * self.epoch)
|
873 |
+
shardshuffle = self.shufflefirst or self.epoch > 0
|
874 |
+
yield from iterate_ranges(self.ranges, self.rng, shardshuffle=shardshuffle)
|
875 |
+
self.epoch += 1
|
876 |
+
|
877 |
+
|
878 |
+
ShardedSampler = ShardListSampler
|
879 |
+
|
880 |
+
|
881 |
+
class ChunkedSampler(Sampler):
|
882 |
+
"""A sampler that samples in chunks and then shuffles the samples within each chunk.
|
883 |
+
|
884 |
+
This preserves locality of reference while still shuffling the data.
|
885 |
+
"""
|
886 |
+
|
887 |
+
def __init__(
|
888 |
+
self,
|
889 |
+
dataset,
|
890 |
+
*,
|
891 |
+
num_samples=None,
|
892 |
+
chunksize=2000,
|
893 |
+
seed=0,
|
894 |
+
shuffle=False,
|
895 |
+
shufflefirst=False,
|
896 |
+
):
|
897 |
+
if isinstance(num_samples, int):
|
898 |
+
lo, hi = 0, num_samples
|
899 |
+
elif num_samples is None:
|
900 |
+
lo, hi = 0, len(dataset)
|
901 |
+
else:
|
902 |
+
lo, hi = num_samples
|
903 |
+
self.ranges = [(i, min(i + chunksize, hi)) for i in range(lo, hi, chunksize)]
|
904 |
+
self.seed = seed
|
905 |
+
self.shuffle = shuffle
|
906 |
+
self.shufflefirst = shufflefirst
|
907 |
+
self.epoch = 0
|
908 |
+
|
909 |
+
def set_epoch(self, epoch):
|
910 |
+
self.epoch = epoch
|
911 |
+
|
912 |
+
def __iter__(self):
|
913 |
+
self.rng = random.Random(self.seed + 1289738273 * self.epoch)
|
914 |
+
shardshuffle = self.shufflefirst or self.epoch > 0
|
915 |
+
yield from iterate_ranges(
|
916 |
+
self.ranges,
|
917 |
+
self.rng,
|
918 |
+
indexshuffle=self.shuffle,
|
919 |
+
shardshuffle=(self.shuffle and shardshuffle),
|
920 |
+
)
|
921 |
+
self.epoch += 1
|
922 |
+
|
923 |
+
def __len__(self):
|
924 |
+
return len(self.ranges)
|
925 |
+
|
926 |
+
|
927 |
+
def DistributedChunkedSampler(
|
928 |
+
dataset: Dataset,
|
929 |
+
*,
|
930 |
+
num_replicas: Optional[int] = None,
|
931 |
+
num_samples: Optional[int] = None,
|
932 |
+
rank: Optional[int] = None,
|
933 |
+
shuffle: bool = True,
|
934 |
+
shufflefirst: bool = False,
|
935 |
+
seed: int = 0,
|
936 |
+
drop_last: bool = None,
|
937 |
+
chunksize: int = 1000000,
|
938 |
+
) -> ChunkedSampler:
|
939 |
+
"""Return a ChunkedSampler for the current worker in distributed training.
|
940 |
+
|
941 |
+
Reverts to a simple ChunkedSampler if not running in distributed mode.
|
942 |
+
|
943 |
+
Since the split among workers takes place before the chunk shuffle,
|
944 |
+
workers end up with a fixed set of shards they need to download. The
|
945 |
+
more workers, the fewer shards are used by each worker.
|
946 |
+
"""
|
947 |
+
if drop_last is not None:
|
948 |
+
warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored")
|
949 |
+
if not dist.is_initialized():
|
950 |
+
warnings.warn("DistributedChunkedSampler is called without distributed initialized; assuming single process")
|
951 |
+
num_replicas = 1
|
952 |
+
rank = 0
|
953 |
+
else:
|
954 |
+
num_replicas = num_replicas or dist.get_world_size()
|
955 |
+
rank = rank or dist.get_rank()
|
956 |
+
assert rank >= 0 and rank < num_replicas
|
957 |
+
|
958 |
+
num_samples = num_samples or len(dataset)
|
959 |
+
worker_chunk = (num_samples + num_replicas - 1) // num_replicas
|
960 |
+
worker_start = rank * worker_chunk
|
961 |
+
worker_end = min(worker_start + worker_chunk, num_samples)
|
962 |
+
return ChunkedSampler(
|
963 |
+
dataset,
|
964 |
+
num_samples=(worker_start, worker_end),
|
965 |
+
chunksize=chunksize,
|
966 |
+
seed=seed,
|
967 |
+
shuffle=shuffle,
|
968 |
+
shufflefirst=shufflefirst,
|
969 |
+
)
|
970 |
+
|
971 |
+
|
972 |
+
class DistributedRangedSampler(Sampler):
|
973 |
+
"""A sampler that samples in chunks and then shuffles the samples within each chunk.
|
974 |
+
|
975 |
+
This preserves locality of reference while still shuffling the data.
|
976 |
+
"""
|
977 |
+
|
978 |
+
def __init__(
|
979 |
+
self,
|
980 |
+
dataset: Dataset,
|
981 |
+
num_replicas: Optional[int] = None,
|
982 |
+
num_samples: Optional[int] = None,
|
983 |
+
rank: Optional[int] = None,
|
984 |
+
drop_last: bool = None,
|
985 |
+
):
|
986 |
+
if drop_last is not None:
|
987 |
+
warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored")
|
988 |
+
if not dist.is_initialized():
|
989 |
+
warnings.warn(
|
990 |
+
"DistributedChunkedSampler is called without distributed initialized; assuming single process"
|
991 |
+
)
|
992 |
+
num_replicas = 1
|
993 |
+
rank = 0
|
994 |
+
else:
|
995 |
+
num_replicas = num_replicas or dist.get_world_size()
|
996 |
+
rank = rank or dist.get_rank()
|
997 |
+
assert rank >= 0 and rank < num_replicas
|
998 |
+
num_samples = num_samples or len(dataset)
|
999 |
+
self.worker_chunk = num_samples // num_replicas
|
1000 |
+
self.worker_start = rank * self.worker_chunk
|
1001 |
+
self.worker_end = min((rank + 1) * self.worker_chunk, num_samples)
|
1002 |
+
self.ranges = range(self.worker_start, self.worker_end)
|
1003 |
+
self.epoch = 0
|
1004 |
+
self.step_start = 0
|
1005 |
+
|
1006 |
+
def set_epoch(self, epoch):
|
1007 |
+
self.epoch = epoch
|
1008 |
+
|
1009 |
+
def __len__(self):
|
1010 |
+
return len(self.ranges)
|
1011 |
+
|
1012 |
+
def set_start(self, start):
|
1013 |
+
self.step_start = start
|
1014 |
+
|
1015 |
+
def __iter__(self):
|
1016 |
+
yield from self.ranges[self.step_start :]
|
1017 |
+
self.epoch += 1
|
1018 |
+
|
1019 |
+
|
1020 |
+
class DistributedLocalSampler(DistributedSampler):
|
1021 |
+
def __iter__(self):
|
1022 |
+
if self.shuffle:
|
1023 |
+
# deterministically shuffle based on epoch and seed
|
1024 |
+
g = torch.Generator()
|
1025 |
+
g.manual_seed(self.seed + self.epoch)
|
1026 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
1027 |
+
else:
|
1028 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
1029 |
+
|
1030 |
+
if not self.drop_last:
|
1031 |
+
# add extra samples to make it evenly divisible
|
1032 |
+
padding_size = self.total_size - len(indices)
|
1033 |
+
if padding_size <= len(indices):
|
1034 |
+
indices += indices[:padding_size]
|
1035 |
+
else:
|
1036 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
1037 |
+
else:
|
1038 |
+
# remove tail of data to make it evenly divisible.
|
1039 |
+
indices = indices[: self.total_size]
|
1040 |
+
assert len(indices) == self.total_size
|
1041 |
+
|
1042 |
+
# subsample
|
1043 |
+
# indices = indices[self.rank:self.total_size:self.num_replicas]
|
1044 |
+
chunk_size = self.total_size // self.num_replicas
|
1045 |
+
begin_idx = chunk_size * self.rank
|
1046 |
+
stop_idx = chunk_size * (self.rank + 1)
|
1047 |
+
indices = indices[begin_idx:stop_idx]
|
1048 |
+
|
1049 |
+
# print("[SamplerIndices: ]", indices)
|
1050 |
+
assert len(indices) == self.num_samples
|
1051 |
+
return iter(indices)
|
diffusion/data/wids/wids_dl.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
|
18 |
+
import fcntl
|
19 |
+
import os
|
20 |
+
import shutil
|
21 |
+
import sys
|
22 |
+
import time
|
23 |
+
from collections import deque
|
24 |
+
from datetime import datetime
|
25 |
+
from urllib.parse import urlparse
|
26 |
+
|
27 |
+
recent_downloads = deque(maxlen=1000)
|
28 |
+
|
29 |
+
open_objects = {}
|
30 |
+
max_open_objects = 100
|
31 |
+
|
32 |
+
|
33 |
+
class ULockFile:
|
34 |
+
"""A simple locking class. We don't need any of the third
|
35 |
+
party libraries since we rely on POSIX semantics for linking
|
36 |
+
below anyway."""
|
37 |
+
|
38 |
+
def __init__(self, path):
|
39 |
+
self.lockfile_path = path
|
40 |
+
self.lockfile = None
|
41 |
+
|
42 |
+
def __enter__(self):
|
43 |
+
self.lockfile = open(self.lockfile_path, "w")
|
44 |
+
fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX)
|
45 |
+
return self
|
46 |
+
|
47 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
48 |
+
fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
|
49 |
+
self.lockfile.close()
|
50 |
+
self.lockfile = None
|
51 |
+
try:
|
52 |
+
os.unlink(self.lockfile_path)
|
53 |
+
except FileNotFoundError:
|
54 |
+
pass
|
55 |
+
|
56 |
+
|
57 |
+
def pipe_download(remote, local):
|
58 |
+
"""Perform a download for a pipe: url."""
|
59 |
+
assert remote.startswith("pipe:")
|
60 |
+
cmd = remote[5:]
|
61 |
+
cmd = cmd.format(local=local)
|
62 |
+
assert os.system(cmd) == 0, "Command failed: %s" % cmd
|
63 |
+
|
64 |
+
|
65 |
+
def copy_file(remote, local):
|
66 |
+
remote = urlparse(remote)
|
67 |
+
assert remote.scheme in ["file", ""]
|
68 |
+
# use absolute path
|
69 |
+
remote = os.path.abspath(remote.path)
|
70 |
+
local = urlparse(local)
|
71 |
+
assert local.scheme in ["file", ""]
|
72 |
+
local = os.path.abspath(local.path)
|
73 |
+
if remote == local:
|
74 |
+
return
|
75 |
+
# check if the local file exists
|
76 |
+
shutil.copyfile(remote, local)
|
77 |
+
|
78 |
+
|
79 |
+
verbose_cmd = int(os.environ.get("WIDS_VERBOSE_CMD", "0"))
|
80 |
+
|
81 |
+
|
82 |
+
def vcmd(flag, verbose_flag=""):
|
83 |
+
return verbose_flag if verbose_cmd else flag
|
84 |
+
|
85 |
+
|
86 |
+
default_cmds = {
|
87 |
+
"posixpath": copy_file,
|
88 |
+
"file": copy_file,
|
89 |
+
"pipe": pipe_download,
|
90 |
+
"http": "curl " + vcmd("-s") + " -L {url} -o {local}",
|
91 |
+
"https": "curl " + vcmd("-s") + " -L {url} -o {local}",
|
92 |
+
"ftp": "curl " + vcmd("-s") + " -L {url} -o {local}",
|
93 |
+
"ftps": "curl " + vcmd("-s") + " -L {url} -o {local}",
|
94 |
+
"gs": "gsutil " + vcmd("-q") + " cp {url} {local}",
|
95 |
+
"s3": "aws s3 cp {url} {local}",
|
96 |
+
}
|
97 |
+
|
98 |
+
|
99 |
+
# TODO(ligeng): change HTTPS download to python requests library
|
100 |
+
|
101 |
+
|
102 |
+
def download_file_no_log(remote, local, handlers=default_cmds):
|
103 |
+
"""Download a file from a remote url to a local path.
|
104 |
+
The remote url can be a pipe: url, in which case the remainder of
|
105 |
+
the url is treated as a command template that is executed to perform the download.
|
106 |
+
"""
|
107 |
+
|
108 |
+
if remote.startswith("pipe:"):
|
109 |
+
schema = "pipe"
|
110 |
+
else:
|
111 |
+
schema = urlparse(remote).scheme
|
112 |
+
if schema is None or schema == "":
|
113 |
+
schema = "posixpath"
|
114 |
+
# get the handler
|
115 |
+
handler = handlers.get(schema)
|
116 |
+
if handler is None:
|
117 |
+
raise ValueError("Unknown schema: %s" % schema)
|
118 |
+
# call the handler
|
119 |
+
if callable(handler):
|
120 |
+
handler(remote, local)
|
121 |
+
else:
|
122 |
+
assert isinstance(handler, str)
|
123 |
+
cmd = handler.format(url=remote, local=local)
|
124 |
+
assert os.system(cmd) == 0, "Command failed: %s" % cmd
|
125 |
+
return local
|
126 |
+
|
127 |
+
|
128 |
+
def download_file(remote, local, handlers=default_cmds, verbose=False):
|
129 |
+
start = time.time()
|
130 |
+
try:
|
131 |
+
return download_file_no_log(remote, local, handlers=handlers)
|
132 |
+
finally:
|
133 |
+
recent_downloads.append((remote, local, time.time(), time.time() - start))
|
134 |
+
if verbose:
|
135 |
+
print(
|
136 |
+
"downloaded",
|
137 |
+
remote,
|
138 |
+
"to",
|
139 |
+
local,
|
140 |
+
"in",
|
141 |
+
time.time() - start,
|
142 |
+
"seconds",
|
143 |
+
file=sys.stderr,
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
def download_and_open(remote, local, mode="rb", handlers=default_cmds, verbose=False):
|
148 |
+
with ULockFile(local + ".lock"):
|
149 |
+
if os.path.exists(remote):
|
150 |
+
# print("enter1", remote, local, mode)
|
151 |
+
result = open(remote, mode)
|
152 |
+
else:
|
153 |
+
# print("enter2", remote, local, mode)
|
154 |
+
if not os.path.exists(local):
|
155 |
+
if verbose:
|
156 |
+
print("downloading", remote, "to", local, file=sys.stderr)
|
157 |
+
download_file(remote, local, handlers=handlers)
|
158 |
+
else:
|
159 |
+
if verbose:
|
160 |
+
print("using cached", local, file=sys.stderr)
|
161 |
+
result = open(local, mode)
|
162 |
+
|
163 |
+
# input()
|
164 |
+
|
165 |
+
if open_objects is not None:
|
166 |
+
for k, v in list(open_objects.items()):
|
167 |
+
if v.closed:
|
168 |
+
del open_objects[k]
|
169 |
+
if len(open_objects) > max_open_objects:
|
170 |
+
raise RuntimeError("Too many open objects")
|
171 |
+
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
|
172 |
+
key = tuple(str(x) for x in [remote, local, mode, current_time])
|
173 |
+
open_objects[key] = result
|
174 |
+
return result
|
diffusion/data/wids/wids_lru.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
|
18 |
+
from collections import OrderedDict
|
19 |
+
|
20 |
+
|
21 |
+
class LRUCache:
|
22 |
+
def __init__(self, capacity: int, release_handler=None):
|
23 |
+
"""Initialize a new LRU cache with the given capacity."""
|
24 |
+
self.capacity = capacity
|
25 |
+
self.cache = OrderedDict()
|
26 |
+
self.release_handler = release_handler
|
27 |
+
|
28 |
+
def __getitem__(self, key):
|
29 |
+
"""Return the value associated with the given key, or None."""
|
30 |
+
if key not in self.cache:
|
31 |
+
return None
|
32 |
+
self.cache.move_to_end(key)
|
33 |
+
return self.cache[key]
|
34 |
+
|
35 |
+
def __setitem__(self, key, value):
|
36 |
+
"""Associate the given value with the given key."""
|
37 |
+
if key in self.cache:
|
38 |
+
self.cache.move_to_end(key)
|
39 |
+
self.cache[key] = value
|
40 |
+
if len(self.cache) > self.capacity:
|
41 |
+
key, value = self.cache.popitem(last=False)
|
42 |
+
if self.release_handler is not None:
|
43 |
+
self.release_handler(key, value)
|
44 |
+
|
45 |
+
def __delitem__(self, key):
|
46 |
+
"""Remove the given key from the cache."""
|
47 |
+
if key in self.cache:
|
48 |
+
if self.release_handler is not None:
|
49 |
+
value = self.cache[key]
|
50 |
+
self.release_handler(key, value)
|
51 |
+
del self.cache[key]
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
"""Return the number of entries in the cache."""
|
55 |
+
return len(self.cache)
|
56 |
+
|
57 |
+
def __contains__(self, key):
|
58 |
+
"""Return whether the cache contains the given key."""
|
59 |
+
return key in self.cache
|
60 |
+
|
61 |
+
def items(self):
|
62 |
+
"""Return an iterator over the keys of the cache."""
|
63 |
+
return self.cache.items()
|
64 |
+
|
65 |
+
def keys(self):
|
66 |
+
"""Return an iterator over the keys of the cache."""
|
67 |
+
return self.cache.keys()
|
68 |
+
|
69 |
+
def values(self):
|
70 |
+
"""Return an iterator over the values of the cache."""
|
71 |
+
return self.cache.values()
|
72 |
+
|
73 |
+
def clear(self):
|
74 |
+
for key in list(self.keys()):
|
75 |
+
value = self.cache[key]
|
76 |
+
if self.release_handler is not None:
|
77 |
+
self.release_handler(key, value)
|
78 |
+
del self[key]
|
79 |
+
|
80 |
+
def __del__(self):
|
81 |
+
self.clear()
|
diffusion/data/wids/wids_mmtar.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
|
18 |
+
import collections
|
19 |
+
import fcntl
|
20 |
+
import io
|
21 |
+
import mmap
|
22 |
+
import os
|
23 |
+
import struct
|
24 |
+
|
25 |
+
TarHeader = collections.namedtuple(
|
26 |
+
"TarHeader",
|
27 |
+
[
|
28 |
+
"name",
|
29 |
+
"mode",
|
30 |
+
"uid",
|
31 |
+
"gid",
|
32 |
+
"size",
|
33 |
+
"mtime",
|
34 |
+
"chksum",
|
35 |
+
"typeflag",
|
36 |
+
"linkname",
|
37 |
+
"magic",
|
38 |
+
"version",
|
39 |
+
"uname",
|
40 |
+
"gname",
|
41 |
+
"devmajor",
|
42 |
+
"devminor",
|
43 |
+
"prefix",
|
44 |
+
],
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def parse_tar_header(header_bytes):
|
49 |
+
header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
|
50 |
+
return TarHeader(*header)
|
51 |
+
|
52 |
+
|
53 |
+
def next_header(offset, header):
|
54 |
+
block_size = 512
|
55 |
+
size = header.size.decode("utf-8").strip("\x00")
|
56 |
+
if size == "":
|
57 |
+
return -1
|
58 |
+
size = int(size, 8)
|
59 |
+
# compute the file size rounded up to the next block size if it is a partial block
|
60 |
+
padded_file_size = (size + block_size - 1) // block_size * block_size
|
61 |
+
return offset + block_size + padded_file_size
|
62 |
+
|
63 |
+
|
64 |
+
# TODO(ligeng): support gzip stream
|
65 |
+
class MMIndexedTar:
|
66 |
+
def __init__(self, fname, index_file=None, verbose=True, cleanup_callback=None):
|
67 |
+
self.verbose = verbose
|
68 |
+
self.cleanup_callback = cleanup_callback
|
69 |
+
if isinstance(fname, str):
|
70 |
+
self.stream = open(fname, "rb")
|
71 |
+
self.fname = fname
|
72 |
+
elif isinstance(fname, io.IOBase):
|
73 |
+
self.stream = fname
|
74 |
+
self.fname = None
|
75 |
+
self.mmapped_file = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
|
76 |
+
if cleanup_callback:
|
77 |
+
cleanup_callback(fname, self.stream.fileno(), "start")
|
78 |
+
self._build_index()
|
79 |
+
|
80 |
+
def close(self, dispose=False):
|
81 |
+
if self.cleanup_callback:
|
82 |
+
self.cleanup_callback(self.fname, self.stream.fileno(), "end")
|
83 |
+
self.mmapped_file.close()
|
84 |
+
self.stream.close()
|
85 |
+
|
86 |
+
def _build_index(self):
|
87 |
+
self.by_name = {}
|
88 |
+
self.by_index = []
|
89 |
+
offset = 0
|
90 |
+
while offset >= 0 and offset < len(self.mmapped_file):
|
91 |
+
header = parse_tar_header(self.mmapped_file[offset : offset + 500])
|
92 |
+
name = header.name.decode("utf-8").strip("\x00")
|
93 |
+
typeflag = header.typeflag.decode("utf-8").strip("\x00")
|
94 |
+
if name != "" and name != "././@PaxHeader" and typeflag in ["0", ""]:
|
95 |
+
try:
|
96 |
+
size = int(header.size.decode("utf-8")[:-1], 8)
|
97 |
+
except ValueError as exn:
|
98 |
+
print(header)
|
99 |
+
raise exn
|
100 |
+
self.by_name[name] = offset
|
101 |
+
self.by_index.append((name, offset, size))
|
102 |
+
offset = next_header(offset, header)
|
103 |
+
|
104 |
+
def names(self):
|
105 |
+
return self.by_name.keys()
|
106 |
+
|
107 |
+
def get_at_offset(self, offset):
|
108 |
+
header = parse_tar_header(self.mmapped_file[offset : offset + 500])
|
109 |
+
name = header.name.decode("utf-8").strip("\x00")
|
110 |
+
start = offset + 512
|
111 |
+
end = start + int(header.size.decode("utf-8")[:-1], 8)
|
112 |
+
return name, self.mmapped_file[start:end]
|
113 |
+
|
114 |
+
def get_at_index(self, index):
|
115 |
+
name, offset, size = self.by_index[index]
|
116 |
+
return self.get_at_offset(offset)
|
117 |
+
|
118 |
+
def get_by_name(self, name):
|
119 |
+
offset = self.by_name[name]
|
120 |
+
return self.get_at_offset(offset)
|
121 |
+
|
122 |
+
def __iter__(self):
|
123 |
+
for name, offset, size in self.by_index:
|
124 |
+
yield name, self.mmapped_file[offset + 512 : offset + 512 + size]
|
125 |
+
|
126 |
+
def __getitem__(self, key):
|
127 |
+
if isinstance(key, int):
|
128 |
+
return self.get_at_index(key)
|
129 |
+
else:
|
130 |
+
return self.get_by_name(key)
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return len(self.by_index)
|
134 |
+
|
135 |
+
def get_file(self, i):
|
136 |
+
fname, data = self.get_at_index(i)
|
137 |
+
return fname, io.BytesIO(data)
|
138 |
+
|
139 |
+
|
140 |
+
def keep_while_reading(fname, fd, phase, delay=0.0):
|
141 |
+
"""This is a possible cleanup callback for cleanup_callback of MIndexedTar.
|
142 |
+
|
143 |
+
It assumes that as long as there are some readers for a file,
|
144 |
+
more readers may be trying to open it.
|
145 |
+
|
146 |
+
Note that on Linux, unlinking the file doesn't matter after
|
147 |
+
it has been mmapped. The contents will only be deleted when
|
148 |
+
all readers close the file. The unlinking merely makes the file
|
149 |
+
unavailable to new readers, since the downloader checks first
|
150 |
+
whether the file exists.
|
151 |
+
"""
|
152 |
+
assert delay == 0.0, "delay not implemented"
|
153 |
+
if fd < 0 or fname is None:
|
154 |
+
return
|
155 |
+
if phase == "start":
|
156 |
+
fcntl.flock(fd, fcntl.LOCK_SH)
|
157 |
+
elif phase == "end":
|
158 |
+
try:
|
159 |
+
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
160 |
+
os.unlink(fname)
|
161 |
+
except FileNotFoundError:
|
162 |
+
# someone else deleted it already
|
163 |
+
pass
|
164 |
+
except BlockingIOError:
|
165 |
+
# we couldn't get an exclusive lock, so someone else is still reading
|
166 |
+
pass
|
167 |
+
else:
|
168 |
+
raise ValueError(f"Unknown phase {phase}")
|
diffusion/data/wids/wids_specs.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
|
18 |
+
import io
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import tempfile
|
22 |
+
from urllib.parse import urlparse, urlunparse
|
23 |
+
|
24 |
+
from .wids_dl import download_and_open
|
25 |
+
|
26 |
+
|
27 |
+
def urldir(url):
|
28 |
+
"""Return the directory part of a url."""
|
29 |
+
parsed_url = urlparse(url)
|
30 |
+
path = parsed_url.path
|
31 |
+
directory = os.path.dirname(path)
|
32 |
+
return parsed_url._replace(path=directory).geturl()
|
33 |
+
|
34 |
+
|
35 |
+
def urlmerge(base, url):
|
36 |
+
"""Merge a base URL and a relative URL.
|
37 |
+
|
38 |
+
The function fills in any missing part of the url from the base,
|
39 |
+
except for params, query, and fragment, which are taken only from the 'url'.
|
40 |
+
For the pathname component, it merges the paths like os.path.join:
|
41 |
+
an absolute path in 'url' overrides the base path, otherwise the paths are merged.
|
42 |
+
|
43 |
+
Parameters:
|
44 |
+
base (str): The base URL.
|
45 |
+
url (str): The URL to merge with the base.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: The merged URL.
|
49 |
+
"""
|
50 |
+
# Parse the base and the relative URL
|
51 |
+
parsed_base = urlparse(base)
|
52 |
+
parsed_url = urlparse(url)
|
53 |
+
|
54 |
+
# Merge paths using os.path.join
|
55 |
+
# If the url path is absolute, it overrides the base path
|
56 |
+
if parsed_url.path.startswith("/"):
|
57 |
+
merged_path = parsed_url.path
|
58 |
+
else:
|
59 |
+
merged_path = os.path.normpath(os.path.join(parsed_base.path, parsed_url.path))
|
60 |
+
|
61 |
+
# Construct the merged URL
|
62 |
+
merged_url = urlunparse(
|
63 |
+
(
|
64 |
+
parsed_url.scheme or parsed_base.scheme,
|
65 |
+
parsed_url.netloc or parsed_base.netloc,
|
66 |
+
merged_path,
|
67 |
+
parsed_url.params, # Use params from the url only
|
68 |
+
parsed_url.query, # Use query from the url only
|
69 |
+
parsed_url.fragment, # Use fragment from the url only
|
70 |
+
)
|
71 |
+
)
|
72 |
+
|
73 |
+
return merged_url
|
74 |
+
|
75 |
+
|
76 |
+
def check_shards(l):
|
77 |
+
"""Check that a list of shards is well-formed.
|
78 |
+
|
79 |
+
This checks that the list is a list of dictionaries, and that
|
80 |
+
each dictionary has a "url" and a "nsamples" key.
|
81 |
+
"""
|
82 |
+
assert isinstance(l, list)
|
83 |
+
for shard in l:
|
84 |
+
assert isinstance(shard, dict)
|
85 |
+
assert "url" in shard
|
86 |
+
assert "nsamples" in shard
|
87 |
+
return l
|
88 |
+
|
89 |
+
|
90 |
+
def set_all(l, k, v):
|
91 |
+
"""Set a key to a value in a list of dictionaries."""
|
92 |
+
if v is None:
|
93 |
+
return
|
94 |
+
for x in l:
|
95 |
+
if k not in x:
|
96 |
+
x[k] = v
|
97 |
+
|
98 |
+
|
99 |
+
def load_remote_dsdesc_raw(source):
|
100 |
+
"""Load a remote or local dataset description in JSON format."""
|
101 |
+
if isinstance(source, str):
|
102 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
103 |
+
dlname = os.path.join(tmpdir, "dataset.json")
|
104 |
+
with download_and_open(source, dlname) as f:
|
105 |
+
dsdesc = json.load(f)
|
106 |
+
elif isinstance(source, io.IOBase):
|
107 |
+
dsdesc = json.load(source)
|
108 |
+
else:
|
109 |
+
# FIXME: use gopen
|
110 |
+
import requests
|
111 |
+
|
112 |
+
jsondata = requests.get(source).text
|
113 |
+
dsdesc = json.loads(jsondata)
|
114 |
+
return dsdesc
|
115 |
+
|
116 |
+
|
117 |
+
def rebase_shardlist(shardlist, base):
|
118 |
+
"""Rebase the URLs in a shardlist."""
|
119 |
+
if base is None:
|
120 |
+
return shardlist
|
121 |
+
for shard in shardlist:
|
122 |
+
shard["url"] = urlmerge(base, shard["url"])
|
123 |
+
return shardlist
|
124 |
+
|
125 |
+
|
126 |
+
def resolve_dsdesc(dsdesc, *, options=None, base=None):
|
127 |
+
"""Resolve a dataset description.
|
128 |
+
|
129 |
+
This rebases the shards as necessary and loads any remote references.
|
130 |
+
|
131 |
+
Dataset descriptions are JSON files. They must have the following format;
|
132 |
+
|
133 |
+
{
|
134 |
+
"wids_version": 1,
|
135 |
+
# optional immediate shardlist
|
136 |
+
"shardlist": [
|
137 |
+
{"url": "http://example.com/file.tar", "nsamples": 1000},
|
138 |
+
...
|
139 |
+
],
|
140 |
+
# sub-datasets
|
141 |
+
"datasets": [
|
142 |
+
{"source_url": "http://example.com/dataset.json"},
|
143 |
+
{"shardlist": [
|
144 |
+
{"url": "http://example.com/file.tar", "nsamples": 1000},
|
145 |
+
...
|
146 |
+
]}
|
147 |
+
...
|
148 |
+
]
|
149 |
+
}
|
150 |
+
"""
|
151 |
+
if options is None:
|
152 |
+
options = {}
|
153 |
+
assert isinstance(dsdesc, dict)
|
154 |
+
dsdesc = dict(dsdesc, **options)
|
155 |
+
shardlist = rebase_shardlist(dsdesc.get("shardlist", []), base)
|
156 |
+
assert shardlist is not None
|
157 |
+
set_all(shardlist, "weight", dsdesc.get("weight"))
|
158 |
+
set_all(shardlist, "name", dsdesc.get("name"))
|
159 |
+
check_shards(shardlist)
|
160 |
+
assert "wids_version" in dsdesc, "No wids_version in dataset description"
|
161 |
+
assert dsdesc["wids_version"] == 1, "Unknown wids_version"
|
162 |
+
for component in dsdesc.get("datasets", []):
|
163 |
+
# we use the weight from the reference to the dataset,
|
164 |
+
# regardless of remote loading
|
165 |
+
weight = component.get("weight")
|
166 |
+
# follow any source_url dsdescs through remote loading
|
167 |
+
source_url = None
|
168 |
+
if "source_url" in component:
|
169 |
+
source_url = component["source_url"]
|
170 |
+
component = load_remote_dsdesc_raw(source_url)
|
171 |
+
assert "source_url" not in component, "double indirection in dataset description"
|
172 |
+
assert "shardlist" in component, "no shardlist in dataset description"
|
173 |
+
# if the component has a base, use it to rebase the shardlist
|
174 |
+
# otherwise use the base from the source_url, if any
|
175 |
+
subbase = component.get("base", urldir(source_url) if source_url else None)
|
176 |
+
if subbase is not None:
|
177 |
+
rebase_shardlist(component["shardlist"], subbase)
|
178 |
+
l = check_shards(component["shardlist"])
|
179 |
+
set_all(l, "weight", weight)
|
180 |
+
set_all(l, "source_url", source_url)
|
181 |
+
set_all(l, "dataset", component.get("name"))
|
182 |
+
shardlist.extend(l)
|
183 |
+
assert len(shardlist) > 0, "No shards found"
|
184 |
+
dsdesc["shardlist"] = shardlist
|
185 |
+
return dsdesc
|
186 |
+
|
187 |
+
|
188 |
+
def load_dsdesc_and_resolve(source, *, options=None, base=None):
|
189 |
+
if options is None:
|
190 |
+
options = {}
|
191 |
+
dsdesc = load_remote_dsdesc_raw(source)
|
192 |
+
return resolve_dsdesc(dsdesc, base=base, options=options)
|
diffusion/data/wids/wids_tar.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
|
18 |
+
import io
|
19 |
+
import os
|
20 |
+
import os.path
|
21 |
+
import pickle
|
22 |
+
import re
|
23 |
+
import tarfile
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
def find_index_file(file):
|
29 |
+
prefix, last_ext = os.path.splitext(file)
|
30 |
+
if re.match("._[0-9]+_$", last_ext):
|
31 |
+
return prefix + ".index"
|
32 |
+
else:
|
33 |
+
return file + ".index"
|
34 |
+
|
35 |
+
|
36 |
+
class TarFileReader:
|
37 |
+
def __init__(self, file, index_file=find_index_file, verbose=True):
|
38 |
+
self.verbose = verbose
|
39 |
+
if callable(index_file):
|
40 |
+
index_file = index_file(file)
|
41 |
+
self.index_file = index_file
|
42 |
+
|
43 |
+
# Open the tar file and keep it open
|
44 |
+
if isinstance(file, str):
|
45 |
+
self.tar_file = tarfile.open(file, "r")
|
46 |
+
else:
|
47 |
+
self.tar_file = tarfile.open(fileobj=file, mode="r")
|
48 |
+
|
49 |
+
# Create the index
|
50 |
+
self._create_tar_index()
|
51 |
+
|
52 |
+
def _create_tar_index(self):
|
53 |
+
if self.index_file is not None and os.path.exists(self.index_file):
|
54 |
+
if self.verbose:
|
55 |
+
print("Loading tar index from", self.index_file)
|
56 |
+
with open(self.index_file, "rb") as stream:
|
57 |
+
self.fnames, self.index = pickle.load(stream)
|
58 |
+
return
|
59 |
+
# Create an empty list for the index
|
60 |
+
self.fnames = []
|
61 |
+
self.index = []
|
62 |
+
|
63 |
+
if self.verbose:
|
64 |
+
print("Creating tar index for", self.tar_file.name, "at", self.index_file)
|
65 |
+
# Iterate over the members of the tar file
|
66 |
+
for member in self.tar_file:
|
67 |
+
# If the member is a file, add it to the index
|
68 |
+
if member.isfile():
|
69 |
+
# Get the file's offset
|
70 |
+
offset = self.tar_file.fileobj.tell()
|
71 |
+
self.fnames.append(member.name)
|
72 |
+
self.index.append([offset, member.size])
|
73 |
+
if self.verbose:
|
74 |
+
print("Done creating tar index for", self.tar_file.name, "at", self.index_file)
|
75 |
+
self.index = np.array(self.index)
|
76 |
+
if self.index_file is not None:
|
77 |
+
if os.path.exists(self.index_file + ".temp"):
|
78 |
+
os.unlink(self.index_file + ".temp")
|
79 |
+
with open(self.index_file + ".temp", "wb") as stream:
|
80 |
+
pickle.dump((self.fnames, self.index), stream)
|
81 |
+
os.rename(self.index_file + ".temp", self.index_file)
|
82 |
+
|
83 |
+
def names(self):
|
84 |
+
return self.fnames
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.index)
|
88 |
+
|
89 |
+
def get_file(self, i):
|
90 |
+
name = self.fnames[i]
|
91 |
+
offset, size = self.index[i]
|
92 |
+
self.tar_file.fileobj.seek(offset)
|
93 |
+
file_bytes = self.tar_file.fileobj.read(size)
|
94 |
+
return name, io.BytesIO(file_bytes)
|
95 |
+
|
96 |
+
def close(self):
|
97 |
+
# Close the tar file
|
98 |
+
self.tar_file.close()
|
diffusion/dpm_solver.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from .model import gaussian_diffusion as gd
|
20 |
+
from .model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper
|
21 |
+
|
22 |
+
|
23 |
+
def DPMS(
|
24 |
+
model,
|
25 |
+
condition,
|
26 |
+
uncondition,
|
27 |
+
cfg_scale,
|
28 |
+
pag_scale=1.0,
|
29 |
+
pag_applied_layers=None,
|
30 |
+
model_type="noise", # or "x_start" or "v" or "score", "flow"
|
31 |
+
noise_schedule="linear",
|
32 |
+
guidance_type="classifier-free",
|
33 |
+
model_kwargs=None,
|
34 |
+
diffusion_steps=1000,
|
35 |
+
schedule="VP",
|
36 |
+
interval_guidance=None,
|
37 |
+
):
|
38 |
+
if pag_applied_layers is None:
|
39 |
+
pag_applied_layers = []
|
40 |
+
if model_kwargs is None:
|
41 |
+
model_kwargs = {}
|
42 |
+
if interval_guidance is None:
|
43 |
+
interval_guidance = [0, 1.0]
|
44 |
+
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
|
45 |
+
|
46 |
+
## 1. Define the noise schedule.
|
47 |
+
if schedule == "VP":
|
48 |
+
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
|
49 |
+
elif schedule == "FLOW":
|
50 |
+
noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
|
51 |
+
|
52 |
+
## 2. Convert your discrete-time `model` to the continuous-time
|
53 |
+
## noise prediction model. Here is an example for a diffusion model
|
54 |
+
## `model` with the noise prediction type ("noise") .
|
55 |
+
model_fn = model_wrapper(
|
56 |
+
model,
|
57 |
+
noise_schedule,
|
58 |
+
model_type=model_type,
|
59 |
+
model_kwargs=model_kwargs,
|
60 |
+
guidance_type=guidance_type,
|
61 |
+
pag_scale=pag_scale,
|
62 |
+
pag_applied_layers=pag_applied_layers,
|
63 |
+
condition=condition,
|
64 |
+
unconditional_condition=uncondition,
|
65 |
+
guidance_scale=cfg_scale,
|
66 |
+
interval_guidance=interval_guidance,
|
67 |
+
)
|
68 |
+
## 3. Define dpm-solver and sample by multistep DPM-Solver.
|
69 |
+
return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
diffusion/flow_euler_sampler.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
import os
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
21 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
22 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
|
26 |
+
class FlowEuler:
|
27 |
+
def __init__(self, model_fn, condition, uncondition, cfg_scale, model_kwargs):
|
28 |
+
self.model = model_fn
|
29 |
+
self.condition = condition
|
30 |
+
self.uncondition = uncondition
|
31 |
+
self.cfg_scale = cfg_scale
|
32 |
+
self.model_kwargs = model_kwargs
|
33 |
+
# repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
34 |
+
self.scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0)
|
35 |
+
|
36 |
+
def sample(self, latents, steps=28):
|
37 |
+
device = self.condition.device
|
38 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, steps, device, None)
|
39 |
+
do_classifier_free_guidance = True
|
40 |
+
|
41 |
+
prompt_embeds = self.condition
|
42 |
+
if do_classifier_free_guidance:
|
43 |
+
prompt_embeds = torch.cat([self.uncondition, self.condition], dim=0)
|
44 |
+
|
45 |
+
for i, t in tqdm(list(enumerate(timesteps)), disable=os.getenv("DPM_TQDM", "False") == "True"):
|
46 |
+
|
47 |
+
# expand the latents if we are doing classifier free guidance
|
48 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
49 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
50 |
+
timestep = t.expand(latent_model_input.shape[0])
|
51 |
+
|
52 |
+
noise_pred = self.model(
|
53 |
+
latent_model_input,
|
54 |
+
timestep,
|
55 |
+
prompt_embeds,
|
56 |
+
**self.model_kwargs,
|
57 |
+
)
|
58 |
+
|
59 |
+
if isinstance(noise_pred, Transformer2DModelOutput):
|
60 |
+
noise_pred = noise_pred[0]
|
61 |
+
|
62 |
+
# perform guidance
|
63 |
+
if do_classifier_free_guidance:
|
64 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
65 |
+
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
66 |
+
|
67 |
+
# compute the previous noisy sample x_t -> x_t-1
|
68 |
+
latents_dtype = latents.dtype
|
69 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
70 |
+
|
71 |
+
if latents.dtype != latents_dtype:
|
72 |
+
latents = latents.to(latents_dtype)
|
73 |
+
|
74 |
+
return latents
|
diffusion/iddpm.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
# Modified from OpenAI's diffusion repos
|
18 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
19 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
20 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
21 |
+
from diffusion.model.respace import SpacedDiffusion, space_timesteps
|
22 |
+
|
23 |
+
from .model import gaussian_diffusion as gd
|
24 |
+
|
25 |
+
|
26 |
+
def Scheduler(
|
27 |
+
timestep_respacing,
|
28 |
+
noise_schedule="linear",
|
29 |
+
use_kl=False,
|
30 |
+
sigma_small=False,
|
31 |
+
predict_xstart=False,
|
32 |
+
predict_v=False,
|
33 |
+
learn_sigma=True,
|
34 |
+
pred_sigma=True,
|
35 |
+
rescale_learned_sigmas=False,
|
36 |
+
diffusion_steps=1000,
|
37 |
+
snr=False,
|
38 |
+
return_startx=False,
|
39 |
+
flow_shift=1.0,
|
40 |
+
):
|
41 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
42 |
+
if use_kl:
|
43 |
+
loss_type = gd.LossType.RESCALED_KL
|
44 |
+
elif rescale_learned_sigmas:
|
45 |
+
loss_type = gd.LossType.RESCALED_MSE
|
46 |
+
else:
|
47 |
+
loss_type = gd.LossType.MSE
|
48 |
+
if timestep_respacing is None or timestep_respacing == "":
|
49 |
+
timestep_respacing = [diffusion_steps]
|
50 |
+
if predict_xstart:
|
51 |
+
model_mean_type = gd.ModelMeanType.START_X
|
52 |
+
elif predict_v:
|
53 |
+
model_mean_type = gd.ModelMeanType.VELOCITY
|
54 |
+
else:
|
55 |
+
model_mean_type = gd.ModelMeanType.EPSILON
|
56 |
+
return SpacedDiffusion(
|
57 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
58 |
+
betas=betas,
|
59 |
+
model_mean_type=model_mean_type,
|
60 |
+
model_var_type=(
|
61 |
+
(
|
62 |
+
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
|
63 |
+
if not learn_sigma
|
64 |
+
else gd.ModelVarType.LEARNED_RANGE
|
65 |
+
)
|
66 |
+
if pred_sigma
|
67 |
+
else None
|
68 |
+
),
|
69 |
+
loss_type=loss_type,
|
70 |
+
snr=snr,
|
71 |
+
return_startx=return_startx,
|
72 |
+
# rescale_timesteps=rescale_timesteps,
|
73 |
+
flow="flow" in noise_schedule,
|
74 |
+
flow_shift=flow_shift,
|
75 |
+
diffusion_steps=diffusion_steps,
|
76 |
+
)
|
diffusion/lcm_scheduler.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
from diffusers import ConfigMixin, SchedulerMixin
|
25 |
+
from diffusers.configuration_utils import register_to_config
|
26 |
+
from diffusers.utils import BaseOutput
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
31 |
+
class LCMSchedulerOutput(BaseOutput):
|
32 |
+
"""
|
33 |
+
Output class for the scheduler's `step` function output.
|
34 |
+
Args:
|
35 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
36 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
37 |
+
denoising loop.
|
38 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
40 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
41 |
+
"""
|
42 |
+
|
43 |
+
prev_sample: torch.FloatTensor
|
44 |
+
denoised: Optional[torch.FloatTensor] = None
|
45 |
+
|
46 |
+
|
47 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
48 |
+
def betas_for_alpha_bar(
|
49 |
+
num_diffusion_timesteps,
|
50 |
+
max_beta=0.999,
|
51 |
+
alpha_transform_type="cosine",
|
52 |
+
):
|
53 |
+
"""
|
54 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
55 |
+
(1-beta) over time from t = [0,1].
|
56 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
57 |
+
to that part of the diffusion process.
|
58 |
+
Args:
|
59 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
60 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
61 |
+
prevent singularities.
|
62 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
63 |
+
Choose from `cosine` or `exp`
|
64 |
+
Returns:
|
65 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
66 |
+
"""
|
67 |
+
if alpha_transform_type == "cosine":
|
68 |
+
|
69 |
+
def alpha_bar_fn(t):
|
70 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
71 |
+
|
72 |
+
elif alpha_transform_type == "exp":
|
73 |
+
|
74 |
+
def alpha_bar_fn(t):
|
75 |
+
return math.exp(t * -12.0)
|
76 |
+
|
77 |
+
else:
|
78 |
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
79 |
+
|
80 |
+
betas = []
|
81 |
+
for i in range(num_diffusion_timesteps):
|
82 |
+
t1 = i / num_diffusion_timesteps
|
83 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
84 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
85 |
+
return torch.tensor(betas, dtype=torch.float32)
|
86 |
+
|
87 |
+
|
88 |
+
def rescale_zero_terminal_snr(betas):
|
89 |
+
"""
|
90 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
91 |
+
Args:
|
92 |
+
betas (`torch.FloatTensor`):
|
93 |
+
the betas that the scheduler is being initialized with.
|
94 |
+
Returns:
|
95 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
96 |
+
"""
|
97 |
+
# Convert betas to alphas_bar_sqrt
|
98 |
+
alphas = 1.0 - betas
|
99 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
100 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
101 |
+
|
102 |
+
# Store old values.
|
103 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
104 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
105 |
+
|
106 |
+
# Shift so the last timestep is zero.
|
107 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
108 |
+
|
109 |
+
# Scale so the first timestep is back to the old value.
|
110 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
111 |
+
|
112 |
+
# Convert alphas_bar_sqrt to betas
|
113 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
114 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
115 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
116 |
+
betas = 1 - alphas
|
117 |
+
|
118 |
+
return betas
|
119 |
+
|
120 |
+
|
121 |
+
class LCMScheduler(SchedulerMixin, ConfigMixin):
|
122 |
+
"""
|
123 |
+
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
124 |
+
non-Markovian guidance.
|
125 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
126 |
+
methods the library implements for all schedulers such as loading and saving.
|
127 |
+
Args:
|
128 |
+
num_train_timesteps (`int`, defaults to 1000):
|
129 |
+
The number of diffusion steps to train the model.
|
130 |
+
beta_start (`float`, defaults to 0.0001):
|
131 |
+
The starting `beta` value of inference.
|
132 |
+
beta_end (`float`, defaults to 0.02):
|
133 |
+
The final `beta` value.
|
134 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
135 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
136 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
137 |
+
trained_betas (`np.ndarray`, *optional*):
|
138 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
139 |
+
clip_sample (`bool`, defaults to `True`):
|
140 |
+
Clip the predicted sample for numerical stability.
|
141 |
+
clip_sample_range (`float`, defaults to 1.0):
|
142 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
143 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
144 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
145 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
146 |
+
otherwise it uses the alpha value at step 0.
|
147 |
+
steps_offset (`int`, defaults to 0):
|
148 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
149 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
150 |
+
Diffusion.
|
151 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
152 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
153 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
154 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
155 |
+
thresholding (`bool`, defaults to `False`):
|
156 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
157 |
+
as Stable Diffusion.
|
158 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
159 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
160 |
+
sample_max_value (`float`, defaults to 1.0):
|
161 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
162 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
163 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
164 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
165 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
166 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
167 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
168 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
169 |
+
"""
|
170 |
+
|
171 |
+
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
172 |
+
order = 1
|
173 |
+
|
174 |
+
@register_to_config
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
num_train_timesteps: int = 1000,
|
178 |
+
beta_start: float = 0.0001,
|
179 |
+
beta_end: float = 0.02,
|
180 |
+
beta_schedule: str = "linear",
|
181 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
182 |
+
clip_sample: bool = True,
|
183 |
+
set_alpha_to_one: bool = True,
|
184 |
+
steps_offset: int = 0,
|
185 |
+
prediction_type: str = "epsilon",
|
186 |
+
thresholding: bool = False,
|
187 |
+
dynamic_thresholding_ratio: float = 0.995,
|
188 |
+
clip_sample_range: float = 1.0,
|
189 |
+
sample_max_value: float = 1.0,
|
190 |
+
timestep_spacing: str = "leading",
|
191 |
+
rescale_betas_zero_snr: bool = False,
|
192 |
+
):
|
193 |
+
if trained_betas is not None:
|
194 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
195 |
+
elif beta_schedule == "linear":
|
196 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
197 |
+
elif beta_schedule == "scaled_linear":
|
198 |
+
# this schedule is very specific to the latent diffusion model.
|
199 |
+
self.betas = (
|
200 |
+
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
201 |
+
)
|
202 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
203 |
+
# Glide cosine schedule
|
204 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
205 |
+
else:
|
206 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
207 |
+
|
208 |
+
# Rescale for zero SNR
|
209 |
+
if rescale_betas_zero_snr:
|
210 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
211 |
+
|
212 |
+
self.alphas = 1.0 - self.betas
|
213 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
214 |
+
|
215 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
216 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
217 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
218 |
+
# whether we use the final alpha of the "non-previous" one.
|
219 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
220 |
+
|
221 |
+
# standard deviation of the initial noise distribution
|
222 |
+
self.init_noise_sigma = 1.0
|
223 |
+
|
224 |
+
# setable values
|
225 |
+
self.num_inference_steps = None
|
226 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
227 |
+
|
228 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
229 |
+
"""
|
230 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
231 |
+
current timestep.
|
232 |
+
Args:
|
233 |
+
sample (`torch.FloatTensor`):
|
234 |
+
The input sample.
|
235 |
+
timestep (`int`, *optional*):
|
236 |
+
The current timestep in the diffusion chain.
|
237 |
+
Returns:
|
238 |
+
`torch.FloatTensor`:
|
239 |
+
A scaled input sample.
|
240 |
+
"""
|
241 |
+
return sample
|
242 |
+
|
243 |
+
def _get_variance(self, timestep, prev_timestep):
|
244 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
245 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
246 |
+
beta_prod_t = 1 - alpha_prod_t
|
247 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
248 |
+
|
249 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
250 |
+
|
251 |
+
return variance
|
252 |
+
|
253 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
254 |
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
255 |
+
"""
|
256 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
257 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
258 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
259 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
260 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
261 |
+
https://arxiv.org/abs/2205.11487
|
262 |
+
"""
|
263 |
+
dtype = sample.dtype
|
264 |
+
batch_size, channels, height, width = sample.shape
|
265 |
+
|
266 |
+
if dtype not in (torch.float32, torch.float64):
|
267 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
268 |
+
|
269 |
+
# Flatten sample for doing quantile calculation along each image
|
270 |
+
sample = sample.reshape(batch_size, channels * height * width)
|
271 |
+
|
272 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
273 |
+
|
274 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
275 |
+
s = torch.clamp(
|
276 |
+
s, min=1, max=self.config.sample_max_value
|
277 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
278 |
+
|
279 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
280 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
281 |
+
|
282 |
+
sample = sample.reshape(batch_size, channels, height, width)
|
283 |
+
sample = sample.to(dtype)
|
284 |
+
|
285 |
+
return sample
|
286 |
+
|
287 |
+
def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
|
288 |
+
"""
|
289 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
290 |
+
Args:
|
291 |
+
num_inference_steps (`int`):
|
292 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
293 |
+
"""
|
294 |
+
|
295 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
296 |
+
raise ValueError(
|
297 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
298 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
299 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
300 |
+
)
|
301 |
+
|
302 |
+
self.num_inference_steps = num_inference_steps
|
303 |
+
|
304 |
+
# LCM Timesteps Setting: # Linear Spacing
|
305 |
+
c = self.config.num_train_timesteps // lcm_origin_steps
|
306 |
+
lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
|
307 |
+
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
308 |
+
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
|
309 |
+
|
310 |
+
self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
|
311 |
+
|
312 |
+
def get_scalings_for_boundary_condition_discrete(self, t):
|
313 |
+
self.sigma_data = 0.5 # Default: 0.5
|
314 |
+
|
315 |
+
# By dividing 0.1: This is almost a delta function at t=0.
|
316 |
+
c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
|
317 |
+
c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
|
318 |
+
return c_skip, c_out
|
319 |
+
|
320 |
+
def step(
|
321 |
+
self,
|
322 |
+
model_output: torch.FloatTensor,
|
323 |
+
timeindex: int,
|
324 |
+
timestep: int,
|
325 |
+
sample: torch.FloatTensor,
|
326 |
+
eta: float = 0.0,
|
327 |
+
use_clipped_model_output: bool = False,
|
328 |
+
generator=None,
|
329 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
330 |
+
return_dict: bool = True,
|
331 |
+
) -> Union[LCMSchedulerOutput, Tuple]:
|
332 |
+
"""
|
333 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
334 |
+
process from the learned model outputs (most often the predicted noise).
|
335 |
+
Args:
|
336 |
+
model_output (`torch.FloatTensor`):
|
337 |
+
The direct output from learned diffusion model.
|
338 |
+
timestep (`float`):
|
339 |
+
The current discrete timestep in the diffusion chain.
|
340 |
+
sample (`torch.FloatTensor`):
|
341 |
+
A current instance of a sample created by the diffusion process.
|
342 |
+
eta (`float`):
|
343 |
+
The weight of noise for added noise in diffusion step.
|
344 |
+
use_clipped_model_output (`bool`, defaults to `False`):
|
345 |
+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
346 |
+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
347 |
+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
348 |
+
`use_clipped_model_output` has no effect.
|
349 |
+
generator (`torch.Generator`, *optional*):
|
350 |
+
A random number generator.
|
351 |
+
variance_noise (`torch.FloatTensor`):
|
352 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
353 |
+
itself. Useful for methods such as [`CycleDiffusion`].
|
354 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
355 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
356 |
+
Returns:
|
357 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
358 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
359 |
+
tuple is returned where the first element is the sample tensor.
|
360 |
+
"""
|
361 |
+
if self.num_inference_steps is None:
|
362 |
+
raise ValueError(
|
363 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
364 |
+
)
|
365 |
+
|
366 |
+
# 1. get previous step value
|
367 |
+
prev_timeindex = timeindex + 1
|
368 |
+
if prev_timeindex < len(self.timesteps):
|
369 |
+
prev_timestep = self.timesteps[prev_timeindex]
|
370 |
+
else:
|
371 |
+
prev_timestep = timestep
|
372 |
+
|
373 |
+
# 2. compute alphas, betas
|
374 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
375 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
376 |
+
|
377 |
+
beta_prod_t = 1 - alpha_prod_t
|
378 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
379 |
+
|
380 |
+
# 3. Get scalings for boundary conditions
|
381 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
382 |
+
|
383 |
+
# 4. Different Parameterization:
|
384 |
+
parameterization = self.config.prediction_type
|
385 |
+
|
386 |
+
if parameterization == "epsilon": # noise-prediction
|
387 |
+
pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
388 |
+
|
389 |
+
elif parameterization == "sample": # x-prediction
|
390 |
+
pred_x0 = model_output
|
391 |
+
|
392 |
+
elif parameterization == "v_prediction": # v-prediction
|
393 |
+
pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
394 |
+
|
395 |
+
# 4. Denoise model output using boundary conditions
|
396 |
+
denoised = c_out * pred_x0 + c_skip * sample
|
397 |
+
|
398 |
+
# 5. Sample z ~ N(0, I), For MultiStep Inference
|
399 |
+
# Noise is not used for one-step sampling.
|
400 |
+
if len(self.timesteps) > 1:
|
401 |
+
noise = torch.randn(model_output.shape).to(model_output.device)
|
402 |
+
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
403 |
+
else:
|
404 |
+
prev_sample = denoised
|
405 |
+
|
406 |
+
if not return_dict:
|
407 |
+
return (prev_sample, denoised)
|
408 |
+
|
409 |
+
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
410 |
+
|
411 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
412 |
+
def add_noise(
|
413 |
+
self,
|
414 |
+
original_samples: torch.FloatTensor,
|
415 |
+
noise: torch.FloatTensor,
|
416 |
+
timesteps: torch.IntTensor,
|
417 |
+
) -> torch.FloatTensor:
|
418 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
419 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
420 |
+
timesteps = timesteps.to(original_samples.device)
|
421 |
+
|
422 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
423 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
424 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
425 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
426 |
+
|
427 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
428 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
429 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
430 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
431 |
+
|
432 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
433 |
+
return noisy_samples
|
434 |
+
|
435 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
436 |
+
def get_velocity(
|
437 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
438 |
+
) -> torch.FloatTensor:
|
439 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
440 |
+
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
441 |
+
timesteps = timesteps.to(sample.device)
|
442 |
+
|
443 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
444 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
445 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
446 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
447 |
+
|
448 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
449 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
450 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
451 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
452 |
+
|
453 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
454 |
+
return velocity
|
455 |
+
|
456 |
+
def __len__(self):
|
457 |
+
return self.config.num_train_timesteps
|
diffusion/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .nets import *
|