ZZZXIANG commited on
Commit
904e2d4
1 Parent(s): 486f0e7

加註解1

Browse files

進度到注意力機制

Files changed (1) hide show
  1. zero123plus/pipeline.py +11 -6
zero123plus/pipeline.py CHANGED
@@ -25,20 +25,25 @@ from diffusers import (
25
  from diffusers.image_processor import VaeImageProcessor
26
  from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0
27
  from diffusers.utils.import_utils import is_xformers_available
 
28
 
29
-
30
  def to_rgb_image(maybe_rgba: Image.Image):
31
  if maybe_rgba.mode == 'RGB':
32
  return maybe_rgba
33
- elif maybe_rgba.mode == 'RGBA':
34
  rgba = maybe_rgba
 
35
  img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
36
- img = Image.fromarray(img, 'RGB')
37
  img.paste(rgba, mask=rgba.getchannel('A'))
38
  return img
39
  else:
40
  raise ValueError("Unsupported image type.", maybe_rgba.mode)
41
-
 
 
 
42
 
43
  class ReferenceOnlyAttnProc(torch.nn.Module):
44
  def __init__(
@@ -75,8 +80,8 @@ class ReferenceOnlyAttnProc(torch.nn.Module):
75
  if self.enabled and is_cfg_guidance:
76
  res = torch.cat([res0, res])
77
  return res
78
-
79
-
80
  class RefOnlyNoisedUNet(torch.nn.Module):
81
  def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
82
  super().__init__()
 
25
  from diffusers.image_processor import VaeImageProcessor
26
  from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0
27
  from diffusers.utils.import_utils import is_xformers_available
28
+ ### AutoencoderKL, UNet2DConditionModel ,DDPMScheduler 這幾個用於生成模型的核心模塊,負責編碼.擴散過程和調度
29
 
30
+ ### 接收一個 PIL 圖像物件,並將該圖像轉換為 RGB 格式
31
  def to_rgb_image(maybe_rgba: Image.Image):
32
  if maybe_rgba.mode == 'RGB':
33
  return maybe_rgba
34
+ elif maybe_rgba.mode == 'RGBA': # A為透明度
35
  rgba = maybe_rgba
36
+ ## 創建一個隨機的 RGB 圖像,尺寸與原始 RGBA 圖像相同
37
  img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
38
+ img = Image.fromarray(img, 'RGB') ##將這個 NumPy 陣列轉換為 PIL 的 RGB 圖像
39
  img.paste(rgba, mask=rgba.getchannel('A'))
40
  return img
41
  else:
42
  raise ValueError("Unsupported image type.", maybe_rgba.mode)
43
+
44
+ #### RGB相對RGBA來說與大多數顯示設備兼容,簡單且高效,更加簡單且資源友好
45
+ #### 並且是數位圖像處理和顯示的標準, 此專案中RGB已足夠
46
+ #### 無論是 JPEG、PNG 等圖片格式,還是 HTML 和 CSS 用於網頁設計的顏色表示,RGB 模式都是標準化的選擇。
47
 
48
  class ReferenceOnlyAttnProc(torch.nn.Module):
49
  def __init__(
 
80
  if self.enabled and is_cfg_guidance:
81
  res = torch.cat([res0, res])
82
  return res
83
+ #### 一種靈活的注意力機制,它可以在訓練或推理過程中根據不同的模式("w"、"r"、"m") 進行操作。
84
+ #### 目的是讓模型對不同的輸入數據賦予不同的「權重」,從而突出重要的信息,忽略次要的細節。
85
  class RefOnlyNoisedUNet(torch.nn.Module):
86
  def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
87
  super().__init__()