| @@ -9,8 +9,25 @@ |
|
|
| 9 |
9 |
from doctr.models import ocr_predictor |
| 10 |
10 |
from doctr.models.predictor import OCRPredictor |
| 11 |
11 |
|
| 12 |
|
-DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet50_rotation"] |
| 13 |
|
-RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"] |
|
12 |
+DET_ARCHS = [ |
|
13 |
+"db_resnet50", |
|
14 |
+"db_resnet34", |
|
15 |
+"db_mobilenet_v3_large", |
|
16 |
+"db_resnet50_rotation", |
|
17 |
+"linknet_resnet18", |
|
18 |
+"linknet_resnet34", |
|
19 |
+"linknet_resnet50", |
|
20 |
+] |
|
21 |
+RECO_ARCHS = [ |
|
22 |
+"crnn_vgg16_bn", |
|
23 |
+"crnn_mobilenet_v3_small", |
|
24 |
+"crnn_mobilenet_v3_large", |
|
25 |
+"master", |
|
26 |
+"sar_resnet31", |
|
27 |
+"vitstr_small", |
|
28 |
+"vitstr_base", |
|
29 |
+"parseq", |
|
30 |
+] |
| 14 |
31 |
|
| 15 |
32 |
|
| 16 |
33 |
def load_predictor(det_arch: str, reco_arch: str, device) -> OCRPredictor: |