RohitGandikota commited on
Commit
1e14cf1
β€’
1 Parent(s): 68e2466

adding train method dropdown

Browse files
app.py CHANGED
@@ -230,12 +230,13 @@ class Demo:
230
  self.iterations_input,
231
  self.lr_input,
232
  self.attributes_input,
233
- self.is_person
 
234
  ],
235
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
236
  )
237
 
238
- def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
239
  iterations_input = min(int(iterations_input),1000)
240
  if attributes_input == '':
241
  attributes_input = None
@@ -257,13 +258,13 @@ class Demo:
257
  attributes = 'white, black, asian, hispanic, indian, male, female'
258
 
259
  self.training = True
260
- train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), device=self.device, attributes=attributes, save_name=save_name)
261
  self.training = False
262
 
263
  torch.cuda.empty_cache()
264
- model_map['Custom Slider'] = f'models/{save_name}'
265
 
266
- return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value='Custom Slider')]
267
 
268
 
269
  def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
 
230
  self.iterations_input,
231
  self.lr_input,
232
  self.attributes_input,
233
+ self.is_person,
234
+ self.train_method_input
235
  ],
236
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
237
  )
238
 
239
+ def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
240
  iterations_input = min(int(iterations_input),1000)
241
  if attributes_input == '':
242
  attributes_input = None
 
258
  attributes = 'white, black, asian, hispanic, indian, male, female'
259
 
260
  self.training = True
261
+ train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name)
262
  self.training = False
263
 
264
  torch.cuda.empty_cache()
265
+ model_map[save_name.replace('.pt','')] = f'models/{save_name}'
266
 
267
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
268
 
269
 
270
  def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
trainscripts/textsliders/data/config-xl.yaml CHANGED
@@ -7,7 +7,7 @@ network:
7
  type: "c3lier" # or "c3lier" or "lierla"
8
  rank: 4
9
  alpha: 1.0
10
- training_method: "xattn"
11
  train:
12
  precision: "bfloat16"
13
  noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
 
7
  type: "c3lier" # or "c3lier" or "lierla"
8
  rank: 4
9
  alpha: 1.0
10
+ training_method: "noxattn"
11
  train:
12
  precision: "bfloat16"
13
  noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
trainscripts/textsliders/demotrain.py CHANGED
@@ -411,7 +411,7 @@ def train(
411
  # train(config, prompts, device)
412
 
413
 
414
- def train_xl(target, positive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
415
 
416
  config = config_util.load_config_from_yaml(config_file)
417
  randn = torch.randint(1, 10000000, (1,)).item()
@@ -427,6 +427,7 @@ def train_xl(target, positive, negative, lr, iterations, config_file, rank, devi
427
  attributes = []
428
  config.network.alpha = 1.0
429
  config.network.rank = int(rank)
 
430
 
431
  # config.save.path += f'/{config.save.name}'
432
 
 
411
  # train(config, prompts, device)
412
 
413
 
414
+ def train_xl(target, positive, negative, lr, iterations, config_file, rank, train_method, device, attributes,save_name):
415
 
416
  config = config_util.load_config_from_yaml(config_file)
417
  randn = torch.randint(1, 10000000, (1,)).item()
 
427
  attributes = []
428
  config.network.alpha = 1.0
429
  config.network.rank = int(rank)
430
+ config.network.training_method = train_method
431
 
432
  # config.save.path += f'/{config.save.name}'
433