Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder, | |
NRTREncoder, SAREncoder, | |
SatrnEncoder, TransformerEncoder) | |
def test_sar_encoder(): | |
with pytest.raises(AssertionError): | |
SAREncoder(enc_bi_rnn='bi') | |
with pytest.raises(AssertionError): | |
SAREncoder(enc_do_rnn=2) | |
with pytest.raises(AssertionError): | |
SAREncoder(enc_gru='gru') | |
with pytest.raises(AssertionError): | |
SAREncoder(d_model=512.5) | |
with pytest.raises(AssertionError): | |
SAREncoder(d_enc=200.5) | |
with pytest.raises(AssertionError): | |
SAREncoder(mask='mask') | |
encoder = SAREncoder() | |
encoder.init_weights() | |
encoder.train() | |
feat = torch.randn(1, 512, 4, 40) | |
img_metas = [{'valid_ratio': 1.0}] | |
with pytest.raises(AssertionError): | |
encoder(feat, img_metas * 2) | |
out_enc = encoder(feat, img_metas) | |
assert out_enc.shape == torch.Size([1, 512]) | |
def test_nrtr_encoder(): | |
tf_encoder = NRTREncoder() | |
tf_encoder.init_weights() | |
tf_encoder.train() | |
feat = torch.randn(1, 512, 1, 25) | |
out_enc = tf_encoder(feat) | |
print('hello', out_enc.size()) | |
assert out_enc.shape == torch.Size([1, 25, 512]) | |
def test_satrn_encoder(): | |
satrn_encoder = SatrnEncoder() | |
satrn_encoder.init_weights() | |
satrn_encoder.train() | |
feat = torch.randn(1, 512, 8, 25) | |
out_enc = satrn_encoder(feat) | |
assert out_enc.shape == torch.Size([1, 200, 512]) | |
def test_base_encoder(): | |
encoder = BaseEncoder() | |
encoder.init_weights() | |
encoder.train() | |
feat = torch.randn(1, 256, 4, 40) | |
out_enc = encoder(feat) | |
assert out_enc.shape == torch.Size([1, 256, 4, 40]) | |
def test_transformer_encoder(): | |
model = TransformerEncoder() | |
x = torch.randn(10, 512, 8, 32) | |
assert model(x).shape == torch.Size([10, 512, 8, 32]) | |
def test_abi_vision_model(): | |
model = ABIVisionModel( | |
decoder=dict(type='ABIVisionDecoder', max_seq_len=10, use_result=None)) | |
x = torch.randn(1, 512, 8, 32) | |
result = model(x) | |
assert result['feature'].shape == torch.Size([1, 10, 512]) | |
assert result['logits'].shape == torch.Size([1, 10, 90]) | |
assert result['attn_scores'].shape == torch.Size([1, 10, 8, 32]) | |