Spaces:
Runtime error
Runtime error
加註解1
Browse files進度到注意力機制
- zero123plus/pipeline.py +11 -6
zero123plus/pipeline.py
CHANGED
@@ -25,20 +25,25 @@ from diffusers import (
|
|
25 |
from diffusers.image_processor import VaeImageProcessor
|
26 |
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0
|
27 |
from diffusers.utils.import_utils import is_xformers_available
|
|
|
28 |
|
29 |
-
|
30 |
def to_rgb_image(maybe_rgba: Image.Image):
|
31 |
if maybe_rgba.mode == 'RGB':
|
32 |
return maybe_rgba
|
33 |
-
elif maybe_rgba.mode == 'RGBA':
|
34 |
rgba = maybe_rgba
|
|
|
35 |
img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
|
36 |
-
img = Image.fromarray(img, 'RGB')
|
37 |
img.paste(rgba, mask=rgba.getchannel('A'))
|
38 |
return img
|
39 |
else:
|
40 |
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
41 |
-
|
|
|
|
|
|
|
42 |
|
43 |
class ReferenceOnlyAttnProc(torch.nn.Module):
|
44 |
def __init__(
|
@@ -75,8 +80,8 @@ class ReferenceOnlyAttnProc(torch.nn.Module):
|
|
75 |
if self.enabled and is_cfg_guidance:
|
76 |
res = torch.cat([res0, res])
|
77 |
return res
|
78 |
-
|
79 |
-
|
80 |
class RefOnlyNoisedUNet(torch.nn.Module):
|
81 |
def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
|
82 |
super().__init__()
|
|
|
25 |
from diffusers.image_processor import VaeImageProcessor
|
26 |
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0
|
27 |
from diffusers.utils.import_utils import is_xformers_available
|
28 |
+
### AutoencoderKL, UNet2DConditionModel ,DDPMScheduler 這幾個用於生成模型的核心模塊,負責編碼.擴散過程和調度
|
29 |
|
30 |
+
### 接收一個 PIL 圖像物件,並將該圖像轉換為 RGB 格式
|
31 |
def to_rgb_image(maybe_rgba: Image.Image):
|
32 |
if maybe_rgba.mode == 'RGB':
|
33 |
return maybe_rgba
|
34 |
+
elif maybe_rgba.mode == 'RGBA': # A為透明度
|
35 |
rgba = maybe_rgba
|
36 |
+
## 創建一個隨機的 RGB 圖像,尺寸與原始 RGBA 圖像相同
|
37 |
img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
|
38 |
+
img = Image.fromarray(img, 'RGB') ##將這個 NumPy 陣列轉換為 PIL 的 RGB 圖像
|
39 |
img.paste(rgba, mask=rgba.getchannel('A'))
|
40 |
return img
|
41 |
else:
|
42 |
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
43 |
+
|
44 |
+
#### RGB相對RGBA來說與大多數顯示設備兼容,簡單且高效,更加簡單且資源友好
|
45 |
+
#### 並且是數位圖像處理和顯示的標準, 此專案中RGB已足夠
|
46 |
+
#### 無論是 JPEG、PNG 等圖片格式,還是 HTML 和 CSS 用於網頁設計的顏色表示,RGB 模式都是標準化的選擇。
|
47 |
|
48 |
class ReferenceOnlyAttnProc(torch.nn.Module):
|
49 |
def __init__(
|
|
|
80 |
if self.enabled and is_cfg_guidance:
|
81 |
res = torch.cat([res0, res])
|
82 |
return res
|
83 |
+
#### 一種靈活的注意力機制,它可以在訓練或推理過程中根據不同的模式("w"、"r"、"m") 進行操作。
|
84 |
+
#### 目的是讓模型對不同的輸入數據賦予不同的「權重」,從而突出重要的信息,忽略次要的細節。
|
85 |
class RefOnlyNoisedUNet(torch.nn.Module):
|
86 |
def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
|
87 |
super().__init__()
|