RohitGandikota commited on
Commit
68e2466
β€’
1 Parent(s): 4d3c7dc

fixing custom slider training

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -245,7 +245,7 @@ class Demo:
245
  save_name = f"{randn}_{target_concept.replace(',','').replace(' ','').replace('.','')[:10]}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:10]}"
246
  save_name += f'_alpha-{1}'
247
  save_name += f'_noxattn'
248
- save_name += f'_rank_{rank}.pt'
249
 
250
  # if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
251
  # return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
@@ -257,7 +257,7 @@ class Demo:
257
  attributes = 'white, black, asian, hispanic, indian, male, female'
258
 
259
  self.training = True
260
- train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=rank, device=self.device, attributes=attributes, save_name=save_name)
261
  self.training = False
262
 
263
  torch.cuda.empty_cache()
@@ -293,7 +293,7 @@ class Demo:
293
  rank = 4
294
  alpha = 1
295
  if 'rank' in model_path:
296
- rank = int(model_path.split('_')[-1].replace('.pt',''))
297
  if 'alpha1' in model_path:
298
  alpha = 1.0
299
  network = LoRANetwork(
 
245
  save_name = f"{randn}_{target_concept.replace(',','').replace(' ','').replace('.','')[:10]}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:10]}"
246
  save_name += f'_alpha-{1}'
247
  save_name += f'_noxattn'
248
+ save_name += f'_rank_{int(rank)}.pt'
249
 
250
  # if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
251
  # return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
 
257
  attributes = 'white, black, asian, hispanic, indian, male, female'
258
 
259
  self.training = True
260
+ train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), device=self.device, attributes=attributes, save_name=save_name)
261
  self.training = False
262
 
263
  torch.cuda.empty_cache()
 
293
  rank = 4
294
  alpha = 1
295
  if 'rank' in model_path:
296
+ rank = int(float(model_path.split('_')[-1].replace('.pt','')))
297
  if 'alpha1' in model_path:
298
  alpha = 1.0
299
  network = LoRANetwork(