radames commited on
Commit
7dc6a72
1 Parent(s): 255bd44

Upload 23 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/example.gif filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # build frontend with node
2
+ FROM node:20-alpine AS frontend
3
+ RUN apk add --no-cache libc6-compat
4
+ WORKDIR /app
5
+
6
+ COPY view .
7
+ RUN npm ci
8
+ RUN npm run build
9
+
10
+ # build backend on CUDA
11
+ FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 AS backend
12
+ WORKDIR /app
13
+
14
+ ENV DEBIAN_FRONTEND=noninteractive
15
+ ENV NODE_MAJOR=20
16
+
17
+ RUN apt-get update && \
18
+ apt-get upgrade -y && \
19
+ apt-get install -y --no-install-recommends \
20
+ git \
21
+ git-lfs \
22
+ wget \
23
+ curl \
24
+ # python build dependencies \
25
+ build-essential \
26
+ libssl-dev \
27
+ zlib1g-dev \
28
+ libbz2-dev \
29
+ libreadline-dev \
30
+ libsqlite3-dev \
31
+ libncursesw5-dev \
32
+ xz-utils \
33
+ tk-dev \
34
+ libxml2-dev \
35
+ libxmlsec1-dev \
36
+ libffi-dev \
37
+ liblzma-dev && \
38
+ apt-get clean && \
39
+ rm -rf /var/lib/apt/lists/*
40
+
41
+ USER root
42
+
43
+ RUN useradd -m -u 1000 user
44
+ USER user
45
+ ENV HOME=/home/user \
46
+ PATH=/home/user/.local/bin:$PATH
47
+ WORKDIR $HOME/app
48
+
49
+ RUN curl https://pyenv.run | bash
50
+ ENV PATH=$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH
51
+ ARG PYTHON_VERSION=3.10.12
52
+ RUN pyenv install $PYTHON_VERSION && \
53
+ pyenv global $PYTHON_VERSION && \
54
+ pyenv rehash && \
55
+ pip install --no-cache-dir -U pip setuptools wheel
56
+
57
+ COPY --chown=user:user . .
58
+ # change dir since pip needs to seed whl folder
59
+ RUN cd server && pip install --no-cache-dir --upgrade -r requirements.txt
60
+ RUN --mount=type=secret,id=GITHUB_TOKEN,mode=0444,required=true \
61
+ pip install git+https://$(cat /run/secrets/GITHUB_TOKEN)@github.com/cumulo-autumn/StreamDiffusion.git@main#egg=stream-diffusion[tensorrt]
62
+ RUN python -m streamdiffusion.tools.install-tensorrt
63
+
64
+ COPY --from=frontend /app/build ./view/build
65
+
66
+ WORKDIR $HOME/app/server
67
+
68
+ USER user
69
+ CMD ["python", "main.py"]
README.md CHANGED
@@ -1,11 +1,37 @@
1
- ---
2
- title: StreamDiffusion Realtime Txt2img
3
- emoji: 👁
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: docker
7
- app_port: 9090
8
- pinned: false
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Txt2Img Example
2
+
3
+ <p align="center">
4
+ <img src="./assets/example.gif" width=80%>
5
+ </p>
6
+
7
+ This example provides a simple implementation of the use of StreamDiffusion to generate images from text.
8
+
9
+ ## Usage
10
+
11
+ ```bash
12
+ chmod +x ./start.sh && ./start.sh
13
+ ```
14
+
15
+ or
16
+
17
+ ```bash
18
+ cd server
19
+ python3 main.py &
20
+ cd ../view
21
+ npm start
22
+ ```
23
+
24
+ ## Docker
25
+
26
+ Build
27
+ `GITHUB_TOKEN` is temp until project is public
28
+ ```bash
29
+ docker build --secret id=GITHUB_TOKEN,src=./github_token.txt -t realtime-txt2img .
30
+ ```
31
+
32
+ Run
33
+ ```bash
34
+ docker run -ti -p 9090:9090 -e HF_HOME=/data -v ~/.cache/huggingface:/data --gpus all realtime-txt2img
35
+ ```
36
+
37
+ `-e HF_HOME=/data -v ~/.cache/huggingface:/data` is used to mount your local huggingface cache to the container, so that you don't need to download the model again.
assets/example.gif ADDED

Git LFS Details

  • SHA256: 72bb052c634ed52b72a50753b247a2c36b526d93d1500c0cd419be8393bf2546
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
server/config.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class Config:
9
+ """
10
+ The configuration for the API.
11
+ """
12
+
13
+ ####################################################################
14
+ # Server
15
+ ####################################################################
16
+ # In most cases, you should leave this as it is.
17
+ host: str = "0.0.0.0"
18
+ port: int = 9090
19
+ workers: int = 1
20
+
21
+ ####################################################################
22
+ # Generation configuration
23
+ ####################################################################
24
+ # The threshold for the Levenstein distance.
25
+ levenstein_distance_threshold: int = 3
26
+
27
+ ####################################################################
28
+ # Model configuration
29
+ ####################################################################
30
+ # SD1.x variant model
31
+ model_id: str = "SimianLuo/LCM_Dreamshaper_v7"
32
+
33
+ # LCM-LORA model
34
+ lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
35
+ # TinyVAE model
36
+ vae_id: str = "madebyollin/taesd"
37
+ # Device to use
38
+ device: torch.device = torch.device("cuda")
39
+ # Data type
40
+ dtype: torch.dtype = torch.float16
41
+
42
+ ####################################################################
43
+ # Inference configuration
44
+ ####################################################################
45
+ # Number of inference steps
46
+ t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
47
+ # Number of warmup steps
48
+ warmup: int = 10
server/main.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import logging
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+
7
+ import uvicorn
8
+ from config import Config
9
+ from fastapi import FastAPI
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.staticfiles import StaticFiles
12
+
13
+ from PIL import Image
14
+ from pydantic import BaseModel
15
+ from wrapper import StreamDiffusionWrapper
16
+
17
+
18
+ logger = logging.getLogger("uvicorn")
19
+ PROJECT_DIR = Path(__file__).parent.parent
20
+
21
+
22
+ class PredictInputModel(BaseModel):
23
+ """
24
+ The input model for the /predict endpoint.
25
+ """
26
+
27
+ prompt: str
28
+
29
+
30
+ class PredictResponseModel(BaseModel):
31
+ """
32
+ The response model for the /predict endpoint.
33
+ """
34
+
35
+ base64_images: list[str]
36
+
37
+
38
+ class UpdatePromptResponseModel(BaseModel):
39
+ """
40
+ The response model for the /update_prompt endpoint.
41
+ """
42
+
43
+ prompt: str
44
+
45
+
46
+ class Api:
47
+ def __init__(self, config: Config) -> None:
48
+ """
49
+ Initialize the API.
50
+
51
+ Parameters
52
+ ----------
53
+ config : Config
54
+ The configuration.
55
+ """
56
+ self.config = config
57
+ self.stream_diffusion = StreamDiffusionWrapper(
58
+ model_id=config.model_id,
59
+ lcm_lora_id=config.lcm_lora_id,
60
+ vae_id=config.vae_id,
61
+ device=config.device,
62
+ dtype=config.dtype,
63
+ t_index_list=config.t_index_list,
64
+ warmup=config.warmup,
65
+ )
66
+ self.app = FastAPI()
67
+ self.app.add_api_route(
68
+ "/api/predict",
69
+ self._predict,
70
+ methods=["POST"],
71
+ response_model=PredictResponseModel,
72
+ )
73
+ self.app.add_middleware(
74
+ CORSMiddleware,
75
+ allow_origins=["*"],
76
+ allow_credentials=True,
77
+ allow_methods=["*"],
78
+ allow_headers=["*"],
79
+ )
80
+ self.app.mount(
81
+ "/", StaticFiles(directory="../view/build", html=True), name="public"
82
+ )
83
+
84
+ self._predict_lock = asyncio.Lock()
85
+ self._update_prompt_lock = asyncio.Lock()
86
+
87
+ self.last_prompt: str = ""
88
+ self.last_images: list[str] = [""]
89
+
90
+ async def _predict(self, inp: PredictInputModel) -> PredictResponseModel:
91
+ """
92
+ Predict an image and return.
93
+
94
+ Parameters
95
+ ----------
96
+ inp : PredictInputModel
97
+ The input.
98
+
99
+ Returns
100
+ -------
101
+ PredictResponseModel
102
+ The prediction result.
103
+ """
104
+ async with self._predict_lock:
105
+ if (
106
+ self._calc_levenstein_distance(inp.prompt, self.last_prompt)
107
+ < self.config.levenstein_distance_threshold
108
+ ):
109
+ logger.info("Using cached images")
110
+ return PredictResponseModel(base64_images=self.last_images)
111
+ self.last_prompt = inp.prompt
112
+ self.last_images = [self._pil_to_base64(image) for image in self.stream_diffusion(inp.prompt)]
113
+ return PredictResponseModel(base64_images=self.last_images)
114
+
115
+ def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes:
116
+ """
117
+ Convert a PIL image to base64.
118
+
119
+ Parameters
120
+ ----------
121
+ image : Image.Image
122
+ The PIL image.
123
+
124
+ format : str
125
+ The image format, by default "JPEG".
126
+
127
+ Returns
128
+ -------
129
+ bytes
130
+ The base64 image.
131
+ """
132
+ buffered = BytesIO()
133
+ image.convert("RGB").save(buffered, format=format)
134
+ return base64.b64encode(buffered.getvalue()).decode("ascii")
135
+
136
+ def _base64_to_pil(self, base64_image: str) -> Image.Image:
137
+ """
138
+ Convert a base64 image to PIL.
139
+
140
+ Parameters
141
+ ----------
142
+ base64_image : str
143
+ The base64 image.
144
+
145
+ Returns
146
+ -------
147
+ Image.Image
148
+ The PIL image.
149
+ """
150
+ if "base64," in base64_image:
151
+ base64_image = base64_image.split("base64,")[1]
152
+ return Image.open(BytesIO(base64.b64decode(base64_image))).convert("RGB")
153
+
154
+ def _calc_levenstein_distance(self, a: str, b: str) -> int:
155
+ """
156
+ Calculate the Levenstein distance between two strings.
157
+
158
+ Parameters
159
+ ----------
160
+ a : str
161
+ The first string.
162
+
163
+ b : str
164
+ The second string.
165
+
166
+ Returns
167
+ -------
168
+ int
169
+ The Levenstein distance.
170
+ """
171
+ if a == b:
172
+ return 0
173
+ a_k = len(a)
174
+ b_k = len(b)
175
+ if a == "":
176
+ return b_k
177
+ if b == "":
178
+ return a_k
179
+ matrix = [[] for i in range(a_k + 1)]
180
+ for i in range(a_k + 1):
181
+ matrix[i] = [0 for j in range(b_k + 1)]
182
+ for i in range(a_k + 1):
183
+ matrix[i][0] = i
184
+ for j in range(b_k + 1):
185
+ matrix[0][j] = j
186
+ for i in range(1, a_k + 1):
187
+ ac = a[i - 1]
188
+ for j in range(1, b_k + 1):
189
+ bc = b[j - 1]
190
+ cost = 0 if (ac == bc) else 1
191
+ matrix[i][j] = min(
192
+ [
193
+ matrix[i - 1][j] + 1,
194
+ matrix[i][j - 1] + 1,
195
+ matrix[i - 1][j - 1] + cost,
196
+ ]
197
+ )
198
+ return matrix[a_k][b_k]
199
+
200
+
201
+ if __name__ == "__main__":
202
+ from config import Config
203
+
204
+ config = Config()
205
+
206
+ uvicorn.run(
207
+ Api(config).app,
208
+ host=config.host,
209
+ port=config.port,
210
+ workers=config.workers,
211
+ )
server/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ xformers
2
+ uvicorn[standard]==0.24.0.post1
3
+ fastapi==0.104
4
+ accelerate
5
+ # git+https://github.com/cumulo-autumn/StreamDiffusion.git@main#egg=stream-diffusion
6
+ --extra-index-url https://download.pytorch.org/whl/cu121
7
+ torch
8
+ torchvision
9
+ torchaudio
10
+ triton
11
+ # https://github.com/chengzeyi/stable-fast --index-url https://download.pytorch.org/whl/cu121
12
+ https://github.com/chengzeyi/stable-fast/releases/download/v0.0.14/stable_fast-0.0.14+torch210cu121-cp310-cp310-manylinux2014_x86_64.whl
server/wrapper.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from typing import List
4
+
5
+ import PIL.Image
6
+ import requests
7
+ import torch
8
+ from diffusers import AutoencoderTiny, StableDiffusionPipeline
9
+
10
+ from streamdiffusion import StreamDiffusion
11
+ from streamdiffusion.acceleration.sfast import accelerate_with_stable_fast
12
+ from streamdiffusion.image_utils import postprocess_image
13
+
14
+
15
+ def download_image(url: str):
16
+ response = requests.get(url)
17
+ image = PIL.Image.open(io.BytesIO(response.content))
18
+ return image
19
+
20
+
21
+ class StreamDiffusionWrapper:
22
+ def __init__(
23
+ self,
24
+ model_id: str,
25
+ lcm_lora_id: str,
26
+ vae_id: str,
27
+ device: str,
28
+ dtype: str,
29
+ t_index_list: List[int],
30
+ warmup: int,
31
+ ):
32
+ self.device = device
33
+ self.dtype = dtype
34
+ self.prompt = ""
35
+
36
+ self.stream = self._load_model(
37
+ model_id=model_id,
38
+ lcm_lora_id=lcm_lora_id,
39
+ vae_id=vae_id,
40
+ t_index_list=t_index_list,
41
+ warmup=warmup,
42
+ )
43
+
44
+ def _load_model(
45
+ self,
46
+ model_id: str,
47
+ lcm_lora_id: str,
48
+ vae_id: str,
49
+ t_index_list: List[int],
50
+ warmup: int,
51
+ ):
52
+ if os.path.exists(model_id):
53
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(model_id).to(
54
+ device=self.device, dtype=self.dtype
55
+ )
56
+ else:
57
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(model_id).to(
58
+ device=self.device, dtype=self.dtype
59
+ )
60
+
61
+ stream = StreamDiffusion(
62
+ pipe=pipe,
63
+ t_index_list=t_index_list,
64
+ torch_dtype=self.dtype,
65
+ is_drawing=True,
66
+ )
67
+ stream.load_lcm_lora(lcm_lora_id)
68
+ stream.fuse_lora()
69
+ stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(device=pipe.device, dtype=pipe.dtype)
70
+ stream = accelerate_with_stable_fast(stream)
71
+
72
+ stream.prepare(
73
+ "",
74
+ num_inference_steps=50,
75
+ generator=torch.manual_seed(2),
76
+ )
77
+
78
+ # warmup
79
+ for _ in range(warmup):
80
+ start = torch.cuda.Event(enable_timing=True)
81
+ end = torch.cuda.Event(enable_timing=True)
82
+
83
+ start.record()
84
+ stream.txt2img()
85
+ end.record()
86
+
87
+ torch.cuda.synchronize()
88
+
89
+ return stream
90
+
91
+ def __call__(self, prompt: str) -> List[PIL.Image.Image]:
92
+ self.stream.prepare("")
93
+
94
+ images = []
95
+ for i in range(9 + 3):
96
+ start = torch.cuda.Event(enable_timing=True)
97
+ end = torch.cuda.Event(enable_timing=True)
98
+
99
+ start.record()
100
+
101
+ if self.prompt != prompt:
102
+ self.stream.update_prompt(prompt)
103
+ self.prompt = prompt
104
+
105
+ x_output = self.stream.txt2img()
106
+ if i >= 3:
107
+ images.append(postprocess_image(x_output, output_type="pil")[0])
108
+ end.record()
109
+
110
+ torch.cuda.synchronize()
111
+
112
+ return images
113
+
114
+
115
+ if __name__ == "__main__":
116
+ wrapper = StreamDiffusionWrapper(10, 10)
117
+ wrapper()
start.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cd view && npm run build && cd ..
2
+ cd server && python3 main.py
view/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Getting Started with Create React App
2
+
3
+ This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app).
4
+
5
+ ## Available Scripts
6
+
7
+ In the project directory, you can run:
8
+
9
+ ### `npm start`
10
+
11
+ Runs the app in the development mode.\
12
+ Open [http://localhost:3000](http://localhost:3000) to view it in the browser.
13
+
14
+ The page will reload if you make edits.\
15
+ You will also see any lint errors in the console.
16
+
17
+ ### `npm test`
18
+
19
+ Launches the test runner in the interactive watch mode.\
20
+ See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information.
21
+
22
+ ### `npm run build`
23
+
24
+ Builds the app for production to the `build` folder.\
25
+ It correctly bundles React in production mode and optimizes the build for the best performance.
26
+
27
+ The build is minified and the filenames include the hashes.\
28
+ Your app is ready to be deployed!
29
+
30
+ See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information.
31
+
32
+ ### `npm run eject`
33
+
34
+ **Note: this is a one-way operation. Once you `eject`, you can’t go back!**
35
+
36
+ If you aren’t satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project.
37
+
38
+ Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you’re on your own.
39
+
40
+ You don’t have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn’t feel obligated to use this feature. However we understand that this tool wouldn’t be useful if you couldn’t customize it when you are ready for it.
41
+
42
+ ## Learn More
43
+
44
+ You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started).
45
+
46
+ To learn React, check out the [React documentation](https://reactjs.org/).
view/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
view/package.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "view",
3
+ "version": "0.1.0",
4
+ "private": true,
5
+ "dependencies": {
6
+ "@emotion/react": "^11.11.1",
7
+ "@emotion/styled": "^11.11.0",
8
+ "@mui/material": "^5.14.18",
9
+ "@testing-library/jest-dom": "^5.17.0",
10
+ "@testing-library/react": "^13.4.0",
11
+ "@testing-library/user-event": "^13.5.0",
12
+ "@types/jest": "^27.5.2",
13
+ "@types/node": "^16.18.64",
14
+ "@types/react": "^18.2.38",
15
+ "@types/react-dom": "^18.2.17",
16
+ "react": "^18.2.0",
17
+ "react-dom": "^18.2.0",
18
+ "react-scripts": "5.0.1",
19
+ "typescript": "^4.9.5",
20
+ "web-vitals": "^2.1.4"
21
+ },
22
+ "scripts": {
23
+ "start": "react-scripts start",
24
+ "build": "react-scripts build",
25
+ "test": "react-scripts test",
26
+ "eject": "react-scripts eject"
27
+ },
28
+ "eslintConfig": {
29
+ "extends": [
30
+ "react-app",
31
+ "react-app/jest"
32
+ ]
33
+ },
34
+ "browserslist": {
35
+ "production": [
36
+ ">0.2%",
37
+ "not dead",
38
+ "not op_mini all"
39
+ ],
40
+ "development": [
41
+ "last 1 chrome version",
42
+ "last 1 firefox version",
43
+ "last 1 safari version"
44
+ ]
45
+ },
46
+ "proxy": "http://localhost:9090"
47
+ }
view/pnpm-lock.yaml ADDED
The diff for this file is too large to render. See raw diff
 
view/public/favicon.ico ADDED
view/public/images/white.jpg ADDED
view/public/index.html ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
6
+ <link rel="preconnect" href="https://fonts.googleapis.com">
7
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
8
+ <link href="https://fonts.googleapis.com/css2?family=Kanit:wght@500&display=swap" rel="stylesheet">
9
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
10
+ <meta name="theme-color" content="#000000" />
11
+ <meta
12
+ name="description"
13
+ content="Web site created using create-react-app"
14
+ />
15
+ <link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
16
+ <!--
17
+ manifest.json provides metadata used when your web app is installed on a
18
+ user's mobile device or desktop. See https://developers.google.com/web/fundamentals/web-app-manifest/
19
+ -->
20
+ <link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
21
+ <!--
22
+ Notice the use of %PUBLIC_URL% in the tags above.
23
+ It will be replaced with the URL of the `public` folder during the build.
24
+ Only files inside the `public` folder can be referenced from the HTML.
25
+
26
+ Unlike "/favicon.ico" or "favicon.ico", "%PUBLIC_URL%/favicon.ico" will
27
+ work correctly both with client-side routing and a non-root public URL.
28
+ Learn how to configure a non-root public URL by running `npm run build`.
29
+ -->
30
+ <title>React App</title>
31
+ </head>
32
+ <body>
33
+ <noscript>You need to enable JavaScript to run this app.</noscript>
34
+ <div id="root"></div>
35
+ <!--
36
+ This HTML file is a template.
37
+ If you open it directly in the browser, you will see an empty page.
38
+
39
+ You can add webfonts, meta tags, or analytics to this file.
40
+ The build step will place the bundled scripts into the <body> tag.
41
+
42
+ To begin the development, run `npm start` or `yarn start`.
43
+ To create a production bundle, use `npm run build` or `yarn build`.
44
+ -->
45
+ </body>
46
+ </html>
view/public/manifest.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "short_name": "React App",
3
+ "name": "Create React App Sample",
4
+ "icons": [
5
+ {
6
+ "src": "favicon.ico",
7
+ "sizes": "128x128 64x64 32x32 24x24 16x16",
8
+ "type": "image/x-icon"
9
+ },
10
+ {
11
+ "src": "logo192.png",
12
+ "type": "image/png",
13
+ "sizes": "192x192"
14
+ },
15
+ {
16
+ "src": "logo512.png",
17
+ "type": "image/png",
18
+ "sizes": "512x512"
19
+ }
20
+ ],
21
+ "start_url": ".",
22
+ "display": "standalone",
23
+ "theme_color": "#000000",
24
+ "background_color": "#ffffff"
25
+ }
view/public/robots.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # https://www.robotstxt.org/robotstxt.html
2
+ User-agent: *
3
+ Disallow:
view/src/App.tsx ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useCallback, useEffect, useState } from "react";
2
+ import { TextField, Grid, Paper } from "@mui/material";
3
+
4
+ function App() {
5
+ const [inputPrompt, setInputPrompt] = useState("");
6
+ const [images, setImages] = useState(Array(9).fill("images/white.jpg"));
7
+
8
+ const fetchImages = useCallback(async () => {
9
+ try {
10
+ const response = await fetch("/api/predict", {
11
+ method: "POST",
12
+ headers: { 'Content-Type': 'application/json' },
13
+ body: JSON.stringify({ prompt: inputPrompt })
14
+ });
15
+ const data = await response.json();
16
+ const imageUrls = data.base64_images.map((base64: string) => `data:image/jpeg;base64,${base64}`);
17
+ setImages(imageUrls);
18
+ } catch (error) {
19
+ console.error("Error fetching images:", error);
20
+ }
21
+ }, [inputPrompt]);
22
+
23
+ const handlePromptChange = (event: React.ChangeEvent<HTMLInputElement>) => {
24
+ setInputPrompt(event.target.value);
25
+ fetchImages();
26
+ };
27
+
28
+ return (
29
+ <div
30
+ className="App"
31
+ style={{
32
+ backgroundColor: "#282c34",
33
+ height: "100vh",
34
+ display: "flex",
35
+ alignItems: "center",
36
+ justifyContent: "center",
37
+ margin: "0",
38
+ color: "#ffffff",
39
+ padding: "20px",
40
+ }}
41
+ >
42
+ <div
43
+ style={{
44
+ backgroundColor: "#282c34",
45
+ alignItems: "center",
46
+ justifyContent: "center",
47
+ display: "flex",
48
+ flexDirection: "column",
49
+ }}
50
+ >
51
+ <Grid container spacing={2}>
52
+ {images.map((image, index) => (
53
+ <Grid item xs={4} key={index}>
54
+ <Paper style={{ padding: "10px", textAlign: "center" }}>
55
+ <img src={image} alt={`Generated ${index}`} style={{ maxWidth: "100%", maxHeight: "200px", borderRadius: "10px" }} />
56
+ </Paper>
57
+ </Grid>
58
+ ))}
59
+ </Grid>
60
+ <TextField
61
+ variant="outlined"
62
+ value={inputPrompt}
63
+ onChange={handlePromptChange}
64
+ style={{ marginBottom: "20px", marginTop: "20px", width: "640px", color: "#ffffff", borderColor: "#ffffff", borderRadius: "10px", backgroundColor: "#ffffff" }}
65
+ placeholder="Enter a prompt"
66
+ />
67
+ </div>
68
+ </div>
69
+ );
70
+ }
71
+
72
+ export default App;
view/src/index.css ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ @import url('https://fonts.googleapis.com/css?family=Kanit:400,700&display=swap');
2
+
3
+ body {
4
+ font-family: 'Kanit', 'Ubuntu', sans-serif;
5
+ margin: 0;
6
+ -webkit-font-smoothing: antialiased;
7
+ -moz-osx-font-smoothing: grayscale;
8
+ }
view/src/index.tsx ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from "react";
2
+ import ReactDOM from "react-dom/client";
3
+ import "./index.css";
4
+ import App from "./App";
5
+ import reportWebVitals from "./reportWebVitals";
6
+
7
+ const root = ReactDOM.createRoot(
8
+ document.getElementById("root") as HTMLElement
9
+ );
10
+ root.render(
11
+ <React.StrictMode>
12
+ <App />
13
+ </React.StrictMode>
14
+ );
15
+
16
+ // If you want to start measuring performance in your app, pass a function
17
+ // to log results (for example: reportWebVitals(console.log))
18
+ // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals
19
+ reportWebVitals();
view/src/react-app-env.d.ts ADDED
@@ -0,0 +1 @@
 
 
1
+ /// <reference types="react-scripts" />
view/src/reportWebVitals.ts ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { ReportHandler } from 'web-vitals';
2
+
3
+ const reportWebVitals = (onPerfEntry?: ReportHandler) => {
4
+ if (onPerfEntry && onPerfEntry instanceof Function) {
5
+ import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => {
6
+ getCLS(onPerfEntry);
7
+ getFID(onPerfEntry);
8
+ getFCP(onPerfEntry);
9
+ getLCP(onPerfEntry);
10
+ getTTFB(onPerfEntry);
11
+ });
12
+ }
13
+ };
14
+
15
+ export default reportWebVitals;
view/tsconfig.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "target": "es5",
4
+ "lib": [
5
+ "dom",
6
+ "dom.iterable",
7
+ "esnext"
8
+ ],
9
+ "allowJs": true,
10
+ "skipLibCheck": true,
11
+ "esModuleInterop": true,
12
+ "allowSyntheticDefaultImports": true,
13
+ "strict": true,
14
+ "forceConsistentCasingInFileNames": true,
15
+ "noFallthroughCasesInSwitch": true,
16
+ "module": "esnext",
17
+ "moduleResolution": "node",
18
+ "resolveJsonModule": true,
19
+ "isolatedModules": true,
20
+ "noEmit": true,
21
+ "jsx": "react-jsx"
22
+ },
23
+ "include": [
24
+ "src"
25
+ ]
26
+ }