Skip to content

Commit d26bdff

Browse files
Zafar Takhirovfacebook-github-bot
authored andcommitted
torchaudio: torch.quantization -> torch.ao.quantization (#1817)
Summary: Pull Request resolved: #1817 This changes the imports in the `torchaudio` to include the new import locations. ``` codemod -d pytorch/audio --extensions py 'torch.quantization' 'torch.ao.quantization' ``` Reviewed By: mthrok Differential Revision: D31302450 fbshipit-source-id: 4dd0087d867dc52e41023660c3056a66fc3e21a4
1 parent 1e7516f commit d26bdff

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import argparse
88
import logging
9+
from typing import Tuple
910

1011
import torch
1112
from torch.utils.mobile_optimizer import optimize_for_mobile
@@ -15,6 +16,12 @@
1516

1617
from greedy_decoder import Decoder
1718

19+
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
20+
if TORCH_VERSION >= (1, 10):
21+
import torch.ao.quantization as tq
22+
else:
23+
import torch.quantization as tq
24+
1825
_LG = logging.getLogger(__name__)
1926

2027

@@ -149,7 +156,7 @@ def _main():
149156
if args.quantize:
150157
_LG.info('Quantizing the model')
151158
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
152-
encoder = torch.quantization.quantize_dynamic(
159+
encoder = tq.quantize_dynamic(
153160
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
154161
_LG.info(encoder)
155162

examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,19 @@
22
import argparse
33
import logging
44
import os
5+
from typing import Tuple
56

67
import torch
78
import torchaudio
89
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
910
from greedy_decoder import Decoder
1011

12+
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
13+
if TORCH_VERSION >= (1, 10):
14+
import torch.ao.quantization as tq
15+
else:
16+
import torch.quantization as tq
17+
1118
_LG = logging.getLogger(__name__)
1219

1320

@@ -90,7 +97,7 @@ def _main():
9097
if args.quantize:
9198
_LG.info('Quantizing the model')
9299
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
93-
encoder = torch.quantization.quantize_dynamic(
100+
encoder = tq.quantize_dynamic(
94101
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
95102
_LG.info(encoder)
96103

test/torchaudio_unittest/models/wav2vec2/model_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn.functional as F
5+
from typing import Tuple
56

67
from torchaudio.models.wav2vec2 import (
78
wav2vec2_asr_base,
@@ -24,6 +25,12 @@
2425
)
2526
from parameterized import parameterized
2627

28+
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
29+
if TORCH_VERSION >= (1, 10):
30+
import torch.ao.quantization as tq
31+
else:
32+
import torch.quantization as tq
33+
2734

2835
def _name_func(testcase_func, i, param):
2936
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
@@ -206,7 +213,7 @@ def _test_quantize_smoke_test(self, model):
206213

207214
# Remove the weight normalization forward hook
208215
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
209-
quantized = torch.quantization.quantize_dynamic(
216+
quantized = tq.quantize_dynamic(
210217
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
211218

212219
# A lazy way to check that Modules are different
@@ -237,7 +244,7 @@ def _test_quantize_torchscript(self, model):
237244

238245
# Remove the weight normalization forward hook
239246
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
240-
quantized = torch.quantization.quantize_dynamic(
247+
quantized = tq.quantize_dynamic(
241248
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
242249

243250
# A lazy way to check that Modules are different

0 commit comments

Comments
 (0)