Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,771 Bytes
37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature
import json
from dataclasses import dataclass
from typing import List, Optional
from custum_3d_diffusion.modules import register
from custum_3d_diffusion.trainings.base import BasicTrainer
from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
def get_HW(resolution):
if isinstance(resolution, str):
resolution = json.loads(resolution)
if isinstance(resolution, int):
H = W = resolution
elif isinstance(resolution, list):
H, W = resolution
return H, W
@register("image2mvimage_trainer")
class Image2MVImageTrainer(BasicTrainer):
"""
Trainer for simple image to multiview images.
"""
@dataclass
class TrainerConfig(BasicTrainer.TrainerConfig):
trainer_name: str = "image2mvimage"
condition_image_column_name: str = "conditioning_image"
image_column_name: str = "image"
condition_dropout: float = 0.
condition_image_resolution: str = "512"
validation_images: Optional[List[str]] = None
noise_offset: float = 0.1
max_loss_drop: float = 0.
snr_gamma: float = 5.0
log_distribution: bool = False
latents_offset: Optional[List[float]] = None
input_perturbation: float = 0.
noisy_condition_input: bool = False # whether to add noise for ref unet input
normal_cls_offset: int = 0
condition_offset: bool = True
zero_snr: bool = False
linear_beta_schedule: bool = False
cfg: TrainerConfig
def configure(self) -> None:
return super().configure()
def init_shared_modules(self, shared_modules: dict) -> dict:
if 'vae' not in shared_modules:
vae = AutoencoderKL.from_pretrained(
self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
)
vae.requires_grad_(False)
vae.to(self.accelerator.device, dtype=self.weight_dtype)
shared_modules['vae'] = vae
if 'image_encoder' not in shared_modules:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
)
image_encoder.requires_grad_(False)
image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
shared_modules['image_encoder'] = image_encoder
if 'feature_extractor' not in shared_modules:
feature_extractor = CLIPImageProcessor.from_pretrained(
self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
)
shared_modules['feature_extractor'] = feature_extractor
return shared_modules
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
raise NotImplementedError()
def loss_rescale(self, loss, timesteps=None):
raise NotImplementedError()
def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
raise NotImplementedError()
def construct_pipeline(self, shared_modules, unet, old_version=False):
MyPipeline = StableDiffusionImage2MVCustomPipeline
pipeline = MyPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path,
vae=shared_modules['vae'],
image_encoder=shared_modules['image_encoder'],
feature_extractor=shared_modules['feature_extractor'],
unet=unet,
safety_checker=None,
torch_dtype=self.weight_dtype,
latents_offset=self.cfg.latents_offset,
noisy_cond_latents=self.cfg.noisy_condition_input,
condition_offset=self.cfg.condition_offset,
)
pipeline.set_progress_bar_config(disable=True)
scheduler_dict = {}
if self.cfg.zero_snr:
scheduler_dict.update(rescale_betas_zero_snr=True)
if self.cfg.linear_beta_schedule:
scheduler_dict.update(beta_schedule='linear')
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
return pipeline
def get_forward_args(self):
if self.cfg.seed is None:
generator = None
else:
generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
H, W = get_HW(self.cfg.resolution)
H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
sub_img_H = H // 2
num_imgs = H // sub_img_H * W // sub_img_H
forward_args = dict(
num_images_per_prompt=num_imgs,
num_inference_steps=50,
height=sub_img_H,
width=sub_img_H,
height_cond=H_cond,
width_cond=W_cond,
generator=generator,
)
if self.cfg.zero_snr:
forward_args.update(guidance_rescale=0.7)
return forward_args
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
forward_args = self.get_forward_args()
forward_args.update(pipeline_call_kwargs)
return pipeline(**forward_args)
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
raise NotImplementedError() |