File size: 3,328 Bytes
59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 783e710 85f2b73 783e710 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 59b2cda 85f2b73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from __future__ import annotations
import gc
import pathlib
import sys
import tempfile
import gradio as gr
import imageio
import PIL.Image
import torch
from einops import rearrange
from huggingface_hub import ModelCard
sys.path.append('Tune-A-Video')
from tuneavideo.models.unet import UNet3DConditionModel
from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
class InferencePipeline:
def __init__(self, hf_token: str | None = None):
self.hf_token = hf_token
self.pipe = None
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
self.model_id = None
def clear(self) -> None:
self.model_id = None
del self.pipe
self.pipe = None
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def check_if_model_is_local(model_id: str) -> bool:
return pathlib.Path(model_id).exists()
@staticmethod
def get_model_card(model_id: str,
hf_token: str | None = None) -> ModelCard:
if InferencePipeline.check_if_model_is_local(model_id):
card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
else:
card_path = model_id
return ModelCard.load(card_path, token=hf_token)
@staticmethod
def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
card = InferencePipeline.get_model_card(model_id, hf_token)
return card.data.base_model
def load_pipe(self, model_id: str) -> None:
if model_id == self.model_id:
return
base_model_id = self.get_base_model_info(model_id, self.hf_token)
unet = UNet3DConditionModel.from_pretrained(
model_id,
subfolder='unet',
torch_dtype=torch.float16,
use_auth_token=self.hf_token)
pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
unet=unet,
torch_dtype=torch.float16,
use_auth_token=self.hf_token)
pipe = pipe.to(self.device)
self.pipe = pipe
self.model_id = model_id # type: ignore
def run(
self,
model_id: str,
prompt: str,
video_length: int,
fps: int,
seed: int,
n_steps: int,
guidance_scale: float,
) -> PIL.Image.Image:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
self.load_pipe(model_id)
generator = torch.Generator(device=self.device).manual_seed(seed)
out = self.pipe(
prompt,
video_length=video_length,
width=512,
height=512,
num_inference_steps=n_steps,
guidance_scale=guidance_scale,
generator=generator,
) # type: ignore
frames = rearrange(out.videos[0], 'c t h w -> t h w c')
frames = (frames * 255).to(torch.uint8).numpy()
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
writer = imageio.get_writer(out_file.name, fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()
return out_file.name
|