dvir-bria commited on
Commit
c5ce443
1 Parent(s): d4b0ed5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +26 -26
model.py CHANGED
@@ -57,7 +57,7 @@ class Model:
57
  beta_schedule="scaled_linear",
58
  num_train_timesteps=1000,
59
  steps_offset=1
60
- )
61
  # pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
62
  pipe.enable_xformers_memory_efficient_attention()
63
  pipe.force_zeros_for_empty_prompt = False
@@ -70,34 +70,34 @@ class Model:
70
  print(f'Loaded {model_id}...')
71
  return pipe
72
 
73
- def set_base_model(self, base_model_id: str) -> str:
74
- if not base_model_id or base_model_id == self.base_model_id:
75
- return self.base_model_id
76
- del self.pipe
77
- torch.cuda.empty_cache()
78
- gc.collect()
79
- try:
80
- self.pipe = self.load_pipe(base_model_id, self.task_name)
81
- except Exception:
82
- self.pipe = self.load_pipe(self.base_model_id, self.task_name)
83
- return self.base_model_id
84
 
85
  def load_controlnet_weight(self, task_name: str) -> None:
86
  print('Entered load_controlnet_weight....')
87
- if task_name == self.task_name:
88
- return
89
- if self.pipe is not None and hasattr(self.pipe, "controlnet"):
90
- del self.pipe.controlnet
91
- torch.cuda.empty_cache()
92
- gc.collect()
93
- model_id = CONTROLNET_MODEL_IDS[task_name]
94
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
95
- print(f'Loaded {model_id}...')
96
- controlnet.to(self.device)
97
- torch.cuda.empty_cache()
98
- gc.collect()
99
- self.pipe.controlnet = controlnet
100
- self.task_name = task_name
101
 
102
  def get_prompt(self, prompt: str, additional_prompt: str) -> str:
103
  if not prompt:
 
57
  beta_schedule="scaled_linear",
58
  num_train_timesteps=1000,
59
  steps_offset=1
60
+ ).to('cuda')
61
  # pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
62
  pipe.enable_xformers_memory_efficient_attention()
63
  pipe.force_zeros_for_empty_prompt = False
 
70
  print(f'Loaded {model_id}...')
71
  return pipe
72
 
73
+ # def set_base_model(self, base_model_id: str) -> str:
74
+ # if not base_model_id or base_model_id == self.base_model_id:
75
+ # return self.base_model_id
76
+ # del self.pipe
77
+ # torch.cuda.empty_cache()
78
+ # gc.collect()
79
+ # try:
80
+ # self.pipe = self.load_pipe(base_model_id, self.task_name)
81
+ # except Exception:
82
+ # self.pipe = self.load_pipe(self.base_model_id, self.task_name)
83
+ # return self.base_model_id
84
 
85
  def load_controlnet_weight(self, task_name: str) -> None:
86
  print('Entered load_controlnet_weight....')
87
+ # if task_name == self.task_name:
88
+ # return
89
+ # if self.pipe is not None and hasattr(self.pipe, "controlnet"):
90
+ # del self.pipe.controlnet
91
+ # torch.cuda.empty_cache()
92
+ # gc.collect()
93
+ # model_id = CONTROLNET_MODEL_IDS[task_name]
94
+ # controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
95
+ # print(f'Loaded {model_id}...')
96
+ # controlnet.to(self.device)
97
+ # torch.cuda.empty_cache()
98
+ # gc.collect()
99
+ # self.pipe.controlnet = controlnet
100
+ # self.task_name = task_name
101
 
102
  def get_prompt(self, prompt: str, additional_prompt: str) -> str:
103
  if not prompt: