Kunpeng Song commited on
Commit
338f71e
1 Parent(s): 8a4a948
Files changed (3) hide show
  1. app.py +1 -1
  2. model_lib/moMA_generator.py +2 -2
  3. model_lib/modules.py +2 -2
app.py CHANGED
@@ -28,7 +28,7 @@ def inference(rgb, mask, subject, prompt):
28
  seed_everything(0)
29
  args = parse_args()
30
  #load MoMA from HuggingFace. Auto download
31
- model = MoMA_main_modal(args).to(args.device, dtype=torch.bfloat16)
32
 
33
 
34
  ################ change texture ##################
 
28
  seed_everything(0)
29
  args = parse_args()
30
  #load MoMA from HuggingFace. Auto download
31
+ model = MoMA_main_modal(args).to(args.device, dtype=torch.float16)
32
 
33
 
34
  ################ change texture ##################
model_lib/moMA_generator.py CHANGED
@@ -93,7 +93,7 @@ class MoMA_generator:
93
  print('Loading StableDiffusion: Realistic_Vision...')
94
  self.pipe = StableDiffusionPipeline.from_pretrained(
95
  "SG161222/Realistic_Vision_V4.0_noVAE",
96
- torch_dtype=torch.bfloat16,
97
  scheduler=noise_scheduler,
98
  vae=vae,
99
  feature_extractor=None,
@@ -112,7 +112,7 @@ class MoMA_generator:
112
  cross_attention_dim=768,
113
  clip_embeddings_dim=1024,
114
  clip_extra_context_tokens=4,
115
- ).to(self.device, dtype=torch.bfloat16)
116
  return image_proj_model
117
 
118
  def set_ip_adapter(self):
 
93
  print('Loading StableDiffusion: Realistic_Vision...')
94
  self.pipe = StableDiffusionPipeline.from_pretrained(
95
  "SG161222/Realistic_Vision_V4.0_noVAE",
96
+ torch_dtype=torch.float16,
97
  scheduler=noise_scheduler,
98
  vae=vae,
99
  feature_extractor=None,
 
112
  cross_attention_dim=768,
113
  clip_embeddings_dim=1024,
114
  clip_extra_context_tokens=4,
115
+ ).to(self.device, dtype=torch.float16)
116
  return image_proj_model
117
 
118
  def set_ip_adapter(self):
model_lib/modules.py CHANGED
@@ -87,7 +87,7 @@ class MoMA_main_modal(nn.Module):
87
 
88
  add_function(self.model_llava)
89
 
90
- self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.bfloat16)
91
  self.load_saved_components()
92
  self.freeze_modules()
93
 
@@ -140,7 +140,7 @@ class MoMA_main_modal(nn.Module):
140
  batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject, mask_path,self)
141
  self.moMA_generator.set_selfAttn_strength(strength)
142
 
143
- with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
144
  with torch.no_grad():
145
  ### key steps
146
  llava_emb = self.forward_MLLM(batch).clone().detach()
 
87
 
88
  add_function(self.model_llava)
89
 
90
+ self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16)
91
  self.load_saved_components()
92
  self.freeze_modules()
93
 
 
140
  batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject, mask_path,self)
141
  self.moMA_generator.set_selfAttn_strength(strength)
142
 
143
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
144
  with torch.no_grad():
145
  ### key steps
146
  llava_emb = self.forward_MLLM(batch).clone().detach()