29
29
30
30
31
31
class TextEmbedding (nn .Module ):
32
- def __init__ (self , text_num_embeds , text_dim , mask_padding = True , conv_layers = 0 , conv_mult = 2 ):
32
+ def __init__ (
33
+ self , text_num_embeds , text_dim , mask_padding = True , average_upsampling = False , conv_layers = 0 , conv_mult = 2
34
+ ):
33
35
super ().__init__ ()
34
36
self .text_embed = nn .Embedding (text_num_embeds + 1 , text_dim ) # use 0 as filler token
35
37
36
38
self .mask_padding = mask_padding # mask filler and batch padding tokens or not
39
+ self .average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder)
40
+ if average_upsampling :
41
+ assert mask_padding , "text_embedding_average_upsampling requires text_mask_padding to be True"
37
42
38
43
if conv_layers > 0 :
39
44
self .extra_modeling = True
@@ -45,11 +50,47 @@ def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0,
45
50
else :
46
51
self .extra_modeling = False
47
52
48
- def forward (self , text : int ["b nt" ], seq_len , drop_text = False ): # noqa: F722
53
+ def average_upsample_text_by_mask (self , text , text_mask , audio_mask ):
54
+ batch , text_len , text_dim = text .shape
55
+
56
+ if audio_mask is None :
57
+ audio_mask = torch .ones_like (text_mask , dtype = torch .bool )
58
+ valid_mask = audio_mask & text_mask
59
+ audio_lens = audio_mask .sum (dim = 1 ) # [batch]
60
+ valid_lens = valid_mask .sum (dim = 1 ) # [batch]
61
+
62
+ upsampled_text = torch .zeros_like (text )
63
+
64
+ for i in range (batch ):
65
+ audio_len = audio_lens [i ].item ()
66
+ valid_len = valid_lens [i ].item ()
67
+
68
+ if valid_len == 0 :
69
+ continue
70
+
71
+ valid_ind = torch .where (valid_mask [i ])[0 ]
72
+ valid_data = text [i , valid_ind , :] # [valid_len, text_dim]
73
+
74
+ base_repeat = audio_len // valid_len
75
+ remainder = audio_len % valid_len
76
+
77
+ indices = []
78
+ for j in range (valid_len ):
79
+ repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0 )
80
+ indices .extend ([j ] * repeat_count )
81
+
82
+ indices = torch .tensor (indices [:audio_len ], device = text .device , dtype = torch .long )
83
+ upsampled = valid_data [indices ] # [audio_len, text_dim]
84
+
85
+ upsampled_text [i , :audio_len , :] = upsampled
86
+
87
+ return upsampled_text
88
+
89
+ def forward (self , text : int ["b nt" ], seq_len , drop_text = False , audio_mask : bool ["b n" ] | None = None ): # noqa: F722
49
90
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
50
91
text = text [:, :seq_len ] # curtail if character tokens are more than the mel spec tokens
51
92
batch , text_len = text .shape [0 ], text .shape [1 ]
52
- text = F .pad (text , (0 , seq_len - text_len ), value = 0 )
93
+ text = F .pad (text , (0 , seq_len - text_len ), value = 0 ) # (opt.) if not self.average_upsampling:
53
94
if self .mask_padding :
54
95
text_mask = text == 0
55
96
@@ -61,7 +102,7 @@ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
61
102
# possible extra modeling
62
103
if self .extra_modeling :
63
104
# sinus pos emb
64
- batch_start = torch .zeros ((batch ,), dtype = torch .long )
105
+ batch_start = torch .zeros ((batch ,), device = text . device , dtype = torch .long )
65
106
pos_idx = get_pos_embed_indices (batch_start , seq_len , max_pos = self .precompute_max_pos )
66
107
text_pos_embed = self .freqs_cis [pos_idx ]
67
108
text = text + text_pos_embed
@@ -75,6 +116,9 @@ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
75
116
else :
76
117
text = self .text_blocks (text )
77
118
119
+ if self .average_upsampling :
120
+ text = self .average_upsample_text_by_mask (text , ~ text_mask , audio_mask )
121
+
78
122
return text
79
123
80
124
@@ -113,6 +157,7 @@ def __init__(
113
157
text_num_embeds = 256 ,
114
158
text_dim = None ,
115
159
text_mask_padding = True ,
160
+ text_embedding_average_upsampling = False ,
116
161
qk_norm = None ,
117
162
conv_layers = 0 ,
118
163
pe_attn_head = None ,
@@ -127,7 +172,11 @@ def __init__(
127
172
if text_dim is None :
128
173
text_dim = mel_dim
129
174
self .text_embed = TextEmbedding (
130
- text_num_embeds , text_dim , mask_padding = text_mask_padding , conv_layers = conv_layers
175
+ text_num_embeds ,
176
+ text_dim ,
177
+ mask_padding = text_mask_padding ,
178
+ average_upsampling = text_embedding_average_upsampling ,
179
+ conv_layers = conv_layers ,
131
180
)
132
181
self .text_cond , self .text_uncond = None , None # text cache
133
182
self .input_embed = InputEmbedding (mel_dim , text_dim , dim )
@@ -190,19 +239,20 @@ def get_input_embed(
190
239
drop_audio_cond : bool = False ,
191
240
drop_text : bool = False ,
192
241
cache : bool = True ,
242
+ audio_mask : bool ["b n" ] | None = None , # noqa: F722
193
243
):
194
244
seq_len = x .shape [1 ]
195
245
if cache :
196
246
if drop_text :
197
247
if self .text_uncond is None :
198
- self .text_uncond = self .text_embed (text , seq_len , drop_text = True )
248
+ self .text_uncond = self .text_embed (text , seq_len , drop_text = True , audio_mask = audio_mask )
199
249
text_embed = self .text_uncond
200
250
else :
201
251
if self .text_cond is None :
202
- self .text_cond = self .text_embed (text , seq_len , drop_text = False )
252
+ self .text_cond = self .text_embed (text , seq_len , drop_text = False , audio_mask = audio_mask )
203
253
text_embed = self .text_cond
204
254
else :
205
- text_embed = self .text_embed (text , seq_len , drop_text = drop_text )
255
+ text_embed = self .text_embed (text , seq_len , drop_text = drop_text , audio_mask = audio_mask )
206
256
207
257
x = self .input_embed (x , cond , text_embed , drop_audio_cond = drop_audio_cond )
208
258
@@ -230,13 +280,19 @@ def forward(
230
280
# t: conditioning time, text: text, x: noised audio + cond audio + text
231
281
t = self .time_embed (time )
232
282
if cfg_infer : # pack cond & uncond forward: b n d -> 2b n d
233
- x_cond = self .get_input_embed (x , cond , text , drop_audio_cond = False , drop_text = False , cache = cache )
234
- x_uncond = self .get_input_embed (x , cond , text , drop_audio_cond = True , drop_text = True , cache = cache )
283
+ x_cond = self .get_input_embed (
284
+ x , cond , text , drop_audio_cond = False , drop_text = False , cache = cache , audio_mask = mask
285
+ )
286
+ x_uncond = self .get_input_embed (
287
+ x , cond , text , drop_audio_cond = True , drop_text = True , cache = cache , audio_mask = mask
288
+ )
235
289
x = torch .cat ((x_cond , x_uncond ), dim = 0 )
236
290
t = torch .cat ((t , t ), dim = 0 )
237
291
mask = torch .cat ((mask , mask ), dim = 0 ) if mask is not None else None
238
292
else :
239
- x = self .get_input_embed (x , cond , text , drop_audio_cond = drop_audio_cond , drop_text = drop_text , cache = cache )
293
+ x = self .get_input_embed (
294
+ x , cond , text , drop_audio_cond = drop_audio_cond , drop_text = drop_text , cache = cache , audio_mask = mask
295
+ )
240
296
241
297
rope = self .rotary_embed .forward_from_seq_len (seq_len )
242
298
0 commit comments