tolgacangoz
commited on
Commit
•
a108e65
1
Parent(s):
c08a7a5
Upload matryoshka.py
Browse files- scheduler/matryoshka.py +2 -2
scheduler/matryoshka.py
CHANGED
@@ -4002,7 +4002,7 @@ class MatryoshkaPipeline(
|
|
4002 |
prompt_attention_mask = torch.cat(
|
4003 |
[
|
4004 |
prompt_attention_mask,
|
4005 |
-
torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long),
|
4006 |
],
|
4007 |
dim=1,
|
4008 |
)
|
@@ -4014,7 +4014,7 @@ class MatryoshkaPipeline(
|
|
4014 |
negative_prompt_attention_mask = torch.cat(
|
4015 |
[
|
4016 |
negative_prompt_attention_mask,
|
4017 |
-
torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long),
|
4018 |
],
|
4019 |
dim=1,
|
4020 |
)
|
|
|
4002 |
prompt_attention_mask = torch.cat(
|
4003 |
[
|
4004 |
prompt_attention_mask,
|
4005 |
+
torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device),
|
4006 |
],
|
4007 |
dim=1,
|
4008 |
)
|
|
|
4014 |
negative_prompt_attention_mask = torch.cat(
|
4015 |
[
|
4016 |
negative_prompt_attention_mask,
|
4017 |
+
torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long, device=device),
|
4018 |
],
|
4019 |
dim=1,
|
4020 |
)
|