File size: 5,771 Bytes
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature

import json
from dataclasses import dataclass
from typing import List, Optional

from custum_3d_diffusion.modules import register
from custum_3d_diffusion.trainings.base import BasicTrainer
from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

def get_HW(resolution):
    if isinstance(resolution, str):
        resolution = json.loads(resolution)
    if isinstance(resolution, int):
        H = W = resolution
    elif isinstance(resolution, list):
        H, W = resolution
    return H, W

@register("image2mvimage_trainer")
class Image2MVImageTrainer(BasicTrainer):
    """
    Trainer for simple image to multiview images.
    """
    @dataclass
    class TrainerConfig(BasicTrainer.TrainerConfig):
        trainer_name: str = "image2mvimage"
        condition_image_column_name: str = "conditioning_image"
        image_column_name: str = "image"
        condition_dropout: float = 0.
        condition_image_resolution: str = "512"
        validation_images: Optional[List[str]] = None
        noise_offset: float = 0.1                           
        max_loss_drop: float = 0.                           
        snr_gamma: float = 5.0                              
        log_distribution: bool = False
        latents_offset: Optional[List[float]] = None
        input_perturbation: float = 0.
        noisy_condition_input: bool = False                 # whether to add noise for ref unet input
        normal_cls_offset: int = 0
        condition_offset: bool = True
        zero_snr: bool = False
        linear_beta_schedule: bool = False

    cfg: TrainerConfig

    def configure(self) -> None:
        return super().configure()

    def init_shared_modules(self, shared_modules: dict) -> dict:
        if 'vae' not in shared_modules:
            vae = AutoencoderKL.from_pretrained(
                self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
            )
            vae.requires_grad_(False)
            vae.to(self.accelerator.device, dtype=self.weight_dtype)
            shared_modules['vae'] = vae
        if 'image_encoder' not in shared_modules:
            image_encoder = CLIPVisionModelWithProjection.from_pretrained(
                self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
            )
            image_encoder.requires_grad_(False)
            image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
            shared_modules['image_encoder'] = image_encoder
        if 'feature_extractor' not in shared_modules:
            feature_extractor = CLIPImageProcessor.from_pretrained(
                self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
            )
            shared_modules['feature_extractor'] = feature_extractor
        return shared_modules

    def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
        raise NotImplementedError()

    def loss_rescale(self, loss, timesteps=None):
        raise NotImplementedError()

    def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
        raise NotImplementedError()
    
    def construct_pipeline(self, shared_modules, unet, old_version=False):
        MyPipeline = StableDiffusionImage2MVCustomPipeline
        pipeline = MyPipeline.from_pretrained(
            self.cfg.pretrained_model_name_or_path,
            vae=shared_modules['vae'],
            image_encoder=shared_modules['image_encoder'],
            feature_extractor=shared_modules['feature_extractor'],
            unet=unet,
            safety_checker=None,
            torch_dtype=self.weight_dtype,
            latents_offset=self.cfg.latents_offset,
            noisy_cond_latents=self.cfg.noisy_condition_input,
            condition_offset=self.cfg.condition_offset,
        )
        pipeline.set_progress_bar_config(disable=True)
        scheduler_dict = {}
        if self.cfg.zero_snr:
            scheduler_dict.update(rescale_betas_zero_snr=True)
        if self.cfg.linear_beta_schedule:
            scheduler_dict.update(beta_schedule='linear')
        
        pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
        return pipeline

    def get_forward_args(self):
        if self.cfg.seed is None:
            generator = None
        else:
            generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
        
        H, W = get_HW(self.cfg.resolution)
        H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)

        sub_img_H = H // 2
        num_imgs = H // sub_img_H * W // sub_img_H

        forward_args = dict(
            num_images_per_prompt=num_imgs,
            num_inference_steps=50,
            height=sub_img_H,
            width=sub_img_H,
            height_cond=H_cond,
            width_cond=W_cond,
            generator=generator,
        )
        if self.cfg.zero_snr:
            forward_args.update(guidance_rescale=0.7)
        return forward_args

    def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
        forward_args = self.get_forward_args()
        forward_args.update(pipeline_call_kwargs)
        return pipeline(**forward_args)

    def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
        raise NotImplementedError()