awacke1 commited on
Commit
8d4588c
1 Parent(s): dd404a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -96
app.py CHANGED
@@ -9,7 +9,7 @@ import re
9
  import random
10
  import torch
11
  import time
12
- import shutil
13
  import zipfile
14
  from PIL import Image
15
  from io import BytesIO
@@ -23,6 +23,7 @@ except:
23
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
24
  TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
25
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
26
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
27
  xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
28
  device = torch.device(
@@ -31,73 +32,79 @@ device = torch.device(
31
  torch_device = device
32
  torch_dtype = torch.float16
33
 
34
- # CSS definition
35
- css = """
36
- #container{
37
- margin: 0 auto;
38
- max-width: 40rem;
39
- }
40
- #intro{
41
- max-width: 100%;
42
- text-align: center;
43
- margin: 0 auto;
44
- }
45
- """
46
-
47
  def encode_file_to_base64(file_path):
48
  with open(file_path, "rb") as file:
49
  encoded = base64.b64encode(file.read()).decode()
50
  return encoded
51
 
52
  def create_zip_of_files(files):
 
 
 
53
  zip_name = "all_files.zip"
54
  with zipfile.ZipFile(zip_name, 'w') as zipf:
55
  for file in files:
56
  zipf.write(file)
57
  return zip_name
58
 
 
59
  def get_zip_download_link(zip_file):
 
 
 
60
  with open(zip_file, 'rb') as f:
61
  data = f.read()
62
  b64 = base64.b64encode(data).decode()
63
  href = f'<a href="data:application/zip;base64,{b64}" download="{zip_file}">Download All</a>'
64
  return href
65
 
 
66
  def clear_all_images():
67
- base_dir = os.getcwd()
68
- img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]
 
 
69
  for file in img_files:
70
  os.remove(file)
71
  print('removed:' + file)
72
-
73
- def save_all_images(images):
74
- if len(images) == 0:
75
- return None, None
76
-
77
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
78
- zip_filename = f"images_and_history_{timestamp}.zip"
79
-
80
  with zipfile.ZipFile(zip_filename, 'w') as zipf:
81
- # Add image files
82
- for file in images:
83
  zipf.write(file, os.path.basename(file))
84
-
85
- # Add prompt history file
86
- if os.path.exists("prompt_history.txt"):
87
- zipf.write("prompt_history.txt")
88
-
89
- # Generate download link
90
- zip_base64 = encode_file_to_base64(zip_filename)
91
- download_link = f'<a href="data:application/zip;base64,{zip_base64}" download="{zip_filename}">Download All (Images & History)</a>'
 
 
92
 
 
 
 
 
 
 
93
  return zip_filename, download_link
94
-
95
  def save_all_button_click():
96
  images = [file for file in os.listdir() if file.lower().endswith((".png", ".jpg", ".jpeg"))]
97
  zip_filename, download_link = save_all_images(images)
98
  if download_link:
99
- return gr.HTML(download_link)
100
 
 
 
101
  def clear_all_button_click():
102
  clear_all_images()
103
 
@@ -120,6 +127,7 @@ pipe.to(device=torch_device, dtype=torch_dtype).to(device)
120
  pipe.unet.to(memory_format=torch.channels_last)
121
  pipe.set_progress_bar_config(disable=True)
122
 
 
123
  if psutil.virtual_memory().total < 64 * 1024**3:
124
  pipe.enable_attention_slicing()
125
 
@@ -128,29 +136,28 @@ if TORCH_COMPILE:
128
  pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
129
  pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
130
 
 
131
  pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
132
  pipe.fuse_lora()
133
 
134
  def safe_filename(text):
 
135
  safe_text = re.sub(r'\W+', '_', text)
136
  timestamp = datetime.datetime.now().strftime("%Y%m%d")
137
  return f"{safe_text}_{timestamp}.png"
138
-
139
  def encode_image(image):
 
140
  buffered = BytesIO()
 
141
  return base64.b64encode(buffered.getvalue()).decode()
142
 
143
  def fake_gan():
144
- base_dir = os.getcwd()
145
- img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]
146
  images = [(random.choice(img_files), os.path.splitext(file)[0]) for file in img_files]
147
  return images
148
-
149
- def save_prompt_to_history(prompt):
150
- with open("prompt_history.txt", "a") as f:
151
- timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
152
- f.write(f"{timestamp}: {prompt}\n")
153
-
154
  def predict(prompt, guidance, steps, seed=1231231):
