|
17 | 17 |
|
18 | 18 | from einops import rearrange, reduce, pack, unpack
|
19 | 19 |
|
20 |
| -from vector_quantize_pytorch import GroupedResidualVQ |
| 20 | +from vector_quantize_pytorch import ( |
| 21 | + GroupedResidualVQ, |
| 22 | + ResidualLFQ |
| 23 | +) |
21 | 24 |
|
22 | 25 | from local_attention import LocalMHA
|
23 | 26 | from local_attention.transformer import FeedForward, DynamicPositionBias
|
@@ -433,6 +436,7 @@ def __init__(
|
433 | 436 | rq_groups = 1,
|
434 | 437 | rq_stochastic_sample_codes = False,
|
435 | 438 | rq_kwargs: dict = {},
|
| 439 | + use_lookup_free_quantizer = True, # proposed in https://arxiv.org/abs/2310.05737, adapted in residual quantization fashion for audio |
436 | 440 | input_channels = 1,
|
437 | 441 | discr_multi_scales = (1, 0.5, 0.25),
|
438 | 442 | stft_normalized = False,
|
@@ -513,21 +517,33 @@ def __init__(
|
513 | 517 |
|
514 | 518 | self.rq_groups = rq_groups
|
515 | 519 |
|
516 |
| - self.rq = GroupedResidualVQ( |
517 |
| - dim = codebook_dim, |
518 |
| - num_quantizers = rq_num_quantizers, |
519 |
| - codebook_size = codebook_size, |
520 |
| - groups = rq_groups, |
521 |
| - decay = rq_ema_decay, |
522 |
| - commitment_weight = rq_commitment_weight, |
523 |
| - quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of, |
524 |
| - kmeans_init = True, |
525 |
| - threshold_ema_dead_code = 2, |
526 |
| - quantize_dropout = True, |
527 |
| - quantize_dropout_cutoff_index = quantize_dropout_cutoff_index, |
528 |
| - stochastic_sample_codes = rq_stochastic_sample_codes, |
529 |
| - **rq_kwargs |
530 |
| - ) |
| 520 | + if use_lookup_free_quantizer: |
| 521 | + assert rq_groups == 1, 'grouped residual LFQ not implemented yet' |
| 522 | + |
| 523 | + self.rq = ResidualLFQ( |
| 524 | + dim = codebook_dim, |
| 525 | + num_quantizers = rq_num_quantizers, |
| 526 | + codebook_size = codebook_size, |
| 527 | + quantize_dropout = True, |
| 528 | + quantize_dropout_cutoff_index = quantize_dropout_cutoff_index, |
| 529 | + **rq_kwargs |
| 530 | + ) |
| 531 | + else: |
| 532 | + self.rq = GroupedResidualVQ( |
| 533 | + dim = codebook_dim, |
| 534 | + num_quantizers = rq_num_quantizers, |
| 535 | + codebook_size = codebook_size, |
| 536 | + groups = rq_groups, |
| 537 | + decay = rq_ema_decay, |
| 538 | + commitment_weight = rq_commitment_weight, |
| 539 | + quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of, |
| 540 | + kmeans_init = True, |
| 541 | + threshold_ema_dead_code = 2, |
| 542 | + quantize_dropout = True, |
| 543 | + quantize_dropout_cutoff_index = quantize_dropout_cutoff_index, |
| 544 | + stochastic_sample_codes = rq_stochastic_sample_codes, |
| 545 | + **rq_kwargs |
| 546 | + ) |
531 | 547 |
|
532 | 548 | self.decoder_film = FiLM(codebook_dim, dim_cond = 2)
|
533 | 549 |
|
|
0 commit comments