Disty0 commited on
Commit
de04131
1 Parent(s): 44fb784

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +48 -20
README.md CHANGED
@@ -117,29 +117,54 @@ pipe = pipe.to(device, dtype=dtype)
117
  pipe.prior_pipe = pipe.prior_pipe.to(device, dtype=dtype)
118
 
119
 
120
- def encode_empty_prompt(
121
  prior_pipe,
122
  device,
123
- batch_size,
124
  num_images_per_prompt,
 
125
  ):
126
 
127
  text_inputs = prior_pipe.tokenizer(
128
- "",
129
- padding="max_length",
130
- max_length=prior_pipe.tokenizer.model_max_length,
131
- truncation=True,
132
  return_tensors="pt",
133
  )
134
-
135
- # Don't use attention mask for empty prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  text_encoder_output = prior_pipe.text_encoder(
137
- text_inputs.input_ids.to(device), attention_mask=None, output_hidden_states=True
138
  )
139
- prompt_embeds = text_encoder_output.hidden_states[-1]
 
 
 
 
140
  prompt_embeds = prompt_embeds.to(dtype=prior_pipe.text_encoder.dtype, device=device)
141
  prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
142
- return prompt_embeds
 
 
 
 
 
143
 
144
 
145
  prompt = "1girl, solo, looking at viewer, open mouth, blue eyes, medium breasts, blonde hair, gloves, dress, bow, hair between eyes, bare shoulders, upper body, hair bow, indoors, elbow gloves, hand on own chest, bridal gauntlets, candlestand, smile, rim lighting, from side, castle interior, looking side,"
@@ -147,18 +172,21 @@ quality_prompt = "extremely aesthetic, best quality, newest"
147
  negative_prompt = "very displeasing, displeasing, worst quality, bad quality, low quality, realistic, monochrome, comic, sketch, oldest, early, artist name, signature, blurry, simple background, upside down, interlocked fingers,"
148
  num_images_per_prompt=1
149
 
150
- # Encode prompts and quality prompts seperately:
151
- # device, batch_size, num_images_per_prompt, cfg, prompt
152
- prompt_embeds, prompt_embeds_pooled, _, _ = pipe.prior_pipe.encode_prompt(device, 1, num_images_per_prompt, False, prompt=prompt)
153
- quality_prompt_embeds, _, _, _ = pipe.prior_pipe.encode_prompt(device, 1, num_images_per_prompt, False, prompt=quality_prompt)
154
-
155
- negative_prompt_embeds, negative_prompt_embeds_pooled, _, _ = pipe.prior_pipe.encode_prompt(device, 1, num_images_per_prompt, False, prompt=negative_prompt)
156
- empty_prompt_embeds = encode_empty_prompt(pipe.prior_pipe, device, 1, num_images_per_prompt)
157
 
 
 
158
  prompt_embeds = torch.cat([prompt_embeds, quality_prompt_embeds], dim=1)
159
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, empty_prompt_embeds], dim=1)
160
 
161
- pipe.prior_pipe.maybe_free_model_hooks()
 
 
 
 
 
 
162
 
163
  output = pipe(
164
  width=1024,
 
117
  pipe.prior_pipe = pipe.prior_pipe.to(device, dtype=dtype)
118
 
119
 
120
+ def encode_prompt(
121
  prior_pipe,
122
  device,
 
123
  num_images_per_prompt,
124
+ prompt=""
125
  ):
126
 
127
  text_inputs = prior_pipe.tokenizer(
128
+ prompt,
129
+ padding="longest",
130
+ truncation=False,
 
131
  return_tensors="pt",
132
  )
133
+ chunk = []
134
+ padding = []
135
+ max_len = 75
136
+ start_token = text_inputs.input_ids[:,0].unsqueeze(0)
137
+ end_token = text_inputs.input_ids[:,-1].unsqueeze(0)
138
+ raw_input_ids = text_inputs.input_ids[:,1:-1]
139
+ prompt_len = len(raw_input_ids[0])
140
+ last_lenght = prompt_len % max_len
141
+
142
+ for i in range(int((prompt_len - last_lenght) / max_len)):
143
+ chunk.append(torch.cat([start_token, raw_input_ids[:,i*max_len:(i+1)*max_len], end_token], dim=1))
144
+ for i in range(max_len - last_lenght):
145
+ padding.append(text_inputs.input_ids[:,-1])
146
+
147
+ last_chunk = torch.cat([raw_input_ids[:,prompt_len-last_lenght:], torch.tensor([padding])], dim=1)
148
+ chunk.append(torch.cat([start_token, last_chunk, end_token], dim=1))
149
+ input_ids = torch.cat(chunk, dim=0).to(device)
150
+
151
+ # Don't use attention masks
152
  text_encoder_output = prior_pipe.text_encoder(
153
+ input_ids, attention_mask=None, output_hidden_states=True
154
  )
155
+
156
+ start_embed = text_encoder_output.hidden_states[-1][:,0].unsqueeze(0)
157
+ end_embed = text_encoder_output.hidden_states[-1][:,-1].unsqueeze(0)
158
+ prompt_embeds = text_encoder_output.hidden_states[-1][:,1:-1].reshape(1,-1,1280)
159
+ prompt_embeds = torch.cat([start_embed, prompt_embeds, end_embed], dim=1)
160
  prompt_embeds = prompt_embeds.to(dtype=prior_pipe.text_encoder.dtype, device=device)
161
  prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
162
+
163
+ prompt_embeds_pooled = text_encoder_output.text_embeds[0].unsqueeze(0).unsqueeze(1)
164
+ prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=prior_pipe.text_encoder.dtype, device=device)
165
+ prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
166
+
167
+ return prompt_embeds, prompt_embeds_pooled
168
 
169
 
170
  prompt = "1girl, solo, looking at viewer, open mouth, blue eyes, medium breasts, blonde hair, gloves, dress, bow, hair between eyes, bare shoulders, upper body, hair bow, indoors, elbow gloves, hand on own chest, bridal gauntlets, candlestand, smile, rim lighting, from side, castle interior, looking side,"
 
172
  negative_prompt = "very displeasing, displeasing, worst quality, bad quality, low quality, realistic, monochrome, comic, sketch, oldest, early, artist name, signature, blurry, simple background, upside down, interlocked fingers,"
173
  num_images_per_prompt=1
174
 
175
+ # Encode prompts and quality prompts eperately, don't use attention masks and long prompt support:
176
+ # pipe, device, num_images_per_prompt, prompt
177
+ empty_prompt_embeds, _ = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt="")
 
 
 
 
178
 
179
+ prompt_embeds, prompt_embeds_pooled = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt=prompt)
180
+ quality_prompt_embeds, _ = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt=quality_prompt)
181
  prompt_embeds = torch.cat([prompt_embeds, quality_prompt_embeds], dim=1)
 
182
 
183
+ negative_prompt_embeds, negative_prompt_embeds_pooled = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt=negative_prompt)
184
+
185
+ while prompt_embeds.shape[1] < negative_prompt_embeds.shape[1]:
186
+ prompt_embeds = torch.cat([prompt_embeds, empty_prompt_embeds], dim=1)
187
+
188
+ while negative_prompt_embeds.shape[1] < prompt_embeds.shape[1]:
189
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, empty_prompt_embeds], dim=1)
190
 
191
  output = pipe(
192
  width=1024,