栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

AlphaFold2代码阅读(五)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

AlphaFold2代码阅读(五)

2021SC@SDUSC


class TemplatePairStack
class TemplatePairStack(hk.Module)

  def __init__(self, config, global_config, name='template_pair_stack'):
    super().__init__(name=name)
    self.config = config
    self.global_config = global_config

  def __call__(self, pair_act, pair_mask, is_training, safe_key=None):

class TemplatePairStack负责的是模板的配对堆栈
参数:
pair_act:单个模板的配对激活,形状为 [N_res, N_res, c_t]。
pair_mask:对掩码,形状为 [N_res, N_res]。
is_training:模块是否处于训练模式。
safe_key:封装随机数生成密钥的安全密钥对象。
返回:
更新的pair_act,形状[N_res, N_res, c_t]。

成对的模板特征被线性投影,以形成初始模板表示
每个模板表示都使用模板对堆栈进行独立处理,所有可训练的参数都在模板之间共享。

if safe_key is None:
      safe_key = prng.SafeKey(hk.next_rng_key())

这里和之前的class EvoformerIteration类似
在这里hk.next_rng_key()返回一个唯一的rng键

    def block(x):
  
      pair_act, safe_key = x

      dropout_wrapper_fn = functools.partial(
          dropout_wrapper, is_training=is_training, global_config=gc)

      safe_key, *sub_keys = safe_key.split(6)
      sub_keys = iter(sub_keys)

      pair_act = dropout_wrapper_fn(
          TriangleAttention(c.triangle_attention_starting_node, gc,
                            name='triangle_attention_starting_node'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleAttention(c.triangle_attention_ending_node, gc,
                            name='triangle_attention_ending_node'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
                                 name='triangle_multiplication_outgoing'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleMultiplication(c.triangle_multiplication_incoming, gc,
                                 name='triangle_multiplication_incoming'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          Transition(c.pair_transition, gc, name='pair_transition'),
          pair_act,
          pair_mask,
          next(sub_keys))

      return pair_act, safe_key

这里是负责生成模板对栈的一个块。最后返回了pair_act,和safe_key
其中上述代码中出现的5个pair_act = dropout_wrapper_fn对应于伪代码

这里的DropoutRowwise用来表示具有形状掩码 [1, Nres, Nchannel]的dropout操作,是跨行共享的
使用DropoutColumnwise来表示形状掩码[Nres,1,Nclanel]的版本,该版本跨列共享。
关于dropout操作:
dropout是指在训练一个大的神经网络的时候,随机“关闭”一些神经元,即把这些神经元从网络中“抹去”,这相当于在本次训练中,这些被“抹去”的神经元不参与本次训练,英文即是“dropout”的意思。如下图所示:

我们看到图中打叉的神经元就是被“dropout”掉的神经元,和这些个神经元相连接的权重值也一并被“抹去”,不参与本次训练。不参与本次训练是说在当前的batch中,不参与训练,每个batch都会随机挑选神经元做dropout。

    if gc.use_remat:
      block = hk.remat(block)

    res_stack = layer_stack.layer_stack(c.num_block)(block)
    pair_act, safe_key = res_stack((pair_act, safe_key))
    return pair_act

上述代码中的layer_stack是来自layer_stack.py中的一个函数
这个函数返回的是使用有效函数调用时将生成层堆栈的可调用对象

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/498956.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号