Spaces:
Runtime error
Runtime error
import os | |
import subprocess | |
import tempfile | |
from pathlib import Path | |
from typing import Union | |
import shutil | |
import cv2 | |
import imageio | |
import numpy as np | |
import torch | |
import torchvision | |
from decord import VideoReader, cpu | |
from einops import rearrange, repeat | |
from t2v_enhanced.utils.iimage import IImage | |
from PIL import Image, ImageDraw, ImageFont | |
from torchvision.utils import save_image | |
channel_first = 0 | |
channel_last = -1 | |
def video_naming(prompt, extension, batch_idx, idx): | |
prompt_identifier = prompt.replace(" ", "_") | |
prompt_identifier = prompt_identifier.replace("/", "_") | |
if len(prompt_identifier) > 40: | |
prompt_identifier = prompt_identifier[:40] | |
filename = f"{batch_idx:04d}_{idx:04d}_{prompt_identifier}.{extension}" | |
return filename | |
def video_naming_chunk(prompt, extension, batch_idx, idx, chunk_idx): | |
prompt_identifier = prompt.replace(" ", "_") | |
prompt_identifier = prompt_identifier.replace("/", "_") | |
if len(prompt_identifier) > 40: | |
prompt_identifier = prompt_identifier[:40] | |
filename = f"{batch_idx}_{idx}_{chunk_idx}_{prompt_identifier}.{extension}" | |
return filename | |
class ResultProcessor(): | |
def __init__(self, fps: int, n_frames: int, logger=None) -> None: | |
self.fps = fps | |
self.logger = logger | |
self.n_frames = n_frames | |
def set_logger(self, logger): | |
self.logger = logger | |
def _create_video(self, video, prompt, filename: Union[str, Path], append_video: torch.FloatTensor = None, input_flow=None): | |
if video.ndim == 5: | |
# can be batches if we provide list of filenames | |
assert video.shape[0] == 1 | |
video = video[0] | |
if video.shape[0] == 3 and video.shape[1] == self.n_frames: | |
video = rearrange(video, "C F W H -> F C W H") | |
assert video.shape[1] == 3, f"Wrong video format. Got {video.shape}" | |
if isinstance(filename, Path): | |
filename = filename.as_posix() | |
# assert video.max() <= 1 and video.min() >= 0 | |
assert video.max() <=1.1 and video.min() >= -0.1, f"video has unexpected range: [{video.min()}, {video.max()}]" | |
vid_obj = IImage(video, vmin=0, vmax=1) | |
if prompt is not None: | |
vid_obj = vid_obj.append_text(prompt, padding=(0, 50, 0, 0)) | |
if append_video is not None: | |
if append_video.ndim == 5: | |
assert append_video.shape[0] == 1 | |
append_video = append_video[0] | |
if append_video.shape[0] < video.shape[0]: | |
append_video = torch.concat([append_video, | |
repeat(append_video[-1, None], "F C W H -> (rep F) C W H", rep=video.shape[0]-append_video.shape[0])], dim=0) | |
if append_video.ndim == 3 and video.ndim == 4: | |
append_video = repeat( | |
append_video, "C W H -> F C W H", F=video.shape[0]) | |
append_video = IImage(append_video, vmin=-1, vmax=1) | |
if prompt is not None: | |
append_video = append_video.append_text( | |
"input_frame", padding=(0, 50, 0, 0)) | |
vid_obj = vid_obj | append_video | |
vid_obj = vid_obj.setFps(self.fps) | |
vid_obj.save(filename) | |
def _create_prompt_file(self, prompt, filename, video_path: str = None): | |
filename = Path(filename) | |
filename = filename.parent / (filename.stem+".txt") | |
with open(filename.as_posix(), "w") as file_writer: | |
file_writer.write(prompt) | |
file_writer.write("\n") | |
if video_path is not None: | |
file_writer.write(video_path) | |
else: | |
file_writer.write(" no_source") | |
def log_video(self, video: torch.FloatTensor, prompt: str, video_id: str, log_folder: str, input_flow=None, video_path_input: str = None, extension: str = "gif", prompt_on_vid: bool = True, append_video: torch.FloatTensor = None): | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
storage_fol = Path(tmpdirname) | |
filename = f"{video_id}.{extension}".replace("/", "_") | |
vid_filename = storage_fol / filename | |
self._create_video( | |
video, prompt if prompt_on_vid else None, vid_filename, append_video, input_flow=input_flow) | |
prompt_file = storage_fol / f"{video_id}.txt" | |
self._create_prompt_file(prompt, prompt_file, video_path_input) | |
if self.logger.experiment.__class__.__name__ == "_DummyExperiment": | |
run_fol = Path(self.logger.save_dir) / \ | |
self.logger.experiment_id / self.logger.run_id / "artifacts" / log_folder | |
if not run_fol.exists(): | |
run_fol.mkdir(parents=True, exist_ok=True) | |
shutil.copy(prompt_file.as_posix(), | |
(run_fol / f"{video_id}.txt").as_posix()) | |
shutil.copy(vid_filename, | |
(run_fol / filename).as_posix()) | |
else: | |
self.logger.experiment.log_artifact( | |
self.logger.run_id, prompt_file.as_posix(), log_folder) | |
self.logger.experiment.log_artifact( | |
self.logger.run_id, vid_filename, log_folder) | |
def save_to_file(self, video: torch.FloatTensor, prompt: str, video_filename: Union[str, Path], input_flow=None, conditional_video_path: str = None, prompt_on_vid: bool = True, conditional_video: torch.FloatTensor = None): | |
self._create_video( | |
video, prompt if prompt_on_vid else None, video_filename, conditional_video, input_flow=input_flow) | |
self._create_prompt_file( | |
prompt, video_filename, conditional_video_path) | |
def add_text_to_image(image_array, text, position, font_size, text_color, font_path=None): | |
# Convert the NumPy array to PIL Image | |
image_pil = Image.fromarray(image_array) | |
# Create a drawing object | |
draw = ImageDraw.Draw(image_pil) | |
if font_path is not None: | |
font = ImageFont.truetype(font_path, font_size) | |
else: | |
try: | |
# Load the font | |
font = ImageFont.truetype( | |
"/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", font_size) | |
except: | |
font = ImageFont.load_default() | |
# Draw the text on the image | |
draw.text(position, text, font=font, fill=text_color) | |
# Convert the PIL Image back to NumPy array | |
modified_image_array = np.array(image_pil) | |
return modified_image_array | |
def add_text_to_video(video_path, prompt): | |
outputs_with_overlay = [] | |
with open(video_path, "rb") as f: | |
vr = VideoReader(f, ctx=cpu(0)) | |
for i in range(len(vr)): | |
frame = vr[i] | |
frame = add_text_to_image(frame, prompt, position=( | |
10, 10), font_size=15, text_color=(255, 0, 0),) | |
outputs_with_overlay.append(frame) | |
outputs = outputs_with_overlay | |
video_path = video_path.replace("mp4", "gif") | |
imageio.mimsave(video_path, outputs, duration=100, loop=0) | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=30, prompt=None): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
if prompt is not None: | |
outputs_with_overlay = [] | |
for frame in outputs: | |
frame_out = add_text_to_image( | |
frame, prompt, position=(10, 10), font_size=10, text_color=(255, 0, 0),) | |
outputs_with_overlay.append(frame_out) | |
outputs = outputs_with_overlay | |
imageio.mimsave(path, outputs, duration=round(1/fps*1000), loop=0) | |
# iio.imwrite(path, outputs) | |
# optimize(path) | |
def set_channel_pos(data, shape_dict, channel_pos): | |
assert data.ndim == 5 or data.ndim == 4 | |
batch_dim = data.shape[0] | |
frame_dim = shape_dict["frame_dim"] | |
channel_dim = shape_dict["channel_dim"] | |
width_dim = shape_dict["width_dim"] | |
height_dim = shape_dict["height_dim"] | |
assert batch_dim != frame_dim | |
assert channel_dim != frame_dim | |
assert channel_dim != batch_dim | |
video_shape = list(data.shape) | |
batch_pos = video_shape.index(batch_dim) | |
channel_pos = video_shape.index(channel_dim) | |
w_pos = video_shape.index(width_dim) | |
h_pos = video_shape.index(height_dim) | |
if w_pos == h_pos: | |
video_shape[w_pos] = -1 | |
h_pos = video_shape.index(height_dim) | |
pattern_order = {} | |
pattern_order[batch_pos] = "B" | |
pattern_order[channel_pos] = "C" | |
pattern_order[w_pos] = "W" | |
pattern_order[h_pos] = "H" | |
if data.ndim == 5: | |
frame_pos = video_shape.index(frame_dim) | |
pattern_order[frame_pos] = "F" | |
if channel_pos == channel_first: | |
pattern = " -> B F C W H" | |
else: | |
pattern = " -> B F W H C" | |
else: | |
if channel_pos == channel_first: | |
pattern = " -> B C W H" | |
else: | |
pattern = " -> B W H C" | |
pattern_input = [pattern_order[idx] for idx in range(data.ndim)] | |
pattern_input = " ".join(pattern_input) | |
pattern = pattern_input + pattern | |
data = rearrange(data, pattern) | |
def merge_first_two_dimensions(tensor): | |
dims = tensor.ndim | |
letters = [] | |
for letter_idx in range(dims-2): | |
letters.append(chr(letter_idx+67)) | |
latters_pattern = " ".join(letters) | |
tensor = rearrange(tensor, "A B "+latters_pattern + | |
" -> (A B) "+latters_pattern) | |
# TODO merging first two dimensions might be easier with reshape so no need to create letters | |
# should be 'tensor.view(*tensor.shape[:2], -1)' | |
return tensor | |
def apply_spatial_function_to_video_tensor(video, shape, func): | |
# TODO detect batch, frame, channel, width, and height | |
assert video.ndim == 5 | |
batch_dim = shape["batch_dim"] | |
frame_dim = shape["frame_dim"] | |
channel_dim = shape["channel_dim"] | |
width_dim = shape["width_dim"] | |
height_dim = shape["height_dim"] | |
assert batch_dim != frame_dim | |
assert channel_dim != frame_dim | |
assert channel_dim != batch_dim | |
video_shape = list(video.shape) | |
batch_pos = video_shape.index(batch_dim) | |
frame_pos = video_shape.index(frame_dim) | |
channel_pos = video_shape.index(channel_dim) | |
w_pos = video_shape.index(width_dim) | |
h_pos = video_shape.index(height_dim) | |
if w_pos == h_pos: | |
video_shape[w_pos] = -1 | |
h_pos = video_shape.index(height_dim) | |
pattern_order = {} | |
pattern_order[batch_pos] = "B" | |
pattern_order[channel_pos] = "C" | |
pattern_order[frame_pos] = "F" | |
pattern_order[w_pos] = "W" | |
pattern_order[h_pos] = "H" | |
pattern_order = sorted(pattern_order.items(), key=lambda x: x[1]) | |
pattern_order = [x[0] for x in pattern_order] | |
input_pattern = " ".join(pattern_order) | |
video = rearrange(video, input_pattern+" -> (B F) C W H") | |
video = func(video) | |
video = rearrange(video, "(B F) C W H -> "+input_pattern, F=frame_dim) | |
return video | |
def dump_frames(videos, as_mosaik, storage_fol, save_image_kwargs): | |
# assume videos is in format B F C H W, range [0,1] | |
num_frames = videos.shape[1] | |
num_videos = videos.shape[0] | |
if videos.shape[2] != 3 and videos.shape[-1] == 3: | |
videos = rearrange(videos, "B F W H C -> B F C W H") | |
frame_counter = 0 | |
if not isinstance(storage_fol, Path): | |
storage_fol = Path(storage_fol) | |
for frame_idx in range(num_frames): | |
print(f" Creating frame {frame_idx}") | |
batch_frame = videos[:, frame_idx, ...] | |
if as_mosaik: | |
filename = storage_fol / f"frame_{frame_counter:03d}.png" | |
save_image(batch_frame, fp=filename.as_posix(), | |
**save_image_kwargs) | |
frame_counter += 1 | |
else: | |
for video_idx in range(num_videos): | |
frame = batch_frame[video_idx] | |
filename = storage_fol / f"frame_{frame_counter:03d}.png" | |
save_image(frame, fp=filename.as_posix(), | |
**save_image_kwargs) | |
frame_counter += 1 | |
def gif_from_videos(videos): | |
assert videos.dim() == 5 | |
assert videos.min() >= 0 | |
assert videos.max() <= 1 | |
gif_file = Path("tmp.gif").absolute() | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
storage_fol = Path(tmpdirname) | |
nrows = min(4, videos.shape[0]) | |
dump_frames( | |
videos=videos, storage_fol=storage_fol, as_mosaik=True, save_image_kwargs={"nrow": nrows}) | |
cmd = f"ffmpeg -y -f image2 -framerate 4 -i {storage_fol / 'frame_%03d.png'} {gif_file.as_posix()}" | |
subprocess.check_call( | |
cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) | |
return gif_file | |
def add_margin(pil_img, top, right, bottom, left, color): | |
width, height = pil_img.size | |
new_width = width + right + left | |
new_height = height + top + bottom | |
result = Image.new(pil_img.mode, (new_width, new_height), color) | |
result.paste(pil_img, (left, top)) | |
return result | |
def resize_to_fit(image, size): | |
W, H = size | |
w, h = image.size | |
if H / h > W / w: | |
H_ = int(h * W / w) | |
W_ = W | |
else: | |
W_ = int(w * H / h) | |
H_ = H | |
return image.resize((W_, H_)) | |
def pad_to_fit(image, size): | |
W, H = size | |
w, h = image.size | |
pad_h = (H - h) // 2 | |
pad_w = (W - w) // 2 | |
return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0)) |