saicharan1234 commited on
Commit
96b5a68
1 Parent(s): 84ec534

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +67 -92
main.py CHANGED
@@ -9,6 +9,7 @@ import numpy as np
9
  import random
10
  from PIL import Image
11
  import io
 
12
 
13
  app = FastAPI()
14
 
@@ -16,91 +17,71 @@ MAX_SEED = np.iinfo(np.int32).max
16
 
17
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
 
19
- # Set directory for model storage
20
- MODEL_DIR = "/data"
21
-
22
- # Ensure model directory exists
23
- os.makedirs(MODEL_DIR, exist_ok=True)
24
-
25
- # Download models to local directory
26
- HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token
27
- def download_model(repo_id, filename=None, model_dir=MODEL_DIR, token=HF_TOKEN):
28
- if filename:
29
- return hf_hub_download(repo_id=repo_id, filename=filename, local_dir=model_dir, token=token)
30
- return hf_hub_download(repo_id=repo_id, local_dir=model_dir, token=token)
31
-
32
- # Paths for models
33
- paths = {
34
- "Fluently XL Final": download_model("fluently/Fluently-XL-Final", "FluentlyXL-Final.safetensors"),
35
- "Fluently Anime": download_model("fluently/Fluently-anime"),
36
- "Fluently Epic": download_model("fluently/Fluently-epic"),
37
- "Fluently XL v4": download_model("fluently/Fluently-XL-v4"),
38
- "Fluently XL v3 Lightning": download_model("fluently/Fluently-XL-v3-lightning"),
39
- "Fluently v4 inpaint": download_model("fluently/Fluently-v4-inpainting"),
40
- "Fluently XL v3 inpaint": download_model("fluently/Fluently-XL-v3-inpainting", "FluentlyXL-v3-inpainting.safetensors"),
41
- }
42
-
43
- # Function to load model dynamically
44
- def load_model(model_name):
45
- if model_name == "Fluently XL Final":
46
- model = StableDiffusionXLPipeline.from_single_file(
47
- paths[model_name],
48
- torch_dtype=torch.float16,
49
- use_safetensors=True,
50
- )
51
- model.scheduler = EulerAncestralDiscreteScheduler.from_config(model.scheduler.config)
52
- elif model_name == "Fluently Anime":
53
- model = StableDiffusionPipeline.from_pretrained(
54
- paths[model_name],
55
- torch_dtype=torch.float16,
56
- use_safetensors=True,
57
- )
58
- model.scheduler = EulerAncestralDiscreteScheduler.from_config(model.scheduler.config)
59
- elif model_name == "Fluently Epic":
60
- model = StableDiffusionPipeline.from_pretrained(
61
- paths[model_name],
62
- torch_dtype=torch.float16,
63
- use_safetensors=True,
64
- )
65
- model.scheduler = EulerAncestralDiscreteScheduler.from_config(model.scheduler.config)
66
- elif model_name == "Fluently XL v4":
67
- model = StableDiffusionXLPipeline.from_pretrained(
68
- paths[model_name],
69
- torch_dtype=torch.float16,
70
- use_safetensors=True,
71
- )
72
- model.scheduler = EulerAncestralDiscreteScheduler.from_config(model.scheduler.config)
73
- elif model_name == "Fluently XL v3 Lightning":
74
- model = StableDiffusionXLPipeline.from_pretrained(
75
- paths[model_name],
76
- torch_dtype=torch.float16,
77
- use_safetensors=True,
78
- )
79
- model.scheduler = DPMSolverSinglestepScheduler.from_config(model.scheduler.config, use_karras_sigmas=False, timestep_spacing="trailing", lower_order_final=True)
80
- elif model_name in ["Fluently v4 inpaint", "Fluently XL v3 inpaint"]:
81
- if model_name == "Fluently v4 inpaint":
82
- model = StableDiffusionInpaintPipeline.from_pretrained(
83
- paths[model_name],
84
- torch_dtype=torch.float16,
85
- use_safetensors=True,
86
- )
87
- else:
88
- model = StableDiffusionXLInpaintPipeline.from_single_file(
89
- paths[model_name],
90
- torch_dtype=torch.float16,
91
- use_safetensors=True,
92
- )
93
- else:
94
- raise ValueError(f"Model {model_name} not found")
95
-
96
- model.to(device)
97
- return model
98
 
