Update README.md
Browse files
README.md
CHANGED
@@ -117,29 +117,54 @@ pipe = pipe.to(device, dtype=dtype)
|
|
117 |
pipe.prior_pipe = pipe.prior_pipe.to(device, dtype=dtype)
|
118 |
|
119 |
|
120 |
-
def
|
121 |
prior_pipe,
|
122 |
device,
|
123 |
-
batch_size,
|
124 |
num_images_per_prompt,
|
|
|
125 |
):
|
126 |
|
127 |
text_inputs = prior_pipe.tokenizer(
|
128 |
-
|
129 |
-
padding="
|
130 |
-
|
131 |
-
truncation=True,
|
132 |
return_tensors="pt",
|
133 |
)
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
text_encoder_output = prior_pipe.text_encoder(
|
137 |
-
|
138 |
)
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
prompt_embeds = prompt_embeds.to(dtype=prior_pipe.text_encoder.dtype, device=device)
|
141 |
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
|
145 |
prompt = "1girl, solo, looking at viewer, open mouth, blue eyes, medium breasts, blonde hair, gloves, dress, bow, hair between eyes, bare shoulders, upper body, hair bow, indoors, elbow gloves, hand on own chest, bridal gauntlets, candlestand, smile, rim lighting, from side, castle interior, looking side,"
|
@@ -147,18 +172,21 @@ quality_prompt = "extremely aesthetic, best quality, newest"
|
|
147 |
negative_prompt = "very displeasing, displeasing, worst quality, bad quality, low quality, realistic, monochrome, comic, sketch, oldest, early, artist name, signature, blurry, simple background, upside down, interlocked fingers,"
|
148 |
num_images_per_prompt=1
|
149 |
|
150 |
-
# Encode prompts and quality prompts
|
151 |
-
#
|
152 |
-
|
153 |
-
quality_prompt_embeds, _, _, _ = pipe.prior_pipe.encode_prompt(device, 1, num_images_per_prompt, False, prompt=quality_prompt)
|
154 |
-
|
155 |
-
negative_prompt_embeds, negative_prompt_embeds_pooled, _, _ = pipe.prior_pipe.encode_prompt(device, 1, num_images_per_prompt, False, prompt=negative_prompt)
|
156 |
-
empty_prompt_embeds = encode_empty_prompt(pipe.prior_pipe, device, 1, num_images_per_prompt)
|
157 |
|
|
|
|
|
158 |
prompt_embeds = torch.cat([prompt_embeds, quality_prompt_embeds], dim=1)
|
159 |
-
negative_prompt_embeds = torch.cat([negative_prompt_embeds, empty_prompt_embeds], dim=1)
|
160 |
|
161 |
-
pipe.prior_pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
output = pipe(
|
164 |
width=1024,
|
|
|
117 |
pipe.prior_pipe = pipe.prior_pipe.to(device, dtype=dtype)
|
118 |
|
119 |
|
120 |
+
def encode_prompt(
|
121 |
prior_pipe,
|
122 |
device,
|
|
|
123 |
num_images_per_prompt,
|
124 |
+
prompt=""
|
125 |
):
|
126 |
|
127 |
text_inputs = prior_pipe.tokenizer(
|
128 |
+
prompt,
|
129 |
+
padding="longest",
|
130 |
+
truncation=False,
|
|
|
131 |
return_tensors="pt",
|
132 |
)
|
133 |
+
chunk = []
|
134 |
+
padding = []
|
135 |
+
max_len = 75
|
136 |
+
start_token = text_inputs.input_ids[:,0].unsqueeze(0)
|
137 |
+
end_token = text_inputs.input_ids[:,-1].unsqueeze(0)
|
138 |
+
raw_input_ids = text_inputs.input_ids[:,1:-1]
|
139 |
+
prompt_len = len(raw_input_ids[0])
|
140 |
+
last_lenght = prompt_len % max_len
|
141 |
+
|
142 |
+
for i in range(int((prompt_len - last_lenght) / max_len)):
|
143 |
+
chunk.append(torch.cat([start_token, raw_input_ids[:,i*max_len:(i+1)*max_len], end_token], dim=1))
|
144 |
+
for i in range(max_len - last_lenght):
|
145 |
+
padding.append(text_inputs.input_ids[:,-1])
|
146 |
+
|
147 |
+
last_chunk = torch.cat([raw_input_ids[:,prompt_len-last_lenght:], torch.tensor([padding])], dim=1)
|
148 |
+
chunk.append(torch.cat([start_token, last_chunk, end_token], dim=1))
|
149 |
+
input_ids = torch.cat(chunk, dim=0).to(device)
|
150 |
+
|
151 |
+
# Don't use attention masks
|
152 |
text_encoder_output = prior_pipe.text_encoder(
|
153 |
+
input_ids, attention_mask=None, output_hidden_states=True
|
154 |
)
|
155 |
+
|
156 |
+
start_embed = text_encoder_output.hidden_states[-1][:,0].unsqueeze(0)
|
157 |
+
end_embed = text_encoder_output.hidden_states[-1][:,-1].unsqueeze(0)
|
158 |
+
prompt_embeds = text_encoder_output.hidden_states[-1][:,1:-1].reshape(1,-1,1280)
|
159 |
+
prompt_embeds = torch.cat([start_embed, prompt_embeds, end_embed], dim=1)
|
160 |
prompt_embeds = prompt_embeds.to(dtype=prior_pipe.text_encoder.dtype, device=device)
|
161 |
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
162 |
+
|
163 |
+
prompt_embeds_pooled = text_encoder_output.text_embeds[0].unsqueeze(0).unsqueeze(1)
|
164 |
+
prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=prior_pipe.text_encoder.dtype, device=device)
|
165 |
+
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
|
166 |
+
|
167 |
+
return prompt_embeds, prompt_embeds_pooled
|
168 |
|
169 |
|
170 |
prompt = "1girl, solo, looking at viewer, open mouth, blue eyes, medium breasts, blonde hair, gloves, dress, bow, hair between eyes, bare shoulders, upper body, hair bow, indoors, elbow gloves, hand on own chest, bridal gauntlets, candlestand, smile, rim lighting, from side, castle interior, looking side,"
|
|
|
172 |
negative_prompt = "very displeasing, displeasing, worst quality, bad quality, low quality, realistic, monochrome, comic, sketch, oldest, early, artist name, signature, blurry, simple background, upside down, interlocked fingers,"
|
173 |
num_images_per_prompt=1
|
174 |
|
175 |
+
# Encode prompts and quality prompts eperately, don't use attention masks and long prompt support:
|
176 |
+
# pipe, device, num_images_per_prompt, prompt
|
177 |
+
empty_prompt_embeds, _ = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt="")
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
prompt_embeds, prompt_embeds_pooled = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt=prompt)
|
180 |
+
quality_prompt_embeds, _ = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt=quality_prompt)
|
181 |
prompt_embeds = torch.cat([prompt_embeds, quality_prompt_embeds], dim=1)
|
|
|
182 |
|
183 |
+
negative_prompt_embeds, negative_prompt_embeds_pooled = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt=negative_prompt)
|
184 |
+
|
185 |
+
while prompt_embeds.shape[1] < negative_prompt_embeds.shape[1]:
|
186 |
+
prompt_embeds = torch.cat([prompt_embeds, empty_prompt_embeds], dim=1)
|
187 |
+
|
188 |
+
while negative_prompt_embeds.shape[1] < prompt_embeds.shape[1]:
|
189 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, empty_prompt_embeds], dim=1)
|
190 |
|
191 |
output = pipe(
|
192 |
width=1024,
|