from transformers import PretrainedConfig import torch.nn as nn from transformers import PreTrainedModel import torch from safetensors.torch import save_file import os from timm.models.vision_transformer import Block from .mar import MAR class MARConfig(PretrainedConfig): model_type = "mar" def __init__(self, img_size=256, vae_stride=16, patch_size=1, encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, mlp_ratio=4., norm_layer="LayerNorm", vae_embed_dim=16, mask_ratio_min=0.7, label_drop_prob=0.1, class_num=1000, attn_dropout=0.1, proj_dropout=0.1, buffer_size=64, diffloss_d=3, diffloss_w=1024, num_sampling_steps='100', diffusion_batch_mul=4, grad_checkpointing=False, **kwargs): super().__init__(**kwargs) # store parameters in the config self.img_size = img_size self.vae_stride = vae_stride self.patch_size = patch_size self.encoder_embed_dim = encoder_embed_dim self.encoder_depth = encoder_depth self.encoder_num_heads = encoder_num_heads self.decoder_embed_dim = decoder_embed_dim self.decoder_depth = decoder_depth self.decoder_num_heads = decoder_num_heads self.mlp_ratio = mlp_ratio self.norm_layer = norm_layer self.vae_embed_dim = vae_embed_dim self.mask_ratio_min = mask_ratio_min self.label_drop_prob = label_drop_prob self.class_num = class_num self.attn_dropout = attn_dropout self.proj_dropout = proj_dropout self.buffer_size = buffer_size self.diffloss_d = diffloss_d self.diffloss_w = diffloss_w self.num_sampling_steps = num_sampling_steps self.diffusion_batch_mul = diffusion_batch_mul self.grad_checkpointing = grad_checkpointing class MARModel(PreTrainedModel): # links to MARConfig class config_class = MARConfig def __init__(self, config): super().__init__(config) self.config = config # convert norm_layer from string to class norm_layer = getattr(nn, config.norm_layer) # init the mar model using the parameters from config self.model = MAR( img_size=config.img_size, vae_stride=config.vae_stride, patch_size=config.patch_size, encoder_embed_dim=config.encoder_embed_dim, encoder_depth=config.encoder_depth, encoder_num_heads=config.encoder_num_heads, decoder_embed_dim=config.decoder_embed_dim, decoder_depth=config.decoder_depth, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio, norm_layer=norm_layer, # use the actual class for the layer vae_embed_dim=config.vae_embed_dim, mask_ratio_min=config.mask_ratio_min, label_drop_prob=config.label_drop_prob, class_num=config.class_num, attn_dropout=config.attn_dropout, proj_dropout=config.proj_dropout, buffer_size=config.buffer_size, diffloss_d=config.diffloss_d, diffloss_w=config.diffloss_w, num_sampling_steps=config.num_sampling_steps, diffusion_batch_mul=config.diffusion_batch_mul, grad_checkpointing=config.grad_checkpointing, ) def forward(self, imgs, labels): # calls the forward method from the mar class - passing imgs & labels return self.model(imgs, labels) def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False): # call the sample_tokens method from the MAR class return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress) # @classmethod # def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) # model = cls(config) # safetensors_path = os.path.join(pretrained_model_name_or_path, "checkpoint-last.safetensors") # if not os.path.exists(safetensors_path): # raise FileNotFoundError(f"safetensors file not found at {safetensors_path}") # state_dict = torch.load(safetensors_path, map_location='cpu') # model.model.load_state_dict(state_dict) # return model def save_pretrained(self, save_directory): # we will save to safetensors os.makedirs(save_directory, exist_ok=True) state_dict = self.model.state_dict() safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors") save_file(state_dict, safetensors_path) # save the configuration as usual self.config.save_pretrained(save_directory)