File size: 171 Bytes
ca4fc4d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
import torch
from Andromeda.configs import Andromeda1Billion

model =  Andromeda1Billion().cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()

model(x) # (1, 1024, 20000)