Upload 81 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +24 -0
- .gitignore +170 -0
- LICENSE +21 -0
- OmniGen/__init__.py +4 -0
- OmniGen/__pycache__/__init__.cpython-310.pyc +0 -0
- OmniGen/__pycache__/model.cpython-310.pyc +0 -0
- OmniGen/__pycache__/pipeline.cpython-310.pyc +0 -0
- OmniGen/__pycache__/processor.cpython-310.pyc +0 -0
- OmniGen/__pycache__/scheduler.cpython-310.pyc +0 -0
- OmniGen/__pycache__/transformer.cpython-310.pyc +0 -0
- OmniGen/__pycache__/utils.cpython-310.pyc +0 -0
- OmniGen/model.py +468 -0
- OmniGen/pipeline.py +289 -0
- OmniGen/processor.py +335 -0
- OmniGen/scheduler.py +55 -0
- OmniGen/train_helper/__init__.py +2 -0
- OmniGen/train_helper/data.py +116 -0
- OmniGen/train_helper/loss.py +68 -0
- OmniGen/transformer.py +159 -0
- OmniGen/utils.py +110 -0
- README.md +93 -14
- app.py +359 -0
- docs/fine-tuning.md +172 -0
- docs/inference.md +96 -0
- imgs/.DS_Store +0 -0
- imgs/demo_cases.png +3 -0
- imgs/demo_cases/AI_Pioneers.jpg +0 -0
- imgs/demo_cases/edit.png +3 -0
- imgs/demo_cases/entity.png +3 -0
- imgs/demo_cases/reasoning.png +3 -0
- imgs/demo_cases/same_pose.png +3 -0
- imgs/demo_cases/skeletal.png +0 -0
- imgs/demo_cases/skeletal2img.png +3 -0
- imgs/demo_cases/t2i_woman_with_book.png +3 -0
- imgs/overall.jpg +3 -0
- imgs/referring.png +3 -0
- imgs/test_cases/1.jpg +3 -0
- imgs/test_cases/2.jpg +3 -0
- imgs/test_cases/3.jpg +3 -0
- imgs/test_cases/4.jpg +3 -0
- imgs/test_cases/Amanda.jpg +3 -0
- imgs/test_cases/control.jpg +3 -0
- imgs/test_cases/icl1.jpg +0 -0
- imgs/test_cases/icl2.jpg +0 -0
- imgs/test_cases/icl3.jpg +0 -0
- imgs/test_cases/lecun.png +0 -0
- imgs/test_cases/mckenna.jpg +3 -0
- imgs/test_cases/pose.png +0 -0
- imgs/test_cases/rose.jpg +0 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,27 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
imgs/demo_cases.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
imgs/demo_cases/edit.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
imgs/demo_cases/entity.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
imgs/demo_cases/reasoning.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
imgs/demo_cases/same_pose.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
imgs/demo_cases/skeletal2img.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
imgs/demo_cases/t2i_woman_with_book.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
imgs/overall.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
imgs/referring.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
imgs/test_cases/1.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
imgs/test_cases/2.jpg filter=lfs diff=lfs merge=lfs -text
|
47 |
+
imgs/test_cases/3.jpg filter=lfs diff=lfs merge=lfs -text
|
48 |
+
imgs/test_cases/4.jpg filter=lfs diff=lfs merge=lfs -text
|
49 |
+
imgs/test_cases/Amanda.jpg filter=lfs diff=lfs merge=lfs -text
|
50 |
+
imgs/test_cases/control.jpg filter=lfs diff=lfs merge=lfs -text
|
51 |
+
imgs/test_cases/mckenna.jpg filter=lfs diff=lfs merge=lfs -text
|
52 |
+
imgs/test_cases/two_man.jpg filter=lfs diff=lfs merge=lfs -text
|
53 |
+
imgs/test_cases/woman.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
toy_data/images/cat.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
toy_data/images/dog2.jpeg filter=lfs diff=lfs merge=lfs -text
|
56 |
+
toy_data/images/dog3.jpeg filter=lfs diff=lfs merge=lfs -text
|
57 |
+
toy_data/images/dog4.jpeg filter=lfs diff=lfs merge=lfs -text
|
58 |
+
toy_data/images/dog5.jpeg filter=lfs diff=lfs merge=lfs -text
|
59 |
+
toy_data/images/walking.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
# myfile
|
165 |
+
.results/
|
166 |
+
local.ipynb
|
167 |
+
convert_to_safetensor.py
|
168 |
+
ttt.ipynb
|
169 |
+
imgs/ttt/
|
170 |
+
*.bak
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 VectorSpaceLab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
OmniGen/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import OmniGen
|
2 |
+
from .processor import OmniGenProcessor
|
3 |
+
from .scheduler import OmniGenScheduler
|
4 |
+
from .pipeline import OmniGenPipeline
|
OmniGen/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (320 Bytes). View file
|
|
OmniGen/__pycache__/model.cpython-310.pyc
ADDED
Binary file (14.4 kB). View file
|
|
OmniGen/__pycache__/pipeline.cpython-310.pyc
ADDED
Binary file (8.89 kB). View file
|
|
OmniGen/__pycache__/processor.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
OmniGen/__pycache__/scheduler.cpython-310.pyc
ADDED
Binary file (2.74 kB). View file
|
|
OmniGen/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (3.94 kB). View file
|
|
OmniGen/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.52 kB). View file
|
|
OmniGen/model.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The code is revised from DiT
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
from typing import Dict
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from diffusers.loaders import PeftAdapterMixin
|
11 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
12 |
+
from huggingface_hub import snapshot_download
|
13 |
+
from safetensors.torch import load_file
|
14 |
+
|
15 |
+
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
16 |
+
|
17 |
+
|
18 |
+
def modulate(x, shift, scale):
|
19 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
20 |
+
|
21 |
+
|
22 |
+
class TimestepEmbedder(nn.Module):
|
23 |
+
"""
|
24 |
+
Embeds scalar timesteps into vector representations.
|
25 |
+
"""
|
26 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
27 |
+
super().__init__()
|
28 |
+
self.mlp = nn.Sequential(
|
29 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
30 |
+
nn.SiLU(),
|
31 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
32 |
+
)
|
33 |
+
self.frequency_embedding_size = frequency_embedding_size
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def timestep_embedding(t, dim, max_period=10000):
|
37 |
+
"""
|
38 |
+
Create sinusoidal timestep embeddings.
|
39 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
40 |
+
These may be fractional.
|
41 |
+
:param dim: the dimension of the output.
|
42 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
43 |
+
:return: an (N, D) Tensor of positional embeddings.
|
44 |
+
"""
|
45 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
46 |
+
half = dim // 2
|
47 |
+
freqs = torch.exp(
|
48 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
49 |
+
).to(device=t.device)
|
50 |
+
args = t[:, None].float() * freqs[None]
|
51 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
52 |
+
if dim % 2:
|
53 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
54 |
+
return embedding
|
55 |
+
|
56 |
+
def forward(self, t, dtype=torch.float32):
|
57 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
58 |
+
t_emb = self.mlp(t_freq)
|
59 |
+
return t_emb
|
60 |
+
|
61 |
+
|
62 |
+
class FinalLayer(nn.Module):
|
63 |
+
"""
|
64 |
+
The final layer of DiT.
|
65 |
+
"""
|
66 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
67 |
+
super().__init__()
|
68 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
69 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
70 |
+
self.adaLN_modulation = nn.Sequential(
|
71 |
+
nn.SiLU(),
|
72 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
73 |
+
)
|
74 |
+
|
75 |
+
def forward(self, x, c):
|
76 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
77 |
+
x = modulate(self.norm_final(x), shift, scale)
|
78 |
+
x = self.linear(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
83 |
+
"""
|
84 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
85 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
86 |
+
"""
|
87 |
+
if isinstance(grid_size, int):
|
88 |
+
grid_size = (grid_size, grid_size)
|
89 |
+
|
90 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
91 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
92 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
93 |
+
grid = np.stack(grid, axis=0)
|
94 |
+
|
95 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
96 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
97 |
+
if cls_token and extra_tokens > 0:
|
98 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
99 |
+
return pos_embed
|
100 |
+
|
101 |
+
|
102 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
103 |
+
assert embed_dim % 2 == 0
|
104 |
+
|
105 |
+
# use half of dimensions to encode grid_h
|
106 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
107 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
108 |
+
|
109 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
110 |
+
return emb
|
111 |
+
|
112 |
+
|
113 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
114 |
+
"""
|
115 |
+
embed_dim: output dimension for each position
|
116 |
+
pos: a list of positions to be encoded: size (M,)
|
117 |
+
out: (M, D)
|
118 |
+
"""
|
119 |
+
assert embed_dim % 2 == 0
|
120 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
121 |
+
omega /= embed_dim / 2.
|
122 |
+
omega = 1. / 10000**omega # (D/2,)
|
123 |
+
|
124 |
+
pos = pos.reshape(-1) # (M,)
|
125 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
126 |
+
|
127 |
+
emb_sin = np.sin(out) # (M, D/2)
|
128 |
+
emb_cos = np.cos(out) # (M, D/2)
|
129 |
+
|
130 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
131 |
+
return emb
|
132 |
+
|
133 |
+
|
134 |
+
class PatchEmbedMR(nn.Module):
|
135 |
+
""" 2D Image to Patch Embedding
|
136 |
+
"""
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
patch_size: int = 2,
|
140 |
+
in_chans: int = 4,
|
141 |
+
embed_dim: int = 768,
|
142 |
+
bias: bool = True,
|
143 |
+
):
|
144 |
+
super().__init__()
|
145 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
x = self.proj(x)
|
149 |
+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
class Int8Quantized(nn.Module):
|
154 |
+
def __init__(self, tensor, scale_factor=None):
|
155 |
+
super().__init__()
|
156 |
+
if scale_factor is None:
|
157 |
+
max_val = torch.max(torch.abs(tensor))
|
158 |
+
scale_factor = max_val / 127.0
|
159 |
+
# Store quantized weights and scale factor
|
160 |
+
self.register_buffer('quantized_weight', torch.round(tensor / scale_factor).to(torch.int8))
|
161 |
+
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
162 |
+
|
163 |
+
def forward(self, dtype=None):
|
164 |
+
# Dequantize and convert to specified dtype
|
165 |
+
weight = self.quantized_weight.float() * self.scale_factor
|
166 |
+
if dtype is not None:
|
167 |
+
weight = weight.to(dtype)
|
168 |
+
return weight
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
class QuantizedLinear(nn.Module):
|
173 |
+
def __init__(self, weight, bias=None):
|
174 |
+
super().__init__()
|
175 |
+
self.weight_quantized = Int8Quantized(weight)
|
176 |
+
if bias is not None:
|
177 |
+
self.register_buffer('bias', bias)
|
178 |
+
else:
|
179 |
+
self.bias = None
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
# Dequantize weight to match input dtype
|
183 |
+
weight = self.weight_quantized(dtype=x.dtype)
|
184 |
+
return F.linear(x, weight, self.bias)
|
185 |
+
|
186 |
+
|
187 |
+
class OmniGen(nn.Module, PeftAdapterMixin):
|
188 |
+
"""
|
189 |
+
Diffusion model with a Transformer backbone.
|
190 |
+
"""
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
transformer_config: Phi3Config,
|
194 |
+
patch_size=2,
|
195 |
+
in_channels=4,
|
196 |
+
pe_interpolation: float = 1.0,
|
197 |
+
pos_embed_max_size: int = 192,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
|
201 |
+
self.in_channels = in_channels
|
202 |
+
self.out_channels = in_channels
|
203 |
+
self.patch_size = patch_size
|
204 |
+
self.pos_embed_max_size = pos_embed_max_size
|
205 |
+
|
206 |
+
hidden_size = transformer_config.hidden_size
|
207 |
+
|
208 |
+
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
209 |
+
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
210 |
+
|
211 |
+
self.time_token = TimestepEmbedder(hidden_size)
|
212 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
213 |
+
|
214 |
+
self.pe_interpolation = pe_interpolation
|
215 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
216 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
217 |
+
|
218 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
219 |
+
|
220 |
+
self.initialize_weights()
|
221 |
+
|
222 |
+
self.llm = Phi3Transformer(config=transformer_config)
|
223 |
+
self.llm.config.use_cache = False
|
224 |
+
|
225 |
+
def _quantize_module(self, module):
|
226 |
+
"""
|
227 |
+
Quantize a module to 8-bit precision
|
228 |
+
"""
|
229 |
+
for name, child in module.named_children():
|
230 |
+
if isinstance(child, nn.Linear):
|
231 |
+
setattr(module, name, QuantizedLinear(child.weight.data, child.bias.data if child.bias is not None else None))
|
232 |
+
elif isinstance(child, nn.LayerNorm):
|
233 |
+
# Skip quantization for LayerNorm
|
234 |
+
continue
|
235 |
+
else:
|
236 |
+
self._quantize_module(child)
|
237 |
+
|
238 |
+
@classmethod
|
239 |
+
def from_pretrained(cls, model_name, quantize=False): # Add quantize parameter
|
240 |
+
if not os.path.exists(model_name):
|
241 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
242 |
+
model_name = snapshot_download(repo_id=model_name,
|
243 |
+
cache_dir=cache_folder,
|
244 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
245 |
+
config = Phi3Config.from_pretrained(model_name)
|
246 |
+
model = cls(config)
|
247 |
+
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
248 |
+
print("Loading safetensors")
|
249 |
+
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
250 |
+
else:
|
251 |
+
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
252 |
+
|
253 |
+
# Load weights first
|
254 |
+
model.load_state_dict(ckpt)
|
255 |
+
|
256 |
+
# Only quantize if explicitly requested
|
257 |
+
if quantize:
|
258 |
+
print("Quantizing weights to 8-bit...")
|
259 |
+
model._quantize_module(model.llm)
|
260 |
+
|
261 |
+
return model
|
262 |
+
def initialize_weights(self):
|
263 |
+
assert not hasattr(self, "llama")
|
264 |
+
|
265 |
+
# Initialize transformer layers:
|
266 |
+
def _basic_init(module):
|
267 |
+
if isinstance(module, nn.Linear):
|
268 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
269 |
+
if module.bias is not None:
|
270 |
+
nn.init.constant_(module.bias, 0)
|
271 |
+
self.apply(_basic_init)
|
272 |
+
|
273 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
274 |
+
w = self.x_embedder.proj.weight.data
|
275 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
276 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
277 |
+
|
278 |
+
w = self.input_x_embedder.proj.weight.data
|
279 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
280 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
281 |
+
|
282 |
+
|
283 |
+
# Initialize timestep embedding MLP:
|
284 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
285 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
286 |
+
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
287 |
+
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
288 |
+
|
289 |
+
# Zero-out output layers:
|
290 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
291 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
292 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
293 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
294 |
+
|
295 |
+
def unpatchify(self, x, h, w):
|
296 |
+
"""
|
297 |
+
x: (N, T, patch_size**2 * C)
|
298 |
+
imgs: (N, H, W, C)
|
299 |
+
"""
|
300 |
+
c = self.out_channels
|
301 |
+
|
302 |
+
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
303 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
304 |
+
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
305 |
+
return imgs
|
306 |
+
|
307 |
+
|
308 |
+
def cropped_pos_embed(self, height, width):
|
309 |
+
"""Crops positional embeddings for SD3 compatibility."""
|
310 |
+
if self.pos_embed_max_size is None:
|
311 |
+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
312 |
+
|
313 |
+
height = height // self.patch_size
|
314 |
+
width = width // self.patch_size
|
315 |
+
if height > self.pos_embed_max_size:
|
316 |
+
raise ValueError(
|
317 |
+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
318 |
+
)
|
319 |
+
if width > self.pos_embed_max_size:
|
320 |
+
raise ValueError(
|
321 |
+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
322 |
+
)
|
323 |
+
|
324 |
+
top = (self.pos_embed_max_size - height) // 2
|
325 |
+
left = (self.pos_embed_max_size - width) // 2
|
326 |
+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
327 |
+
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
328 |
+
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
329 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
330 |
+
return spatial_pos_embed
|
331 |
+
|
332 |
+
|
333 |
+
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
334 |
+
if isinstance(latents, list):
|
335 |
+
return_list = False
|
336 |
+
if padding_latent is None:
|
337 |
+
padding_latent = [None] * len(latents)
|
338 |
+
return_list = True
|
339 |
+
patched_latents, num_tokens, shapes = [], [], []
|
340 |
+
for latent, padding in zip(latents, padding_latent):
|
341 |
+
height, width = latent.shape[-2:]
|
342 |
+
if is_input_images:
|
343 |
+
latent = self.input_x_embedder(latent)
|
344 |
+
else:
|
345 |
+
latent = self.x_embedder(latent)
|
346 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
347 |
+
latent = latent + pos_embed
|
348 |
+
if padding is not None:
|
349 |
+
latent = torch.cat([latent, padding], dim=-2)
|
350 |
+
patched_latents.append(latent)
|
351 |
+
|
352 |
+
num_tokens.append(pos_embed.size(1))
|
353 |
+
shapes.append([height, width])
|
354 |
+
if not return_list:
|
355 |
+
latents = torch.cat(patched_latents, dim=0)
|
356 |
+
else:
|
357 |
+
latents = patched_latents
|
358 |
+
else:
|
359 |
+
height, width = latents.shape[-2:]
|
360 |
+
if is_input_images:
|
361 |
+
latents = self.input_x_embedder(latents)
|
362 |
+
else:
|
363 |
+
latents = self.x_embedder(latents)
|
364 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
365 |
+
latents = latents + pos_embed
|
366 |
+
num_tokens = latents.size(1)
|
367 |
+
shapes = [height, width]
|
368 |
+
return latents, num_tokens, shapes
|
369 |
+
|
370 |
+
|
371 |
+
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):
|
372 |
+
"""
|
373 |
+
|
374 |
+
"""
|
375 |
+
input_is_list = isinstance(x, list)
|
376 |
+
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
377 |
+
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
378 |
+
|
379 |
+
if input_img_latents is not None:
|
380 |
+
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
381 |
+
if input_ids is not None:
|
382 |
+
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
383 |
+
input_img_inx = 0
|
384 |
+
for b_inx in input_image_sizes.keys():
|
385 |
+
for start_inx, end_inx in input_image_sizes[b_inx]:
|
386 |
+
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
387 |
+
input_img_inx += 1
|
388 |
+
if input_img_latents is not None:
|
389 |
+
assert input_img_inx == len(input_latents)
|
390 |
+
|
391 |
+
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
392 |
+
else:
|
393 |
+
input_emb = torch.cat([time_token, x], dim=1)
|
394 |
+
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
|
395 |
+
output, past_key_values = output.last_hidden_state, output.past_key_values
|
396 |
+
if input_is_list:
|
397 |
+
image_embedding = output[:, -max(num_tokens):]
|
398 |
+
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
399 |
+
x = self.final_layer(image_embedding, time_emb)
|
400 |
+
latents = []
|
401 |
+
for i in range(x.size(0)):
|
402 |
+
latent = x[i:i+1, :num_tokens[i]]
|
403 |
+
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
404 |
+
latents.append(latent)
|
405 |
+
else:
|
406 |
+
image_embedding = output[:, -num_tokens:]
|
407 |
+
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
408 |
+
x = self.final_layer(image_embedding, time_emb)
|
409 |
+
latents = self.unpatchify(x, shapes[0], shapes[1])
|
410 |
+
|
411 |
+
if return_past_key_values:
|
412 |
+
return latents, past_key_values
|
413 |
+
return latents
|
414 |
+
|
415 |
+
@torch.no_grad()
|
416 |
+
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
|
417 |
+
"""
|
418 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
419 |
+
"""
|
420 |
+
self.llm.config.use_cache = use_kv_cache
|
421 |
+
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True)
|
422 |
+
if use_img_cfg:
|
423 |
+
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
424 |
+
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
425 |
+
model_out = [cond, cond, cond]
|
426 |
+
else:
|
427 |
+
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
428 |
+
cond = uncond + cfg_scale * (cond - uncond)
|
429 |
+
model_out = [cond, cond]
|
430 |
+
|
431 |
+
return torch.cat(model_out, dim=0), past_key_values
|
432 |
+
|
433 |
+
|
434 |
+
@torch.no_grad()
|
435 |
+
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, return_past_key_values=True):
|
436 |
+
"""
|
437 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
438 |
+
"""
|
439 |
+
self.llm.config.use_cache = use_kv_cache
|
440 |
+
if past_key_values is None:
|
441 |
+
past_key_values = [None] * len(attention_mask)
|
442 |
+
|
443 |
+
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
444 |
+
timestep = timestep.to(x[0].dtype)
|
445 |
+
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
446 |
+
|
447 |
+
model_out, pask_key_values = [], []
|
448 |
+
for i in range(len(input_ids)):
|
449 |
+
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values[i])
|
450 |
+
model_out.append(temp_out)
|
451 |
+
pask_key_values.append(temp_pask_key_values)
|
452 |
+
|
453 |
+
if len(model_out) == 3:
|
454 |
+
cond, uncond, img_cond = model_out
|
455 |
+
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
456 |
+
model_out = [cond, cond, cond]
|
457 |
+
elif len(model_out) == 2:
|
458 |
+
cond, uncond = model_out
|
459 |
+
cond = uncond + cfg_scale * (cond - uncond)
|
460 |
+
model_out = [cond, cond]
|
461 |
+
else:
|
462 |
+
return model_out[0]
|
463 |
+
|
464 |
+
return torch.cat(model_out, dim=0), pask_key_values
|
465 |
+
|
466 |
+
|
467 |
+
|
468 |
+
|
OmniGen/pipeline.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import inspect
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from huggingface_hub import snapshot_download
|
9 |
+
from peft import LoraConfig, PeftModel
|
10 |
+
from diffusers.models import AutoencoderKL
|
11 |
+
from diffusers.utils import (
|
12 |
+
USE_PEFT_BACKEND,
|
13 |
+
is_torch_xla_available,
|
14 |
+
logging,
|
15 |
+
replace_example_docstring,
|
16 |
+
scale_lora_layers,
|
17 |
+
unscale_lora_layers,
|
18 |
+
)
|
19 |
+
from safetensors.torch import load_file
|
20 |
+
|
21 |
+
from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
|
22 |
+
|
23 |
+
import gc # For clearing unused objects
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
EXAMPLE_DOC_STRING = """
|
28 |
+
Examples:
|
29 |
+
```py
|
30 |
+
>>> from OmniGen import OmniGenPipeline
|
31 |
+
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
32 |
+
... base_model
|
33 |
+
... )
|
34 |
+
>>> prompt = "A woman holds a bouquet of flowers and faces the camera"
|
35 |
+
>>> image = pipe(
|
36 |
+
... prompt,
|
37 |
+
... guidance_scale=3.0,
|
38 |
+
... num_inference_steps=50,
|
39 |
+
... ).images[0]
|
40 |
+
>>> image.save("t2i.png")
|
41 |
+
```
|
42 |
+
"""
|
43 |
+
|
44 |
+
class OmniGenPipeline:
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
vae: AutoencoderKL,
|
48 |
+
model: OmniGen,
|
49 |
+
processor: OmniGenProcessor,
|
50 |
+
|
51 |
+
):
|
52 |
+
self.vae = vae
|
53 |
+
self.model = model
|
54 |
+
self.processor = processor
|
55 |
+
|
56 |
+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
57 |
+
self.model.to(self.device)
|
58 |
+
self.model.eval()
|
59 |
+
self.vae.to(self.device)
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_pretrained(cls, model_name, vae_path: str=None, Quantization: bool=False):
|
63 |
+
if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
|
64 |
+
logger.info("Model not found, downloading...")
|
65 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
66 |
+
model_name = snapshot_download(repo_id=model_name,
|
67 |
+
cache_dir=cache_folder,
|
68 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
|
69 |
+
logger.info(f"Downloaded model to {model_name}")
|
70 |
+
|
71 |
+
# Pass Quantization parameter to OmniGen's from_pretrained
|
72 |
+
model = OmniGen.from_pretrained(model_name, quantize=Quantization)
|
73 |
+
|
74 |
+
processor = OmniGenProcessor.from_pretrained(model_name)
|
75 |
+
|
76 |
+
if os.path.exists(os.path.join(model_name, "vae")):
|
77 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
|
78 |
+
elif vae_path is not None:
|
79 |
+
vae = AutoencoderKL.from_pretrained(vae_path)
|
80 |
+
else:
|
81 |
+
logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
|
82 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
83 |
+
|
84 |
+
return cls(vae, model, processor)
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
def merge_lora(self, lora_path: str):
|
89 |
+
model = PeftModel.from_pretrained(self.model, lora_path)
|
90 |
+
model.merge_and_unload()
|
91 |
+
|
92 |
+
self.model = model
|
93 |
+
|
94 |
+
def to(self, device: Union[str, torch.device]):
|
95 |
+
if isinstance(device, str):
|
96 |
+
device = torch.device(device)
|
97 |
+
self.model.to(device)
|
98 |
+
self.vae.to(device)
|
99 |
+
|
100 |
+
def vae_encode(self, x, dtype):
|
101 |
+
if self.vae.config.shift_factor is not None:
|
102 |
+
x = self.vae.encode(x).latent_dist.sample()
|
103 |
+
x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
104 |
+
else:
|
105 |
+
x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor)
|
106 |
+
x = x.to(dtype)
|
107 |
+
return x
|
108 |
+
|
109 |
+
def move_to_device(self, data):
|
110 |
+
if isinstance(data, list):
|
111 |
+
return [x.to(self.device) for x in data]
|
112 |
+
return data.to(self.device)
|
113 |
+
|
114 |
+
|
115 |
+
@torch.no_grad()
|
116 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
117 |
+
def __call__(
|
118 |
+
self,
|
119 |
+
prompt: Union[str, List[str]],
|
120 |
+
input_images: Union[List[str], List[List[str]]] = None,
|
121 |
+
height: int = 1024,
|
122 |
+
width: int = 1024,
|
123 |
+
num_inference_steps: int = 50,
|
124 |
+
guidance_scale: float = 3,
|
125 |
+
use_img_guidance: bool = True,
|
126 |
+
img_guidance_scale: float = 1.6,
|
127 |
+
separate_cfg_infer: bool = False,
|
128 |
+
use_kv_cache: bool = True,
|
129 |
+
dtype: torch.dtype = torch.bfloat16,
|
130 |
+
seed: int = None,
|
131 |
+
Quantization: bool = False,
|
132 |
+
):
|
133 |
+
|
134 |
+
r"""
|
135 |
+
Function invoked when calling the pipeline for generation.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
prompt (`str` or `List[str]`):
|
139 |
+
The prompt or prompts to guide the image generation.
|
140 |
+
input_images (`List[str]` or `List[List[str]]`, *optional*):
|
141 |
+
The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
|
142 |
+
height (`int`, *optional*, defaults to 1024):
|
143 |
+
The height in pixels of the generated image. The number must be a multiple of 16.
|
144 |
+
width (`int`, *optional*, defaults to 1024):
|
145 |
+
The width in pixels of the generated image. The number must be a multiple of 16.
|
146 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
147 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
|
148 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
149 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
150 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
151 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
152 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
153 |
+
usually at the expense of lower image quality.
|
154 |
+
use_img_guidance (`bool`, *optional*, defaults to True):
|
155 |
+
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
156 |
+
img_guidance_scale (`float`, *optional*, defaults to 1.6):
|
157 |
+
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
158 |
+
separate_cfg_infer (`bool`, *optional*, defaults to False):
|
159 |
+
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
|
160 |
+
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
|
161 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
162 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
163 |
+
to make generation deterministic.
|
164 |
+
Examples:
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
A list with the generated images.
|
168 |
+
"""
|
169 |
+
|
170 |
+
|
171 |
+
assert height%16 == 0 and width%16 == 0
|
172 |
+
if separate_cfg_infer:
|
173 |
+
use_kv_cache = False
|
174 |
+
# raise "Currently, don't support both use_kv_cache and separate_cfg_infer"
|
175 |
+
if input_images is None:
|
176 |
+
use_img_guidance = False
|
177 |
+
if isinstance(prompt, str):
|
178 |
+
prompt = [prompt]
|
179 |
+
input_images = [input_images] if input_images is not None else None
|
180 |
+
|
181 |
+
|
182 |
+
input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer)
|
183 |
+
|
184 |
+
num_prompt = len(prompt)
|
185 |
+
num_cfg = 2 if use_img_guidance else 1
|
186 |
+
latent_size_h, latent_size_w = height // 8, width // 8
|
187 |
+
|
188 |
+
if seed is not None:
|
189 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
190 |
+
else:
|
191 |
+
generator = None
|
192 |
+
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
|
193 |
+
latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype)
|
194 |
+
|
195 |
+
|
196 |
+
# Load VAE into VRAM (GPU) in bfloat16
|
197 |
+
self.vae.to(self.device, dtype=torch.bfloat16)
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
input_img_latents = []
|
203 |
+
if separate_cfg_infer:
|
204 |
+
for temp_pixel_values in input_data['input_pixel_values']:
|
205 |
+
temp_input_latents = []
|
206 |
+
for img in temp_pixel_values:
|
207 |
+
img = self.vae_encode(img.to(self.device, dtype=torch.bfloat16), dtype)
|
208 |
+
|
209 |
+
temp_input_latents.append(img)
|
210 |
+
input_img_latents.append(temp_input_latents)
|
211 |
+
else:
|
212 |
+
for img in input_data['input_pixel_values']:
|
213 |
+
img = self.vae_encode(img.to(self.device, dtype=torch.bfloat16), dtype)
|
214 |
+
|
215 |
+
input_img_latents.append(img)
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
|
220 |
+
input_img_latents=input_img_latents,
|
221 |
+
input_image_sizes=input_data['input_image_sizes'],
|
222 |
+
attention_mask=self.move_to_device(input_data["attention_mask"]),
|
223 |
+
position_ids=self.move_to_device(input_data["position_ids"]),
|
224 |
+
cfg_scale=guidance_scale,
|
225 |
+
img_cfg_scale=img_guidance_scale,
|
226 |
+
use_img_cfg=use_img_guidance,
|
227 |
+
use_kv_cache=use_kv_cache)
|
228 |
+
|
229 |
+
|
230 |
+
#unlode vae to cpu
|
231 |
+
self.vae.to('cpu')
|
232 |
+
torch.cuda.empty_cache() # Clear VRAM
|
233 |
+
gc.collect() # Run garbage collection to free system RAM
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
if separate_cfg_infer:
|
238 |
+
func = self.model.forward_with_separate_cfg
|
239 |
+
else:
|
240 |
+
func = self.model.forward_with_cfg
|
241 |
+
|
242 |
+
|
243 |
+
#move main model to gpu
|
244 |
+
self.model.to(self.device, dtype=dtype)
|
245 |
+
|
246 |
+
|
247 |
+
scheduler = OmniGenScheduler(num_steps=num_inference_steps)
|
248 |
+
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache)
|
249 |
+
samples = samples.chunk((1 + num_cfg), dim=0)[0]
|
250 |
+
|
251 |
+
if self.vae.config.shift_factor is not None:
|
252 |
+
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
253 |
+
else:
|
254 |
+
samples = samples / self.vae.config.scaling_factor
|
255 |
+
|
256 |
+
#unlode main model to cpu
|
257 |
+
self.model.to('cpu')
|
258 |
+
torch.cuda.empty_cache() # Clear VRAM
|
259 |
+
gc.collect() # Run garbage collection to free system RAM
|
260 |
+
|
261 |
+
# Move samples to GPU and ensure they are in bfloat16 (for the VAE)
|
262 |
+
samples = samples.to(self.device, dtype=torch.bfloat16)
|
263 |
+
|
264 |
+
# Load VAE into VRAM (GPU) in bfloat16
|
265 |
+
self.vae.to(self.device, dtype=torch.bfloat16)
|
266 |
+
|
267 |
+
# Decode the samples using the VAE
|
268 |
+
samples = self.vae.decode(samples).sample
|
269 |
+
|
270 |
+
#unlode vae to cpu
|
271 |
+
self.vae.to('cpu')
|
272 |
+
torch.cuda.empty_cache() # Clear VRAM
|
273 |
+
gc.collect() # Run garbage collection to free system RAM
|
274 |
+
|
275 |
+
|
276 |
+
# Convert samples back to float32 for further processing
|
277 |
+
samples = samples.to(torch.float32)
|
278 |
+
|
279 |
+
|
280 |
+
# Convert samples to uint8 for final image output
|
281 |
+
output_samples = (samples * 0.5 + 0.5).clamp(0, 1) * 255
|
282 |
+
output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
283 |
+
# Create output images
|
284 |
+
output_images = []
|
285 |
+
for i, sample in enumerate(output_samples):
|
286 |
+
output_images.append(Image.fromarray(sample))
|
287 |
+
|
288 |
+
# Return the generated images
|
289 |
+
return output_images
|
OmniGen/processor.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import Dict, List
|
4 |
+
import json
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from huggingface_hub import snapshot_download
|
13 |
+
|
14 |
+
from OmniGen.utils import (
|
15 |
+
create_logger,
|
16 |
+
update_ema,
|
17 |
+
requires_grad,
|
18 |
+
center_crop_arr,
|
19 |
+
crop_arr,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class OmniGenProcessor:
|
26 |
+
def __init__(self,
|
27 |
+
text_tokenizer,
|
28 |
+
max_image_size: int=1024):
|
29 |
+
self.text_tokenizer = text_tokenizer
|
30 |
+
self.max_image_size = max_image_size
|
31 |
+
|
32 |
+
self.image_transform = transforms.Compose([
|
33 |
+
transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
36 |
+
])
|
37 |
+
|
38 |
+
self.collator = OmniGenCollator()
|
39 |
+
self.separate_collator = OmniGenSeparateCollator()
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_pretrained(cls, model_name):
|
43 |
+
if not os.path.exists(model_name):
|
44 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
45 |
+
model_name = snapshot_download(repo_id=model_name,
|
46 |
+
cache_dir=cache_folder,
|
47 |
+
allow_patterns="*.json")
|
48 |
+
text_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
49 |
+
|
50 |
+
return cls(text_tokenizer)
|
51 |
+
|
52 |
+
|
53 |
+
def process_image(self, image):
|
54 |
+
image = Image.open(image).convert('RGB')
|
55 |
+
return self.image_transform(image)
|
56 |
+
|
57 |
+
def process_multi_modal_prompt(self, text, input_images):
|
58 |
+
text = self.add_prefix_instruction(text)
|
59 |
+
if input_images is None or len(input_images) == 0:
|
60 |
+
model_inputs = self.text_tokenizer(text)
|
61 |
+
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
|
62 |
+
|
63 |
+
pattern = r"<\|image_\d+\|>"
|
64 |
+
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
|
65 |
+
|
66 |
+
for i in range(1, len(prompt_chunks)):
|
67 |
+
if prompt_chunks[i][0] == 1:
|
68 |
+
prompt_chunks[i] = prompt_chunks[i][1:]
|
69 |
+
|
70 |
+
image_tags = re.findall(pattern, text)
|
71 |
+
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
72 |
+
|
73 |
+
unique_image_ids = sorted(list(set(image_ids)))
|
74 |
+
assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
75 |
+
# total images must be the same as the number of image tags
|
76 |
+
assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
77 |
+
|
78 |
+
input_images = [input_images[x-1] for x in image_ids]
|
79 |
+
|
80 |
+
all_input_ids = []
|
81 |
+
img_inx = []
|
82 |
+
idx = 0
|
83 |
+
for i in range(len(prompt_chunks)):
|
84 |
+
all_input_ids.extend(prompt_chunks[i])
|
85 |
+
if i != len(prompt_chunks) -1:
|
86 |
+
start_inx = len(all_input_ids)
|
87 |
+
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
|
88 |
+
img_inx.append([start_inx, start_inx+size])
|
89 |
+
all_input_ids.extend([0]*size)
|
90 |
+
|
91 |
+
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
|
92 |
+
|
93 |
+
|
94 |
+
def add_prefix_instruction(self, prompt):
|
95 |
+
user_prompt = '<|user|>\n'
|
96 |
+
generation_prompt = 'Generate an image according to the following instructions\n'
|
97 |
+
assistant_prompt = '<|assistant|>\n<|diffusion|>'
|
98 |
+
prompt_suffix = "<|end|>\n"
|
99 |
+
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
|
100 |
+
return prompt
|
101 |
+
|
102 |
+
|
103 |
+
def __call__(self,
|
104 |
+
instructions: List[str],
|
105 |
+
input_images: List[List[str]] = None,
|
106 |
+
height: int = 1024,
|
107 |
+
width: int = 1024,
|
108 |
+
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
|
109 |
+
use_img_cfg: bool = True,
|
110 |
+
separate_cfg_input: bool = False,
|
111 |
+
) -> Dict:
|
112 |
+
|
113 |
+
if input_images is None:
|
114 |
+
use_img_cfg = False
|
115 |
+
if isinstance(instructions, str):
|
116 |
+
instructions = [instructions]
|
117 |
+
input_images = [input_images]
|
118 |
+
|
119 |
+
input_data = []
|
120 |
+
for i in range(len(instructions)):
|
121 |
+
cur_instruction = instructions[i]
|
122 |
+
cur_input_images = None if input_images is None else input_images[i]
|
123 |
+
if cur_input_images is not None and len(cur_input_images) > 0:
|
124 |
+
cur_input_images = [self.process_image(x) for x in cur_input_images]
|
125 |
+
else:
|
126 |
+
cur_input_images = None
|
127 |
+
assert "<img><|image_1|></img>" not in cur_instruction
|
128 |
+
|
129 |
+
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
|
130 |
+
|
131 |
+
|
132 |
+
neg_mllm_input, img_cfg_mllm_input = None, None
|
133 |
+
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
|
134 |
+
if use_img_cfg:
|
135 |
+
if cur_input_images is not None and len(cur_input_images) >= 1:
|
136 |
+
img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
|
137 |
+
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
|
138 |
+
else:
|
139 |
+
img_cfg_mllm_input = neg_mllm_input
|
140 |
+
|
141 |
+
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
|
142 |
+
|
143 |
+
if separate_cfg_input:
|
144 |
+
return self.separate_collator(input_data)
|
145 |
+
return self.collator(input_data)
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
class OmniGenCollator:
|
151 |
+
def __init__(self, pad_token_id=2, hidden_size=3072):
|
152 |
+
self.pad_token_id = pad_token_id
|
153 |
+
self.hidden_size = hidden_size
|
154 |
+
|
155 |
+
def create_position(self, attention_mask, num_tokens_for_output_images):
|
156 |
+
position_ids = []
|
157 |
+
text_length = attention_mask.size(-1)
|
158 |
+
img_length = max(num_tokens_for_output_images)
|
159 |
+
for mask in attention_mask:
|
160 |
+
temp_l = torch.sum(mask)
|
161 |
+
temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
|
162 |
+
position_ids.append(temp_position)
|
163 |
+
return torch.LongTensor(position_ids)
|
164 |
+
|
165 |
+
def create_mask(self, attention_mask, num_tokens_for_output_images):
|
166 |
+
extended_mask = []
|
167 |
+
padding_images = []
|
168 |
+
text_length = attention_mask.size(-1)
|
169 |
+
img_length = max(num_tokens_for_output_images)
|
170 |
+
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
|
171 |
+
inx = 0
|
172 |
+
for mask in attention_mask:
|
173 |
+
temp_l = torch.sum(mask)
|
174 |
+
pad_l = text_length - temp_l
|
175 |
+
|
176 |
+
temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
|
177 |
+
|
178 |
+
image_mask = torch.zeros(size=(temp_l+1, img_length))
|
179 |
+
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
|
180 |
+
|
181 |
+
image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
|
182 |
+
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
|
183 |
+
|
184 |
+
if pad_l > 0:
|
185 |
+
pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
|
186 |
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
|
187 |
+
|
188 |
+
pad_mask = torch.ones(size=(pad_l, seq_len))
|
189 |
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
|
190 |
+
|
191 |
+
true_img_length = num_tokens_for_output_images[inx]
|
192 |
+
pad_img_length = img_length - true_img_length
|
193 |
+
if pad_img_length > 0:
|
194 |
+
temp_mask[:, -pad_img_length:] = 0
|
195 |
+
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
|
196 |
+
else:
|
197 |
+
temp_padding_imgs = None
|
198 |
+
|
199 |
+
extended_mask.append(temp_mask.unsqueeze(0))
|
200 |
+
padding_images.append(temp_padding_imgs)
|
201 |
+
inx += 1
|
202 |
+
return torch.cat(extended_mask, dim=0), padding_images
|
203 |
+
|
204 |
+
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
|
205 |
+
for b_inx in image_sizes.keys():
|
206 |
+
for start_inx, end_inx in image_sizes[b_inx]:
|
207 |
+
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
|
208 |
+
|
209 |
+
return attention_mask
|
210 |
+
|
211 |
+
def pad_input_ids(self, input_ids, image_sizes):
|
212 |
+
max_l = max([len(x) for x in input_ids])
|
213 |
+
padded_ids = []
|
214 |
+
attention_mask = []
|
215 |
+
new_image_sizes = []
|
216 |
+
|
217 |
+
for i in range(len(input_ids)):
|
218 |
+
temp_ids = input_ids[i]
|
219 |
+
temp_l = len(temp_ids)
|
220 |
+
pad_l = max_l - temp_l
|
221 |
+
if pad_l == 0:
|
222 |
+
attention_mask.append([1]*max_l)
|
223 |
+
padded_ids.append(temp_ids)
|
224 |
+
else:
|
225 |
+
attention_mask.append([0]*pad_l+[1]*temp_l)
|
226 |
+
padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
|
227 |
+
|
228 |
+
if i in image_sizes:
|
229 |
+
new_inx = []
|
230 |
+
for old_inx in image_sizes[i]:
|
231 |
+
new_inx.append([x+pad_l for x in old_inx])
|
232 |
+
image_sizes[i] = new_inx
|
233 |
+
|
234 |
+
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
|
235 |
+
|
236 |
+
|
237 |
+
def process_mllm_input(self, mllm_inputs, target_img_size):
|
238 |
+
num_tokens_for_output_images = []
|
239 |
+
for img_size in target_img_size:
|
240 |
+
num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
|
241 |
+
|
242 |
+
pixel_values, image_sizes = [], {}
|
243 |
+
b_inx = 0
|
244 |
+
for x in mllm_inputs:
|
245 |
+
if x['pixel_values'] is not None:
|
246 |
+
pixel_values.extend(x['pixel_values'])
|
247 |
+
for size in x['image_sizes']:
|
248 |
+
if b_inx not in image_sizes:
|
249 |
+
image_sizes[b_inx] = [size]
|
250 |
+
else:
|
251 |
+
image_sizes[b_inx].append(size)
|
252 |
+
b_inx += 1
|
253 |
+
pixel_values = [x.unsqueeze(0) for x in pixel_values]
|
254 |
+
|
255 |
+
|
256 |
+
input_ids = [x['input_ids'] for x in mllm_inputs]
|
257 |
+
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
|
258 |
+
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
|
259 |
+
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
|
260 |
+
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
|
261 |
+
|
262 |
+
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
|
263 |
+
|
264 |
+
|
265 |
+
def __call__(self, features):
|
266 |
+
mllm_inputs = [f[0] for f in features]
|
267 |
+
cfg_mllm_inputs = [f[1] for f in features]
|
268 |
+
img_cfg_mllm_input = [f[2] for f in features]
|
269 |
+
target_img_size = [f[3] for f in features]
|
270 |
+
|
271 |
+
|
272 |
+
if img_cfg_mllm_input[0] is not None:
|
273 |
+
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
|
274 |
+
target_img_size = target_img_size + target_img_size + target_img_size
|
275 |
+
else:
|
276 |
+
mllm_inputs = mllm_inputs + cfg_mllm_inputs
|
277 |
+
target_img_size = target_img_size + target_img_size
|
278 |
+
|
279 |
+
|
280 |
+
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
281 |
+
|
282 |
+
data = {"input_ids": all_padded_input_ids,
|
283 |
+
"attention_mask": all_attention_mask,
|
284 |
+
"position_ids": all_position_ids,
|
285 |
+
"input_pixel_values": all_pixel_values,
|
286 |
+
"input_image_sizes": all_image_sizes,
|
287 |
+
"padding_images": all_padding_images,
|
288 |
+
}
|
289 |
+
return data
|
290 |
+
|
291 |
+
|
292 |
+
class OmniGenSeparateCollator(OmniGenCollator):
|
293 |
+
def __call__(self, features):
|
294 |
+
mllm_inputs = [f[0] for f in features]
|
295 |
+
cfg_mllm_inputs = [f[1] for f in features]
|
296 |
+
img_cfg_mllm_input = [f[2] for f in features]
|
297 |
+
target_img_size = [f[3] for f in features]
|
298 |
+
|
299 |
+
|
300 |
+
all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
|
301 |
+
|
302 |
+
|
303 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
304 |
+
all_padded_input_ids.append(padded_input_ids)
|
305 |
+
all_attention_mask.append(attention_mask)
|
306 |
+
all_position_ids.append(position_ids)
|
307 |
+
all_pixel_values.append(pixel_values)
|
308 |
+
all_image_sizes.append(image_sizes)
|
309 |
+
all_padding_images.append(padding_images)
|
310 |
+
|
311 |
+
if cfg_mllm_inputs[0] is not None:
|
312 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
|
313 |
+
all_padded_input_ids.append(padded_input_ids)
|
314 |
+
all_attention_mask.append(attention_mask)
|
315 |
+
all_position_ids.append(position_ids)
|
316 |
+
all_pixel_values.append(pixel_values)
|
317 |
+
all_image_sizes.append(image_sizes)
|
318 |
+
all_padding_images.append(padding_images)
|
319 |
+
if img_cfg_mllm_input[0] is not None:
|
320 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
|
321 |
+
all_padded_input_ids.append(padded_input_ids)
|
322 |
+
all_attention_mask.append(attention_mask)
|
323 |
+
all_position_ids.append(position_ids)
|
324 |
+
all_pixel_values.append(pixel_values)
|
325 |
+
all_image_sizes.append(image_sizes)
|
326 |
+
all_padding_images.append(padding_images)
|
327 |
+
|
328 |
+
data = {"input_ids": all_padded_input_ids,
|
329 |
+
"attention_mask": all_attention_mask,
|
330 |
+
"position_ids": all_position_ids,
|
331 |
+
"input_pixel_values": all_pixel_values,
|
332 |
+
"input_image_sizes": all_image_sizes,
|
333 |
+
"padding_images": all_padding_images,
|
334 |
+
}
|
335 |
+
return data
|
OmniGen/scheduler.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
from transformers.cache_utils import Cache, DynamicCache
|
4 |
+
|
5 |
+
class OmniGenScheduler:
|
6 |
+
def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
|
7 |
+
self.num_steps = num_steps
|
8 |
+
self.time_shift = time_shifting_factor
|
9 |
+
|
10 |
+
t = torch.linspace(0, 1, num_steps+1)
|
11 |
+
t = t / (t + time_shifting_factor - time_shifting_factor * t)
|
12 |
+
self.sigma = t
|
13 |
+
|
14 |
+
def crop_kv_cache(self, past_key_values, num_tokens_for_img):
|
15 |
+
crop_past_key_values = ()
|
16 |
+
for layer_idx in range(len(past_key_values)):
|
17 |
+
key_states, value_states = past_key_values[layer_idx][:2]
|
18 |
+
crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
|
19 |
+
return crop_past_key_values
|
20 |
+
# return DynamicCache.from_legacy_cache(crop_past_key_values)
|
21 |
+
|
22 |
+
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
|
23 |
+
if isinstance(position_ids, list):
|
24 |
+
for i in range(len(position_ids)):
|
25 |
+
position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
|
26 |
+
else:
|
27 |
+
position_ids = position_ids[:, -(num_tokens_for_img+1):]
|
28 |
+
return position_ids
|
29 |
+
|
30 |
+
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
|
31 |
+
if isinstance(attention_mask, list):
|
32 |
+
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
|
33 |
+
return attention_mask[..., -(num_tokens_for_img+1):, :]
|
34 |
+
|
35 |
+
def __call__(self, z, func, model_kwargs, use_kv_cache: bool=True):
|
36 |
+
past_key_values = None
|
37 |
+
for i in tqdm(range(self.num_steps)):
|
38 |
+
timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
|
39 |
+
pred, temp_past_key_values = func(z, timesteps, past_key_values=past_key_values, **model_kwargs)
|
40 |
+
sigma_next = self.sigma[i+1]
|
41 |
+
sigma = self.sigma[i]
|
42 |
+
z = z + (sigma_next - sigma) * pred
|
43 |
+
if i == 0 and use_kv_cache:
|
44 |
+
num_tokens_for_img = z.size(-1)*z.size(-2) // 4
|
45 |
+
if isinstance(temp_past_key_values, list):
|
46 |
+
past_key_values = [self.crop_kv_cache(x, num_tokens_for_img) for x in temp_past_key_values]
|
47 |
+
model_kwargs['input_ids'] = [None] * len(temp_past_key_values)
|
48 |
+
else:
|
49 |
+
past_key_values = self.crop_kv_cache(temp_past_key_values, num_tokens_for_img)
|
50 |
+
model_kwargs['input_ids'] = None
|
51 |
+
|
52 |
+
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
|
53 |
+
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
|
54 |
+
return z
|
55 |
+
|
OmniGen/train_helper/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .data import DatasetFromJson, TrainDataCollator
|
2 |
+
from .loss import training_losses
|
OmniGen/train_helper/data.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datasets
|
3 |
+
from datasets import load_dataset, ClassLabel, concatenate_datasets
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
from PIL import Image
|
8 |
+
import json
|
9 |
+
import copy
|
10 |
+
# import torchvision.transforms as T
|
11 |
+
from torchvision import transforms
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
|
15 |
+
from OmniGen import OmniGenProcessor
|
16 |
+
from OmniGen.processor import OmniGenCollator
|
17 |
+
|
18 |
+
|
19 |
+
class DatasetFromJson(torch.utils.data.Dataset):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
json_file: str,
|
23 |
+
image_path: str,
|
24 |
+
processer: OmniGenProcessor,
|
25 |
+
image_transform,
|
26 |
+
max_input_length_limit: int = 18000,
|
27 |
+
condition_dropout_prob: float = 0.1,
|
28 |
+
keep_raw_resolution: bool = True,
|
29 |
+
):
|
30 |
+
|
31 |
+
self.image_transform = image_transform
|
32 |
+
self.processer = processer
|
33 |
+
self.condition_dropout_prob = condition_dropout_prob
|
34 |
+
self.max_input_length_limit = max_input_length_limit
|
35 |
+
self.keep_raw_resolution = keep_raw_resolution
|
36 |
+
|
37 |
+
self.data = load_dataset('json', data_files=json_file)['train']
|
38 |
+
self.image_path = image_path
|
39 |
+
|
40 |
+
def process_image(self, image_file):
|
41 |
+
if self.image_path is not None:
|
42 |
+
image_file = os.path.join(self.image_path, image_file)
|
43 |
+
image = Image.open(image_file).convert('RGB')
|
44 |
+
return self.image_transform(image)
|
45 |
+
|
46 |
+
def get_example(self, index):
|
47 |
+
example = self.data[index]
|
48 |
+
|
49 |
+
instruction, input_images, output_image = example['instruction'], example['input_images'], example['output_image']
|
50 |
+
if random.random() < self.condition_dropout_prob:
|
51 |
+
instruction = '<cfg>'
|
52 |
+
input_images = None
|
53 |
+
if input_images is not None:
|
54 |
+
input_images = [self.process_image(x) for x in input_images]
|
55 |
+
mllm_input = self.processer.process_multi_modal_prompt(instruction, input_images)
|
56 |
+
|
57 |
+
output_image = self.process_image(output_image)
|
58 |
+
|
59 |
+
return (mllm_input, output_image)
|
60 |
+
|
61 |
+
|
62 |
+
def __getitem__(self, index):
|
63 |
+
return self.get_example(index)
|
64 |
+
for _ in range(8):
|
65 |
+
try:
|
66 |
+
mllm_input, output_image = self.get_example(index)
|
67 |
+
if len(mllm_input['input_ids']) > self.max_input_length_limit:
|
68 |
+
raise RuntimeError(f"cur number of tokens={len(mllm_input['input_ids'])}, larger than max_input_length_limit={self.max_input_length_limit}")
|
69 |
+
return mllm_input, output_image
|
70 |
+
except Exception as e:
|
71 |
+
print("error when loading data: ", e)
|
72 |
+
print(self.data[index])
|
73 |
+
index = random.randint(0, len(self.data)-1)
|
74 |
+
raise RuntimeError("Too many bad data.")
|
75 |
+
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.data)
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class TrainDataCollator(OmniGenCollator):
|
83 |
+
def __init__(self, pad_token_id: int, hidden_size: int, keep_raw_resolution: bool):
|
84 |
+
self.pad_token_id = pad_token_id
|
85 |
+
self.hidden_size = hidden_size
|
86 |
+
self.keep_raw_resolution = keep_raw_resolution
|
87 |
+
|
88 |
+
def __call__(self, features):
|
89 |
+
mllm_inputs = [f[0] for f in features]
|
90 |
+
|
91 |
+
output_images = [f[1].unsqueeze(0) for f in features]
|
92 |
+
target_img_size = [[x.size(-2), x.size(-1)] for x in output_images]
|
93 |
+
|
94 |
+
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
95 |
+
|
96 |
+
if not self.keep_raw_resolution:
|
97 |
+
output_image = torch.cat(output_image, dim=0)
|
98 |
+
if len(pixel_values) > 0:
|
99 |
+
all_pixel_values = torch.cat(all_pixel_values, dim=0)
|
100 |
+
else:
|
101 |
+
all_pixel_values = None
|
102 |
+
|
103 |
+
data = {"input_ids": all_padded_input_ids,
|
104 |
+
"attention_mask": all_attention_mask,
|
105 |
+
"position_ids": all_position_ids,
|
106 |
+
"input_pixel_values": all_pixel_values,
|
107 |
+
"input_image_sizes": all_image_sizes,
|
108 |
+
"padding_images": all_padding_images,
|
109 |
+
"output_images": output_images,
|
110 |
+
}
|
111 |
+
return data
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
OmniGen/train_helper/loss.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def sample_x0(x1):
|
5 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
6 |
+
Args:
|
7 |
+
x1 - data point; [batch, *dim]
|
8 |
+
"""
|
9 |
+
if isinstance(x1, (list, tuple)):
|
10 |
+
x0 = [torch.randn_like(img_start) for img_start in x1]
|
11 |
+
else:
|
12 |
+
x0 = torch.randn_like(x1)
|
13 |
+
|
14 |
+
return x0
|
15 |
+
|
16 |
+
def sample_timestep(x1):
|
17 |
+
u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
|
18 |
+
t = 1 / (1 + torch.exp(-u))
|
19 |
+
t = t.to(x1[0])
|
20 |
+
return t
|
21 |
+
|
22 |
+
|
23 |
+
def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
|
24 |
+
"""Loss for training torche score model
|
25 |
+
Args:
|
26 |
+
- model: backbone model; could be score, noise, or velocity
|
27 |
+
- x1: datapoint
|
28 |
+
- model_kwargs: additional arguments for torche model
|
29 |
+
"""
|
30 |
+
if model_kwargs == None:
|
31 |
+
model_kwargs = {}
|
32 |
+
|
33 |
+
B = len(x1)
|
34 |
+
|
35 |
+
x0 = sample_x0(x1)
|
36 |
+
t = sample_timestep(x1)
|
37 |
+
|
38 |
+
if isinstance(x1, (list, tuple)):
|
39 |
+
xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
|
40 |
+
ut = [x1[i] - x0[i] for i in range(B)]
|
41 |
+
else:
|
42 |
+
dims = [1] * (len(x1.size()) - 1)
|
43 |
+
t_ = t.view(t.size(0), *dims)
|
44 |
+
xt = t_ * x1 + (1 - t_) * x0
|
45 |
+
ut = x1 - x0
|
46 |
+
|
47 |
+
model_output = model(xt, t, **model_kwargs)
|
48 |
+
|
49 |
+
terms = {}
|
50 |
+
|
51 |
+
if isinstance(x1, (list, tuple)):
|
52 |
+
assert len(model_output) == len(ut) == len(x1)
|
53 |
+
for i in range(B):
|
54 |
+
terms["loss"] = torch.stack(
|
55 |
+
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
|
56 |
+
dim=0,
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
terms["loss"] = mean_flat(((model_output - ut) ** 2))
|
60 |
+
|
61 |
+
return terms
|
62 |
+
|
63 |
+
|
64 |
+
def mean_flat(x):
|
65 |
+
"""
|
66 |
+
Take torche mean over all non-batch dimensions.
|
67 |
+
"""
|
68 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
OmniGen/transformer.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
|
11 |
+
from transformers.modeling_outputs import (
|
12 |
+
BaseModelOutputWithPast,
|
13 |
+
CausalLMOutputWithPast,
|
14 |
+
SequenceClassifierOutputWithPast,
|
15 |
+
TokenClassifierOutput,
|
16 |
+
)
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers import Phi3Config, Phi3Model
|
19 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class Phi3Transformer(Phi3Model):
|
26 |
+
"""
|
27 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
28 |
+
We only modified the attention mask
|
29 |
+
Args:
|
30 |
+
config: Phi3Config
|
31 |
+
"""
|
32 |
+
|
33 |
+
def forward(
|
34 |
+
self,
|
35 |
+
input_ids: torch.LongTensor = None,
|
36 |
+
attention_mask: Optional[torch.Tensor] = None,
|
37 |
+
position_ids: Optional[torch.LongTensor] = None,
|
38 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
39 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
40 |
+
use_cache: Optional[bool] = None,
|
41 |
+
output_attentions: Optional[bool] = None,
|
42 |
+
output_hidden_states: Optional[bool] = None,
|
43 |
+
return_dict: Optional[bool] = None,
|
44 |
+
cache_position: Optional[torch.LongTensor] = None,
|
45 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
46 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
47 |
+
output_hidden_states = (
|
48 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
49 |
+
)
|
50 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
51 |
+
|
52 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
53 |
+
|
54 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
55 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
56 |
+
|
57 |
+
if self.gradient_checkpointing and self.training:
|
58 |
+
if use_cache:
|
59 |
+
logger.warning_once(
|
60 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
61 |
+
)
|
62 |
+
use_cache = False
|
63 |
+
|
64 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
65 |
+
return_legacy_cache = False
|
66 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
67 |
+
return_legacy_cache = True
|
68 |
+
if past_key_values is None:
|
69 |
+
past_key_values = DynamicCache()
|
70 |
+
else:
|
71 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
72 |
+
logger.warning_once(
|
73 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
74 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
75 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
76 |
+
)
|
77 |
+
|
78 |
+
if inputs_embeds is None:
|
79 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
80 |
+
|
81 |
+
if cache_position is None:
|
82 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
83 |
+
cache_position = torch.arange(
|
84 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
85 |
+
)
|
86 |
+
if position_ids is None:
|
87 |
+
position_ids = cache_position.unsqueeze(0)
|
88 |
+
|
89 |
+
if attention_mask is not None and attention_mask.dim() == 3:
|
90 |
+
dtype = inputs_embeds.dtype
|
91 |
+
min_dtype = torch.finfo(dtype).min
|
92 |
+
attention_mask = (1 - attention_mask) * min_dtype
|
93 |
+
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
94 |
+
else:
|
95 |
+
raise
|
96 |
+
# causal_mask = self._update_causal_mask(
|
97 |
+
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
98 |
+
# )
|
99 |
+
|
100 |
+
hidden_states = inputs_embeds
|
101 |
+
|
102 |
+
# decoder layers
|
103 |
+
all_hidden_states = () if output_hidden_states else None
|
104 |
+
all_self_attns = () if output_attentions else None
|
105 |
+
next_decoder_cache = None
|
106 |
+
|
107 |
+
for decoder_layer in self.layers:
|
108 |
+
if output_hidden_states:
|
109 |
+
all_hidden_states += (hidden_states,)
|
110 |
+
|
111 |
+
if self.gradient_checkpointing and self.training:
|
112 |
+
layer_outputs = self._gradient_checkpointing_func(
|
113 |
+
decoder_layer.__call__,
|
114 |
+
hidden_states,
|
115 |
+
attention_mask,
|
116 |
+
position_ids,
|
117 |
+
past_key_values,
|
118 |
+
output_attentions,
|
119 |
+
use_cache,
|
120 |
+
cache_position,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
layer_outputs = decoder_layer(
|
124 |
+
hidden_states,
|
125 |
+
attention_mask=attention_mask,
|
126 |
+
position_ids=position_ids,
|
127 |
+
past_key_value=past_key_values,
|
128 |
+
output_attentions=output_attentions,
|
129 |
+
use_cache=use_cache,
|
130 |
+
cache_position=cache_position,
|
131 |
+
)
|
132 |
+
|
133 |
+
hidden_states = layer_outputs[0]
|
134 |
+
|
135 |
+
if use_cache:
|
136 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
137 |
+
|
138 |
+
if output_attentions:
|
139 |
+
all_self_attns += (layer_outputs[1],)
|
140 |
+
|
141 |
+
hidden_states = self.norm(hidden_states)
|
142 |
+
|
143 |
+
# add hidden states from the last decoder layer
|
144 |
+
if output_hidden_states:
|
145 |
+
all_hidden_states += (hidden_states,)
|
146 |
+
|
147 |
+
next_cache = next_decoder_cache if use_cache else None
|
148 |
+
if return_legacy_cache:
|
149 |
+
next_cache = next_cache.to_legacy_cache()
|
150 |
+
|
151 |
+
if not return_dict:
|
152 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
153 |
+
return BaseModelOutputWithPast(
|
154 |
+
last_hidden_state=hidden_states,
|
155 |
+
past_key_values=next_cache,
|
156 |
+
hidden_states=all_hidden_states,
|
157 |
+
attentions=all_self_attns,
|
158 |
+
)
|
159 |
+
|
OmniGen/utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def create_logger(logging_dir):
|
8 |
+
"""
|
9 |
+
Create a logger that writes to a log file and stdout.
|
10 |
+
"""
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
14 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
15 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
16 |
+
)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
return logger
|
19 |
+
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def update_ema(ema_model, model, decay=0.9999):
|
23 |
+
"""
|
24 |
+
Step the EMA model towards the current model.
|
25 |
+
"""
|
26 |
+
ema_params = dict(ema_model.named_parameters())
|
27 |
+
for name, param in model.named_parameters():
|
28 |
+
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
29 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def requires_grad(model, flag=True):
|
35 |
+
"""
|
36 |
+
Set requires_grad flag for all parameters in a model.
|
37 |
+
"""
|
38 |
+
for p in model.parameters():
|
39 |
+
p.requires_grad = flag
|
40 |
+
|
41 |
+
|
42 |
+
def center_crop_arr(pil_image, image_size):
|
43 |
+
"""
|
44 |
+
Center cropping implementation from ADM.
|
45 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
46 |
+
"""
|
47 |
+
while min(*pil_image.size) >= 2 * image_size:
|
48 |
+
pil_image = pil_image.resize(
|
49 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
50 |
+
)
|
51 |
+
|
52 |
+
scale = image_size / min(*pil_image.size)
|
53 |
+
pil_image = pil_image.resize(
|
54 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
55 |
+
)
|
56 |
+
|
57 |
+
arr = np.array(pil_image)
|
58 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
59 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
60 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def crop_arr(pil_image, max_image_size):
|
65 |
+
while min(*pil_image.size) >= 2 * max_image_size:
|
66 |
+
pil_image = pil_image.resize(
|
67 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
68 |
+
)
|
69 |
+
|
70 |
+
if max(*pil_image.size) > max_image_size:
|
71 |
+
scale = max_image_size / max(*pil_image.size)
|
72 |
+
pil_image = pil_image.resize(
|
73 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
74 |
+
)
|
75 |
+
|
76 |
+
if min(*pil_image.size) < 16:
|
77 |
+
scale = 16 / min(*pil_image.size)
|
78 |
+
pil_image = pil_image.resize(
|
79 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
80 |
+
)
|
81 |
+
|
82 |
+
arr = np.array(pil_image)
|
83 |
+
crop_y1 = (arr.shape[0] % 16) // 2
|
84 |
+
crop_y2 = arr.shape[0] % 16 - crop_y1
|
85 |
+
|
86 |
+
crop_x1 = (arr.shape[1] % 16) // 2
|
87 |
+
crop_x2 = arr.shape[1] % 16 - crop_x1
|
88 |
+
|
89 |
+
arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
|
90 |
+
return Image.fromarray(arr)
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def vae_encode(vae, x, weight_dtype):
|
95 |
+
if x is not None:
|
96 |
+
if vae.config.shift_factor is not None:
|
97 |
+
x = vae.encode(x).latent_dist.sample()
|
98 |
+
x = (x - vae.config.shift_factor) * vae.config.scaling_factor
|
99 |
+
else:
|
100 |
+
x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
|
101 |
+
x = x.to(weight_dtype)
|
102 |
+
return x
|
103 |
+
|
104 |
+
def vae_encode_list(vae, x, weight_dtype):
|
105 |
+
latents = []
|
106 |
+
for img in x:
|
107 |
+
img = vae_encode(vae, img, weight_dtype)
|
108 |
+
latents.append(img)
|
109 |
+
return latents
|
110 |
+
|
README.md
CHANGED
@@ -1,14 +1,93 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">OmniGen: Unified Image Generation</h1>
|
2 |
+
|
3 |
+
|
4 |
+
<p align="center">
|
5 |
+
<a href="">
|
6 |
+
<img alt="Build" src="https://img.shields.io/badge/Project%20Page-OmniGen-yellow">
|
7 |
+
</a>
|
8 |
+
<a href="https://arxiv.org/abs/2409.11340">
|
9 |
+
<img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-2409.11340-b31b1b.svg">
|
10 |
+
</a>
|
11 |
+
<a href="https://huggingface.co/spaces/Shitao/OmniGen">
|
12 |
+
<img alt="License" src="https://img.shields.io/badge/HF%20Demo-🤗-lightblue">
|
13 |
+
</a>
|
14 |
+
<a href="https://huggingface.co/Shitao/OmniGen-v1">
|
15 |
+
<img alt="Build" src="https://img.shields.io/badge/HF%20Model-🤗-yellow">
|
16 |
+
</a>
|
17 |
+
</p>
|
18 |
+
|
19 |
+
<h4 align="center">
|
20 |
+
<p>
|
21 |
+
<a href=#2-news>Credits for Quantized version</a> |
|
22 |
+
<a href=#3-methodology>Methodology</a> |
|
23 |
+
<a href=#4-what-can-omnigen-do>Capabilities</a> |
|
24 |
+
<a href="#license">License</a> |
|
25 |
+
<a href="#citation">Citation</a>
|
26 |
+
<p>
|
27 |
+
</h4>
|
28 |
+
|
29 |
+
|
30 |
+
## 1. Overview
|
31 |
+
|
32 |
+
OmniGen is a unified image generation model that can generate a wide range of images from multi-modal prompts. It is designed to be simple, flexible and easy to use. We provide [inference code](#5-quick-start) so that everyone can explore more functionalities of OmniGen.
|
33 |
+
|
34 |
+
Existing image generation models often require loading several additional network modules (such as ControlNet, IP-Adapter, Reference-Net, etc.) and performing extra preprocessing steps (e.g., face detection, pose estimation, cropping, etc.) to generate a satisfactory image. However, **we believe that the future image generation paradigm should be more simple and flexible, that is, generating various images directly through arbitrarily multi-modal instructions without the need for additional plugins and operations, similar to how GPT works in language generation.**
|
35 |
+
|
36 |
+
Due to the limited resources, OmniGen still has room for improvement. We will continue to optimize it, and hope it inspire more universal image generation models. You can also easily fine-tune OmniGen without worrying about designing networks for specific tasks; you just need to prepare the corresponding data, and then run the [script](#6-finetune). Imagination is no longer limited; everyone can construct any image generation task, and perhaps we can achieve very interesting, wonderful and creative things.
|
37 |
+
|
38 |
+
If you have any questions, ideas or interesting tasks you want OmniGen to accomplish, feel free to discuss with us: [email protected], [email protected], [email protected]. We welcome any feedback to help us improve the model.
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
## 2. Credits for Quantized version
|
43 |
+
- https://github.com/Manni1000
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
## 3. Methodology
|
48 |
+
|
49 |
+
You can see details in our [paper](https://arxiv.org/abs/2409.11340).
|
50 |
+
|
51 |
+
|
52 |
+
## 4. What Can OmniGen do?
|
53 |
+
|
54 |
+
|
55 |
+
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, image editing, and image-conditioned generation. **OmniGen don't need additional plugins or operations, it can automatically identify the features (e.g., required object, human pose, depth mapping) in input images according the text prompt.**
|
56 |
+
We showcase some examples in [inference.ipynb](inference.ipynb). And in [inference_demo.ipynb](inference_demo.ipynb), we show an interesting pipeline to generate and modify a image.
|
57 |
+
|
58 |
+
Here is the illustration of OmniGen's capabilities:
|
59 |
+
- You can control the image generation flexibly via OmniGen
|
60 |
+
![demo](./imgs/demo_cases.png)
|
61 |
+
- Referring Expression Generation: You can generate images by simply referring to objects, and OmniGen will automatically recognize the required objects in the image.
|
62 |
+
![demo](./imgs/referring.png)
|
63 |
+
|
64 |
+
If you are not entirely satisfied with certain functionalities or wish to add new capabilities, you can try [fine-tuning OmniGen](#6-finetune).
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
## 5. Quick Start
|
69 |
+
|
70 |
+
### Please refer youtube video for installation
|
71 |
+
|
72 |
+
https://www.youtube.com/watch?v=9ZXmXA2AJZ4
|
73 |
+
|
74 |
+
|
75 |
+
## License
|
76 |
+
This repo is licensed under the [MIT License](LICENSE).
|
77 |
+
|
78 |
+
|
79 |
+
## Citation
|
80 |
+
If you find this repository useful, please consider giving a star ⭐ and citation
|
81 |
+
```
|
82 |
+
@article{xiao2024omnigen,
|
83 |
+
title={Omnigen: Unified image generation},
|
84 |
+
author={Xiao, Shitao and Wang, Yueze and Zhou, Junjie and Yuan, Huaying and Xing, Xingrun and Yan, Ruiran and Wang, Shuting and Huang, Tiejun and Liu, Zheng},
|
85 |
+
journal={arXiv preprint arXiv:2409.11340},
|
86 |
+
year={2024}
|
87 |
+
}
|
88 |
+
```
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
app.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
from threading import Lock
|
5 |
+
|
6 |
+
from OmniGen import OmniGenPipeline
|
7 |
+
|
8 |
+
class OmniGenManager:
|
9 |
+
def __init__(self):
|
10 |
+
self.pipe = None
|
11 |
+
self.lock = Lock()
|
12 |
+
self.current_quantization = None
|
13 |
+
|
14 |
+
def get_pipeline(self, quantization: bool) -> OmniGenPipeline:
|
15 |
+
"""
|
16 |
+
Get or initialize the pipeline with the specified quantization setting.
|
17 |
+
Uses a lock to ensure thread safety.
|
18 |
+
"""
|
19 |
+
with self.lock:
|
20 |
+
# Only reinitialize if quantization setting changed or pipeline doesn't exist
|
21 |
+
if self.pipe is None or self.current_quantization != quantization:
|
22 |
+
# Clear any existing pipeline
|
23 |
+
if self.pipe is not None:
|
24 |
+
del self.pipe
|
25 |
+
self.pipe = None
|
26 |
+
|
27 |
+
# Initialize new pipeline
|
28 |
+
self.pipe = OmniGenPipeline.from_pretrained(
|
29 |
+
"Shitao/OmniGen-v1",
|
30 |
+
Quantization=quantization
|
31 |
+
)
|
32 |
+
self.current_quantization = quantization
|
33 |
+
|
34 |
+
return self.pipe
|
35 |
+
|
36 |
+
# Create a single instance of the manager
|
37 |
+
pipeline_manager = OmniGenManager()
|
38 |
+
|
39 |
+
def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, quantization):
|
40 |
+
input_images = [img1, img2, img3]
|
41 |
+
# 去除 None
|
42 |
+
input_images = [img for img in input_images if img is not None]
|
43 |
+
if len(input_images) == 0:
|
44 |
+
input_images = None
|
45 |
+
|
46 |
+
# Get or initialize pipeline with current settings
|
47 |
+
pipe = pipeline_manager.get_pipeline(quantization)
|
48 |
+
|
49 |
+
# Generate image
|
50 |
+
output = pipe(
|
51 |
+
prompt=text,
|
52 |
+
input_images=input_images,
|
53 |
+
height=height,
|
54 |
+
width=width,
|
55 |
+
guidance_scale=guidance_scale,
|
56 |
+
img_guidance_scale=1.6,
|
57 |
+
num_inference_steps=inference_steps,
|
58 |
+
separate_cfg_infer=True, # set False can speed up the inference process
|
59 |
+
use_kv_cache=False,
|
60 |
+
seed=seed,
|
61 |
+
)
|
62 |
+
img = output[0]
|
63 |
+
return img
|
64 |
+
# def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
|
65 |
+
# input_images = []
|
66 |
+
# if img1:
|
67 |
+
# input_images.append(Image.open(img1))
|
68 |
+
# if img2:
|
69 |
+
# input_images.append(Image.open(img2))
|
70 |
+
# if img3:
|
71 |
+
# input_images.append(Image.open(img3))
|
72 |
+
|
73 |
+
# return input_images[0] if input_images else None
|
74 |
+
|
75 |
+
|
76 |
+
def get_example():
|
77 |
+
case = [
|
78 |
+
[
|
79 |
+
"A curly-haired man in a red shirt is drinking tea.",
|
80 |
+
None,
|
81 |
+
None,
|
82 |
+
None,
|
83 |
+
1024,
|
84 |
+
1024,
|
85 |
+
2.5,
|
86 |
+
1.6,
|
87 |
+
50,
|
88 |
+
0,
|
89 |
+
],
|
90 |
+
[
|
91 |
+
"The woman in <img><|image_1|></img> waves her hand happily in the crowd",
|
92 |
+
"./imgs/test_cases/zhang.png",
|
93 |
+
None,
|
94 |
+
None,
|
95 |
+
1024,
|
96 |
+
1024,
|
97 |
+
2.5,
|
98 |
+
1.9,
|
99 |
+
50,
|
100 |
+
128,
|
101 |
+
],
|
102 |
+
[
|
103 |
+
"A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
104 |
+
"./imgs/test_cases/two_man.jpg",
|
105 |
+
None,
|
106 |
+
None,
|
107 |
+
1024,
|
108 |
+
1024,
|
109 |
+
2.5,
|
110 |
+
1.6,
|
111 |
+
50,
|
112 |
+
0,
|
113 |
+
],
|
114 |
+
[
|
115 |
+
"Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. The other woman is <img><|image_2|></img>.",
|
116 |
+
"./imgs/test_cases/mckenna.jpg",
|
117 |
+
"./imgs/test_cases/Amanda.jpg",
|
118 |
+
None,
|
119 |
+
1024,
|
120 |
+
1024,
|
121 |
+
2.5,
|
122 |
+
1.8,
|
123 |
+
50,
|
124 |
+
168,
|
125 |
+
],
|
126 |
+
[
|
127 |
+
"A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
|
128 |
+
"./imgs/test_cases/1.jpg",
|
129 |
+
"./imgs/test_cases/2.jpg",
|
130 |
+
None,
|
131 |
+
1024,
|
132 |
+
1024,
|
133 |
+
2.5,
|
134 |
+
1.6,
|
135 |
+
50,
|
136 |
+
60,
|
137 |
+
],
|
138 |
+
[
|
139 |
+
"A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
|
140 |
+
"./imgs/test_cases/3.jpg",
|
141 |
+
"./imgs/test_cases/4.jpg",
|
142 |
+
None,
|
143 |
+
1024,
|
144 |
+
1024,
|
145 |
+
2.5,
|
146 |
+
1.8,
|
147 |
+
50,
|
148 |
+
66,
|
149 |
+
],
|
150 |
+
[
|
151 |
+
"The flower <img><|image_1|><\/img> is placed in the vase which is in the middle of <img><|image_2|><\/img> on a wooden table of a living room",
|
152 |
+
"./imgs/test_cases/rose.jpg",
|
153 |
+
"./imgs/test_cases/vase.jpg",
|
154 |
+
None,
|
155 |
+
1024,
|
156 |
+
1024,
|
157 |
+
2.5,
|
158 |
+
1.6,
|
159 |
+
50,
|
160 |
+
0,
|
161 |
+
],
|
162 |
+
[
|
163 |
+
"<img><|image_1|><img>\n Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola.",
|
164 |
+
"./imgs/demo_cases/t2i_woman_with_book.png",
|
165 |
+
None,
|
166 |
+
None,
|
167 |
+
1024,
|
168 |
+
1024,
|
169 |
+
2.5,
|
170 |
+
1.6,
|
171 |
+
50,
|
172 |
+
222,
|
173 |
+
],
|
174 |
+
[
|
175 |
+
"Detect the skeleton of human in this image: <img><|image_1|></img>.",
|
176 |
+
"./imgs/test_cases/control.jpg",
|
177 |
+
None,
|
178 |
+
None,
|
179 |
+
1024,
|
180 |
+
1024,
|
181 |
+
2.0,
|
182 |
+
1.6,
|
183 |
+
50,
|
184 |
+
0,
|
185 |
+
],
|
186 |
+
[
|
187 |
+
"Generate a new photo using the following picture and text as conditions: <img><|image_1|><img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
188 |
+
"./imgs/demo_cases/skeletal.png",
|
189 |
+
None,
|
190 |
+
None,
|
191 |
+
1024,
|
192 |
+
1024,
|
193 |
+
2,
|
194 |
+
1.6,
|
195 |
+
50,
|
196 |
+
42,
|
197 |
+
],
|
198 |
+
[
|
199 |
+
"Following the pose of this image <img><|image_1|><img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
200 |
+
"./imgs/demo_cases/edit.png",
|
201 |
+
None,
|
202 |
+
None,
|
203 |
+
1024,
|
204 |
+
1024,
|
205 |
+
2.0,
|
206 |
+
1.6,
|
207 |
+
50,
|
208 |
+
123,
|
209 |
+
],
|
210 |
+
[
|
211 |
+
"Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
212 |
+
"./imgs/demo_cases/edit.png",
|
213 |
+
None,
|
214 |
+
None,
|
215 |
+
1024,
|
216 |
+
1024,
|
217 |
+
2.0,
|
218 |
+
1.6,
|
219 |
+
50,
|
220 |
+
1,
|
221 |
+
],
|
222 |
+
[
|
223 |
+
"<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
|
224 |
+
"./imgs/test_cases/watch.jpg",
|
225 |
+
None,
|
226 |
+
None,
|
227 |
+
1024,
|
228 |
+
1024,
|
229 |
+
2.5,
|
230 |
+
1.6,
|
231 |
+
50,
|
232 |
+
0,
|
233 |
+
],
|
234 |
+
[
|
235 |
+
"According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
|
236 |
+
"./imgs/test_cases/icl1.jpg",
|
237 |
+
"./imgs/test_cases/icl2.jpg",
|
238 |
+
"./imgs/test_cases/icl3.jpg",
|
239 |
+
1024,
|
240 |
+
1024,
|
241 |
+
2.5,
|
242 |
+
1.6,
|
243 |
+
50,
|
244 |
+
1,
|
245 |
+
],
|
246 |
+
]
|
247 |
+
return case
|
248 |
+
|
249 |
+
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed):
|
250 |
+
return generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed)
|
251 |
+
|
252 |
+
description = """
|
253 |
+
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
|
254 |
+
|
255 |
+
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
|
256 |
+
For example, use an image of a woman to generate a new image:
|
257 |
+
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
|
258 |
+
|
259 |
+
Tips:
|
260 |
+
- Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
|
261 |
+
- Low-quality: More detailed prompt will lead to better results.
|
262 |
+
- Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
|
263 |
+
- Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
|
264 |
+
- For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
|
265 |
+
"""
|
266 |
+
|
267 |
+
# Gradio 接口
|
268 |
+
with gr.Blocks() as demo:
|
269 |
+
gr.Markdown("# OmniGen: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen)")
|
270 |
+
gr.Markdown(description)
|
271 |
+
with gr.Row():
|
272 |
+
with gr.Column():
|
273 |
+
# 文本输入框
|
274 |
+
prompt_input = gr.Textbox(
|
275 |
+
label="Enter your prompt, use <img><|image_i|></img> to represent i-th input image", placeholder="Type your prompt here..."
|
276 |
+
)
|
277 |
+
|
278 |
+
with gr.Row(equal_height=True):
|
279 |
+
# 图片上传框
|
280 |
+
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
|
281 |
+
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
|
282 |
+
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
|
283 |
+
|
284 |
+
# 高度和宽度滑块
|
285 |
+
height_input = gr.Slider(
|
286 |
+
label="Height", minimum=256, maximum=2048, value=1024, step=16
|
287 |
+
)
|
288 |
+
width_input = gr.Slider(
|
289 |
+
label="Width", minimum=256, maximum=2048, value=1024, step=16
|
290 |
+
)
|
291 |
+
|
292 |
+
# 引导尺度输入
|
293 |
+
guidance_scale_input = gr.Slider(
|
294 |
+
label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
|
295 |
+
)
|
296 |
+
|
297 |
+
img_guidance_scale_input = gr.Slider(
|
298 |
+
label="img_guidance_scale", minimum=1.0, maximum=2.0, value=1.6, step=0.1
|
299 |
+
)
|
300 |
+
|
301 |
+
num_inference_steps = gr.Slider(
|
302 |
+
label="Inference Steps", minimum=1, maximum=100, value=50, step=1
|
303 |
+
)
|
304 |
+
|
305 |
+
Quantization = gr.Checkbox(
|
306 |
+
label="Low VRAM (8-bit Quantization)", value=True
|
307 |
+
)
|
308 |
+
|
309 |
+
seed_input = gr.Slider(
|
310 |
+
label="Seed", minimum=0, maximum=2147483647, value=42, step=1
|
311 |
+
)
|
312 |
+
|
313 |
+
# 生成按钮
|
314 |
+
generate_button = gr.Button("Generate Image")
|
315 |
+
|
316 |
+
with gr.Column():
|
317 |
+
# 输出图像框
|
318 |
+
output_image = gr.Image(label="Output Image")
|
319 |
+
|
320 |
+
# 按钮点击事件
|
321 |
+
generate_button.click(
|
322 |
+
generate_image,
|
323 |
+
inputs=[
|
324 |
+
prompt_input,
|
325 |
+
image_input_1,
|
326 |
+
image_input_2,
|
327 |
+
image_input_3,
|
328 |
+
height_input,
|
329 |
+
width_input,
|
330 |
+
guidance_scale_input,
|
331 |
+
img_guidance_scale_input,
|
332 |
+
num_inference_steps,
|
333 |
+
seed_input,
|
334 |
+
Quantization,
|
335 |
+
],
|
336 |
+
outputs=output_image,
|
337 |
+
)
|
338 |
+
|
339 |
+
gr.Examples(
|
340 |
+
examples=get_example(),
|
341 |
+
fn=run_for_examples,
|
342 |
+
inputs=[
|
343 |
+
prompt_input,
|
344 |
+
image_input_1,
|
345 |
+
image_input_2,
|
346 |
+
image_input_3,
|
347 |
+
height_input,
|
348 |
+
width_input,
|
349 |
+
guidance_scale_input,
|
350 |
+
img_guidance_scale_input,
|
351 |
+
num_inference_steps,
|
352 |
+
seed_input,
|
353 |
+
Quantization,
|
354 |
+
],
|
355 |
+
outputs=output_image,
|
356 |
+
)
|
357 |
+
|
358 |
+
# 启动应用
|
359 |
+
demo.launch()
|
docs/fine-tuning.md
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fine-tuning OmniGen
|
2 |
+
|
3 |
+
Fine-tuning Omnigen can better help you handle specific image generation tasks. For example, by fine-tuning on a person's images, you can generate multiple pictures of that person while maintaining task consistency.
|
4 |
+
|
5 |
+
A lot of previous work focused on designing new networks to facilitate specific tasks. For instance, ControlNet was proposed to handle image conditions, and IP-Adapter was constructed to maintain ID features. If you want to perform new tasks, you need to build new architectures and repeatedly debug them. Adding and adjusting extra network parameters is usually time-consuming and labor-intensive, which is not user-friendly and cost-efficient enough. However, with Omnigen, all of this becomes very simple.
|
6 |
+
|
7 |
+
By comparison, Omnigen can accept multi-modal conditional inputs and has been pre-trained on various tasks. You can fine-tune it on any task without designing specialized networks like ControlNet or IP-Adapter for a specific task.
|
8 |
+
|
9 |
+
**All you need to do is prepare the data and start training. You can break the limitations of previous models, allowing Omnigen to accomplish a variety of interesting tasks, even those that have never been done before.**
|
10 |
+
|
11 |
+
|
12 |
+
## Installation
|
13 |
+
|
14 |
+
```bash
|
15 |
+
git clone https://github.com/VectorSpaceLab/OmniGen.git
|
16 |
+
cd OmniGen
|
17 |
+
pip install -e .
|
18 |
+
```
|
19 |
+
|
20 |
+
|
21 |
+
## Full fine-tuning
|
22 |
+
|
23 |
+
### Fine-tuning command
|
24 |
+
|
25 |
+
```bash
|
26 |
+
accelerate launch \
|
27 |
+
--num_processes=1 \
|
28 |
+
--use_fsdp \
|
29 |
+
--fsdp_offload_params false \
|
30 |
+
--fsdp_sharding_strategy SHARD_GRAD_OP \
|
31 |
+
--fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \
|
32 |
+
--fsdp_transformer_layer_cls_to_wrap Phi3DecoderLayer \
|
33 |
+
--fsdp_state_dict_type FULL_STATE_DICT \
|
34 |
+
--fsdp_forward_prefetch false \
|
35 |
+
--fsdp_use_orig_params True \
|
36 |
+
--fsdp_cpu_ram_efficient_loading false \
|
37 |
+
--fsdp_sync_module_states True \
|
38 |
+
train.py \
|
39 |
+
--model_name_or_path Shitao/OmniGen-v1 \
|
40 |
+
--json_file ./toy_data/toy_data.jsonl \
|
41 |
+
--image_path ./toy_data/images \
|
42 |
+
--batch_size_per_device 1 \
|
43 |
+
--lr 2e-5 \
|
44 |
+
--keep_raw_resolution \
|
45 |
+
--max_image_size 1024 \
|
46 |
+
--gradient_accumulation_steps 1 \
|
47 |
+
--ckpt_every 100 \
|
48 |
+
--epochs 100 \
|
49 |
+
--log_every 1 \
|
50 |
+
--results_dir ./results/toy_finetune
|
51 |
+
```
|
52 |
+
|
53 |
+
Some important arguments:
|
54 |
+
- `num_processes`: number of GPU to use for training
|
55 |
+
- `model_name_or_path`: path to the pretrained model
|
56 |
+
- `json_file`: path to the json file containing the training data, e.g., ./toy_data/toy_data.jsonl
|
57 |
+
- `image_path`: path to the image folder, e.g., ./toy_data/images
|
58 |
+
- `batch_size_per_device`: batch size per device
|
59 |
+
- `lr`: learning rate
|
60 |
+
- `keep_raw_resolution`: whether to keep the original resolution of the image, if not, all images will be resized to (max_image_size, max_image_size)
|
61 |
+
- `max_image_size`: max image size
|
62 |
+
- `gradient_accumulation_steps`: number of steps to accumulate gradients
|
63 |
+
- `ckpt_every`: number of steps to save checkpoint
|
64 |
+
- `epochs`: number of epochs
|
65 |
+
- `log_every`: number of steps to log
|
66 |
+
- `results_dir`: path to the results folder
|
67 |
+
|
68 |
+
The data format of json_file is as follows:
|
69 |
+
```
|
70 |
+
{
|
71 |
+
"instruction": str,
|
72 |
+
"input_images": [str, str, ...],
|
73 |
+
"output_images": str
|
74 |
+
}
|
75 |
+
```
|
76 |
+
You can see a toy example in `./toy_data/toy_data.jsonl`.
|
77 |
+
|
78 |
+
If an OOM(Out of Memory) issue occurs, you can try to decrease the `batch_size_per_device` or `max_image_size`. You can also try to use LoRA instead of full fine-tuning.
|
79 |
+
|
80 |
+
|
81 |
+
### Inference
|
82 |
+
|
83 |
+
The checkpoint can be found at `{results_dir}/checkpoints/*`. You can use the following command to load saved checkpoint:
|
84 |
+
```python
|
85 |
+
from OmniGen import OmniGenPipeline
|
86 |
+
|
87 |
+
pipe = OmniGenPipeline.from_pretrained("checkpoint_path") # e.g., ./results/toy_finetune/checkpoints/0000200
|
88 |
+
```
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
## LoRA fine-tuning
|
95 |
+
LoRA fine-tuning is a simple way to fine-tune OmniGen with less GPU memory. To use lora, you should add `--use_lora` and `--lora_rank` to the command.
|
96 |
+
|
97 |
+
```bash
|
98 |
+
accelerate launch \
|
99 |
+
--num_processes=1 \
|
100 |
+
train.py \
|
101 |
+
--model_name_or_path Shitao/OmniGen-v1 \
|
102 |
+
--batch_size_per_device 2 \
|
103 |
+
--condition_dropout_prob 0.01 \
|
104 |
+
--lr 3e-4 \
|
105 |
+
--use_lora \
|
106 |
+
--lora_rank 8 \
|
107 |
+
--json_file ./toy_data/toy_data.jsonl \
|
108 |
+
--image_path ./toy_data/images \
|
109 |
+
--max_input_length_limit 18000 \
|
110 |
+
--keep_raw_resolution \
|
111 |
+
--max_image_size 1024 \
|
112 |
+
--gradient_accumulation_steps 1 \
|
113 |
+
--ckpt_every 100 \
|
114 |
+
--epochs 100 \
|
115 |
+
--log_every 1 \
|
116 |
+
--results_dir ./results/toy_finetune_lora
|
117 |
+
```
|
118 |
+
|
119 |
+
### Inference
|
120 |
+
|
121 |
+
The checkpoint can be found at `{results_dir}/checkpoints/*`. You can use the following command to load checkpoint:
|
122 |
+
```python
|
123 |
+
from OmniGen import OmniGenPipeline
|
124 |
+
|
125 |
+
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
|
126 |
+
pipe.merge_lora("checkpoint_path") # e.g., ./results/toy_finetune_lora/checkpoints/0000100
|
127 |
+
```
|
128 |
+
|
129 |
+
|
130 |
+
## A simple example
|
131 |
+
|
132 |
+
Here is an example for learning new concepts: "sks dog". We use five images of one dog from [dog-example](https://huggingface.co/datasets/diffusers/dog-example).
|
133 |
+
|
134 |
+
The json file is `./toy_data/toy_subject_data.jsonl`, and the images have been saved in `./toy_data/images`.
|
135 |
+
|
136 |
+
```bash
|
137 |
+
accelerate launch \
|
138 |
+
--num_processes=1 \
|
139 |
+
train.py \
|
140 |
+
--model_name_or_path Shitao/OmniGen-v1 \
|
141 |
+
--batch_size_per_device 2 \
|
142 |
+
--condition_dropout_prob 0.01 \
|
143 |
+
--lr 1e-3 \
|
144 |
+
--use_lora \
|
145 |
+
--lora_rank 8 \
|
146 |
+
--json_file ./toy_data/toy_subject_data.jsonl \
|
147 |
+
--image_path ./toy_data/images \
|
148 |
+
--max_input_length_limit 18000 \
|
149 |
+
--keep_raw_resolution \
|
150 |
+
--max_image_size 1024 \
|
151 |
+
--gradient_accumulation_steps 1 \
|
152 |
+
--ckpt_every 100 \
|
153 |
+
--epochs 200 \
|
154 |
+
--log_every 1 \
|
155 |
+
--results_dir ./results/toy_finetune_lora
|
156 |
+
```
|
157 |
+
|
158 |
+
After training, you can use the following command to generate images:
|
159 |
+
```python
|
160 |
+
from OmniGen import OmniGenPipeline
|
161 |
+
|
162 |
+
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
|
163 |
+
pipe.merge_lora("checkpoint_path") # e.g., ./results/toy_finetune_lora/checkpoints/0000200
|
164 |
+
|
165 |
+
images = pipe(
|
166 |
+
prompt="a photo of sks dog running in the snow",
|
167 |
+
height=1024,
|
168 |
+
width=1024,
|
169 |
+
guidance_scale=3
|
170 |
+
)
|
171 |
+
images[0].save("example_sks_dog_snow.png")
|
172 |
+
```
|
docs/inference.md
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inference with OmniGen
|
2 |
+
|
3 |
+
To handle some complex tasks, image generation models are becoming increasingly sophisticated, leading to more and more cumbersome workflows. Existing image generation models like SD and Flux require loading many additional network modules (such as ControlNet, IP-Adapter, Reference-Net) and extra preprocessing steps (e.g., face detection, pose detection, image cropping) to generate a satisfactory image. This complex workflow is not user-friendly. We believe that future image generation models should be simpler, generating various images directly through instructions, similar to how GPT works in language generation.
|
4 |
+
|
5 |
+
Therefore, we propose OmniGen, a model capable of handling various image generation tasks within a single framework. The goal of OmniGen is to complete various image generation tasks without relying on any additional components or image preprocessing steps. OmniGen supports tasks including text-to-image generation, image editing, subject-driven image generation, and classical vision tasks, among others. More capabilities can be found in our examples. We provide inference code so you can explore more unknown functionalities yourself.
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
## Install
|
10 |
+
```bash
|
11 |
+
git clone https://github.com/staoxiao/OmniGen.git
|
12 |
+
cd OmniGen
|
13 |
+
pip install -e .
|
14 |
+
```
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
## Generate Images
|
19 |
+
You can use the following code to generate images:
|
20 |
+
```python
|
21 |
+
from OmniGen import OmniGenPipeline
|
22 |
+
|
23 |
+
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
|
24 |
+
|
25 |
+
# Text to Image
|
26 |
+
images = pipe(
|
27 |
+
prompt="A curly-haired man in a red shirt is drinking tea.",
|
28 |
+
height=1024,
|
29 |
+
width=1024,
|
30 |
+
guidance_scale=2.5,
|
31 |
+
seed=0,
|
32 |
+
)
|
33 |
+
images[0].save("example_t2i.png") # save output PIL Image
|
34 |
+
|
35 |
+
# Multi-modal to Image
|
36 |
+
# In prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img>
|
37 |
+
# You can add multiple images in the input_images. Please ensure that each image has its placeholder. For example, for the list input_images [img1_path, img2_path], the prompt needs to have two placeholders: <img><|image_1|></img>, <img><|image_2|></img>.
|
38 |
+
images = pipe(
|
39 |
+
prompt="A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>."
|
40 |
+
input_images=["./imgs/test_cases/two_man.jpg"]
|
41 |
+
height=1024,
|
42 |
+
width=1024,
|
43 |
+
separate_cfg_infer=False, # if OOM, you can set separate_cfg_infer=True
|
44 |
+
guidance_scale=2.5,
|
45 |
+
img_guidance_scale=1.6
|
46 |
+
)
|
47 |
+
images[0].save("example_ti2i.png") # save output PIL image
|
48 |
+
```
|
49 |
+
|
50 |
+
Some important arguments:
|
51 |
+
- `guidance_scale`: The strength of the guidance. Based on our experience, it is usually best to set it between 2 and 3. The higher the value, the more similar the generated image will be to the prompt. If the image appears oversaturated, please reduce the scale.
|
52 |
+
- `height` and `width`: The height and width of the generated image. The default value is 1024x1024. OmniGen support any size, but these number must be divisible by 16.
|
53 |
+
- `num_inference_steps`: The number of steps to take in the diffusion process. The higher the value, the more detailed the generated image will be.
|
54 |
+
- `separate_cfg_infer`: Whether to use separate inference process for CFG guidance. If set to True, memory cost will be lower but the generation speed will be slower. Default is False.
|
55 |
+
- `use_kv_cache`: Whether to use key-value cache. Default is True.
|
56 |
+
- `seed`: The seed for random number generator.
|
57 |
+
|
58 |
+
**More examples please refer to [inference.ipynb](../inference.ipynb)**
|
59 |
+
|
60 |
+
|
61 |
+
#### Input data
|
62 |
+
OmniGen can accept multi-modal input data. Specifically, you should pass two arguments: `prompt` and `input_images`.
|
63 |
+
For text to image generation, you can pass a string as `prompt`, or pass a list of strings as `prompt` to generate multiple images.
|
64 |
+
|
65 |
+
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>`.
|
66 |
+
For example, if you want to generate an image with a person holding a bouquet of flowers, you can pass the following prompt:
|
67 |
+
```
|
68 |
+
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."
|
69 |
+
input_images = ["./imgs/test_cases/liuyifei.png"]
|
70 |
+
```
|
71 |
+
The placeholder `<|image_1|>` will be replaced by the image at `input_images[0]`, i.e., `./imgs/test_cases/liuyifei.png`.
|
72 |
+
|
73 |
+
If you want to generate multiple images, you can pass a list of prompts and a list of image paths. For example:
|
74 |
+
```
|
75 |
+
prompt = ["A woman holds a bouquet of flowers and faces the camera.", "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."]
|
76 |
+
input_images = [[], ["./imgs/test_cases/liuyifei.png"]]
|
77 |
+
```
|
78 |
+
|
79 |
+
|
80 |
+
#### Gradio Demo
|
81 |
+
We have constructed a online demo in [Huggingface](https://huggingface.co/spaces/Shitao/OmniGen).
|
82 |
+
|
83 |
+
For the local gradio demo, you can run with the following command:
|
84 |
+
```python
|
85 |
+
python app.py
|
86 |
+
```
|
87 |
+
|
88 |
+
|
89 |
+
## Tips
|
90 |
+
- OOM issue: If you encounter OOM issue, you can try to set `separate_cfg_infer=True`. This will reduce the memory usage but increase the generation latecy. You also can reduce the size of the image, e.g., `height=768, width=512`.
|
91 |
+
- Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
|
92 |
+
- Not match the prompt: If the image does not match the prompt, please try to increase the `guidance_scale`.
|
93 |
+
- Low-quality: More detailed prompt will lead to better results. Besides, larger size of the image (`height` and `width`) will also help.
|
94 |
+
- Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
|
95 |
+
- Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
|
96 |
+
- For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
|
imgs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
imgs/demo_cases.png
ADDED
Git LFS Details
|
imgs/demo_cases/AI_Pioneers.jpg
ADDED
imgs/demo_cases/edit.png
ADDED
Git LFS Details
|
imgs/demo_cases/entity.png
ADDED
Git LFS Details
|
imgs/demo_cases/reasoning.png
ADDED
Git LFS Details
|
imgs/demo_cases/same_pose.png
ADDED
Git LFS Details
|
imgs/demo_cases/skeletal.png
ADDED
imgs/demo_cases/skeletal2img.png
ADDED
Git LFS Details
|
imgs/demo_cases/t2i_woman_with_book.png
ADDED
Git LFS Details
|
imgs/overall.jpg
ADDED
Git LFS Details
|
imgs/referring.png
ADDED
Git LFS Details
|
imgs/test_cases/1.jpg
ADDED
Git LFS Details
|
imgs/test_cases/2.jpg
ADDED
Git LFS Details
|
imgs/test_cases/3.jpg
ADDED
Git LFS Details
|
imgs/test_cases/4.jpg
ADDED
Git LFS Details
|
imgs/test_cases/Amanda.jpg
ADDED
Git LFS Details
|
imgs/test_cases/control.jpg
ADDED
Git LFS Details
|
imgs/test_cases/icl1.jpg
ADDED
imgs/test_cases/icl2.jpg
ADDED
imgs/test_cases/icl3.jpg
ADDED
imgs/test_cases/lecun.png
ADDED
imgs/test_cases/mckenna.jpg
ADDED
Git LFS Details
|
imgs/test_cases/pose.png
ADDED
imgs/test_cases/rose.jpg
ADDED