Spaces:
Running
Running
File size: 8,396 Bytes
07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 a28e78a 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 ab7be96 07c6a04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
from videosys.utils.logging import logger
PAB_MANAGER = None
class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
spatial_broadcast: bool = False,
spatial_threshold: list = None,
spatial_range: int = None,
temporal_broadcast: bool = False,
temporal_threshold: list = None,
temporal_range: int = None,
mlp_broadcast: bool = False,
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
self.cross_range = cross_range
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
self.spatial_range = spatial_range
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
self.temporal_range = temporal_range
self.mlp_broadcast = mlp_broadcast
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
self.mlp_temporal_outputs = {}
self.mlp_spatial_outputs = {}
class PABManager:
def __init__(self, config: PABConfig):
self.config: PABConfig = config
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
logger.info(init_prompt)
def if_broadcast_cross(self, timestep: int, count: int):
if (
self.config.cross_broadcast
and (timestep is not None)
and (count % self.config.cross_range != 0)
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_temporal(self, timestep: int, count: int):
if (
self.config.temporal_broadcast
and (timestep is not None)
and (count % self.config.temporal_range != 0)
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
if (
self.config.spatial_broadcast
and (timestep is not None)
and (count % self.config.spatial_range != 0)
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
@staticmethod
def _is_t_in_skip_config(all_timesteps, timestep, config):
is_t_in_skip_config = False
skip_range = None
for key in config:
if key not in all_timesteps:
continue
index = all_timesteps.index(key)
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
if timestep in skip_range:
is_t_in_skip_config = True
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
break
return is_t_in_skip_config, skip_range
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not self.config.mlp_broadcast:
return False, None, False, None
if is_temporal:
cur_config = self.config.mlp_temporal_broadcast_config
else:
cur_config = self.config.mlp_spatial_broadcast_config
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
next_flag = False
if (
self.config.mlp_broadcast
and (timestep is not None)
and (timestep in cur_config)
and (block_idx in cur_config[timestep]["block"])
):
flag = False
next_flag = True
count = count + 1
elif (
self.config.mlp_broadcast
and (timestep is not None)
and (is_t_in_skip_config)
and (block_idx in cur_config[skip_range[0]]["block"])
):
flag = True
count = 0
else:
flag = False
return flag, count, next_flag, skip_range
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
if is_temporal:
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
else:
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
skip_start_t = skip_range[0]
if is_temporal:
skip_output = (
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
if self.config.mlp_temporal_outputs is not None
else None
)
else:
skip_output = (
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
if self.config.mlp_spatial_outputs is not None
else None
)
if skip_output is not None:
if timestep == skip_range[-1]:
# TODO: save memory
if is_temporal:
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
else:
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
else:
raise ValueError(
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
)
return skip_output
def get_spatial_mlp_outputs(self):
return self.config.mlp_spatial_outputs
def get_temporal_mlp_outputs(self):
return self.config.mlp_temporal_outputs
def set_pab_manager(config: PABConfig):
global PAB_MANAGER
PAB_MANAGER = PABManager(config)
def enable_pab():
if PAB_MANAGER is None:
return False
return (
PAB_MANAGER.config.cross_broadcast
or PAB_MANAGER.config.spatial_broadcast
or PAB_MANAGER.config.temporal_broadcast
)
def update_steps(steps: int):
if PAB_MANAGER is not None:
PAB_MANAGER.config.steps = steps
def if_broadcast_cross(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_cross(timestep, count)
def if_broadcast_temporal(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not enable_pab():
return False, count
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
|