Andromeda / testing /model.py
kye's picture
Upload 73 files
ca4fc4d
raw
history blame contribute delete
No virus
5.28 kB
import torch
import unittest
from Andromeda.model import Andromeda
class TestAndromeda(unittest.TestCase):
def setUp(self):
self.model = Andromeda()
def test_initialization(self):
self.assertIsNotNone(self.model.andromeda, "Transformer is not initialized.")
self.assertIsNotNone(self.model.decoder, "AutoregressiveWrapper is not initialized.")
def test_forward_pass(self):
input_tokens = torch.randint(0, 50432, (1, 8192))
output = self.model(input_tokens)
self.assertIsInstance(output, torch.Tensor, "Output is not a PyTorch tensor.")
self.assertEqual(output.shape[0], input_tokens.shape[0], "Output batch size does not match input.")
def test_error_handling(self):
with self.assertRaises(Exception):
self.model.forward(None)
def test_model_parameters(self):
self.assertEqual(self.model.Andromeda.num_tokens, 50432, "Number of tokens is not correctly set.")
self.assertEqual(self.model.Andromeda.max_seq_len, 8192, "Max sequence length is not correctly set.")
def test_model_output(self):
input_tokens = torch.randint(0, 50432, (1, 8192))
output1 = self.model(input_tokens)
output2 = self.model(input_tokens)
self.assertTrue(torch.allclose(output1, output2), "Model does not produce consistent output.")
class TestAndromedaExtended(unittest.TestCase):
def setUp(self):
self.model = Andromeda()
def test_input_size(self):
for seq_len in [512, 1024, 2048, 4096]:
input_tokens = torch.randint(0, 50432, (1, seq_len))
output = self.model(input_tokens)
self.assertEqual(output.shape[1], seq_len, f"Output sequence length does not match input for seq_len={seq_len}.")
def test_batch_size(self):
for batch_size in [2, 4, 8, 16]:
input_tokens = torch.randint(0, 50432, (batch_size, 8192))
output = self.model(input_tokens)
self.assertEqual(output.shape[0], batch_size, f"Output batch size does not match input for batch_size={batch_size}.")
def test_token_range(self):
for token in [0, 50431]:
input_tokens = torch.full((1, 8192), fill_value=token)
output = self.model(input_tokens)
self.assertIsInstance(output, torch.Tensor, f"Output is not a PyTorch tensor for token={token}.")
def test_model_depth(self):
for depth in [16, 32, 64]:
model = Andromeda(depth=depth)
self.assertEqual(model.Andromeda.attn_layers.depth, depth, f"Model depth is not correctly set for depth={depth}.")
def test_model_dim(self):
for dim in [1280, 2560, 5120]:
model = Andromeda(dim=dim)
self.assertEqual(model.Andromeda.attn_layers.dim, dim, f"Model dimension is not correctly set for dim={dim}.")
def test_model_heads(self):
for heads in [12, 24, 48]:
model = Andromeda(heads=heads)
self.assertEqual(model.Andromeda.attn_layers.heads, heads, f"Number of heads is not correctly set for heads={heads}.")
def test_model_dim_head(self):
for dim_head in [64, 128, 256]:
model = Andromeda(dim_head=dim_head)
self.assertEqual(model.Andromeda.attn_layers.dim_head, dim_head, f"Head dimension is not correctly set for dim_head={dim_head}.")
def test_model_alibi_num_heads(self):
for alibi_num_heads in [6, 12, 24]:
model = Andromeda(alibi_num_heads=alibi_num_heads)
self.assertEqual(model.Andromeda.attn_layers.alibi_num_heads, alibi_num_heads, f"Number of alibi heads is not correctly set for alibi_num_heads={alibi_num_heads}.")
def test_model_shift_tokens(self):
for shift_tokens in [0, 1, 2]:
model = Andromeda(shift_tokens=shift_tokens)
self.assertEqual(model.Andromeda.attn_layers.shift_tokens, shift_tokens, f"Number of shift tokens is not correctly set for shift_tokens={shift_tokens}.")
def test_model_use_abs_pos_emb(self):
for use_abs_pos_emb in [True, False]:
model = Andromeda(use_abs_pos_emb=use_abs_pos_emb)
self.assertEqual(model.Andromeda.use_abs_pos_emb, use_abs_pos_emb, f"Use absolute position embedding flag is not correctly set for use_abs_pos_emb={use_abs_pos_emb}.")
def test_model_alibi_pos_bias(self):
for alibi_pos_bias in [True, False]:
model = Andromeda(alibi_pos_bias=alibi_pos_bias)
self.assertEqual(model.Andromeda.attn_layers.alibi_pos_bias, alibi_pos_bias, f"Alibi position bias flag is not correctly set for alibi_pos_bias={alibi_pos_bias}.")
def test_model_rotary_xpos(self):
for rotary_xpos in [True, False]:
model = Andromeda(rotary_xpos=rotary_xpos)
self.assertEqual(model.Andromeda.attn_layers.rotary_xpos, rotary_xpos, f"Rotary position flag is not correctly set for rotary_xpos={rotary_xpos}.")
def test_model_attn_flash(self):
for attn_flash in [True, False]:
model = Andromeda(attn_flash=attn_flash)
self.assertEqual(model.Andromeda.attn_layers.attn_flash, attn_flash, f"Attention flash flag is not correctly set for attn_flash={attn_flash}")
if __name__ == '__main__':
unittest.main()