Spaces:
Paused
Paused
File size: 7,027 Bytes
2ba49a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from PIL import Image
import cupy as cp
import numpy as np
from tqdm import tqdm
from ..extensions.FastBlend.patch_match import PyramidPatchMatcher
from ..extensions.FastBlend.runners.fast import TableManager
from .base import VideoProcessor
class FastBlendSmoother(VideoProcessor):
def __init__(
self,
inference_mode="fast", batch_size=8, window_size=60,
minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0
):
self.inference_mode = inference_mode
self.batch_size = batch_size
self.window_size = window_size
self.ebsynth_config = {
"minimum_patch_size": minimum_patch_size,
"threads_per_block": threads_per_block,
"num_iter": num_iter,
"gpu_id": gpu_id,
"guide_weight": guide_weight,
"initialize": initialize,
"tracking_window_size": tracking_window_size
}
@staticmethod
def from_model_manager(model_manager, **kwargs):
# TODO: fetch GPU ID from model_manager
return FastBlendSmoother(**kwargs)
def inference_fast(self, frames_guide, frames_style):
table_manager = TableManager()
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**self.ebsynth_config
)
# left part
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4")
table_l = table_manager.remapping_table_to_blending_table(table_l)
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4")
# right part
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4")
table_r = table_manager.remapping_table_to_blending_table(table_r)
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1]
# merge
frames = []
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
weight_m = -1
weight = weight_l + weight_m + weight_r
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
frames.append(frame)
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
frames = [Image.fromarray(frame) for frame in frames]
return frames
def inference_balanced(self, frames_guide, frames_style):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**self.ebsynth_config
)
output_frames = []
# tasks
n = len(frames_style)
tasks = []
for target in range(n):
for source in range(target - self.window_size, target + self.window_size + 1):
if source >= 0 and source < n and source != target:
tasks.append((source, target))
# run
frames = [(None, 1) for i in range(n)]
for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"):
tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))]
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for (source, target), result in zip(tasks_batch, target_style):
frame, weight = frames[target]
if frame is None:
frame = frames_style[target]
frames[target] = (
frame * (weight / (weight + 1)) + result / (weight + 1),
weight + 1
)
if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size):
frame = frame.clip(0, 255).astype("uint8")
output_frames.append(Image.fromarray(frame))
frames[target] = (None, 1)
return output_frames
def inference_accurate(self, frames_guide, frames_style):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
use_mean_target_style=True,
**self.ebsynth_config
)
output_frames = []
# run
n = len(frames_style)
for target in tqdm(range(n), desc="Accurate Mode"):
l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n)
remapped_frames = []
for i in range(l, r, self.batch_size):
j = min(i + self.batch_size, r)
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
target_guide = np.stack([frames_guide[target]] * (j - i))
source_style = np.stack([frames_style[source] for source in range(i, j)])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
remapped_frames.append(target_style)
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
frame = frame.clip(0, 255).astype("uint8")
output_frames.append(Image.fromarray(frame))
return output_frames
def release_vram(self):
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
def __call__(self, rendered_frames, original_frames=None, **kwargs):
rendered_frames = [np.array(frame) for frame in rendered_frames]
original_frames = [np.array(frame) for frame in original_frames]
if self.inference_mode == "fast":
output_frames = self.inference_fast(original_frames, rendered_frames)
elif self.inference_mode == "balanced":
output_frames = self.inference_balanced(original_frames, rendered_frames)
elif self.inference_mode == "accurate":
output_frames = self.inference_accurate(original_frames, rendered_frames)
else:
raise ValueError("inference_mode must be fast, balanced or accurate")
self.release_vram()
return output_frames
|