栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

tensorflow从入门到精通——线性回归实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

tensorflow从入门到精通——线性回归实现

import tensorflow as tf
import os
os.environ['TF_LOG_MIN_LEVEL'] = '2'

class LinearRegression():

    def __init__(self,data=None):
        if data is None:
            self.X,self.Y = self.gener_data()
        else:
            self.X,self.Y = data

        self.weights = tf.Variable(initial_value=tf.random_normal(shape=[1,1],mean=0,stddev=0.1),dtype=tf.float32)
        self.bias = tf.Variable(initial_value=tf.random_normal(shape=[1,1],mean=0,stddev=0.1),dtype=tf.float32)

        with tf.get_default_graph().device("/gpu:0"):
            self.y_ped = tf.matmul(self.X,self.weights)+self.bias

        #     损失函数
        self.loss = tf.reduce_mean(tf.square(self.Y-self.y_ped))
        # 定义损失函数
        self.optim = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(self.loss)
        self.init_data = tf.global_variables_initializer()

    def fit(self,epochs=1000):

        print("model training....")
        with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=True
        )) as sess:
            sess.run(self.init_data)
            print("初始化变量:Weithts = %f, bias = %f, loss = %f" %
                  (sess.run(self.weights), sess.run(self.bias), sess.run(self.loss)))

            for i in range(epochs):
                #优化
                sess.run(self.optim)
                if (i + 1) % 1 == 0:
                    print("训练后第%d次后:Weithts = %f, bias = %f, loss = %f" %
                          (i + 1, sess.run(self.weights), sess.run(self.bias), sess.run(self.loss)))


    def predict(self,x):

        return tf.matmul(x,self.weights)+self.bias

    def gener_data(self):

        X =  tf.random_normal(shape=[100,1],dtype=tf.float32)
        noise = tf.random_normal(shape=[100,1],dtype=tf.float32)/1000.
        # tf.case
        Y = tf.matmul(X,[[2.0]])+0.5+noise

        return X,Y

if __name__ == '__main__':
    model = LinearRegression()
    model.fit(epochs=10000)



```![在这里插入图片描述](https://img-blog.csdnimg.cn/bd5efcf9b3514979aba18999f18d9b19.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA5bCP6ZmIcGhk,size_20,color_FFFFFF,t_70,g_se,x_16)
![在这里插入图片描述](https://img-blog.csdnimg.cn/56a1000dc9ad4c1a8e7643e6f066448a.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA5bCP6ZmIcGhk,size_20,color_FFFFFF,t_70,g_se,x_16)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/294574.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号