Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from craftsman.utils.typing import * | |
from craftsman.utils.checkpoint import checkpoint | |
from .utils import init_linear | |
from .attention import ResidualAttentionBlock | |
class Perceiver(nn.Module): | |
def __init__( | |
self, | |
*, | |
n_ctx: int, | |
width: int, | |
layers: int, | |
heads: int, | |
init_scale: float = 0.25, | |
qkv_bias: bool = True, | |
use_flash: bool = False, | |
use_checkpoint: bool = False | |
): | |
super().__init__() | |
self.n_ctx = n_ctx | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.ModuleList( | |
[ | |
ResidualAttentionBlock( | |
n_ctx=n_ctx, | |
width=width, | |
heads=heads, | |
init_scale=init_scale, | |
qkv_bias=qkv_bias, | |
use_flash=use_flash, | |
use_checkpoint=use_checkpoint | |
) | |
for _ in range(layers) | |
] | |
) | |
def forward(self, x: torch.Tensor): | |
for block in self.resblocks: | |
x = block(x) | |
return x |