Spaces:
Runtime error
Runtime error
RohitGandikota
commited on
Commit
β’
68e2466
1
Parent(s):
4d3c7dc
fixing custom slider training
Browse files
app.py
CHANGED
@@ -245,7 +245,7 @@ class Demo:
|
|
245 |
save_name = f"{randn}_{target_concept.replace(',','').replace(' ','').replace('.','')[:10]}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:10]}"
|
246 |
save_name += f'_alpha-{1}'
|
247 |
save_name += f'_noxattn'
|
248 |
-
save_name += f'_rank_{rank}.pt'
|
249 |
|
250 |
# if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
|
251 |
# return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
|
@@ -257,7 +257,7 @@ class Demo:
|
|
257 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
258 |
|
259 |
self.training = True
|
260 |
-
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=rank, device=self.device, attributes=attributes, save_name=save_name)
|
261 |
self.training = False
|
262 |
|
263 |
torch.cuda.empty_cache()
|
@@ -293,7 +293,7 @@ class Demo:
|
|
293 |
rank = 4
|
294 |
alpha = 1
|
295 |
if 'rank' in model_path:
|
296 |
-
rank = int(model_path.split('_')[-1].replace('.pt',''))
|
297 |
if 'alpha1' in model_path:
|
298 |
alpha = 1.0
|
299 |
network = LoRANetwork(
|
|
|
245 |
save_name = f"{randn}_{target_concept.replace(',','').replace(' ','').replace('.','')[:10]}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:10]}"
|
246 |
save_name += f'_alpha-{1}'
|
247 |
save_name += f'_noxattn'
|
248 |
+
save_name += f'_rank_{int(rank)}.pt'
|
249 |
|
250 |
# if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
|
251 |
# return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
|
|
|
257 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
258 |
|
259 |
self.training = True
|
260 |
+
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), device=self.device, attributes=attributes, save_name=save_name)
|
261 |
self.training = False
|
262 |
|
263 |
torch.cuda.empty_cache()
|
|
|
293 |
rank = 4
|
294 |
alpha = 1
|
295 |
if 'rank' in model_path:
|
296 |
+
rank = int(float(model_path.split('_')[-1].replace('.pt','')))
|
297 |
if 'alpha1' in model_path:
|
298 |
alpha = 1.0
|
299 |
network = LoRANetwork(
|