jadechoghari commited on
Commit
de18685
1 Parent(s): eec2810

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +51 -28
pipeline.py CHANGED
@@ -10,7 +10,7 @@ from .vae import AutoencoderKL
10
  from .mar import mar_base, mar_large, mar_huge
11
 
12
  # inheriting from DiffusionPipeline for HF
13
- class MARModel(DiffusionPipeline):
14
 
15
  def __init__(self):
16
  super().__init__()
@@ -32,44 +32,52 @@ class MARModel(DiffusionPipeline):
32
  num_sampling_steps = kwargs.get("num_sampling_steps", 100)
33
  model_type = kwargs.get("model_type", "mar_base")
34
 
 
 
 
 
 
35
 
 
 
 
36
  if model_type == "mar_base":
37
- self.model = mar_base(
38
- buffer_size=buffer_size,
39
- diffloss_d=diffloss_d,
40
- diffloss_w=diffloss_w,
41
- num_sampling_steps=str(num_sampling_steps)
42
- ).to(device)
43
  elif model_type == "mar_large":
44
- self.model = mar_large(
45
- buffer_size=buffer_size,
46
- diffloss_d=diffloss_d,
47
- diffloss_w=diffloss_w,
48
- num_sampling_steps=str(num_sampling_steps)
49
- ).to(device)
50
  elif model_type == "mar_huge":
51
- self.model = mar_huge(
52
- buffer_size=buffer_size,
53
- diffloss_d=diffloss_d,
54
- diffloss_w=diffloss_w,
55
- num_sampling_steps=str(num_sampling_steps)
56
- ).to(device)
57
- # download and load the model weights (.safetensors or .pth)
58
  model_checkpoint_path = hf_hub_download(
59
  repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
60
  filename=kwargs.get("model_filename", "checkpoint-last.pth")
61
  )
 
 
 
62
 
63
- state_dict = torch.load(model_checkpoint_path, map_location=device)["model_ema"]
 
 
 
 
 
64
 
65
- self.model.load_state_dict(state_dict, strict=False)
66
- self.model.eval()
 
67
 
68
  # download and load the vae
69
  vae_checkpoint_path = hf_hub_download(
70
  repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
71
  filename=kwargs.get("vae_filename", "kl16.ckpt")
72
  )
 
73
 
74
  vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=vae_checkpoint_path)
75
  vae = vae.to(device).eval()
@@ -83,19 +91,34 @@ class MARModel(DiffusionPipeline):
83
  cfg_scale = kwargs.get("cfg_scale", 4)
84
  cfg_schedule = kwargs.get("cfg_schedule", "constant")
85
  temperature = kwargs.get("temperature", 1.0)
86
- class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])
87
- class_labels = torch.Tensor(class_labels).long().to(device)
 
88
 
89
  # generate the tokens and images
90
  with torch.cuda.amp.autocast():
91
- sampled_tokens = self.model.sample_tokens(
92
  bsz=len(class_labels), num_iter=num_ar_steps,
93
  cfg=cfg_scale, cfg_schedule=cfg_schedule,
94
- labels=torch.Tensor(class_labels).long().to(device),
95
  temperature=temperature, progress=True
96
  )
97
 
98
  sampled_images = vae.decode(sampled_tokens / 0.2325)
99
 
100
- return sampled_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
10
  from .mar import mar_base, mar_large, mar_huge
11
 
12
  # inheriting from DiffusionPipeline for HF
13
+ class MARModel(DiffusionPipeline):
14
 
15
  def __init__(self):
16
  super().__init__()
 
32
  num_sampling_steps = kwargs.get("num_sampling_steps", 100)
33
  model_type = kwargs.get("model_type", "mar_base")
34
 
35
+ model_mapping = {
36
+ "mar_base": mar_base,
37
+ "mar_large": mar_large,
38
+ "mar_huge": mar_huge
39
+ }
40
 
41
+ num_sampling_steps_diffloss = 100 # Example number of sampling steps
42
+
43
+ # download the pretrained model and set diffloss parameters
44
  if model_type == "mar_base":
45
+ diffloss_d = 6
46
+ diffloss_w = 1024
 
 
 
 
47
  elif model_type == "mar_large":
48
+ diffloss_d = 8
49
+ diffloss_w = 1280
 
 
 
 
50
  elif model_type == "mar_huge":
51
+ diffloss_d = 12
52
+ diffloss_w = 1536
53
+ else:
54
+ raise NotImplementedError
55
+ download and load the model weights (.safetensors or .pth)
 
 
56
  model_checkpoint_path = hf_hub_download(
57
  repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
58
  filename=kwargs.get("model_filename", "checkpoint-last.pth")
59
  )
60
+ model_checkpoint_path = kwargs.get("model_checkpoint_path", "./mar/checkpoint-last.pth")
61
+
62
+ model_fn = model_mapping[model_type]
63
 
64
+ model = model_fn(
65
+ buffer_size=64,
66
+ diffloss_d=diffloss_d,
67
+ diffloss_w=diffloss_w,
68
+ num_sampling_steps=str(num_sampling_steps_diffloss)
69
+ ).cuda()
70
 
71
+ state_dict = torch.load(f"./mar/checkpoint-last.pth")["model_ema"]
72
+ model.load_state_dict(state_dict)
73
+ model.eval()
74
 
75
  # download and load the vae
76
  vae_checkpoint_path = hf_hub_download(
77
  repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
78
  filename=kwargs.get("vae_filename", "kl16.ckpt")
79
  )
80
+ vae_checkpoint_path = kwargs.get("vae_checkpoint_path", vae_checkpoint_path)
81
 
82
  vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=vae_checkpoint_path)
83
  vae = vae.to(device).eval()
 
91
  cfg_scale = kwargs.get("cfg_scale", 4)
92
  cfg_schedule = kwargs.get("cfg_schedule", "constant")
93
  temperature = kwargs.get("temperature", 1.0)
94
+ # class_labels = kwargs.get("class_labels", 207, 360, 388, 113, 355, 980, 323, 979)
95
+ class_labels = 207, 360, 388, 113, 355, 980, 323, 979
96
+ print("the labels", class_labels)
97
 
98
  # generate the tokens and images
99
  with torch.cuda.amp.autocast():
100
+ sampled_tokens = model.sample_tokens(
101
  bsz=len(class_labels), num_iter=num_ar_steps,
102
  cfg=cfg_scale, cfg_schedule=cfg_schedule,
103
+ labels=torch.Tensor(class_labels).long().cuda(),
104
  temperature=temperature, progress=True
105
  )
106
 
107
  sampled_images = vae.decode(sampled_tokens / 0.2325)
108
 
109
+ output_dir = kwargs.get("output_dir", "./")
110
+ os.makedirs(output_dir, exist_ok=True)
111
+
112
+ # save the images
113
+ image_path = os.path.join(output_dir, "sampled_image.png")
114
+ samples_per_row = kwargs.get("samples_per_row", 6)
115
+
116
+ save_image(
117
+ sampled_images, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)
118
+ )
119
+
120
+ # return as a pil image
121
+ image = Image.open(image_path)
122
+
123
+ return image
124