File size: 1,217 Bytes
5de2f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
__all__ = ['build_encoder']

from importlib import import_module

name_to_module = {
    'MobileNetV1Enhance': '.rec_mv1_enhance',
    'ResNet31': '.rec_resnet_31',
    'MobileNetV3': '.rec_mobilenet_v3',
    'PPLCNetV3': '.rec_lcnetv3',
    'PPHGNet_small': '.rec_hgnet',
    'ResNet': '.rec_resnet_vd',
    'MTB': '.rec_nrtr_mtb',
    'SVTRNet': '.svtrnet',
    'ResNet45': '.rec_resnet_45',
    'ViT': '.vit',
    'SVTRNet2DPos': '.svtrnet2dpos',
    'SVTRv2': '.svtrv2',
    'FocalSVTR': '.focalsvtr',
    'ResNet_FPN': '.rec_resnet_fpn',
    'ResNet_ASTER': '.resnet31_rnn',
    'SVTRv2LNConv': '.svtrv2_lnconv',
    'SVTRv2LNConvTwo33': '.svtrv2_lnconv_two33',
    'CAMEncoder': '.cam_encoder',
    'ConvNeXtV2': '.convnextv2',
    'AutoSTREncoder': '.autostr_encoder',
    'NRTREncoder': '.nrtr_encoder',
    'RepSVTREncoder': '.repvit',
}


def build_encoder(config):

    module_name = config.pop('name')
    assert module_name in name_to_module, Exception(
        f'Encoder only supports: {list(name_to_module.keys())}')

    module_path = name_to_module[module_name]
    mod = import_module(module_path, package=__package__)
    module_class = getattr(mod, module_name)(**config)

    return module_class