File size: 1,366 Bytes
8cd00a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from diffusers import DDPMScheduler
import torch

class HookedNoiseScheduler:
    scheduler: DDPMScheduler
    pre_hooks: list
    post_hooks: list

    def __init__(self, scheduler):
        object.__setattr__(self, 'scheduler', scheduler)
        object.__setattr__(self, 'pre_hooks', [])
        object.__setattr__(self, 'post_hooks', [])
    
    def step(
        self,
        model_output, timestep, sample, generator, return_dict
    ):
        assert return_dict == False, "return_dict == True is not implemented"
        for hook in self.pre_hooks:
            hook_output = hook(model_output, timestep, sample, generator)
            if hook_output is not None:
                model_output, timestep, sample, generator = hook_output

        (pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict)
        
        for hook in self.post_hooks:
            hook_output = hook(pred_prev_sample)
            if hook_output is not None:
                pred_prev_sample = hook_output

        return (pred_prev_sample, ) 

    def __getattr__(self, name):
        return getattr(self.scheduler, name)

    def __setattr__(self, name, value):
        if name in {'scheduler', 'pre_hooks', 'post_hooks'}:
            object.__setattr__(self, name, value)
        else:
            setattr(self.scheduler, name, value)