Disty0 commited on
Commit
99eb79b
1 Parent(s): 05566dd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +39 -27
README.md CHANGED
@@ -124,33 +124,45 @@ def encode_prompt(
124
  prompt=""
125
  ):
126
 
127
- text_inputs = prior_pipe.tokenizer(
128
- prompt,
129
- padding="longest",
130
- truncation=False,
131
- return_tensors="pt",
132
- )
133
- chunk = []
134
- padding = []
135
- max_len = 75
136
- start_token = text_inputs.input_ids[:,0].unsqueeze(0)
137
- end_token = text_inputs.input_ids[:,-1].unsqueeze(0)
138
- raw_input_ids = text_inputs.input_ids[:,1:-1]
139
- prompt_len = len(raw_input_ids[0])
140
- last_lenght = prompt_len % max_len
141
-
142
- for i in range(int((prompt_len - last_lenght) / max_len)):
143
- chunk.append(torch.cat([start_token, raw_input_ids[:,i*max_len:(i+1)*max_len], end_token], dim=1))
144
- for i in range(max_len - last_lenght):
145
- padding.append(text_inputs.input_ids[:,-1])
146
-
147
- last_chunk = torch.cat([raw_input_ids[:,prompt_len-last_lenght:], torch.tensor([padding])], dim=1)
148
- chunk.append(torch.cat([start_token, last_chunk, end_token], dim=1))
149
- input_ids = torch.cat(chunk, dim=0).to(device)
150
-
151
- # Don't use attention masks
 
 
 
 
 
 
 
 
 
 
 
 
152
  text_encoder_output = prior_pipe.text_encoder(
153
- input_ids, attention_mask=None, output_hidden_states=True
154
  )
155
 
156
  prompt_embeds = text_encoder_output.hidden_states[-1].reshape(1,-1,1280)
@@ -169,7 +181,7 @@ quality_prompt = "very aesthetic, best quality, newest"
169
  negative_prompt = "very displeasing, displeasing, worst quality, bad quality, low quality, realistic, monochrome, comic, sketch, oldest, early, artist name, signature, blurry, simple background, upside down, interlocked fingers,"
170
  num_images_per_prompt=1
171
 
172
- # Encode prompts and quality prompts eperately, don't use attention masks and long prompt support:
173
  # pipe, device, num_images_per_prompt, prompt
174
  empty_prompt_embeds, _ = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt="")
175
 
 
124
  prompt=""
125
  ):
126
 
127
+ if prompt == "":
128
+ text_inputs = prior_pipe.tokenizer(
129
+ prompt,
130
+ padding="max_length",
131
+ max_length=77,
132
+ truncation=False,
133
+ return_tensors="pt",
134
+ )
135
+ input_ids = text_inputs.input_ids
136
+ attention_mask=None
137
+ else:
138
+ text_inputs = prior_pipe.tokenizer(
139
+ prompt,
140
+ padding="longest",
141
+ truncation=False,
142
+ return_tensors="pt",
143
+ )
144
+ chunk = []
145
+ padding = []
146
+ max_len = 75
147
+ start_token = text_inputs.input_ids[:,0].unsqueeze(0)
148
+ end_token = text_inputs.input_ids[:,-1].unsqueeze(0)
149
+ raw_input_ids = text_inputs.input_ids[:,1:-1]
150
+ prompt_len = len(raw_input_ids[0])
151
+ last_lenght = prompt_len % max_len
152
+
153
+ for i in range(int((prompt_len - last_lenght) / max_len)):
154
+ chunk.append(torch.cat([start_token, raw_input_ids[:,i*max_len:(i+1)*max_len], end_token], dim=1))
155
+ for i in range(max_len - last_lenght):
156
+ padding.append(text_inputs.input_ids[:,-1])
157
+
158
+ last_chunk = torch.cat([raw_input_ids[:,prompt_len-last_lenght:], torch.tensor([padding])], dim=1)
159
+ chunk.append(torch.cat([start_token, last_chunk, end_token], dim=1))
160
+ input_ids = torch.cat(chunk, dim=0)
161
+ attention_mask = torch.ones(input_ids.shape, device=device, dtype=torch.int64)
162
+ attention_mask[-1,last_lenght+1:] = 0
163
+
164
  text_encoder_output = prior_pipe.text_encoder(
165
+ input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
166
  )
167
 
168
  prompt_embeds = text_encoder_output.hidden_states[-1].reshape(1,-1,1280)
 
181
  negative_prompt = "very displeasing, displeasing, worst quality, bad quality, low quality, realistic, monochrome, comic, sketch, oldest, early, artist name, signature, blurry, simple background, upside down, interlocked fingers,"
182
  num_images_per_prompt=1
183
 
184
+ # Encode prompts and quality prompts eperately, long prompt support and don't use attention masks for empty prompts:
185
  # pipe, device, num_images_per_prompt, prompt
186
  empty_prompt_embeds, _ = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt="")
187