prithivMLmods commited on
Commit
9a0d412
1 Parent(s): 29ba56c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -23
app.py CHANGED
@@ -29,7 +29,7 @@ MODEL_ID = os.getenv("MODEL_VAL_PATH", "SG161222/RealVisXL_V4.0_Lightning")
29
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
30
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
31
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
32
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
33
 
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -59,24 +59,37 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
59
  seed = random.randint(0, MAX_SEED)
60
  return seed
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  @spaces.GPU(duration=60, enable_queue=True)
63
  def generate(
64
  prompt: str,
65
  negative_prompt: str = "",
66
  use_negative_prompt: bool = False,
67
  seed: int = 1,
68
- width: int = 1024,
69
- height: int = 1024,
70
  guidance_scale: float = 3,
71
  num_inference_steps: int = 25,
72
  randomize_seed: bool = False,
73
  use_resolution_binning: bool = True,
74
  num_images: int = 1, # Number of images to generate
 
75
  progress=gr.Progress(track_tqdm=True),
76
  ):
77
  seed = int(randomize_seed_fn(seed, randomize_seed))
78
  generator = torch.Generator(device=device).manual_seed(seed)
79
 
 
 
80
  options = {
81
  "prompt": [prompt] * num_images,
82
  "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
@@ -92,7 +105,7 @@ def generate(
92
  options["use_resolution_binning"] = True
93
 
94
  images = []
95
- for i in range(0, num_images, BATCH_SIZE):
96
  batch_options = options.copy()
97
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
98
  if "negative_prompt" in batch_options:
@@ -103,7 +116,6 @@ def generate(
103
  return image_paths, seed
104
 
105
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
106
-
107
  gr.Markdown(DESCRIPTIONx)
108
  with gr.Row():
109
  prompt = gr.Text(
@@ -115,6 +127,13 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
115
  )
116
  run_button = gr.Button("Run ⚡", scale=0)
117
  result = gr.Gallery(label="Result", columns=1, show_label=False)
 
 
 
 
 
 
 
118
 
119
  with gr.Accordion("Advanced options", open=False, visible=True):
120
  num_images = gr.Slider(
@@ -142,21 +161,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
142
  value=0,
143
  )
144
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
145
- with gr.Row(visible=True):
146
- width = gr.Slider(
147
- label="Width",
148
- minimum=512,
149
- maximum=MAX_IMAGE_SIZE,
150
- step=64,
151
- value=1024,
152
- )
153
- height = gr.Slider(
154
- label="Height",
155
- minimum=512,
156
- maximum=MAX_IMAGE_SIZE,
157
- step=64,
158
- value=1024,
159
- )
160
  with gr.Row():
161
  guidance_scale = gr.Slider(
162
  label="Guidance Scale",
@@ -197,12 +202,11 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
197
  negative_prompt,
198
  use_negative_prompt,
199
  seed,
200
- width,
201
- height,
202
  guidance_scale,
203
  num_inference_steps,
204
  randomize_seed,
205
- num_images
 
206
  ],
207
  outputs=[result, seed],
208
  api_name="run",
 
29
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
30
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
31
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
32
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
33
 
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
59
  seed = random.randint(0, MAX_SEED)
60
  return seed
61
 
62
+ def set_wallpaper_size(size):
63
+ if size == "Mobile (1080x1920)":
64
+ return 1080, 1920
65
+ elif size == "Desktop (1920x1080)":
66
+ return 1920, 1080
67
+ elif size == "Extended (1920x512)":
68
+ return 1920, 512
69
+ elif size == "Headers (1080x512)":
70
+ return 1080, 512
71
+ else:
72
+ return 1024, 1024 # Default return if none of the conditions are met
73
+
74
  @spaces.GPU(duration=60, enable_queue=True)
75
  def generate(
76
  prompt: str,
77
  negative_prompt: str = "",
78
  use_negative_prompt: bool = False,
79
  seed: int = 1,
 
 
80
  guidance_scale: float = 3,
81
  num_inference_steps: int = 25,
82
  randomize_seed: bool = False,
83
  use_resolution_binning: bool = True,
84
  num_images: int = 1, # Number of images to generate
85
+ wallpaper_size: str = "Default (1024x1024)",
86
  progress=gr.Progress(track_tqdm=True),
87
  ):
88
  seed = int(randomize_seed_fn(seed, randomize_seed))
89
  generator = torch.Generator(device=device).manual_seed(seed)
90
 
91
+ width, height = set_wallpaper_size(wallpaper_size)
92
+
93
  options = {
94
  "prompt": [prompt] * num_images,
95
  "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
 
105
  options["use_resolution_binning"] = True
106
 
107
  images = []
108
+ for i in range 0, num_images, BATCH_SIZE):
109
  batch_options = options.copy()
110
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
111
  if "negative_prompt" in batch_options:
 
116
  return image_paths, seed
117
 
118
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
 
119
  gr.Markdown(DESCRIPTIONx)
120
  with gr.Row():
121
  prompt = gr.Text(
 
127
  )
128
  run_button = gr.Button("Run ⚡", scale=0)
129
  result = gr.Gallery(label="Result", columns=1, show_label=False)
130
+
131
+ with gr.Row(visible=True):
132
+ wallpaper_size = gr.Radio(
133
+ choices=["Mobile (1080x1920)", "Desktop (1920x1080)", "Extended (1920x512)", "Headers (1080x512)", "Default (1024x1024)"],
134
+ label="Pixel Size(x*y)",
135
+ value="Default (1024x1024)"
136
+ )
137
 
138
  with gr.Accordion("Advanced options", open=False, visible=True):
139
  num_images = gr.Slider(
 
161
  value=0,
162
  )
163
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
164
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  with gr.Row():
166
  guidance_scale = gr.Slider(
167
  label="Guidance Scale",
 
202
  negative_prompt,
203
  use_negative_prompt,
204
  seed,
 
 
205
  guidance_scale,
206
  num_inference_steps,
207
  randomize_seed,
208
+ num_images,
209
+ wallpaper_size,
210
  ],
211
  outputs=[result, seed],
212
  api_name="run",