Speed Increased by loading pipeline

#1
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -33,11 +33,10 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
33
 
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
 
36
- def load_pipeline(pipeline_type):
37
- if pipeline_type == "text2img":
38
- return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16)
39
- elif pipeline_type == "img2img":
40
- return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
41
 
42
  def save_image(img):
43
  unique_name = str(uuid.uuid4()) + ".png"
@@ -66,15 +65,13 @@ def generate(
66
  use_resolution_binning: bool = True,
67
  progress=gr.Progress(track_tqdm=True),
68
  ):
69
- pipe = load_pipeline("text2img")
70
- pipe.to(device)
71
  seed = int(randomize_seed_fn(seed, randomize_seed))
72
  generator = torch.Generator().manual_seed(seed)
73
 
74
  if not use_negative_prompt:
75
  negative_prompt = None # type: ignore
76
 
77
- output = pipe(
78
  prompt=prompt,
79
  negative_prompt=negative_prompt,
80
  width=width,
@@ -104,8 +101,6 @@ def img2img_generate(
104
  use_resolution_binning: bool = True,
105
  progress=gr.Progress(track_tqdm=True),
106
  ):
107
- pipe = load_pipeline("img2img")
108
- pipe.to(device)
109
  seed = int(randomize_seed_fn(seed, randomize_seed))
110
  generator = torch.Generator().manual_seed(seed)
111
 
@@ -114,7 +109,7 @@ def img2img_generate(
114
 
115
  init_image = init_image.resize((768, 768))
116
 
117
- output = pipe(
118
  prompt=prompt,
119
  image=init_image,
120
  negative_prompt=negative_prompt,
 
33
 
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
 
36
+ pipe_t2i = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16)
37
+ pipe_t2i.to(device)
38
+ pipe_i2i = StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
39
+ pipe_i2i.tp(device)
 
40
 
41
  def save_image(img):
42
  unique_name = str(uuid.uuid4()) + ".png"
 
65
  use_resolution_binning: bool = True,
66
  progress=gr.Progress(track_tqdm=True),
67
  ):
 
 
68
  seed = int(randomize_seed_fn(seed, randomize_seed))
69
  generator = torch.Generator().manual_seed(seed)
70
 
71
  if not use_negative_prompt:
72
  negative_prompt = None # type: ignore
73
 
74
+ output = pipe_t2i(
75
  prompt=prompt,
76
  negative_prompt=negative_prompt,
77
  width=width,
 
101
  use_resolution_binning: bool = True,
102
  progress=gr.Progress(track_tqdm=True),
103
  ):
 
 
104
  seed = int(randomize_seed_fn(seed, randomize_seed))
105
  generator = torch.Generator().manual_seed(seed)
106
 
 
109
 
110
  init_image = init_image.resize((768, 768))
111
 
112
+ output = pipe_i2i(
113
  prompt=prompt,
114
  image=init_image,
115
  negative_prompt=negative_prompt,