Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
•
d82dd30
1
Parent(s):
da5331f
update
Browse files- app.py +9 -5
- evosdxl_jp_v1.py +3 -1
- requirements.txt +2 -2
app.py
CHANGED
@@ -6,10 +6,11 @@ import uuid
|
|
6 |
|
7 |
import gradio as gr
|
8 |
import numpy as np
|
9 |
-
import spaces
|
10 |
import torch
|
11 |
from PIL import Image
|
12 |
from evosdxl_jp_v1 import load_evosdxl_jp
|
|
|
13 |
|
14 |
DESCRIPTION = """# 🐟 EvoSDXL-JP
|
15 |
🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
|
@@ -23,12 +24,14 @@ if not torch.cuda.is_available():
|
|
23 |
MAX_SEED = np.iinfo(np.int32).max
|
24 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
25 |
|
26 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
27 |
|
28 |
NUM_IMAGES_PER_PROMPT = 1
|
29 |
ENABLE_CPU_OFFLOAD = False
|
30 |
USE_TORCH_COMPILE = False
|
31 |
-
SAFETY_CHECKER = True
|
|
|
32 |
DEVELOP_MODE = True
|
33 |
if SAFETY_CHECKER:
|
34 |
from safety_checker import StableDiffusionSafetyChecker
|
@@ -53,7 +56,8 @@ if SAFETY_CHECKER:
|
|
53 |
return images, has_nsfw_concepts
|
54 |
|
55 |
|
56 |
-
pipe = load_evosdxl_jp("cpu").to("cuda")
|
|
|
57 |
|
58 |
def show_warning(warning_text: str) -> gr.Blocks:
|
59 |
with gr.Blocks() as demo:
|
@@ -154,4 +158,4 @@ with gr.Blocks(css=css) as demo:
|
|
154 |
Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
|
155 |
利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。""")
|
156 |
|
157 |
-
demo.queue().launch()
|
|
|
6 |
|
7 |
import gradio as gr
|
8 |
import numpy as np
|
9 |
+
#import spaces
|
10 |
import torch
|
11 |
from PIL import Image
|
12 |
from evosdxl_jp_v1 import load_evosdxl_jp
|
13 |
+
import devicetorch
|
14 |
|
15 |
DESCRIPTION = """# 🐟 EvoSDXL-JP
|
16 |
🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
|
|
|
24 |
MAX_SEED = np.iinfo(np.int32).max
|
25 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
26 |
|
27 |
+
#device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
device = devicetorch.get(torch)
|
29 |
|
30 |
NUM_IMAGES_PER_PROMPT = 1
|
31 |
ENABLE_CPU_OFFLOAD = False
|
32 |
USE_TORCH_COMPILE = False
|
33 |
+
#SAFETY_CHECKER = True
|
34 |
+
SAFETY_CHECKER = False
|
35 |
DEVELOP_MODE = True
|
36 |
if SAFETY_CHECKER:
|
37 |
from safety_checker import StableDiffusionSafetyChecker
|
|
|
56 |
return images, has_nsfw_concepts
|
57 |
|
58 |
|
59 |
+
#pipe = load_evosdxl_jp("cpu").to("cuda")
|
60 |
+
pipe = load_evosdxl_jp("cpu").to(device)
|
61 |
|
62 |
def show_warning(warning_text: str) -> gr.Blocks:
|
63 |
with gr.Blocks() as demo:
|
|
|
158 |
Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
|
159 |
利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。""")
|
160 |
|
161 |
+
demo.queue().launch()
|
evosdxl_jp_v1.py
CHANGED
@@ -11,6 +11,7 @@ from diffusers import (
|
|
11 |
EulerDiscreteScheduler,
|
12 |
)
|
13 |
from diffusers.loaders import LoraLoaderMixin
|
|
|
14 |
|
15 |
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
|
16 |
JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
|
@@ -137,7 +138,8 @@ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
|
|
137 |
],
|
138 |
)
|
139 |
del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
|
140 |
-
|
|
|
141 |
unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
|
142 |
unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
|
143 |
unet.load_state_dict({**new_conv, **new_attn})
|
|
|
11 |
EulerDiscreteScheduler,
|
12 |
)
|
13 |
from diffusers.loaders import LoraLoaderMixin
|
14 |
+
import devicetorch
|
15 |
|
16 |
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
|
17 |
JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
|
|
|
138 |
],
|
139 |
)
|
140 |
del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
|
141 |
+
devicetorch.empty_cache(torch)
|
142 |
+
#torch.cuda.empty_cache()
|
143 |
unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
|
144 |
unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
|
145 |
unet.load_state_dict({**new_conv, **new_attn})
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
torch
|
2 |
diffusers==0.26.0
|
3 |
transformers
|
4 |
safetensors
|
5 |
accelerate
|
6 |
-
sentencepiece
|
|
|
1 |
+
#torch
|
2 |
diffusers==0.26.0
|
3 |
transformers
|
4 |
safetensors
|
5 |
accelerate
|
6 |
+
sentencepiece
|