File size: 6,927 Bytes
8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 a60b0e7 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 bf5ad47 8b7a3d1 c0a7c3c bf5ad47 8b7a3d1 a60b0e7 8b7a3d1 6558e17 5207f10 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 db0cd98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
#!/usr/bin/env python
from __future__ import annotations
import enum
import gradio as gr
from huggingface_hub import HfApi
from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
from inference import InferencePipeline
from utils import find_exp_dirs
class ModelSource(enum.Enum):
HUB_LIB = UploadTarget.MODEL_LIBRARY.value
LOCAL = 'Local'
class InferenceUtil:
def __init__(self, hf_token: str | None):
self.hf_token = hf_token
def load_hub_model_list(self) -> dict:
api = HfApi(token=self.hf_token)
choices = [
info.modelId
for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
]
return gr.update(choices=choices,
value=choices[0] if choices else None)
@staticmethod
def load_local_model_list() -> dict:
choices = find_exp_dirs()
return gr.update(choices=choices,
value=choices[0] if choices else None)
def reload_model_list(self, model_source: str) -> dict:
if model_source == ModelSource.HUB_LIB.value:
return self.load_hub_model_list()
elif model_source == ModelSource.LOCAL.value:
return self.load_local_model_list()
else:
raise ValueError
def load_model_info(self, model_id: str) -> tuple[str, str]:
try:
card = InferencePipeline.get_model_card(model_id, self.hf_token)
except Exception:
return '', ''
base_model = getattr(card.data, 'base_model', '')
training_prompt = getattr(card.data, 'training_prompt', '')
return base_model, training_prompt
def reload_model_list_and_update_model_info(
self, model_source: str) -> tuple[dict, str, str]:
model_list_update = self.reload_model_list(model_source)
model_list = model_list_update['choices']
model_info = self.load_model_info(model_list[0] if model_list else '')
return model_list_update, *model_info
def create_inference_demo(pipe: InferencePipeline,
hf_token: str | None = None,
disable_run_button: bool = False) -> gr.Blocks:
app = InferenceUtil(hf_token)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Box():
model_source = gr.Radio(
label='Model Source',
choices=[_.value for _ in ModelSource],
value=ModelSource.HUB_LIB.value)
reload_button = gr.Button('Reload Model List')
model_id = gr.Dropdown(label='Model ID',
choices=None,
value=None)
with gr.Accordion(
label=
'Model info (Base model and prompt used for training)',
open=False):
with gr.Row():
base_model_used_for_training = gr.Text(
label='Base model', interactive=False)
prompt_used_for_training = gr.Text(
label='Training prompt', interactive=False)
prompt = gr.Textbox(
label='Prompt',
max_lines=1,
placeholder='Example: "A panda is surfing"')
video_length = gr.Slider(label='Video length',
minimum=4,
maximum=12,
step=1,
value=8)
fps = gr.Slider(label='FPS',
minimum=1,
maximum=12,
step=1,
value=1)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=100000,
step=1,
value=0)
with gr.Accordion('Advanced options', open=False):
num_steps = gr.Slider(label='Number of Steps',
minimum=0,
maximum=100,
step=1,
value=50)
guidance_scale = gr.Slider(label='Guidance scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
run_button = gr.Button('Generate',
interactive=not disable_run_button)
gr.Markdown('''
- After training, you can press "Reload Model List" button to load your trained model names.
- It takes a few minutes to download model first.
- Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
''')
with gr.Column():
result = gr.Video(label='Result')
model_source.change(fn=app.reload_model_list_and_update_model_info,
inputs=model_source,
outputs=[
model_id,
base_model_used_for_training,
prompt_used_for_training,
])
reload_button.click(fn=app.reload_model_list_and_update_model_info,
inputs=model_source,
outputs=[
model_id,
base_model_used_for_training,
prompt_used_for_training,
])
model_id.change(fn=app.load_model_info,
inputs=model_id,
outputs=[
base_model_used_for_training,
prompt_used_for_training,
])
inputs = [
model_id,
prompt,
video_length,
fps,
seed,
num_steps,
guidance_scale,
]
prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
return demo
if __name__ == '__main__':
import os
hf_token = os.getenv('HF_TOKEN')
pipe = InferencePipeline(hf_token)
demo = create_inference_demo(pipe, hf_token)
demo.queue(api_open=False, max_size=10).launch()
|