Spaces:
Running
on
A10G
Running
on
A10G
File size: 294 Bytes
bfd34e9 |
1 2 3 4 5 6 7 8 9 |
import torch
from ... import share
def forward(self, x, context=None):
x = x + self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) # Self Attn.
x = x + self.attn2(self.norm2(x), context=context) # Cross Attn.
x = x + self.ff(self.norm3(x))
return x
|