Skip to content

Commit 06a7491

Browse files
committed
add option for text embedding late average upsampling
1 parent ac3c435 commit 06a7491

File tree

1 file changed

+67
-11
lines changed
  • src/f5_tts/model/backbones

1 file changed

+67
-11
lines changed

src/f5_tts/model/backbones/dit.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,16 @@
2929

3030

3131
class TextEmbedding(nn.Module):
32-
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
32+
def __init__(
33+
self, text_num_embeds, text_dim, mask_padding=True, average_upsampling=False, conv_layers=0, conv_mult=2
34+
):
3335
super().__init__()
3436
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
3537

3638
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
39+
self.average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder)
40+
if average_upsampling:
41+
assert mask_padding, "text_embedding_average_upsampling requires text_mask_padding to be True"
3742

3843
if conv_layers > 0:
3944
self.extra_modeling = True
@@ -45,11 +50,47 @@ def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0,
4550
else:
4651
self.extra_modeling = False
4752

48-
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
53+
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
54+
batch, text_len, text_dim = text.shape
55+
56+
if audio_mask is None:
57+
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
58+
valid_mask = audio_mask & text_mask
59+
audio_lens = audio_mask.sum(dim=1) # [batch]
60+
valid_lens = valid_mask.sum(dim=1) # [batch]
61+
62+
upsampled_text = torch.zeros_like(text)
63+
64+
for i in range(batch):
65+
audio_len = audio_lens[i].item()
66+
valid_len = valid_lens[i].item()
67+
68+
if valid_len == 0:
69+
continue
70+
71+
valid_ind = torch.where(valid_mask[i])[0]
72+
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
73+
74+
base_repeat = audio_len // valid_len
75+
remainder = audio_len % valid_len
76+
77+
indices = []
78+
for j in range(valid_len):
79+
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
80+
indices.extend([j] * repeat_count)
81+
82+
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
83+
upsampled = valid_data[indices] # [audio_len, text_dim]
84+
85+
upsampled_text[i, :audio_len, :] = upsampled
86+
87+
return upsampled_text
88+
89+
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
4990
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
5091
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
5192
batch, text_len = text.shape[0], text.shape[1]
52-
text = F.pad(text, (0, seq_len - text_len), value=0)
93+
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
5394
if self.mask_padding:
5495
text_mask = text == 0
5596

@@ -61,7 +102,7 @@ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
61102
# possible extra modeling
62103
if self.extra_modeling:
63104
# sinus pos emb
64-
batch_start = torch.zeros((batch,), dtype=torch.long)
105+
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
65106
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
66107
text_pos_embed = self.freqs_cis[pos_idx]
67108
text = text + text_pos_embed
@@ -75,6 +116,9 @@ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
75116
else:
76117
text = self.text_blocks(text)
77118

119+
if self.average_upsampling:
120+
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
121+
78122
return text
79123

80124

@@ -113,6 +157,7 @@ def __init__(
113157
text_num_embeds=256,
114158
text_dim=None,
115159
text_mask_padding=True,
160+
text_embedding_average_upsampling=False,
116161
qk_norm=None,
117162
conv_layers=0,
118163
pe_attn_head=None,
@@ -127,7 +172,11 @@ def __init__(
127172
if text_dim is None:
128173
text_dim = mel_dim
129174
self.text_embed = TextEmbedding(
130-
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
175+
text_num_embeds,
176+
text_dim,
177+
mask_padding=text_mask_padding,
178+
average_upsampling=text_embedding_average_upsampling,
179+
conv_layers=conv_layers,
131180
)
132181
self.text_cond, self.text_uncond = None, None # text cache
133182
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
@@ -190,19 +239,20 @@ def get_input_embed(
190239
drop_audio_cond: bool = False,
191240
drop_text: bool = False,
192241
cache: bool = True,
242+
audio_mask: bool["b n"] | None = None, # noqa: F722
193243
):
194244
seq_len = x.shape[1]
195245
if cache:
196246
if drop_text:
197247
if self.text_uncond is None:
198-
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
248+
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
199249
text_embed = self.text_uncond
200250
else:
201251
if self.text_cond is None:
202-
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
252+
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
203253
text_embed = self.text_cond
204254
else:
205-
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
255+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
206256

207257
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
208258

@@ -230,13 +280,19 @@ def forward(
230280
# t: conditioning time, text: text, x: noised audio + cond audio + text
231281
t = self.time_embed(time)
232282
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
233-
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
234-
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
283+
x_cond = self.get_input_embed(
284+
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache, audio_mask=mask
285+
)
286+
x_uncond = self.get_input_embed(
287+
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache, audio_mask=mask
288+
)
235289
x = torch.cat((x_cond, x_uncond), dim=0)
236290
t = torch.cat((t, t), dim=0)
237291
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
238292
else:
239-
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
293+
x = self.get_input_embed(
294+
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache, audio_mask=mask
295+
)
240296

241297
rope = self.rotary_embed.forward_from_seq_len(seq_len)
242298

0 commit comments

Comments
 (0)