栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

1.4tensorflow简单示例-线性模型

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

1.4tensorflow简单示例-线性模型

目录
  • tensorflow简单介绍
  • 代码拆分(每个代码块可以放到一个jupyter的cell里)
    • 导入tensorflow及其他包
    • 随机生成线性模型随机点
    • 线性模型的构建和训练
      • 构造线性模型
      • 二次代价函数
      • 定义一个梯度下降法来进行训练的优化器
      • 最小化代价函数
      • 定义变量初始化
      • 采用图和会话训练模型
  • 完整代码

tensorflow简单介绍

tensor采用图运算的方式搭建并训练深度学习网络,该部分使用的库包版本为tensorflow==1.14.0

代码拆分(每个代码块可以放到一个jupyter的cell里) 导入tensorflow及其他包
import tensorflow as tf
import numpy as np
随机生成线性模型随机点
# 使用numpy生成100个随机点
x_data = np.random.rand(100)
y_data = x_data*0.1 + 0.2
线性模型的构建和训练 构造线性模型
b = tf.Variable(0.)
k = tf.Variable(0.)
y = k*x_data + b
二次代价函数
loss = tf.reduce_mean(tf.square(y_data-y))
定义一个梯度下降法来进行训练的优化器
optimizer = tf.train.GradientDescentOptimizer(0.2)
最小化代价函数
train = optimizer.minimize(loss)
定义变量初始化
init = tf.global_variables_initializer()
采用图和会话训练模型
with tf.Session() as sess:
    sess.run(init)
    for step in range(201):
        sess.run(train)
        if step%20==0:
            print(step,sess.run([k,b,loss]))
完整代码
import tensorflow as tf
import numpy as np

# 使用numpy生成100个随机点
x_data = np.random.rand(100)
y_data = x_data*0.1 + 0.2

# 构造一个线性模型
b = tf.Variable(0.)
k = tf.Variable(0.)
y = k*x_data + b

# 二次代价函数
loss = tf.reduce_mean(tf.square(y_data-y))
# 定义一个梯度下降法来进行训练的优化器
optimizer = tf.train.GradientDescentOptimizer(0.2)
# 最小化代价函数
train = optimizer.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for step in range(201):
        sess.run(train)
        if step%20==0:
            print(step,sess.run([k,b,loss]))

输出:

output:
0 [0.0451881, 0.09705041, 0.016242022]
20 [0.09693174, 0.20142093, 9.116924e-07]
40 [0.098395586, 0.20074317, 2.493052e-07]
60 [0.099161, 0.20038863, 6.817332e-08]
80 [0.099561274, 0.20020323, 1.864167e-08]
100 [0.099770576, 0.20010626, 5.097479e-09]
120 [0.09988003, 0.20005557, 1.3938917e-09]
140 [0.099937275, 0.20002906, 3.8108994e-10]
160 [0.0999672, 0.20001519, 1.04200586e-10]
180 [0.09998284, 0.20000795, 2.8495e-11]
200 [0.09999103, 0.20000416, 7.796692e-12]
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/293690.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号