ZennyKenny commited on
Commit
8f76168
1 Parent(s): 12a7530

training-script

Browse files
Files changed (1) hide show
  1. app.py +113 -17
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import DiffusionPipeline
5
  import torch
 
 
 
 
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
@@ -11,7 +14,7 @@ if torch.cuda.is_available():
11
  pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
  pipe.enable_xformers_memory_efficient_attention()
13
  pipe = pipe.to(device)
14
- else:
15
  pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
  pipe = pipe.to(device)
17
 
@@ -19,24 +22,77 @@ MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
21
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
22
-
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
25
 
26
  generator = torch.Generator().manual_seed(seed)
27
 
28
  image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
 
38
  return image
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  examples = [
41
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
  "An astronaut riding a green horse",
@@ -133,14 +189,54 @@ with gr.Blocks(css=css) as demo:
133
  )
134
 
135
  gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
 
 
 
 
 
 
 
 
 
 
 
138
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
144
  )
145
 
146
- demo.queue().launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
  import torch
5
+ from diffusers import DiffusionPipeline, StableDiffusionXLBaseModel, StableDiffusionTrainer
6
+ from transformers import CLIPTextModel, CLIPTokenizer, TrainingArguments
7
+ from datasets import load_dataset
8
+ from huggingface_hub import HfApi, HfFolder, Repository
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
14
  pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
15
  pipe.enable_xformers_memory_efficient_attention()
16
  pipe = pipe.to(device)
17
+ else:
18
  pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
19
  pipe = pipe.to(device)
20
 
 
22
  MAX_IMAGE_SIZE = 1024
23
 
24
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
25
  if randomize_seed:
26
  seed = random.randint(0, MAX_SEED)
27
 
28
  generator = torch.Generator().manual_seed(seed)
29
 
30
  image = pipe(
31
+ prompt=prompt,
32
+ negative_prompt=negative_prompt,
33
+ guidance_scale=guidance_scale,
34
+ num_inference_steps=num_inference_steps,
35
+ width=width,
36
+ height=height,
37
+ generator=generator
38
+ ).images[0]
39
 
40
  return image
41
 
42
+ def get_latest_version(repo_id):
43
+ api = HfApi()
44
+ repo_info = api.repo_info(repo_id)
45
+ versions = [tag.name for tag in repo_info.tags]
46
+ if not versions:
47
+ return "v_0.0"
48
+ latest_version = sorted(versions)[-1]
49
+ return latest_version
50
+
51
+ def increment_version(version):
52
+ major, minor = map(int, version.split('_')[1:])
53
+ minor += 1
54
+ return f"v_{major}.{minor}"
55
+
56
+ def train_model(train_data_path, output_dir, num_train_epochs, per_device_train_batch_size, learning_rate):
57
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
58
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
59
+
60
+ base_model = StableDiffusionXLBaseModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
61
+
62
+ dataset = load_dataset('imagefolder', data_dir=train_data_path)
63
+
64
+ training_args = TrainingArguments(
65
+ output_dir=output_dir,
66
+ num_train_epochs=num_train_epochs,
67
+ per_device_train_batch_size=per_device_train_batch_size,
68
+ learning_rate=learning_rate,
69
+ logging_dir="./logs",
70
+ logging_steps=10,
71
+ )
72
+
73
+ trainer = StableDiffusionTrainer(
74
+ model=base_model,
75
+ args=training_args,
76
+ train_dataset=dataset['train'],
77
+ tokenizer=tokenizer,
78
+ )
79
+
80
+ trainer.train()
81
+ base_model.save_pretrained(output_dir)
82
+
83
+ # Publish the model
84
+ repo_id = "ZennyKenny/stable-diffusion-xl-base-1.0_NatalieDiffusion"
85
+ latest_version = get_latest_version(repo_id)
86
+ new_version = increment_version(latest_version)
87
+
88
+ api = HfApi()
89
+ token = HfFolder.get_token()
90
+ repo = Repository(output_dir, clone_from=repo_id, token=token)
91
+ repo.git_tag(new_version)
92
+ repo.push_tag(new_version)
93
+
94
+ return f"Training complete. Model saved to {output_dir} and published as version {new_version}."
95
+
96
  examples = [
97
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
98
  "An astronaut riding a green horse",
 
189
  )
190
 
191
  gr.Examples(
192
+ examples=examples,
193
+ inputs=[prompt]
194
+ )
195
+
196
+ # Add new section for training the model
197
+ with gr.Accordion("Training Settings", open=False):
198
+ train_data_path = gr.Text(
199
+ label="Training Data Path",
200
+ placeholder="Enter the path to your training data",
201
+ )
202
+ output_dir = gr.Text(
203
+ label="Output Directory",
204
+ placeholder="Enter the output directory for the trained model",
205
  )
206
+ num_train_epochs = gr.Slider(
207
+ label="Number of Training Epochs",
208
+ minimum=1,
209
+ maximum=10,
210
+ step=1,
211
+ value=3,
212
+ )
213
+ per_device_train_batch_size = gr.Slider(
214
+ label="Batch Size per Device",
215
+ minimum=1,
216
+ maximum=16,
217
+ step=1,
218
+ value=4,
219
+ )
220
+ learning_rate = gr.Slider(
221
+ label="Learning Rate",
222
+ minimum=1e-5,
223
+ maximum=1e-3,
224
+ step=1e-5,
225
+ value=5e-5,
226
+ )
227
+ train_button = gr.Button("Train Model")
228
+ train_result = gr.Text(label="Training Result", show_label=False)
229
+
230
+ train_button.click(
231
+ fn=train_model,
232
+ inputs=[train_data_path, output_dir, num_train_epochs, per_device_train_batch_size, learning_rate],
233
+ outputs=[train_result],
234
+ )
235
 
236
  run_button.click(
237
+ fn=infer,
238
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
239
+ outputs=[result]
240
  )
241
 
242
+ demo.queue().launch()