栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

在TensorFlow中使用预训练的单词嵌入(word2vec或Glove)

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

在TensorFlow中使用预训练的单词嵌入(word2vec或Glove)

您可以通过多种方式在TensorFlow中使用预训练的嵌入。假设您将NemPy数组嵌入到

embedding
具有
vocab_size
行和
embedding_dim
列的NumPy数组中,并且想要创建一个
W
可用于调用的张量
tf.nn.embedding_lookup()

  1. 只需创建
    W
    一个
    tf.constant()
    是需要
    embedding
    为它的价值:
    W = tf.constant(embedding, name="W")

这是最简单的方法,但是由于a的值

tf.constant()
多次存储在内存中,因此内存使用效率不高。由于
embedding
可能很大,因此只应将这种方法用于玩具示例。

  1. 创建

    W
    为a,
    tf.Variable
    并通过NumPy数组对其进行初始化
    tf.placeholder()

    W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_dim]), trainable=False, name="W")

    embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_dim])
    embedding_init = W.assign(embedding_placeholder)

    sess = tf.Session()

    sess.run(embedding_init, feed_dict={embedding_placeholder: embedding})

这样可以避免

embedding
在图表中存储的副本,但确实需要足够的内存才能一次在内存中保留矩阵的两个副本(一个用于NumPy数组,一个用于
tf.Variable
)。请注意,我假设您想在训练期间保持嵌入矩阵不变,因此
W
是使用创建的
trainable=False

  1. 如果将嵌入训练为另一个TensorFlow模型的一部分,则可以使用

    tf.train.Saver
    从另一个模型的检查点文件中加载值。这意味着嵌入矩阵可以完全绕过Python。
    W
    按照选项2创建,然后执行以下操作:

    W = tf.Variable(...)

    embedding_saver = tf.train.Saver({“name_of_variable_in_other_model”: W})

    sess = tf.Session()
    embedding_saver.restore(sess, “checkpoint_filename.ckpt”)



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/634052.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号