栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

yourtts代码解读

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

yourtts代码解读

输入及embedding处理

aux_input

{'d_vectors': None, 'speaker_ids': tensor([0], device='cuda:0'), 'language_ids': tensor([0], device='cuda:0')}

# speaker embedding
        if self.args.use_speaker_embedding and sid is not None:

            # sid tensor([0], device='cuda:0')
            g = self.emb_g(sid).unsqueeze(-1)  # [b, h, 1] [1,256,1] 1为batch 256为映射后的维度,1为第一个speaker
 # language embedding
        lang_emb = None
        if self.args.use_language_embedding and lid is not None:
            lang_emb = self.emb_l(lid).unsqueeze(-1) # [b,language维度,1] [1,4,1]

文本编码器

总体

生成batch的时候,已经对一个batch的文本统一了长度(padding)

x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
        # x torch.Size([1, 196, 105])
        # logs_p torch.Size([1, 192, 105])
        # x_mask torch.Size([1, 1, 105])

text_encoder前的x torch.Size([4, 117])
transformer后,但没线性投影x torch.Size([4, 196, 117]) #196是因为wordembedding
线性投影后stats torch.Size([4, 384, 117])
m torch.Size([4, 192, 117])
logs torch.Size([4, 192, 117])


TextEncoder内部的emb

x = self.emb(x) * math.sqrt(self.hidden_channels)  # [b, t, h]

x前 torch.Size([1, 105])  x后 torch.Size([1, 105, 192])  hidden_channels初始就为192


# concat the lang emb in embedding chars
if lang_emb is not None:
    x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)

x前 torch.Size([1, 105, 192])
x后 torch.Size([1, 105,196])

x变换为[b.h,t]形式

x和x_mask一起传入transformer的encoder中,此过程不细看

mask

NLP 中的Mask全解 - 知乎

x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 
sequence_mask用于在计算中屏蔽掉句子填充的部分
def sequence_mask(sequence_length, max_len=None):
    """Create a sequence mask for filtering padding in a sequence tensor.

    Args:
        sequence_length (torch.tensor): Sequence lengths. 传入batchsize大小
        max_len (int, Optional): Maximum sequence length. Defaults to None. 传入batch中最长文本的长度

    Shapes:
        - mask: :math:`[B, T_max]`
    """
    if max_len is None:  # 如果传入batch的最长文本的长度太长,会设置为none
        max_len = sequence_length.data.max()
    seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
    # B x T_max
    mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
    return mask

图中mask的维度为[4,583]即[B,t] batchsize和该batch中最长的文本长度t,这样在计算中就可以过滤掉填充的部分

x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask

x torch.Size([4, 196, 123])
transformer后x torch.Size([4, 196, 123])
线性投影后stats torch.Size([4, 384, 123]) 384 为transformer的 outchanels(192)*2
m torch.Size([4, 192, 123])
logs torch.Size([4, 192, 123])
 


# posterior encoder

引入了speaker_embedding

z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)

y torch.Size([1, 513, 551]) [B, C, T_spec]

z torch.Size([1, 192, 551])
y_mask torch.Size([1, 1, 551])
logs_q torch.Size([1, 192, 551])

z的第二维度也映射回192
 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/769455.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号