155
  generator = torch.manual_seed(seed)
156
  last_time = time.time()
@@ -161,13 +168,10 @@ def predict(prompt, guidance, steps, seed=1231231):
161
  guidance_scale=guidance,
162
  width=512,
163
  height=512,
 
164
  output_type="pil",
165
  )
166
  print(f"Pipe took {time.time() - last_time} seconds")
167
-
168
- # Save prompt to history
169
- save_prompt_to_history(prompt)
170
-
171
  nsfw_content_detected = (
172
  results.nsfw_content_detected[0]
173
  if "nsfw_content_detected" in results
@@ -183,24 +187,35 @@ def predict(prompt, guidance, steps, seed=1231231):
183
  safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
184
  filename = f"{safe_date_time}_{safe_prompt}.png"
185
 
 
186
  if len(results.images) > 0:
187
- image_path = os.path.join("", filename)
188
  results.images[0].save(image_path)
189
  print(f"#Image saved as {image_path}")
190
  gr.File(image_path)
191
  gr.Button(link=image_path)
 
 
 
192
  except:
193
  return results.images[0]
194
 
195
  return results.images[0] if len(results.images) > 0 else None
196
 
197
- def read_prompt_history():
198
- if os.path.exists("prompt_history.txt"):
199
- with open("prompt_history.txt", "r") as f:
200
- return f.read()
201
- return "No prompts yet."
202
 
 
 
 
 
 
 
 
 
 
 
 
203
  with gr.Blocks(css=css) as demo:
 
204
  with gr.Column(elem_id="container"):
205
  gr.Markdown(
206
  """4📝RT🖼️Images - 🕹️ Real Time 🎨 Image Generator Gallery 🌐""",
@@ -215,8 +230,10 @@ with gr.Blocks(css=css) as demo:
215
 
216
  gr.Button("Download", link="/file=all_files.zip")
217
 
 
218
  image = gr.Image(type="filepath")
219
 
 
220
  with gr.Row(variant="compact"):
221
  text = gr.Textbox(
222
  label="Image Sets",
@@ -231,9 +248,12 @@ with gr.Blocks(css=css) as demo:
231
  )
232
 
233
  with gr.Row(variant="compact"):
 
234
  save_all_button = gr.Button("💾 Save All", scale=1)
 
235
  clear_all_button = gr.Button("🗑️ Clear All", scale=1)
236
 
 
237
  with gr.Accordion("Advanced options", open=False):
238
  guidance = gr.Slider(
239
  label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
@@ -243,54 +263,46 @@ with gr.Blocks(css=css) as demo:
243
  randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
244
  )
245
 
246
- with gr.Accordion("Prompt History", open=False):
247
- prompt_history = gr.Textbox(label="Prompt History", lines=10, max_lines=20, interactive=False)
248
-
249
  with gr.Accordion("Run with diffusers"):
250
  gr.Markdown(
251
  """## Running LCM-LoRAs it with `diffusers`
252
- ```bash
253
- pip install diffusers==0.23.0
254
- ```
255
-
256
- ```py
257
- from diffusers import DiffusionPipeline, LCMScheduler
258
- pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7").to("cuda")
259
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
260
- pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") #yes, it's a normal LoRA
261
- results = pipe(
262
- prompt="ImageEditor",
263
- num_inference_steps=4,
264
- guidance_scale=0.0,
265
- )
266
- results.images[0]
267
- ```
268
- """
269
  )
270
 
271
- with gr.Column():
272
- file_obj = gr.File(label="Input File")
273
- input = file_obj
274
-
275
- inputs = [prompt, guidance, steps, seed]
276
- generate_bt.click(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
277
- btn.click(fake_gan, None, gallery)
278
- prompt.submit(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
279
- guidance.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
280
- steps.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
281
- seed.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
282
-
283
- def update_prompt_history():
284
- return read_prompt_history()
285
 
286
- generate_bt.click(fn=update_prompt_history, outputs=prompt_history)
287
- prompt.submit(fn=update_prompt_history, outputs=prompt_history)
288
 
289
- save_all_button.click(
290
- fn=lambda: save_all_images([f for f in os.listdir() if f.lower().endswith((".png", ".jpg", ".jpeg"))]),
291
- outputs=[gr.File(), gr.HTML()]
292
- )
293
- clear_all_button.click(clear_all_button_click)
294
 
295
  demo.queue()
296
- demo.launch(allowed_paths=["/"])
 
9
  import random
10
  import torch
11
  import time
12
+ import shutil # Added for zip functionality
13
  import zipfile
14
  from PIL import Image
15
  from io import BytesIO
 
23
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
24
  TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
25
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
26
+ # check if MPS is available OSX only M1/M2/M3 chips
27
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
28
  xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
29
  device = torch.device(
 
32
  torch_device = device
33
  torch_dtype = torch.float16
34
 
35
+ # Function to encode a file to base64
 
 
 
 
 
 
 
 
 
 
 
 
36
  def encode_file_to_base64(file_path):
37
  with open(file_path, "rb") as file:
38
  encoded = base64.b64encode(file.read()).decode()
39
  return encoded
40
 
41
  def create_zip_of_files(files):
42
+ """
43
+ Create a zip file from a list of files.
44
+ """
45
  zip_name = "all_files.zip"
46
  with zipfile.ZipFile(zip_name, 'w') as zipf:
47
  for file in files:
48
  zipf.write(file)
49
  return zip_name
50
 
51
+
52
  def get_zip_download_link(zip_file):
53
+ """
54
+ Generate a link to download the zip file.
55
+ """
56
  with open(zip_file, 'rb') as f:
57
  data = f.read()
58
  b64 = base64.b64encode(data).decode()
59
  href = f'<a href="data:application/zip;base64,{b64}" download="{zip_file}">Download All</a>'
60
  return href
61
 
62
+ # Function to clear all image files
63
  def clear_all_images():
64
+ base_dir = os.getcwd() # Get the current base directory
65
+ img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))] # List all files ending with ".jpg" or ".jpeg"
66
+
67
+ # Remove all image files
68
  for file in img_files:
69
  os.remove(file)
70
  print('removed:' + file)
71
+
72
+ # add file save and download and clear:
73
+ # Function to create a zip file from a list of files
74
+ def create_zip(files):
 
75
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
76
+ zip_filename = f"images_{timestamp}.zip"
77
+ print('Creating file ' + zip_filename)
78
  with zipfile.ZipFile(zip_filename, 'w') as zipf:
79
+ for file in files:
 
80
  zipf.write(file, os.path.basename(file))
81
+ print('added:' + file)
82
+ return zip_filename
83
+
84
+ def get_zip_download_link(zip_file):
85
+ """
86
+ Generate a link to download the zip file.
87
+ """
88
+ zip_base64 = encode_file_to_base64(zip_file) # Encode the zip file to base64
89
+ href = f'<a href="data:application/zip;base64,{zip_base64}" download="{zip_file}">Download All</a>'
90
+ return href
91
 
92
+ def save_all_images(images):
93
+ if len(images) == 0:
94
+ return None, None
95
+ zip_filename = create_zip_of_files(images) # Create a zip file from the list of image files
96
+ print(f"Zip file created: {zip_filename}")
97
+ download_link = get_zip_download_link(zip_filename)
98
  return zip_filename, download_link
99
+
100
  def save_all_button_click():
101
  images = [file for file in os.listdir() if file.lower().endswith((".png", ".jpg", ".jpeg"))]
102
  zip_filename, download_link = save_all_images(images)
103
  if download_link:
104
+ gr.HTML(download_link)
105
 
106
+
107
+ # Function to handle "Clear All" button click
108
  def clear_all_button_click():
109
  clear_all_images()
110
 
 
127
  pipe.unet.to(memory_format=torch.channels_last)
128
  pipe.set_progress_bar_config(disable=True)
129
 
130
+ # check if computer has less than 64GB of RAM using sys or os
131
  if psutil.virtual_memory().total < 64 * 1024**3:
132
  pipe.enable_attention_slicing()
133
 
 
136
  pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
137
  pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
138
 
139
+ # Load LCM LoRA
140
  pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
141
  pipe.fuse_lora()
142
 
143
  def safe_filename(text):
144
+ """Generate a safe filename from a string."""
145
  safe_text = re.sub(r'\W+', '_', text)
146
  timestamp = datetime.datetime.now().strftime("%Y%m%d")
147
  return f"{safe_text}_{timestamp}.png"
148
+
149
  def encode_image(image):
150
+ """Encode image to base64."""
151
  buffered = BytesIO()
152
+ #image.save(buffered, format="PNG")
153
  return base64.b64encode(buffered.getvalue()).decode()
154
 
155
  def fake_gan():
156
+ base_dir = os.getcwd() # Get the current base directory
157
+ img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))] # List all files ending with ".jpg" or ".jpeg"
158
  images = [(random.choice(img_files), os.path.splitext(file)[0]) for file in img_files]
