Skip to content

Commit c748fcd

Browse files
committed
integrate residual lookup free quantization
1 parent 4bc50b0 commit c748fcd

File tree

4 files changed

+49
-21
lines changed

4 files changed

+49
-21
lines changed

README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ from audiolm_pytorch import SoundStream, SoundStreamTrainer
6767
soundstream = SoundStream(
6868
codebook_size = 1024,
6969
rq_num_quantizers = 8,
70-
rq_groups = 2, # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
71-
attn_window_size = 128, # local attention receptive field at bottleneck
72-
attn_depth = 2 # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
70+
rq_groups = 2, # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
71+
use_lookup_free_quantizer = True, # whether to use residual lookup free quantization
72+
attn_window_size = 128, # local attention receptive field at bottleneck
73+
attn_depth = 2 # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
7374
)
7475

7576
trainer = SoundStreamTrainer(
@@ -509,3 +510,14 @@ $ accelerate launch train.py
509510
year = {2022}
510511
}
511512
```
513+
514+
```bibtex
515+
@misc{yu2023language,
516+
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
517+
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
518+
year = {2023},
519+
eprint = {2310.05737},
520+
archivePrefix = {arXiv},
521+
primaryClass = {cs.CV}
522+
}
523+
```

audiolm_pytorch/soundstream.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
from einops import rearrange, reduce, pack, unpack
1919

20-
from vector_quantize_pytorch import GroupedResidualVQ
20+
from vector_quantize_pytorch import (
21+
GroupedResidualVQ,
22+
ResidualLFQ
23+
)
2124

2225
from local_attention import LocalMHA
2326
from local_attention.transformer import FeedForward, DynamicPositionBias
@@ -433,6 +436,7 @@ def __init__(
433436
rq_groups = 1,
434437
rq_stochastic_sample_codes = False,
435438
rq_kwargs: dict = {},
439+
use_lookup_free_quantizer = True, # proposed in https://arxiv.org/abs/2310.05737, adapted in residual quantization fashion for audio
436440
input_channels = 1,
437441
discr_multi_scales = (1, 0.5, 0.25),
438442
stft_normalized = False,
@@ -513,21 +517,33 @@ def __init__(
513517

514518
self.rq_groups = rq_groups
515519

516-
self.rq = GroupedResidualVQ(
517-
dim = codebook_dim,
518-
num_quantizers = rq_num_quantizers,
519-
codebook_size = codebook_size,
520-
groups = rq_groups,
521-
decay = rq_ema_decay,
522-
commitment_weight = rq_commitment_weight,
523-
quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of,
524-
kmeans_init = True,
525-
threshold_ema_dead_code = 2,
526-
quantize_dropout = True,
527-
quantize_dropout_cutoff_index = quantize_dropout_cutoff_index,
528-
stochastic_sample_codes = rq_stochastic_sample_codes,
529-
**rq_kwargs
530-
)
520+
if use_lookup_free_quantizer:
521+
assert rq_groups == 1, 'grouped residual LFQ not implemented yet'
522+
523+
self.rq = ResidualLFQ(
524+
dim = codebook_dim,
525+
num_quantizers = rq_num_quantizers,
526+
codebook_size = codebook_size,
527+
quantize_dropout = True,
528+
quantize_dropout_cutoff_index = quantize_dropout_cutoff_index,
529+
**rq_kwargs
530+
)
531+
else:
532+
self.rq = GroupedResidualVQ(
533+
dim = codebook_dim,
534+
num_quantizers = rq_num_quantizers,
535+
codebook_size = codebook_size,
536+
groups = rq_groups,
537+
decay = rq_ema_decay,
538+
commitment_weight = rq_commitment_weight,
539+
quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of,
540+
kmeans_init = True,
541+
threshold_ema_dead_code = 2,
542+
quantize_dropout = True,
543+
quantize_dropout_cutoff_index = quantize_dropout_cutoff_index,
544+
stochastic_sample_codes = rq_stochastic_sample_codes,
545+
**rq_kwargs
546+
)
531547

532548
self.decoder_film = FiLM(codebook_dim, dim_cond = 2)
533549

audiolm_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.7'
1+
__version__ = '1.6.0'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
'torchaudio',
3535
'transformers',
3636
'tqdm',
37-
'vector-quantize-pytorch>=1.7.0'
37+
'vector-quantize-pytorch>=1.10.2'
3838
],
3939
classifiers=[
4040
'Development Status :: 4 - Beta',

0 commit comments

Comments
 (0)