Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
from mmocr.models.textrecog.backbones import (ResNet, ResNet31OCR, ResNetABI, | |
ShallowCNN, VeryDeepVgg) | |
def test_resnet31_ocr_backbone(): | |
"""Test resnet backbone.""" | |
with pytest.raises(AssertionError): | |
ResNet31OCR(2.5) | |
with pytest.raises(AssertionError): | |
ResNet31OCR(3, layers=5) | |
with pytest.raises(AssertionError): | |
ResNet31OCR(3, channels=5) | |
# Test ResNet18 forward | |
model = ResNet31OCR() | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 3, 32, 160) | |
feat = model(imgs) | |
assert feat.shape == torch.Size([1, 512, 4, 40]) | |
def test_vgg_deep_vgg_ocr_backbone(): | |
model = VeryDeepVgg() | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 3, 32, 160) | |
feats = model(imgs) | |
assert feats.shape == torch.Size([1, 512, 1, 41]) | |
def test_shallow_cnn_ocr_backbone(): | |
model = ShallowCNN() | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 1, 32, 100) | |
feat = model(imgs) | |
assert feat.shape == torch.Size([1, 512, 8, 25]) | |
def test_resnet_abi(): | |
"""Test resnet backbone.""" | |
with pytest.raises(AssertionError): | |
ResNetABI(2.5) | |
with pytest.raises(AssertionError): | |
ResNetABI(3, arch_settings=5) | |
with pytest.raises(AssertionError): | |
ResNetABI(3, stem_channels=None) | |
with pytest.raises(AssertionError): | |
ResNetABI(arch_settings=[3, 4, 6, 6], strides=[1, 2, 1, 2, 1]) | |
# Test forwarding | |
model = ResNetABI() | |
model.train() | |
imgs = torch.randn(1, 3, 32, 160) | |
feat = model(imgs) | |
assert feat.shape == torch.Size([1, 512, 8, 40]) | |
def test_resnet(): | |
"""Test all ResNet backbones.""" | |
resnet45_aster = ResNet( | |
in_channels=3, | |
stem_channels=[64, 128], | |
block_cfgs=dict(type='BasicBlock', use_conv1x1='True'), | |
arch_layers=[3, 4, 6, 6, 3], | |
arch_channels=[32, 64, 128, 256, 512], | |
strides=[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]) | |
resnet45_abi = ResNet( | |
in_channels=3, | |
stem_channels=32, | |
block_cfgs=dict(type='BasicBlock', use_conv1x1='True'), | |
arch_layers=[3, 4, 6, 6, 3], | |
arch_channels=[32, 64, 128, 256, 512], | |
strides=[2, 1, 2, 1, 1]) | |
resnet_31 = ResNet( | |
in_channels=3, | |
stem_channels=[64, 128], | |
block_cfgs=dict(type='BasicBlock'), | |
arch_layers=[1, 2, 5, 3], | |
arch_channels=[256, 256, 512, 512], | |
strides=[1, 1, 1, 1], | |
plugins=[ | |
dict( | |
cfg=dict(type='Maxpool2d', kernel_size=2, stride=(2, 2)), | |
stages=(True, True, False, False), | |
position='before_stage'), | |
dict( | |
cfg=dict(type='Maxpool2d', kernel_size=(2, 1), stride=(2, 1)), | |
stages=(False, False, True, False), | |
position='before_stage'), | |
dict( | |
cfg=dict( | |
type='ConvModule', | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU')), | |
stages=(True, True, True, True), | |
position='after_stage') | |
]) | |
img = torch.rand(1, 3, 32, 100) | |
assert resnet45_aster(img).shape == torch.Size([1, 512, 1, 25]) | |
assert resnet45_abi(img).shape == torch.Size([1, 512, 8, 25]) | |
assert resnet_31(img).shape == torch.Size([1, 512, 4, 25]) | |