栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

transform mask_mask transformer?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

transform mask_mask transformer?

import torch
import numpy as np
import matplotlib.pyplot as plt
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    print(attn_shape)
    print(np.ones(attn_shape))
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    print(subsequent_mask)
    return torch.from_numpy(subsequent_mask) == 0
print(subsequent_mask(5))
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])
print(subsequent_mask(5))
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])
涉及的知识点: 1 np.triu or numpy.triu

对于m*n mimport numpy as np print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=-1))) print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=0))) print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=1))) #输出 数组的上三角部分: [[ 1 2 3] [ 4 5 6] [ 0 8 9] [ 0 0 12] [ 0 0 0] [ 0 0 0] [ 0 0 0]] 数组的上三角部分: [[1 2 3] [0 5 6] [0 0 9] [0 0 0] [0 0 0] [0 0 0] [0 0 0]] 数组的上三角部分: [[0 2 3] [0 0 6] [0 0 0] [0 0 0] [0 0 0] [0 0 0] [0 0 0]]

矩阵的shape是(7,3),可见k=-1是从第三行(index=2)为下标开始的,依次类推k=0是从第二行(index=1)为下标开始的,k=1是从第一行(index=0)为下标开始的

对于n*n的矩阵

import numpy as np
print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=-1)))
print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=0)))
print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=1)))
#输出
数组的上三角部分:
[[1 2 3]
 [4 5 6]
 [0 8 9]]
数组的上三角部分:
[[1 2 3]
 [0 5 6]
 [0 0 9]]
数组的上三角部分:
[[0 2 3]
 [0 0 6]
 [0 0 0]]

对于m*n mimport numpy as np print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6]], k=-1))) print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6]], k=0))) print('数组的上三角部分:n{}'.format(np.triu([[1,2,3],[4,5,6]], k=1))) #输出 数组的上三角部分: [[1 2 3] [4 5 6]] 数组的上三角部分: [[1 2 3] [0 5 6]] 数组的上三角部分: [[0 2 3] [0 0 6]]

从第一行可以看到对于这个2*3的矩阵,k=-1表示从第三行开始,但是矩阵没有第三行,所以原样输出
其他k的取值还是按照之前陈述的规律输出

2 astype
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')

作用就是numpy.ndarray类型中数字转换成uint8类型的数据
uint8表示:uint8是8位无符号整型

3 torch.from_numpy(subsequent_mask)

是将ndarray类型的数据转换成tensor类型的数据

4 torch.from_numpy(subsequent_mask) == 0

将每个位置的数==0和零判断是否相等,如果=0,此位置为True,否为False
目的:是将下三角为0,上三角为1的矩阵进行翻转得到,下三角为True,上三角为False

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/783197.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号