Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,590 Bytes
57876e1 d58b8ba 57876e1 d58b8ba 57876e1 1942098 57876e1 1942098 57876e1 3f4b0ee 57876e1 3f4b0ee 57876e1 3128011 57876e1 3222587 57876e1 490c4aa 57876e1 0127f03 57876e1 8cd5ef4 57876e1 8cd5ef4 9b1b6c3 8cd5ef4 57876e1 3222587 57876e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import spaces
import gradio as gr
import random
import numpy as np
import torch
from torchvision.transforms.functional import center_crop
try:
# Try to install detectron2 from source. Needed for semseg plotting functionality.
os.system("python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
except Exception as e:
print('detectron2 cannot be installed. Falling back to simple semseg visualization.')
print(e)
# We recommend running this demo on an A100 GPU
if torch.cuda.is_available():
device = "cuda"
gpu_type = torch.cuda.get_device_name(torch.cuda.current_device())
power_device = f"{gpu_type}"
torch.cuda.max_memory_allocated(device=device)
else:
device = "cpu"
power_device = "CPU"
os.system("pip uninstall -y xformers") # Only use xformers on GPU
from fourm.demo_4M_sampler import Demo4MSampler
from fourm.data.modality_transforms import RGBTransform
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
MAX_SEED = np.iinfo(np.int32).max
FM_MODEL_ID = 'EPFL-VILAB/4M-21_XL'
MODEL_NAME = FM_MODEL_ID.split('/')[1].replace('_', ' ')
# Human poses visualization is disabled, since it needs SMPL weights. To enable human pose prediction and rendering:
# 1) Install via `pip install timm yacs smplx pyrender pyopengl==3.1.4`
# You may need to follow the pyrender install instructions: https://pyrender.readthedocs.io/en/latest/install/index.html
# 2) Download SMPL data from https://smpl.is.tue.mpg.de/. See https://github.com/shubham-goel/4D-Humans/ for an example
# 3) Copy the required SMPL files (smpl_mean_params.npz, SMPL_to_J19.pkl, smpl/SMPL_NEUTRAL.pkl) to fourm/utils/hmr2_utils/data .
MANUAL_MODS_OVERRIDE = [
'color_palette', 'tok_depth@224', 'tok_imagebind@224', 'sam_instance', 'tok_dinov2_global',
'tok_normal@224', 'tok_sam_edge@224', 'det', 'tok_canny_edge@224', 'tok_semseg@224', 'rgb@224',
'caption', 't5_caption', 'tok_imagebind_global', 'tok_rgb@224', 'tok_clip@224', 'metadata', 'tok_dinov2@224'
]
sampler = Demo4MSampler(
fm=FM_MODEL_ID,
fm_sr=None,
tok_human_poses=None,
tok_text='./text_tokenizer_4m_wordpiece_30k.json',
mods=MANUAL_MODS_OVERRIDE,
).to(device)
def img_from_path(img_path: str):
rgb_transform = RGBTransform(imagenet_default_mean_and_std=True)
img_pil = rgb_transform.load(img_path)
img_pil = rgb_transform.preprocess(img_pil)
img_pil = center_crop(img_pil, (min(img_pil.size), min(img_pil.size))).resize((224,224))
img = rgb_transform.postprocess(img_pil).unsqueeze(0)
return img
@spaces.GPU(duration=100)
def infer(img_path, seed=0, randomize_seed=False, target_modalities=None, top_p=0.8, top_k=0.0):
if randomize_seed:
seed = None
img = img_from_path(img_path).to(device)
preds = sampler({'rgb@224': img}, seed=seed, target_modalities=target_modalities, top_p=top_p, top_k=top_k)
return sampler.modalities_to_pil(preds, use_fixed_plotting_order=True, resize=512)
examples = [
'examples/example_0.png', 'examples/example_1.png', 'examples/example_2.png',
'examples/example_3.png', 'examples/example_4.png', 'examples/example_5.png',
]
css="""
#col-container {
margin: 0 auto;
max-width: 1500px;
}
#col-input-container {
margin: 0 auto;
max-width: 400px;
}
#run-button {
margin: 0 auto;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# 4M: Massively Multimodal Masked Modeling
""")
with gr.Row():
with gr.Column(elem_id="col-input-container"):
gr.Markdown(f"""
*A framework for training any-to-any multimodal foundation models. Scalable. Open-sourced. Across tens of modalities and tasks.*
[`Website`](https://4m.epfl.ch) | [`GitHub`](https://github.com/apple/ml-4m) <br>[`4M Paper (NeurIPS'23)`](https://arxiv.org/abs/2312.06647) | [`4M-21 Paper (NeurIPS'24)`](https://arxiv.org/abs/2406.09406)
This demo predicts all modalities from a given RGB input, using [{FM_MODEL_ID}](https://huggingface.co/{FM_MODEL_ID}), running on *{power_device}*.
For more generative any-to-any examples, please see our [GitHub repo](https://github.com/apple/ml-4m#generation).
""")
img_path = gr.Image(label='RGB input image', type='filepath')
run_button = gr.Button(f"Predict with {MODEL_NAME}", scale=0, elem_id="run-button")
with gr.Accordion("Advanced Settings", open=False):
target_modalities = gr.CheckboxGroup(
choices=[
('CLIP-B/16', 'tok_clip@224'), ('DINOv2-B/14', 'tok_dinov2@224'), ('ImageBind-H/14', 'tok_imagebind@224'),
('Depth', 'tok_depth@224'), ('Surface normals', 'tok_normal@224'), ('Semantic segmentation', 'tok_semseg@224'),
('Canny edges', 'tok_canny_edge@224'), ('SAM edges', 'tok_sam_edge@224'), ('Caption', 'caption'),
('Bounding boxes', 'det'), ('SAM instances (single pass*)', 'sam_instance'), ('Color palette', 'color_palette'),
('Metadata', 'metadata'),
],
value=[
'tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224',
'tok_depth@224', 'tok_normal@224', 'tok_semseg@224',
'tok_canny_edge@224', 'tok_sam_edge@224', 'caption',
'det', 'sam_instance', 'color_palette', 'metadata'
],
label="Target modalities",
info='Choose which modalities are predicted (in this order).'
)
gr.Markdown(f"""
**Information on modalities**:
\* *SAM instances* in this demo are generated in a single pass and may look sparse. For sampling dense SAM instances, please see the convenience function
[`generate_sam_dense`](https://github.com/apple/ml-4m/blob/e11539965e45aa6731143d742c4493c46b4ef620/fourm/models/generate.py#L1230-L1273)
in `fourm.models.generate.GenerationSampler`, and our [4M-21 interactive notebook](https://github.com/apple/ml-4m/blob/main/notebooks/generation_4M-21.ipynb) for usage examples.
\*\* While 4M-21 models are capable of predicting *4D human poses*, visualizing them requires the SMPL model which cannot be distributed.
To visualize poses, please follow these steps:
1) Install via `pip install timm yacs smplx pyrender pyopengl==3.1.4`.
You may need to follow the [pyrender install instructions](https://pyrender.readthedocs.io/en/latest/install/index.html).
2) Download SMPL data from https://smpl.is.tue.mpg.de/. See https://github.com/shubham-goel/4D-Humans/ for an example.
3) Copy the required SMPL files (`smpl_mean_params.npz`, `SMPL_to_J19.pkl`, `smpl/SMPL_NEUTRAL.pkl`) to `fourm/utils/hmr2_utils/data` .
""")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.8)
top_k = gr.Slider(label="Top-k", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
result = gr.Gallery(
label="Predictions", show_label=True, elem_id="gallery", type='pil',
columns=[4], rows=None, object_fit="contain", height="auto"
)
gr.Examples(
examples = examples,
fn = infer,
inputs = [img_path],
outputs = [result],
cache_examples='lazy',
)
run_button.click(
fn = infer,
inputs = [img_path, seed, randomize_seed, target_modalities, top_p, top_k],
outputs = [result]
)
demo.queue(max_size=10).launch() |