159
  return images
160
+
 
 
 
 
 
161
  def predict(prompt, guidance, steps, seed=1231231):
162
  generator = torch.manual_seed(seed)
163
  last_time = time.time()
 
168
  guidance_scale=guidance,
169
  width=512,
170
  height=512,
171
+ # original_inference_steps=params.lcm_steps,
172
  output_type="pil",
173
  )
174
  print(f"Pipe took {time.time() - last_time} seconds")
 
 
 
 
175
  nsfw_content_detected = (
176
  results.nsfw_content_detected[0]
177
  if "nsfw_content_detected" in results
 
187
  safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
188
  filename = f"{safe_date_time}_{safe_prompt}.png"
189
 
190
+ # Save the image
191
  if len(results.images) > 0:
192
+ image_path = os.path.join("", filename) # Specify your directory
193
  results.images[0].save(image_path)
194
  print(f"#Image saved as {image_path}")
195
  gr.File(image_path)
196
  gr.Button(link=image_path)
197
+ # encoded_image = encode_image(image)
198
+ # html_link = f'<a href="data:image/png;base64,{encoded_image}" download="{filename}">Download Image</a>'
199
+ # gr.HTML(html_link)
200
  except:
201
  return results.images[0]
202
 
203
  return results.images[0] if len(results.images) > 0 else None
204
 
 
 
 
 
 
205
 
206
+ css = """
207
+ #container{
208
+ margin: 0 auto;
209
+ max-width: 40rem;
210
+ }
211
+ #intro{
212
+ max-width: 100%;
213
+ text-align: center;
214
+ margin: 0 auto;
215
+ }
216
+ """
217
  with gr.Blocks(css=css) as demo:
218
+
219
  with gr.Column(elem_id="container"):
220
  gr.Markdown(
221
  """4📝RT🖼️Images - 🕹️ Real Time 🎨 Image Generator Gallery 🌐""",
 
230
 
231
  gr.Button("Download", link="/file=all_files.zip")
232
 
233
+ # Image Result from last prompt
234
  image = gr.Image(type="filepath")
235
 
236
+ # Gallery of Generated Images with Image Names in Random Set to Download
237
  with gr.Row(variant="compact"):
238
  text = gr.Textbox(
239
  label="Image Sets",
 
248
  )
249
 
250
  with gr.Row(variant="compact"):
251
+ # Add "Save All" button with emoji
252
  save_all_button = gr.Button("💾 Save All", scale=1)
253
+ # Add "Clear All" button with emoji
254
  clear_all_button = gr.Button("🗑️ Clear All", scale=1)
255
 
256
+ # Advanced Generate Options
257
  with gr.Accordion("Advanced options", open=False):
258
  guidance = gr.Slider(
259
  label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
 
263
  randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
264
  )
265
 
266
+ # Diffusers
 
 
267
  with gr.Accordion("Run with diffusers"):
268
  gr.Markdown(
269
  """## Running LCM-LoRAs it with `diffusers`
270
+ ```bash
271
+ pip install diffusers==0.23.0
272
+ ```
273
+
274
+ ```py
275
+ from diffusers import DiffusionPipeline, LCMScheduler
276
+ pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7").to("cuda")
277
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
278
+ pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") #yes, it's a normal LoRA
279
+ results = pipe(
280
+ prompt="ImageEditor",
281
+ num_inference_steps=4,
282
+ guidance_scale=0.0,
283
+ )
284
+ results.images[0]
285
+ ```
286
+ """
287
  )
288
 
289
+ # Function IO Eventing and Controls
290
+ inputs = [prompt, guidance, steps, seed]
291
+ generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
292
+ btn.click(fake_gan, None, gallery)
293
+ prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
294
+ guidance.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
295
+ steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
296
+ seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
 
 
 
 
 
 
297
 
298
+ # Attach click event handlers to the buttons
299
+ save_all_button.click(save_all_button_click)
300
 
301
+ with gr.Column():
302
+ file_obj = gr.File(label="Input File")
303
+ input= file_obj
304
+
305
+ clear_all_button.click(clear_all_button_click)
306
 
307
  demo.queue()
308
+ demo.launch(allowed_paths=["/"])