Upload 23 files
Browse files- .gitattributes +1 -0
- Dockerfile +69 -0
- README.md +37 -11
- assets/example.gif +3 -0
- server/config.py +48 -0
- server/main.py +211 -0
- server/requirements.txt +12 -0
- server/wrapper.py +117 -0
- start.sh +2 -0
- view/README.md +46 -0
- view/package-lock.json +0 -0
- view/package.json +47 -0
- view/pnpm-lock.yaml +0 -0
- view/public/favicon.ico +0 -0
- view/public/images/white.jpg +0 -0
- view/public/index.html +46 -0
- view/public/manifest.json +25 -0
- view/public/robots.txt +3 -0
- view/src/App.tsx +72 -0
- view/src/index.css +8 -0
- view/src/index.tsx +19 -0
- view/src/react-app-env.d.ts +1 -0
- view/src/reportWebVitals.ts +15 -0
- view/tsconfig.json +26 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/example.gif filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# build frontend with node
|
2 |
+
FROM node:20-alpine AS frontend
|
3 |
+
RUN apk add --no-cache libc6-compat
|
4 |
+
WORKDIR /app
|
5 |
+
|
6 |
+
COPY view .
|
7 |
+
RUN npm ci
|
8 |
+
RUN npm run build
|
9 |
+
|
10 |
+
# build backend on CUDA
|
11 |
+
FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 AS backend
|
12 |
+
WORKDIR /app
|
13 |
+
|
14 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
15 |
+
ENV NODE_MAJOR=20
|
16 |
+
|
17 |
+
RUN apt-get update && \
|
18 |
+
apt-get upgrade -y && \
|
19 |
+
apt-get install -y --no-install-recommends \
|
20 |
+
git \
|
21 |
+
git-lfs \
|
22 |
+
wget \
|
23 |
+
curl \
|
24 |
+
# python build dependencies \
|
25 |
+
build-essential \
|
26 |
+
libssl-dev \
|
27 |
+
zlib1g-dev \
|
28 |
+
libbz2-dev \
|
29 |
+
libreadline-dev \
|
30 |
+
libsqlite3-dev \
|
31 |
+
libncursesw5-dev \
|
32 |
+
xz-utils \
|
33 |
+
tk-dev \
|
34 |
+
libxml2-dev \
|
35 |
+
libxmlsec1-dev \
|
36 |
+
libffi-dev \
|
37 |
+
liblzma-dev && \
|
38 |
+
apt-get clean && \
|
39 |
+
rm -rf /var/lib/apt/lists/*
|
40 |
+
|
41 |
+
USER root
|
42 |
+
|
43 |
+
RUN useradd -m -u 1000 user
|
44 |
+
USER user
|
45 |
+
ENV HOME=/home/user \
|
46 |
+
PATH=/home/user/.local/bin:$PATH
|
47 |
+
WORKDIR $HOME/app
|
48 |
+
|
49 |
+
RUN curl https://pyenv.run | bash
|
50 |
+
ENV PATH=$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH
|
51 |
+
ARG PYTHON_VERSION=3.10.12
|
52 |
+
RUN pyenv install $PYTHON_VERSION && \
|
53 |
+
pyenv global $PYTHON_VERSION && \
|
54 |
+
pyenv rehash && \
|
55 |
+
pip install --no-cache-dir -U pip setuptools wheel
|
56 |
+
|
57 |
+
COPY --chown=user:user . .
|
58 |
+
# change dir since pip needs to seed whl folder
|
59 |
+
RUN cd server && pip install --no-cache-dir --upgrade -r requirements.txt
|
60 |
+
RUN --mount=type=secret,id=GITHUB_TOKEN,mode=0444,required=true \
|
61 |
+
pip install git+https://$(cat /run/secrets/GITHUB_TOKEN)@github.com/cumulo-autumn/StreamDiffusion.git@main#egg=stream-diffusion[tensorrt]
|
62 |
+
RUN python -m streamdiffusion.tools.install-tensorrt
|
63 |
+
|
64 |
+
COPY --from=frontend /app/build ./view/build
|
65 |
+
|
66 |
+
WORKDIR $HOME/app/server
|
67 |
+
|
68 |
+
USER user
|
69 |
+
CMD ["python", "main.py"]
|
README.md
CHANGED
@@ -1,11 +1,37 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Txt2Img Example
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<img src="./assets/example.gif" width=80%>
|
5 |
+
</p>
|
6 |
+
|
7 |
+
This example provides a simple implementation of the use of StreamDiffusion to generate images from text.
|
8 |
+
|
9 |
+
## Usage
|
10 |
+
|
11 |
+
```bash
|
12 |
+
chmod +x ./start.sh && ./start.sh
|
13 |
+
```
|
14 |
+
|
15 |
+
or
|
16 |
+
|
17 |
+
```bash
|
18 |
+
cd server
|
19 |
+
python3 main.py &
|
20 |
+
cd ../view
|
21 |
+
npm start
|
22 |
+
```
|
23 |
+
|
24 |
+
## Docker
|
25 |
+
|
26 |
+
Build
|
27 |
+
`GITHUB_TOKEN` is temp until project is public
|
28 |
+
```bash
|
29 |
+
docker build --secret id=GITHUB_TOKEN,src=./github_token.txt -t realtime-txt2img .
|
30 |
+
```
|
31 |
+
|
32 |
+
Run
|
33 |
+
```bash
|
34 |
+
docker run -ti -p 9090:9090 -e HF_HOME=/data -v ~/.cache/huggingface:/data --gpus all realtime-txt2img
|
35 |
+
```
|
36 |
+
|
37 |
+
`-e HF_HOME=/data -v ~/.cache/huggingface:/data` is used to mount your local huggingface cache to the container, so that you don't need to download the model again.
|
assets/example.gif
ADDED
Git LFS Details
|
server/config.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class Config:
|
9 |
+
"""
|
10 |
+
The configuration for the API.
|
11 |
+
"""
|
12 |
+
|
13 |
+
####################################################################
|
14 |
+
# Server
|
15 |
+
####################################################################
|
16 |
+
# In most cases, you should leave this as it is.
|
17 |
+
host: str = "0.0.0.0"
|
18 |
+
port: int = 9090
|
19 |
+
workers: int = 1
|
20 |
+
|
21 |
+
####################################################################
|
22 |
+
# Generation configuration
|
23 |
+
####################################################################
|
24 |
+
# The threshold for the Levenstein distance.
|
25 |
+
levenstein_distance_threshold: int = 3
|
26 |
+
|
27 |
+
####################################################################
|
28 |
+
# Model configuration
|
29 |
+
####################################################################
|
30 |
+
# SD1.x variant model
|
31 |
+
model_id: str = "SimianLuo/LCM_Dreamshaper_v7"
|
32 |
+
|
33 |
+
# LCM-LORA model
|
34 |
+
lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
|
35 |
+
# TinyVAE model
|
36 |
+
vae_id: str = "madebyollin/taesd"
|
37 |
+
# Device to use
|
38 |
+
device: torch.device = torch.device("cuda")
|
39 |
+
# Data type
|
40 |
+
dtype: torch.dtype = torch.float16
|
41 |
+
|
42 |
+
####################################################################
|
43 |
+
# Inference configuration
|
44 |
+
####################################################################
|
45 |
+
# Number of inference steps
|
46 |
+
t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
|
47 |
+
# Number of warmup steps
|
48 |
+
warmup: int = 10
|
server/main.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import base64
|
3 |
+
import logging
|
4 |
+
from io import BytesIO
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import uvicorn
|
8 |
+
from config import Config
|
9 |
+
from fastapi import FastAPI
|
10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
11 |
+
from fastapi.staticfiles import StaticFiles
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
from pydantic import BaseModel
|
15 |
+
from wrapper import StreamDiffusionWrapper
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("uvicorn")
|
19 |
+
PROJECT_DIR = Path(__file__).parent.parent
|
20 |
+
|
21 |
+
|
22 |
+
class PredictInputModel(BaseModel):
|
23 |
+
"""
|
24 |
+
The input model for the /predict endpoint.
|
25 |
+
"""
|
26 |
+
|
27 |
+
prompt: str
|
28 |
+
|
29 |
+
|
30 |
+
class PredictResponseModel(BaseModel):
|
31 |
+
"""
|
32 |
+
The response model for the /predict endpoint.
|
33 |
+
"""
|
34 |
+
|
35 |
+
base64_images: list[str]
|
36 |
+
|
37 |
+
|
38 |
+
class UpdatePromptResponseModel(BaseModel):
|
39 |
+
"""
|
40 |
+
The response model for the /update_prompt endpoint.
|
41 |
+
"""
|
42 |
+
|
43 |
+
prompt: str
|
44 |
+
|
45 |
+
|
46 |
+
class Api:
|
47 |
+
def __init__(self, config: Config) -> None:
|
48 |
+
"""
|
49 |
+
Initialize the API.
|
50 |
+
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
config : Config
|
54 |
+
The configuration.
|
55 |
+
"""
|
56 |
+
self.config = config
|
57 |
+
self.stream_diffusion = StreamDiffusionWrapper(
|
58 |
+
model_id=config.model_id,
|
59 |
+
lcm_lora_id=config.lcm_lora_id,
|
60 |
+
vae_id=config.vae_id,
|
61 |
+
device=config.device,
|
62 |
+
dtype=config.dtype,
|
63 |
+
t_index_list=config.t_index_list,
|
64 |
+
warmup=config.warmup,
|
65 |
+
)
|
66 |
+
self.app = FastAPI()
|
67 |
+
self.app.add_api_route(
|
68 |
+
"/api/predict",
|
69 |
+
self._predict,
|
70 |
+
methods=["POST"],
|
71 |
+
response_model=PredictResponseModel,
|
72 |
+
)
|
73 |
+
self.app.add_middleware(
|
74 |
+
CORSMiddleware,
|
75 |
+
allow_origins=["*"],
|
76 |
+
allow_credentials=True,
|
77 |
+
allow_methods=["*"],
|
78 |
+
allow_headers=["*"],
|
79 |
+
)
|
80 |
+
self.app.mount(
|
81 |
+
"/", StaticFiles(directory="../view/build", html=True), name="public"
|
82 |
+
)
|
83 |
+
|
84 |
+
self._predict_lock = asyncio.Lock()
|
85 |
+
self._update_prompt_lock = asyncio.Lock()
|
86 |
+
|
87 |
+
self.last_prompt: str = ""
|
88 |
+
self.last_images: list[str] = [""]
|
89 |
+
|
90 |
+
async def _predict(self, inp: PredictInputModel) -> PredictResponseModel:
|
91 |
+
"""
|
92 |
+
Predict an image and return.
|
93 |
+
|
94 |
+
Parameters
|
95 |
+
----------
|
96 |
+
inp : PredictInputModel
|
97 |
+
The input.
|
98 |
+
|
99 |
+
Returns
|
100 |
+
-------
|
101 |
+
PredictResponseModel
|
102 |
+
The prediction result.
|
103 |
+
"""
|
104 |
+
async with self._predict_lock:
|
105 |
+
if (
|
106 |
+
self._calc_levenstein_distance(inp.prompt, self.last_prompt)
|
107 |
+
< self.config.levenstein_distance_threshold
|
108 |
+
):
|
109 |
+
logger.info("Using cached images")
|
110 |
+
return PredictResponseModel(base64_images=self.last_images)
|
111 |
+
self.last_prompt = inp.prompt
|
112 |
+
self.last_images = [self._pil_to_base64(image) for image in self.stream_diffusion(inp.prompt)]
|
113 |
+
return PredictResponseModel(base64_images=self.last_images)
|
114 |
+
|
115 |
+
def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes:
|
116 |
+
"""
|
117 |
+
Convert a PIL image to base64.
|
118 |
+
|
119 |
+
Parameters
|
120 |
+
----------
|
121 |
+
image : Image.Image
|
122 |
+
The PIL image.
|
123 |
+
|
124 |
+
format : str
|
125 |
+
The image format, by default "JPEG".
|
126 |
+
|
127 |
+
Returns
|
128 |
+
-------
|
129 |
+
bytes
|
130 |
+
The base64 image.
|
131 |
+
"""
|
132 |
+
buffered = BytesIO()
|
133 |
+
image.convert("RGB").save(buffered, format=format)
|
134 |
+
return base64.b64encode(buffered.getvalue()).decode("ascii")
|
135 |
+
|
136 |
+
def _base64_to_pil(self, base64_image: str) -> Image.Image:
|
137 |
+
"""
|
138 |
+
Convert a base64 image to PIL.
|
139 |
+
|
140 |
+
Parameters
|
141 |
+
----------
|
142 |
+
base64_image : str
|
143 |
+
The base64 image.
|
144 |
+
|
145 |
+
Returns
|
146 |
+
-------
|
147 |
+
Image.Image
|
148 |
+
The PIL image.
|
149 |
+
"""
|
150 |
+
if "base64," in base64_image:
|
151 |
+
base64_image = base64_image.split("base64,")[1]
|
152 |
+
return Image.open(BytesIO(base64.b64decode(base64_image))).convert("RGB")
|
153 |
+
|
154 |
+
def _calc_levenstein_distance(self, a: str, b: str) -> int:
|
155 |
+
"""
|
156 |
+
Calculate the Levenstein distance between two strings.
|
157 |
+
|
158 |
+
Parameters
|
159 |
+
----------
|
160 |
+
a : str
|
161 |
+
The first string.
|
162 |
+
|
163 |
+
b : str
|
164 |
+
The second string.
|
165 |
+
|
166 |
+
Returns
|
167 |
+
-------
|
168 |
+
int
|
169 |
+
The Levenstein distance.
|
170 |
+
"""
|
171 |
+
if a == b:
|
172 |
+
return 0
|
173 |
+
a_k = len(a)
|
174 |
+
b_k = len(b)
|
175 |
+
if a == "":
|
176 |
+
return b_k
|
177 |
+
if b == "":
|
178 |
+
return a_k
|
179 |
+
matrix = [[] for i in range(a_k + 1)]
|
180 |
+
for i in range(a_k + 1):
|
181 |
+
matrix[i] = [0 for j in range(b_k + 1)]
|
182 |
+
for i in range(a_k + 1):
|
183 |
+
matrix[i][0] = i
|
184 |
+
for j in range(b_k + 1):
|
185 |
+
matrix[0][j] = j
|
186 |
+
for i in range(1, a_k + 1):
|
187 |
+
ac = a[i - 1]
|
188 |
+
for j in range(1, b_k + 1):
|
189 |
+
bc = b[j - 1]
|
190 |
+
cost = 0 if (ac == bc) else 1
|
191 |
+
matrix[i][j] = min(
|
192 |
+
[
|
193 |
+
matrix[i - 1][j] + 1,
|
194 |
+
matrix[i][j - 1] + 1,
|
195 |
+
matrix[i - 1][j - 1] + cost,
|
196 |
+
]
|
197 |
+
)
|
198 |
+
return matrix[a_k][b_k]
|
199 |
+
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
from config import Config
|
203 |
+
|
204 |
+
config = Config()
|
205 |
+
|
206 |
+
uvicorn.run(
|
207 |
+
Api(config).app,
|
208 |
+
host=config.host,
|
209 |
+
port=config.port,
|
210 |
+
workers=config.workers,
|
211 |
+
)
|
server/requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
xformers
|
2 |
+
uvicorn[standard]==0.24.0.post1
|
3 |
+
fastapi==0.104
|
4 |
+
accelerate
|
5 |
+
# git+https://github.com/cumulo-autumn/StreamDiffusion.git@main#egg=stream-diffusion
|
6 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
7 |
+
torch
|
8 |
+
torchvision
|
9 |
+
torchaudio
|
10 |
+
triton
|
11 |
+
# https://github.com/chengzeyi/stable-fast --index-url https://download.pytorch.org/whl/cu121
|
12 |
+
https://github.com/chengzeyi/stable-fast/releases/download/v0.0.14/stable_fast-0.0.14+torch210cu121-cp310-cp310-manylinux2014_x86_64.whl
|
server/wrapper.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import requests
|
7 |
+
import torch
|
8 |
+
from diffusers import AutoencoderTiny, StableDiffusionPipeline
|
9 |
+
|
10 |
+
from streamdiffusion import StreamDiffusion
|
11 |
+
from streamdiffusion.acceleration.sfast import accelerate_with_stable_fast
|
12 |
+
from streamdiffusion.image_utils import postprocess_image
|
13 |
+
|
14 |
+
|
15 |
+
def download_image(url: str):
|
16 |
+
response = requests.get(url)
|
17 |
+
image = PIL.Image.open(io.BytesIO(response.content))
|
18 |
+
return image
|
19 |
+
|
20 |
+
|
21 |
+
class StreamDiffusionWrapper:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
model_id: str,
|
25 |
+
lcm_lora_id: str,
|
26 |
+
vae_id: str,
|
27 |
+
device: str,
|
28 |
+
dtype: str,
|
29 |
+
t_index_list: List[int],
|
30 |
+
warmup: int,
|
31 |
+
):
|
32 |
+
self.device = device
|
33 |
+
self.dtype = dtype
|
34 |
+
self.prompt = ""
|
35 |
+
|
36 |
+
self.stream = self._load_model(
|
37 |
+
model_id=model_id,
|
38 |
+
lcm_lora_id=lcm_lora_id,
|
39 |
+
vae_id=vae_id,
|
40 |
+
t_index_list=t_index_list,
|
41 |
+
warmup=warmup,
|
42 |
+
)
|
43 |
+
|
44 |
+
def _load_model(
|
45 |
+
self,
|
46 |
+
model_id: str,
|
47 |
+
lcm_lora_id: str,
|
48 |
+
vae_id: str,
|
49 |
+
t_index_list: List[int],
|
50 |
+
warmup: int,
|
51 |
+
):
|
52 |
+
if os.path.exists(model_id):
|
53 |
+
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(model_id).to(
|
54 |
+
device=self.device, dtype=self.dtype
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(model_id).to(
|
58 |
+
device=self.device, dtype=self.dtype
|
59 |
+
)
|
60 |
+
|
61 |
+
stream = StreamDiffusion(
|
62 |
+
pipe=pipe,
|
63 |
+
t_index_list=t_index_list,
|
64 |
+
torch_dtype=self.dtype,
|
65 |
+
is_drawing=True,
|
66 |
+
)
|
67 |
+
stream.load_lcm_lora(lcm_lora_id)
|
68 |
+
stream.fuse_lora()
|
69 |
+
stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(device=pipe.device, dtype=pipe.dtype)
|
70 |
+
stream = accelerate_with_stable_fast(stream)
|
71 |
+
|
72 |
+
stream.prepare(
|
73 |
+
"",
|
74 |
+
num_inference_steps=50,
|
75 |
+
generator=torch.manual_seed(2),
|
76 |
+
)
|
77 |
+
|
78 |
+
# warmup
|
79 |
+
for _ in range(warmup):
|
80 |
+
start = torch.cuda.Event(enable_timing=True)
|
81 |
+
end = torch.cuda.Event(enable_timing=True)
|
82 |
+
|
83 |
+
start.record()
|
84 |
+
stream.txt2img()
|
85 |
+
end.record()
|
86 |
+
|
87 |
+
torch.cuda.synchronize()
|
88 |
+
|
89 |
+
return stream
|
90 |
+
|
91 |
+
def __call__(self, prompt: str) -> List[PIL.Image.Image]:
|
92 |
+
self.stream.prepare("")
|
93 |
+
|
94 |
+
images = []
|
95 |
+
for i in range(9 + 3):
|
96 |
+
start = torch.cuda.Event(enable_timing=True)
|
97 |
+
end = torch.cuda.Event(enable_timing=True)
|
98 |
+
|
99 |
+
start.record()
|
100 |
+
|
101 |
+
if self.prompt != prompt:
|
102 |
+
self.stream.update_prompt(prompt)
|
103 |
+
self.prompt = prompt
|
104 |
+
|
105 |
+
x_output = self.stream.txt2img()
|
106 |
+
if i >= 3:
|
107 |
+
images.append(postprocess_image(x_output, output_type="pil")[0])
|
108 |
+
end.record()
|
109 |
+
|
110 |
+
torch.cuda.synchronize()
|
111 |
+
|
112 |
+
return images
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
wrapper = StreamDiffusionWrapper(10, 10)
|
117 |
+
wrapper()
|
start.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
cd view && npm run build && cd ..
|
2 |
+
cd server && python3 main.py
|
view/README.md
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Getting Started with Create React App
|
2 |
+
|
3 |
+
This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app).
|
4 |
+
|
5 |
+
## Available Scripts
|
6 |
+
|
7 |
+
In the project directory, you can run:
|
8 |
+
|
9 |
+
### `npm start`
|
10 |
+
|
11 |
+
Runs the app in the development mode.\
|
12 |
+
Open [http://localhost:3000](http://localhost:3000) to view it in the browser.
|
13 |
+
|
14 |
+
The page will reload if you make edits.\
|
15 |
+
You will also see any lint errors in the console.
|
16 |
+
|
17 |
+
### `npm test`
|
18 |
+
|
19 |
+
Launches the test runner in the interactive watch mode.\
|
20 |
+
See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information.
|
21 |
+
|
22 |
+
### `npm run build`
|
23 |
+
|
24 |
+
Builds the app for production to the `build` folder.\
|
25 |
+
It correctly bundles React in production mode and optimizes the build for the best performance.
|
26 |
+
|
27 |
+
The build is minified and the filenames include the hashes.\
|
28 |
+
Your app is ready to be deployed!
|
29 |
+
|
30 |
+
See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information.
|
31 |
+
|
32 |
+
### `npm run eject`
|
33 |
+
|
34 |
+
**Note: this is a one-way operation. Once you `eject`, you can’t go back!**
|
35 |
+
|
36 |
+
If you aren’t satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project.
|
37 |
+
|
38 |
+
Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you’re on your own.
|
39 |
+
|
40 |
+
You don’t have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn’t feel obligated to use this feature. However we understand that this tool wouldn’t be useful if you couldn’t customize it when you are ready for it.
|
41 |
+
|
42 |
+
## Learn More
|
43 |
+
|
44 |
+
You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started).
|
45 |
+
|
46 |
+
To learn React, check out the [React documentation](https://reactjs.org/).
|
view/package-lock.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
view/package.json
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "view",
|
3 |
+
"version": "0.1.0",
|
4 |
+
"private": true,
|
5 |
+
"dependencies": {
|
6 |
+
"@emotion/react": "^11.11.1",
|
7 |
+
"@emotion/styled": "^11.11.0",
|
8 |
+
"@mui/material": "^5.14.18",
|
9 |
+
"@testing-library/jest-dom": "^5.17.0",
|
10 |
+
"@testing-library/react": "^13.4.0",
|
11 |
+
"@testing-library/user-event": "^13.5.0",
|
12 |
+
"@types/jest": "^27.5.2",
|
13 |
+
"@types/node": "^16.18.64",
|
14 |
+
"@types/react": "^18.2.38",
|
15 |
+
"@types/react-dom": "^18.2.17",
|
16 |
+
"react": "^18.2.0",
|
17 |
+
"react-dom": "^18.2.0",
|
18 |
+
"react-scripts": "5.0.1",
|
19 |
+
"typescript": "^4.9.5",
|
20 |
+
"web-vitals": "^2.1.4"
|
21 |
+
},
|
22 |
+
"scripts": {
|
23 |
+
"start": "react-scripts start",
|
24 |
+
"build": "react-scripts build",
|
25 |
+
"test": "react-scripts test",
|
26 |
+
"eject": "react-scripts eject"
|
27 |
+
},
|
28 |
+
"eslintConfig": {
|
29 |
+
"extends": [
|
30 |
+
"react-app",
|
31 |
+
"react-app/jest"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
"browserslist": {
|
35 |
+
"production": [
|
36 |
+
">0.2%",
|
37 |
+
"not dead",
|
38 |
+
"not op_mini all"
|
39 |
+
],
|
40 |
+
"development": [
|
41 |
+
"last 1 chrome version",
|
42 |
+
"last 1 firefox version",
|
43 |
+
"last 1 safari version"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
"proxy": "http://localhost:9090"
|
47 |
+
}
|
view/pnpm-lock.yaml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
view/public/favicon.ico
ADDED
view/public/images/white.jpg
ADDED
view/public/index.html
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="utf-8" />
|
5 |
+
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
|
6 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
7 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
8 |
+
<link href="https://fonts.googleapis.com/css2?family=Kanit:wght@500&display=swap" rel="stylesheet">
|
9 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
10 |
+
<meta name="theme-color" content="#000000" />
|
11 |
+
<meta
|
12 |
+
name="description"
|
13 |
+
content="Web site created using create-react-app"
|
14 |
+
/>
|
15 |
+
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
|
16 |
+
<!--
|
17 |
+
manifest.json provides metadata used when your web app is installed on a
|
18 |
+
user's mobile device or desktop. See https://developers.google.com/web/fundamentals/web-app-manifest/
|
19 |
+
-->
|
20 |
+
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
|
21 |
+
<!--
|
22 |
+
Notice the use of %PUBLIC_URL% in the tags above.
|
23 |
+
It will be replaced with the URL of the `public` folder during the build.
|
24 |
+
Only files inside the `public` folder can be referenced from the HTML.
|
25 |
+
|
26 |
+
Unlike "/favicon.ico" or "favicon.ico", "%PUBLIC_URL%/favicon.ico" will
|
27 |
+
work correctly both with client-side routing and a non-root public URL.
|
28 |
+
Learn how to configure a non-root public URL by running `npm run build`.
|
29 |
+
-->
|
30 |
+
<title>React App</title>
|
31 |
+
</head>
|
32 |
+
<body>
|
33 |
+
<noscript>You need to enable JavaScript to run this app.</noscript>
|
34 |
+
<div id="root"></div>
|
35 |
+
<!--
|
36 |
+
This HTML file is a template.
|
37 |
+
If you open it directly in the browser, you will see an empty page.
|
38 |
+
|
39 |
+
You can add webfonts, meta tags, or analytics to this file.
|
40 |
+
The build step will place the bundled scripts into the <body> tag.
|
41 |
+
|
42 |
+
To begin the development, run `npm start` or `yarn start`.
|
43 |
+
To create a production bundle, use `npm run build` or `yarn build`.
|
44 |
+
-->
|
45 |
+
</body>
|
46 |
+
</html>
|
view/public/manifest.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"short_name": "React App",
|
3 |
+
"name": "Create React App Sample",
|
4 |
+
"icons": [
|
5 |
+
{
|
6 |
+
"src": "favicon.ico",
|
7 |
+
"sizes": "128x128 64x64 32x32 24x24 16x16",
|
8 |
+
"type": "image/x-icon"
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"src": "logo192.png",
|
12 |
+
"type": "image/png",
|
13 |
+
"sizes": "192x192"
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"src": "logo512.png",
|
17 |
+
"type": "image/png",
|
18 |
+
"sizes": "512x512"
|
19 |
+
}
|
20 |
+
],
|
21 |
+
"start_url": ".",
|
22 |
+
"display": "standalone",
|
23 |
+
"theme_color": "#000000",
|
24 |
+
"background_color": "#ffffff"
|
25 |
+
}
|
view/public/robots.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# https://www.robotstxt.org/robotstxt.html
|
2 |
+
User-agent: *
|
3 |
+
Disallow:
|
view/src/App.tsx
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React, { useCallback, useEffect, useState } from "react";
|
2 |
+
import { TextField, Grid, Paper } from "@mui/material";
|
3 |
+
|
4 |
+
function App() {
|
5 |
+
const [inputPrompt, setInputPrompt] = useState("");
|
6 |
+
const [images, setImages] = useState(Array(9).fill("images/white.jpg"));
|
7 |
+
|
8 |
+
const fetchImages = useCallback(async () => {
|
9 |
+
try {
|
10 |
+
const response = await fetch("/api/predict", {
|
11 |
+
method: "POST",
|
12 |
+
headers: { 'Content-Type': 'application/json' },
|
13 |
+
body: JSON.stringify({ prompt: inputPrompt })
|
14 |
+
});
|
15 |
+
const data = await response.json();
|
16 |
+
const imageUrls = data.base64_images.map((base64: string) => `data:image/jpeg;base64,${base64}`);
|
17 |
+
setImages(imageUrls);
|
18 |
+
} catch (error) {
|
19 |
+
console.error("Error fetching images:", error);
|
20 |
+
}
|
21 |
+
}, [inputPrompt]);
|
22 |
+
|
23 |
+
const handlePromptChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
24 |
+
setInputPrompt(event.target.value);
|
25 |
+
fetchImages();
|
26 |
+
};
|
27 |
+
|
28 |
+
return (
|
29 |
+
<div
|
30 |
+
className="App"
|
31 |
+
style={{
|
32 |
+
backgroundColor: "#282c34",
|
33 |
+
height: "100vh",
|
34 |
+
display: "flex",
|
35 |
+
alignItems: "center",
|
36 |
+
justifyContent: "center",
|
37 |
+
margin: "0",
|
38 |
+
color: "#ffffff",
|
39 |
+
padding: "20px",
|
40 |
+
}}
|
41 |
+
>
|
42 |
+
<div
|
43 |
+
style={{
|
44 |
+
backgroundColor: "#282c34",
|
45 |
+
alignItems: "center",
|
46 |
+
justifyContent: "center",
|
47 |
+
display: "flex",
|
48 |
+
flexDirection: "column",
|
49 |
+
}}
|
50 |
+
>
|
51 |
+
<Grid container spacing={2}>
|
52 |
+
{images.map((image, index) => (
|
53 |
+
<Grid item xs={4} key={index}>
|
54 |
+
<Paper style={{ padding: "10px", textAlign: "center" }}>
|
55 |
+
<img src={image} alt={`Generated ${index}`} style={{ maxWidth: "100%", maxHeight: "200px", borderRadius: "10px" }} />
|
56 |
+
</Paper>
|
57 |
+
</Grid>
|
58 |
+
))}
|
59 |
+
</Grid>
|
60 |
+
<TextField
|
61 |
+
variant="outlined"
|
62 |
+
value={inputPrompt}
|
63 |
+
onChange={handlePromptChange}
|
64 |
+
style={{ marginBottom: "20px", marginTop: "20px", width: "640px", color: "#ffffff", borderColor: "#ffffff", borderRadius: "10px", backgroundColor: "#ffffff" }}
|
65 |
+
placeholder="Enter a prompt"
|
66 |
+
/>
|
67 |
+
</div>
|
68 |
+
</div>
|
69 |
+
);
|
70 |
+
}
|
71 |
+
|
72 |
+
export default App;
|
view/src/index.css
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@import url('https://fonts.googleapis.com/css?family=Kanit:400,700&display=swap');
|
2 |
+
|
3 |
+
body {
|
4 |
+
font-family: 'Kanit', 'Ubuntu', sans-serif;
|
5 |
+
margin: 0;
|
6 |
+
-webkit-font-smoothing: antialiased;
|
7 |
+
-moz-osx-font-smoothing: grayscale;
|
8 |
+
}
|
view/src/index.tsx
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React from "react";
|
2 |
+
import ReactDOM from "react-dom/client";
|
3 |
+
import "./index.css";
|
4 |
+
import App from "./App";
|
5 |
+
import reportWebVitals from "./reportWebVitals";
|
6 |
+
|
7 |
+
const root = ReactDOM.createRoot(
|
8 |
+
document.getElementById("root") as HTMLElement
|
9 |
+
);
|
10 |
+
root.render(
|
11 |
+
<React.StrictMode>
|
12 |
+
<App />
|
13 |
+
</React.StrictMode>
|
14 |
+
);
|
15 |
+
|
16 |
+
// If you want to start measuring performance in your app, pass a function
|
17 |
+
// to log results (for example: reportWebVitals(console.log))
|
18 |
+
// or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals
|
19 |
+
reportWebVitals();
|
view/src/react-app-env.d.ts
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
/// <reference types="react-scripts" />
|
view/src/reportWebVitals.ts
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { ReportHandler } from 'web-vitals';
|
2 |
+
|
3 |
+
const reportWebVitals = (onPerfEntry?: ReportHandler) => {
|
4 |
+
if (onPerfEntry && onPerfEntry instanceof Function) {
|
5 |
+
import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => {
|
6 |
+
getCLS(onPerfEntry);
|
7 |
+
getFID(onPerfEntry);
|
8 |
+
getFCP(onPerfEntry);
|
9 |
+
getLCP(onPerfEntry);
|
10 |
+
getTTFB(onPerfEntry);
|
11 |
+
});
|
12 |
+
}
|
13 |
+
};
|
14 |
+
|
15 |
+
export default reportWebVitals;
|
view/tsconfig.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compilerOptions": {
|
3 |
+
"target": "es5",
|
4 |
+
"lib": [
|
5 |
+
"dom",
|
6 |
+
"dom.iterable",
|
7 |
+
"esnext"
|
8 |
+
],
|
9 |
+
"allowJs": true,
|
10 |
+
"skipLibCheck": true,
|
11 |
+
"esModuleInterop": true,
|
12 |
+
"allowSyntheticDefaultImports": true,
|
13 |
+
"strict": true,
|
14 |
+
"forceConsistentCasingInFileNames": true,
|
15 |
+
"noFallthroughCasesInSwitch": true,
|
16 |
+
"module": "esnext",
|
17 |
+
"moduleResolution": "node",
|
18 |
+
"resolveJsonModule": true,
|
19 |
+
"isolatedModules": true,
|
20 |
+
"noEmit": true,
|
21 |
+
"jsx": "react-jsx"
|
22 |
+
},
|
23 |
+
"include": [
|
24 |
+
"src"
|
25 |
+
]
|
26 |
+
}
|