Skip to content

Commit 2fbcce1

Browse files
authored
Merge pull request #211 from Visual-Behavior/trt
add create_calibrator function
2 parents 05124ed + 6b12f47 commit 2fbcce1

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

alonet/torch2trt/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from .TRTEngineBuilder import TRTEngineBuilder
22
from .TRTExecutor import TRTExecutor
33
from .base_exporter import BaseTRTExporter
4-
from .utils import load_trt_custom_plugins
4+
from .utils import load_trt_custom_plugins, create_calibrator
5+
from .calibrator import DataBatchStreamer
56

67
from alonet import ALONET_ROOT
78
import os
89

10+
911
MS_DEFORM_IM2COL_PLUGIN_LIB = os.path.join(
1012
ALONET_ROOT, "torch2trt/plugins/ms_deform_im2col/build/libms_deform_im2col_trt.so"
1113
)

alonet/torch2trt/calibrator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1+
from aloscene.frame import Frame
2+
13
import os
24
import torch
35
import numpy as np
46
import tensorrt as trt
5-
67
import pycuda.driver as cuda
7-
import pycuda.autoinit
8-
9-
from aloscene.frame import Frame
108

119

1210
class DataBatchStreamer:

alonet/torch2trt/utils.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1+
from alonet import ALONET_ROOT
2+
from alonet.torch2trt.calibrator import (
3+
LegacyCalibrator,
4+
MinMaxCalibrator,
5+
EntropyCalibrator,
6+
EntropyCalibrator2,
7+
)
8+
19
import os
10+
import ctypes
211

312
try:
413
import pycuda.driver as cuda
@@ -10,8 +19,25 @@
1019
pass
1120

1221

13-
import ctypes
14-
from alonet import ALONET_ROOT
22+
def create_calibrator(name: str, *args, **kwargs):
23+
"""Creates calibrator from name
24+
25+
Parameters
26+
----------
27+
name : str
28+
Calibrator name
29+
"""
30+
CALIBS = ["minmax", "entropy", "entropy2", "legacy"]
31+
if name == "entropy2":
32+
return EntropyCalibrator2(*args, **kwargs)
33+
elif name == "entropy":
34+
return EntropyCalibrator(*args, **kwargs)
35+
elif name == "minmax":
36+
return MinMaxCalibrator(*args, **kwargs)
37+
elif name == "legacy":
38+
return LegacyCalibrator(*args, **kwargs)
39+
else:
40+
raise AttributeError(f"Unknown calibrator name, should be one of {' '.join(CALIBS)}")
1541

1642

1743
def load_trt_custom_plugins(lib_path: str):

0 commit comments

Comments
 (0)