Spaces:
Runtime error
Runtime error
RohitGandikota
commited on
Commit
β’
1e14cf1
1
Parent(s):
68e2466
adding train method dropdown
Browse files
app.py
CHANGED
@@ -230,12 +230,13 @@ class Demo:
|
|
230 |
self.iterations_input,
|
231 |
self.lr_input,
|
232 |
self.attributes_input,
|
233 |
-
self.is_person
|
|
|
234 |
],
|
235 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
236 |
)
|
237 |
|
238 |
-
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
|
239 |
iterations_input = min(int(iterations_input),1000)
|
240 |
if attributes_input == '':
|
241 |
attributes_input = None
|
@@ -257,13 +258,13 @@ class Demo:
|
|
257 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
258 |
|
259 |
self.training = True
|
260 |
-
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), device=self.device, attributes=attributes, save_name=save_name)
|
261 |
self.training = False
|
262 |
|
263 |
torch.cuda.empty_cache()
|
264 |
-
model_map['
|
265 |
|
266 |
-
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value='
|
267 |
|
268 |
|
269 |
def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
|
|
|
230 |
self.iterations_input,
|
231 |
self.lr_input,
|
232 |
self.attributes_input,
|
233 |
+
self.is_person,
|
234 |
+
self.train_method_input
|
235 |
],
|
236 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
237 |
)
|
238 |
|
239 |
+
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
|
240 |
iterations_input = min(int(iterations_input),1000)
|
241 |
if attributes_input == '':
|
242 |
attributes_input = None
|
|
|
258 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
259 |
|
260 |
self.training = True
|
261 |
+
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name)
|
262 |
self.training = False
|
263 |
|
264 |
torch.cuda.empty_cache()
|
265 |
+
model_map[save_name.replace('.pt','')] = f'models/{save_name}'
|
266 |
|
267 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
|
268 |
|
269 |
|
270 |
def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
|
trainscripts/textsliders/data/config-xl.yaml
CHANGED
@@ -7,7 +7,7 @@ network:
|
|
7 |
type: "c3lier" # or "c3lier" or "lierla"
|
8 |
rank: 4
|
9 |
alpha: 1.0
|
10 |
-
training_method: "
|
11 |
train:
|
12 |
precision: "bfloat16"
|
13 |
noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
|
|
|
7 |
type: "c3lier" # or "c3lier" or "lierla"
|
8 |
rank: 4
|
9 |
alpha: 1.0
|
10 |
+
training_method: "noxattn"
|
11 |
train:
|
12 |
precision: "bfloat16"
|
13 |
noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
|
trainscripts/textsliders/demotrain.py
CHANGED
@@ -411,7 +411,7 @@ def train(
|
|
411 |
# train(config, prompts, device)
|
412 |
|
413 |
|
414 |
-
def train_xl(target, positive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
|
415 |
|
416 |
config = config_util.load_config_from_yaml(config_file)
|
417 |
randn = torch.randint(1, 10000000, (1,)).item()
|
@@ -427,6 +427,7 @@ def train_xl(target, positive, negative, lr, iterations, config_file, rank, devi
|
|
427 |
attributes = []
|
428 |
config.network.alpha = 1.0
|
429 |
config.network.rank = int(rank)
|
|
|
430 |
|
431 |
# config.save.path += f'/{config.save.name}'
|
432 |
|
|
|
411 |
# train(config, prompts, device)
|
412 |
|
413 |
|
414 |
+
def train_xl(target, positive, negative, lr, iterations, config_file, rank, train_method, device, attributes,save_name):
|
415 |
|
416 |
config = config_util.load_config_from_yaml(config_file)
|
417 |
randn = torch.randint(1, 10000000, (1,)).item()
|
|
|
427 |
attributes = []
|
428 |
config.network.alpha = 1.0
|
429 |
config.network.rank = int(rank)
|
430 |
+
config.network.training_method = train_method
|
431 |
|
432 |
# config.save.path += f'/{config.save.name}'
|
433 |
|