Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
@@ -57,7 +57,7 @@ class Model:
|
|
57 |
beta_schedule="scaled_linear",
|
58 |
num_train_timesteps=1000,
|
59 |
steps_offset=1
|
60 |
-
)
|
61 |
# pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
|
62 |
pipe.enable_xformers_memory_efficient_attention()
|
63 |
pipe.force_zeros_for_empty_prompt = False
|
@@ -70,34 +70,34 @@ class Model:
|
|
70 |
print(f'Loaded {model_id}...')
|
71 |
return pipe
|
72 |
|
73 |
-
def set_base_model(self, base_model_id: str) -> str:
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
|
85 |
def load_controlnet_weight(self, task_name: str) -> None:
|
86 |
print('Entered load_controlnet_weight....')
|
87 |
-
if task_name == self.task_name:
|
88 |
-
|
89 |
-
if self.pipe is not None and hasattr(self.pipe, "controlnet"):
|
90 |
-
|
91 |
-
torch.cuda.empty_cache()
|
92 |
-
gc.collect()
|
93 |
-
model_id = CONTROLNET_MODEL_IDS[task_name]
|
94 |
-
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
95 |
-
print(f'Loaded {model_id}...')
|
96 |
-
controlnet.to(self.device)
|
97 |
-
torch.cuda.empty_cache()
|
98 |
-
gc.collect()
|
99 |
-
self.pipe.controlnet = controlnet
|
100 |
-
self.task_name = task_name
|
101 |
|
102 |
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
103 |
if not prompt:
|
|
|
57 |
beta_schedule="scaled_linear",
|
58 |
num_train_timesteps=1000,
|
59 |
steps_offset=1
|
60 |
+
).to('cuda')
|
61 |
# pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
|
62 |
pipe.enable_xformers_memory_efficient_attention()
|
63 |
pipe.force_zeros_for_empty_prompt = False
|
|
|
70 |
print(f'Loaded {model_id}...')
|
71 |
return pipe
|
72 |
|
73 |
+
# def set_base_model(self, base_model_id: str) -> str:
|
74 |
+
# if not base_model_id or base_model_id == self.base_model_id:
|
75 |
+
# return self.base_model_id
|
76 |
+
# del self.pipe
|
77 |
+
# torch.cuda.empty_cache()
|
78 |
+
# gc.collect()
|
79 |
+
# try:
|
80 |
+
# self.pipe = self.load_pipe(base_model_id, self.task_name)
|
81 |
+
# except Exception:
|
82 |
+
# self.pipe = self.load_pipe(self.base_model_id, self.task_name)
|
83 |
+
# return self.base_model_id
|
84 |
|
85 |
def load_controlnet_weight(self, task_name: str) -> None:
|
86 |
print('Entered load_controlnet_weight....')
|
87 |
+
# if task_name == self.task_name:
|
88 |
+
# return
|
89 |
+
# if self.pipe is not None and hasattr(self.pipe, "controlnet"):
|
90 |
+
# del self.pipe.controlnet
|
91 |
+
# torch.cuda.empty_cache()
|
92 |
+
# gc.collect()
|
93 |
+
# model_id = CONTROLNET_MODEL_IDS[task_name]
|
94 |
+
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
95 |
+
# print(f'Loaded {model_id}...')
|
96 |
+
# controlnet.to(self.device)
|
97 |
+
# torch.cuda.empty_cache()
|
98 |
+
# gc.collect()
|
99 |
+
# self.pipe.controlnet = controlnet
|
100 |
+
# self.task_name = task_name
|
101 |
|
102 |
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
103 |
if not prompt:
|