99
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
100
  if randomize_seed:
101
  seed = random.randint(0, MAX_SEED)
102
  return seed
103
 
 
104
  @app.post("/generate")
105
  async def generate(
106
  model: str = Form(...),
@@ -125,10 +106,8 @@ async def generate(
125
  inpaint_image_pil = Image.open(io.BytesIO(await inpaint_image.read())) if inpaint_image else None
126
  mask_image_pil = Image.open(io.BytesIO(await mask_image.read())) if mask_image else None
127
 
128
- model_pipeline = load_model(model)
129
-
130
  if model == "Fluently XL Final":
131
- images = model_pipeline(
132
  prompt=prompt,
133
  negative_prompt=negative_prompt,
134
  width=width,
@@ -139,7 +118,7 @@ async def generate(
139
  output_type="pil",
140
  ).images
141
  elif model == "Fluently Anime":
142
- images = model_pipeline(
143
  prompt=prompt,
144
  negative_prompt=negative_prompt,
145
  width=width,
@@ -150,7 +129,7 @@ async def generate(
150
  output_type="pil",
151
  ).images
152
  elif model == "Fluently Epic":
153
- images = model_pipeline(
154
  prompt=prompt,
155
  negative_prompt=negative_prompt,
156
  width=width,
@@ -161,7 +140,7 @@ async def generate(
161
  output_type="pil",
162
  ).images
163
  elif model == "Fluently XL v4":
164
- images = model_pipeline(
165
  prompt=prompt,
166
  negative_prompt=negative_prompt,
167
  width=width,
@@ -172,7 +151,7 @@ async def generate(
172
  output_type="pil",
173
  ).images
174
  elif model == "Fluently XL v3 Lightning":
175
- images = model_pipeline(
176
  prompt=prompt,
177
  negative_prompt=negative_prompt,
178
  width=width,
@@ -183,8 +162,8 @@ async def generate(
183
  output_type="pil",
184
  ).images
185
  elif model == "Fluently v4 inpaint" or model == "Fluently XL v3 inpaint":
186
- blurred_mask = model_pipeline.mask_processor.blur(mask_image_pil, blur_factor=blur_factor)
187
- images = model_pipeline(
188
  prompt=prompt,
189
  image=inpaint_image_pil,
190
  mask_image=blurred_mask,
@@ -198,10 +177,6 @@ async def generate(
198
  output_type="pil",
199
  ).images
200
 
201
- # Unload the model from the device
202
- model_pipeline.to("cpu")
203
- torch.cuda.empty_cache()
204
-
205
  img = images[0]
206
  img_byte_arr = io.BytesIO()
207
  img.save(img_byte_arr, format='PNG')
 
9
  import random
10
  from PIL import Image
11
  import io
12
+ import os
13
 
14
  app = FastAPI()
15
 
 
17
 
18
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
 
20
+ # Load HF token from environment variable
21
+ HF_TOKEN = os.getenv("HF_TOKEN")
22
+
23
+ # Load pipelines
24
+ pipe_xl_final = StableDiffusionXLPipeline.from_single_file(
25
+ hf_hub_download(repo_id="fluently/Fluently-XL-Final", filename="FluentlyXL-Final.safetensors", token=HF_TOKEN),
26
+ torch_dtype=torch.float16,
27
+ use_safetensors=True,
28
+ )
29
+ pipe_xl_final.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_xl_final.scheduler.config)
30
+ pipe_xl_final.to(device)
31
+
32
+ pipe_anime = StableDiffusionPipeline.from_pretrained(
33
+ "fluently/Fluently-anime",
34
+ torch_dtype=torch.float16,
35
+ use_safetensors=True,
36
+ )
37
+ pipe_anime.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_anime.scheduler.config)
38
+ pipe_anime.to(device)
39
+
40
+ pipe_epic = StableDiffusionPipeline.from_pretrained(
41
+ "fluently/Fluently-epic",
42
+ torch_dtype=torch.float16,
43
+ use_safetensors=True,
44
+ )
45
+ pipe_epic.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_epic.scheduler.config)
46
+ pipe_epic.to(device)
47
+
48
+ pipe_xl_inpaint = StableDiffusionXLInpaintPipeline.from_single_file(
49
+ "https://huggingface.co/fluently/Fluently-XL-v3-inpainting/blob/main/FluentlyXL-v3-inpainting.safetensors",
50
+ torch_dtype=torch.float16,
51
+ use_safetensors=True,
52
+ )
53
+ pipe_xl_inpaint.to(device)
54
+
55
+ pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
56
+ "fluently/Fluently-v4-inpainting",
57
+ torch_dtype=torch.float16,
58
+ use_safetensors=True,
59
+ )
60
+ pipe_inpaint.to(device)
61
+
62
+ pipe_xl = StableDiffusionXLPipeline.from_pretrained(
63
+ "fluently/Fluently-XL-v4",
64
+ torch_dtype=torch.float16,
65
+ use_safetensors=True,
66
+ )
67
+ pipe_xl.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_xl.scheduler.config)
68
+ pipe_xl.to(device)
69
+
70
+ pipe_xl_lightning = StableDiffusionXLPipeline.from_pretrained(
71
+ "fluently/Fluently-XL-v3-lightning",
72
+ torch_dtype=torch.float16,
73
+ use_safetensors=True,
74
+ )
75
+ pipe_xl_lightning.scheduler = DPMSolverSinglestepScheduler.from_config(pipe_xl_lightning.scheduler.config, use_karras_sigmas=False, timestep_spacing="trailing", lower_order_final=True)
76
+ pipe_xl_lightning.to(device)
77
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
80
  if randomize_seed:
81
  seed = random.randint(0, MAX_SEED)
82
  return seed
83
 
84
+
85
  @app.post("/generate")
86
  async def generate(
87
  model: str = Form(...),
 
106
  inpaint_image_pil = Image.open(io.BytesIO(await inpaint_image.read())) if inpaint_image else None
107
  mask_image_pil = Image.open(io.BytesIO(await mask_image.read())) if mask_image else None
108
 
 
 
109
  if model == "Fluently XL Final":
110
+ images = pipe_xl_final(
111
  prompt=prompt,
112
  negative_prompt=negative_prompt,
113
  width=width,
 
118
  output_type="pil",
119
  ).images
120
  elif model == "Fluently Anime":
121
+ images = pipe_anime(
122
  prompt=prompt,
123
  negative_prompt=negative_prompt,
124
  width=width,
 
129
  output_type="pil",
130
  ).images
131
  elif model == "Fluently Epic":
132
+ images = pipe_epic(
133
  prompt=prompt,
134
  negative_prompt=negative_prompt,
135
  width=width,
 
140
  output_type="pil",
141
  ).images
142
  elif model == "Fluently XL v4":
143
+ images = pipe_xl(
144
  prompt=prompt,
145
  negative_prompt=negative_prompt,
146
  width=width,
 
151
  output_type="pil",
152
  ).images
153
  elif model == "Fluently XL v3 Lightning":
154
+ images = pipe_xl_lightning(
155
  prompt=prompt,
156
  negative_prompt=negative_prompt,
157
  width=width,
 
162
  output_type="pil",
163
  ).images
164
  elif model == "Fluently v4 inpaint" or model == "Fluently XL v3 inpaint":
165
+ blurred_mask = pipe_inpaint.mask_processor.blur(mask_image_pil, blur_factor=blur_factor)
166
+ images = pipe_inpaint(
167
  prompt=prompt,
168
  image=inpaint_image_pil,
169
  mask_image=blurred_mask,
 
177
  output_type="pil",
178
  ).images
179
 
 
 
 
 
180
  img = images[0]
181
  img_byte_arr = io.BytesIO()
182
  img.save(img_byte_arr, format='PNG')