jadechoghari commited on
Commit
1822fe2
1 Parent(s): f427d24

Rename modeling.py to pipeline_mar.py

Browse files

diffusers lib is better than transformers for this model

Files changed (2) hide show
  1. modeling.py +0 -183
  2. pipeline_mar.py +83 -0
modeling.py DELETED
@@ -1,183 +0,0 @@
1
- from transformers import PretrainedConfig
2
- import torch.nn as nn
3
- from transformers import PreTrainedModel
4
- import torch
5
- from huggingface_hub import hf_hub_download
6
- from safetensors.torch import save_file, load_file
7
- import os
8
- from timm.models.vision_transformer import Block
9
- from . import mar
10
- from .vae import AutoencoderKL
11
- from .mar import MAR
12
- import numpy as np
13
-
14
- class MARConfig(PretrainedConfig):
15
- model_type = "mar"
16
-
17
- def __init__(self,
18
- img_size=256,
19
- vae_stride=16,
20
- patch_size=1,
21
- encoder_embed_dim=1024,
22
- encoder_depth=16,
23
- encoder_num_heads=16,
24
- decoder_embed_dim=1024,
25
- decoder_depth=16,
26
- decoder_num_heads=16,
27
- mlp_ratio=4.,
28
- norm_layer="LayerNorm",
29
- vae_embed_dim=16,
30
- mask_ratio_min=0.7,
31
- label_drop_prob=0.1,
32
- class_num=1000,
33
- attn_dropout=0.1,
34
- proj_dropout=0.1,
35
- buffer_size=64,
36
- diffloss_d=3,
37
- diffloss_w=1024,
38
- num_sampling_steps='100',
39
- diffusion_batch_mul=4,
40
- grad_checkpointing=False,
41
- **kwargs):
42
- super().__init__(**kwargs)
43
-
44
- # store parameters in the config
45
- self.img_size = img_size
46
- self.vae_stride = vae_stride
47
- self.patch_size = patch_size
48
- self.encoder_embed_dim = encoder_embed_dim
49
- self.encoder_depth = encoder_depth
50
- self.encoder_num_heads = encoder_num_heads
51
- self.decoder_embed_dim = decoder_embed_dim
52
- self.decoder_depth = decoder_depth
53
- self.decoder_num_heads = decoder_num_heads
54
- self.mlp_ratio = mlp_ratio
55
- self.norm_layer = norm_layer
56
- self.vae_embed_dim = vae_embed_dim
57
- self.mask_ratio_min = mask_ratio_min
58
- self.label_drop_prob = label_drop_prob
59
- self.class_num = class_num
60
- self.attn_dropout = attn_dropout
61
- self.proj_dropout = proj_dropout
62
- self.buffer_size = buffer_size
63
- self.diffloss_d = diffloss_d
64
- self.diffloss_w = diffloss_w
65
- self.num_sampling_steps = num_sampling_steps
66
- self.diffusion_batch_mul = diffusion_batch_mul
67
- self.grad_checkpointing = grad_checkpointing
68
-
69
-
70
-
71
- class MARModel(PreTrainedModel):
72
- # links to MARConfig class
73
- config_class = MARConfig
74
-
75
- def __init__(self, config):
76
- super().__init__(config)
77
- self.config = config
78
-
79
- # convert norm_layer from string to class
80
- norm_layer = getattr(nn, config.norm_layer)
81
-
82
- # init the mar model using the parameters from config
83
- self.model = MAR(
84
- img_size=config.img_size,
85
- vae_stride=config.vae_stride,
86
- patch_size=config.patch_size,
87
- encoder_embed_dim=config.encoder_embed_dim,
88
- encoder_depth=config.encoder_depth,
89
- encoder_num_heads=config.encoder_num_heads,
90
- decoder_embed_dim=config.decoder_embed_dim,
91
- decoder_depth=config.decoder_depth,
92
- decoder_num_heads=config.decoder_num_heads,
93
- mlp_ratio=config.mlp_ratio,
94
- norm_layer=norm_layer, # use the actual class for the layer
95
- vae_embed_dim=config.vae_embed_dim,
96
- mask_ratio_min=config.mask_ratio_min,
97
- label_drop_prob=config.label_drop_prob,
98
- class_num=config.class_num,
99
- attn_dropout=config.attn_dropout,
100
- proj_dropout=config.proj_dropout,
101
- buffer_size=config.buffer_size,
102
- diffloss_d=config.diffloss_d,
103
- diffloss_w=config.diffloss_w,
104
- num_sampling_steps=config.num_sampling_steps,
105
- diffusion_batch_mul=config.diffusion_batch_mul,
106
- grad_checkpointing=config.grad_checkpointing,
107
- )
108
-
109
- def forward_train(self, imgs, labels):
110
- # calls the forward method from the mar class - passing imgs & labels
111
- return self.model(imgs, labels)
112
-
113
- def forward(self, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
114
- # call the sample_tokens method from the MAR class
115
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
- checkpoint_path = hf_hub_download(
117
- repo_id=pretrained_model_name_or_path,
118
- filename=f"kl16.safetensors"
119
- )
120
- vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=checkpoint_path)
121
- vae = vae.to(device).eval()
122
- # can customize more from the user
123
- seed = 0
124
- torch.manual_seed(seed)
125
- np.random.seed(seed)
126
- num_ar_steps = 64
127
- cfg_scale = 4
128
- cfg_schedule = "constant"
129
- temperature = 1.0
130
- # TODO: this should be defined by the user
131
- class_labels = 207, 360, 388, 113, 355, 980, 323, 979 #@param {type:"raw"}
132
- samples_per_row = 4
133
-
134
- with torch.cuda.amp.autocast():
135
- sampled_tokens = self.model.sample_tokens(
136
- bsz=len(class_labels), num_iter=num_ar_steps,
137
- cfg=cfg_scale, cfg_schedule=cfg_schedule,
138
- labels=torch.Tensor(class_labels).long().to(device),
139
- temperature=temperature, progress=True)
140
- sampled_images = vae.decode(sampled_tokens / 0.2325)
141
- return sampled_images
142
-
143
- @classmethod
144
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
145
- # config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
146
- # model = cls(config)
147
- buffer_size = kwargs.get('buffer_size', 64)
148
- diffloss_d = kwargs.get('diffloss_d', 3)
149
- diffloss_w = kwargs.get('diffloss_w', 1024)
150
- num_sampling_steps_diffloss = kwargs.get('num_sampling_steps', 100)
151
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
- model_type = "mar_base"
153
- model_architecture = mar.__dict__[model_type](
154
- buffer_size=buffer_size,
155
- diffloss_d=diffloss_d,
156
- diffloss_w=diffloss_w,
157
- num_sampling_steps=str(num_sampling_steps_diffloss)
158
- ).to(device)
159
- checkpoint_path = hf_hub_download(
160
- repo_id=pretrained_model_name_or_path,
161
- filename=f"checkpoint-last.pth"
162
- )
163
-
164
- state_dict = torch.load(checkpoint_path, map_location=device)["model_ema"]
165
-
166
- model_architecture.load_state_dict(state_dict, strict=False)
167
-
168
- # update this so the model works on the forward call
169
- model = model_architecture
170
- model.eval()
171
-
172
- return model
173
-
174
-
175
- def save_pretrained(self, save_directory):
176
- # we will save to safetensors
177
- os.makedirs(save_directory, exist_ok=True)
178
- state_dict = self.model.state_dict()
179
- safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors")
180
- save_file(state_dict, safetensors_path)
181
-
182
- # save the configuration as usual
183
- self.config.save_pretrained(save_directory)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_mar.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ import torch
3
+ import numpy as np
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+ import os
7
+ from mar.vae import AutoencoderKL
8
+ from mar import mar
9
+
10
+ # inheriting from DiffusionPipeline for HF
11
+ class MARModel(DiffusionPipeline):
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ @torch.no_grad()
17
+ def _call(self, *args, **kwargs):
18
+ """
19
+ This method downloads the model and VAE components,
20
+ then executes the forward pass based on the user's input.
21
+ """
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+
25
+
26
+ # init the mar model architecture
27
+ buffer_size = kwargs.get("buffer_size", 64)
28
+ diffloss_d = kwargs.get("diffloss_d", 3)
29
+ diffloss_w = kwargs.get("diffloss_w", 1024)
30
+ num_sampling_steps = kwargs.get("num_sampling_steps", 100)
31
+ model_type = kwargs.get("model_type", "mar_base")
32
+
33
+
34
+ self.model = mar.__dict__[model_type](
35
+ buffer_size=buffer_size,
36
+ diffloss_d=diffloss_d,
37
+ diffloss_w=diffloss_w,
38
+ num_sampling_steps=str(num_sampling_steps)
39
+ ).to(device)
40
+ # download and load the model weights (.safetensors or .pth)
41
+ model_checkpoint_path = hf_hub_download(
42
+ repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
43
+ filename=kwargs.get("model_filename", "checkpoint-last.pth")
44
+ )
45
+
46
+ state_dict = torch.load(model_checkpoint_path, map_location=device)["model_ema"]
47
+
48
+ self.model.load_state_dict(state_dict, strict=False)
49
+ self.model.eval()
50
+
51
+ # download and load the vae
52
+ vae_checkpoint_path = hf_hub_download(
53
+ repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
54
+ filename=kwargs.get("vae_filename", "kl16.ckpt")
55
+ )
56
+
57
+ vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=vae_checkpoint_path)
58
+ vae = vae.to(device).eval()
59
+
60
+ # set up user-specified or default values for generation
61
+ seed = kwargs.get("seed", 0)
62
+ torch.manual_seed(seed)
63
+ np.random.seed(seed)
64
+
65
+ num_ar_steps = kwargs.get("num_ar_steps", 64)
66
+ cfg_scale = kwargs.get("cfg_scale", 4)
67
+ cfg_schedule = kwargs.get("cfg_schedule", "constant")
68
+ temperature = kwargs.get("temperature", 1.0)
69
+ class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])
70
+
71
+ # generate the tokens and images
72
+ with torch.cuda.amp.autocast():
73
+ sampled_tokens = self.model.sample_tokens(
74
+ bsz=len(class_labels), num_iter=num_ar_steps,
75
+ cfg=cfg_scale, cfg_schedule=cfg_schedule,
76
+ labels=torch.Tensor(class_labels).long().to(device),
77
+ temperature=temperature, progress=True
78
+ )
79
+
80
+ sampled_images = vae.decode(sampled_tokens / 0.2325)
81
+
82
+ return sampled_images
83
+