# matmul_qk(bs, 8, seq_len_q, seq_len_k)
matmul_qk tf.matmul(q, k, transpose_b True) # (..., seq_len_q, seq_len_k)
# 缩放 matmul_qk
dk tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits matmul_qk / tf.math.sqrt(dk)
# 将 mask 加入到缩放的张量上。
if mask is not None:
# mask为1的位置变成非常小的数
scaled_attention_logits (mask * -1e9)
# softmax 在最后一个轴 seq_len_k 上归一化 因此分数
# 相加等于1。
attention_weights tf.nn.softmax(scaled_attention_logits, axis -1) # (..., seq_len_q, seq_len_k)
output tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def get_angles(pos, i, d_model):
获取角度
angle_rates 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
位置编码
angle_rads get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# 将 sin 应用于数组中的偶数索引 indices 2i
angle_rads[:, 0::2] np.sin(angle_rads[:, 0::2])
# 将 cos 应用于数组中的奇数索引 2i 1
angle_rads[:, 1::2] np.cos(angle_rads[:, 1::2])
pos_encoding angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype tf.float32)
def create_padding_mask(seq):
创建填充遮挡 1为遮挡位置
seq tf.cast(tf.math.equal(seq, 0), tf.float32)
# 添加额外的维度来将填充加到
# 注意力对数 logits 。
return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
def create_look_ahead_mask(size):
创建前瞻遮挡 1为遮挡位置
mask 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask # (seq_len, seq_len)
class MultiHeadAttention(tf.keras.layers.Layer):
多头注意力
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads num_heads
self.d_model d_model
assert d_model % self.num_heads 0
self.depth d_model // self.num_heads
self.wq tf.keras.layers.Dense(d_model)
self.wk tf.keras.layers.Dense(d_model)
self.wv tf.keras.layers.Dense(d_model)
self.dense tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
分拆最后一个维度到 (num_heads, depth).
转置结果使得形状为 (batch_size, num_heads, seq_len, depth)
x tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm [0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size tf.shape(q)[0]
q self.wq(q) # (batch_size, seq_len, d_model)
k self.wk(k) # (batch_size, seq_len, d_model)
v self.wv(v) # (batch_size, seq_len, d_model)
q self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights scaled_dot_product_attention(
q, k, v, mask)
scaled_attention tf.transpose(scaled_attention, perm [0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
Bahdanau注意力
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
super().__init__()
# For Eqn. (4), the Bahdanau attention
self.W1 tf.keras.layers.Dense(units, use_bias False)
self.W2 tf.keras.layers.Dense(units, use_bias False)
self.attention tf.keras.layers.AdditiveAttention()
def call(self, query, value, mask):
shape_checker ShapeChecker()
shape_checker(query, ( batch , t , query_units ))
shape_checker(value, ( batch , s , value_units ))
shape_checker(mask, ( batch , s ))
# From Eqn. (4), W1 ht .
w1_query self.W1(query)
shape_checker(w1_query, ( batch , t , attn_units ))
# From Eqn. (4), W2 hs .
w2_key self.W2(value)
shape_checker(w2_key, ( batch , s , attn_units ))
query_mask tf.ones(tf.shape(query)[:-1], dtype bool)
value_mask mask
context_vector, attention_weights self.attention(
inputs [w1_query, value, w2_key],
mask [query_mask, value_mask],
return_attention_scores True,
shape_checker(context_vector, ( batch , t , value_units ))
shape_checker(attention_weights, ( batch , t , s ))
return context_vector, attention_weights
参考资料 https://www.tensorflow.org/text/tutorials/transformer#scaled_dot_product_attention