栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

增加类别

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

增加类别

tf.estimator.DNNClassifier
构造函数具有
weight_column
参数:

weight_column

:字符串或
_NumericColumn
通过
tf.feature_column.numeric_column
定义表示权重的特征列而创建的
。在训练过程中,它可用于减轻体重或增强示例效果。它将乘以示例的损失。如果是字符串,则用作从中获取权重张量的键
features
。如果为
_NumericColumn
,则通过key获取原始张量
weight_column.key
,然后
weight_column.normalizer_fn
将其应用于权重张量。


因此,只需添加一个新列并为稀有类填充一些权重即可:

weight = tf.feature_column.numeric_column('weight')...tf.estimator.DNNClassifier(..., weight_column=weight)

[更新] 这是一个完整的工作示例:

import numpy as npimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('mnist', one_hot=False)train_x, train_y = mnist.train.next_batch(1024)test_x, test_y = mnist.test.images, mnist.test.labelsx_column = tf.feature_column.numeric_column('x', shape=[784])weight_column = tf.feature_column.numeric_column('weight')classifier = tf.estimator.DNNClassifier(feature_columns=[x_column],       hidden_units=[100, 100],       weight_column=weight_column,       n_classes=10)# Trainingtrain_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': train_x, 'weight': np.ones(train_x.shape[0])},        y=train_y.astype(np.int32),        num_epochs=None, shuffle=True)classifier.train(input_fn=train_input_fn, steps=1000)# Testingtest_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': test_x, 'weight': np.ones(test_x.shape[0])},       y=test_y.astype(np.int32),       num_epochs=1, shuffle=False)acc = classifier.evaluate(input_fn=test_input_fn)print('Test Accuracy: %.3f' % acc['accuracy'])


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/636987.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号