栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

逻辑回归python实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

逻辑回归python实现

逻辑回归 Sigmod函数

g ( z ) = 1 1 + e − z g(z)=frac{1}{1+e^{-z}} g(z)=1+e−z1​

预测函数

h θ ( x ) = g ( θ T x ) h_{theta}(x)=g(theta^{T}x) hθ​(x)=g(θTx)

代价函数

c o s t ( h θ ( x ) , y ) = { − l o g ( h θ ( x ) ) , if  y = 1 − l o g ( 1 − h θ ( x ) ) , if  y = 0 cost(h_{theta}(x),y)= begin{cases} -log(h_{theta}(x)), & text {if $y=1$} \ -log(1-h_{theta}(x)), & text{if $y=0$} end{cases} cost(hθ​(x),y)={−log(hθ​(x)),−log(1−hθ​(x)),​if y=1if y=0​

如果标签为1,预测值越大则损失越小

如果标签为0,预测值越小则损失越大

分别对应上面两个函数

c o s t ( h θ ( x ) , y ) = − y l o g ( h θ ( x i ) ) − ( 1 − y ) l o g ( 1 − h θ ( x i ) ) cost(h_{theta}(x),y)=-ylog(h_{theta}(x^i))-(1-y)log(1-h_{theta}(x^i)) cost(hθ​(x),y)=−ylog(hθ​(xi))−(1−y)log(1−hθ​(xi))

损失函数

J ( θ ) = − 1 m ∑ i = 1 m c o s t ( h θ ( x ) , y ) J(theta)=-frac{1}{m}sum_{i=1}^{m} cost(h_{theta}(x),y) J(θ)=−m1​i=1∑m​cost(hθ​(x),y)

最速下降

θ j = θ j − α m j ′ ( θ ) 即 θ j = θ j − α m ( h θ ( x i ) − y i ) x j i theta_{j}=theta_{j}-frac{alpha}{m} j^{'}(theta)\ 即theta_{j}=theta_{j}-frac{alpha}{m}(h_{theta}(x^{i})-y^{i})x_{j}^i θj​=θj​−mα​j′(θ)即θj​=θj​−mα​(hθ​(xi)−yi)xji​

− j ′ ( θ ) 保 障 损 失 函 数 始 终 处 于 下 降 -j^{'}(theta)保障损失函数始终处于下降 −j′(θ)保障损失函数始终处于下降

代码实现(python)
import matplotlib
import matplotlib.pyplot as plt
import csv
import numpy as np
import math


def loadDataset():
    data = []
    labels = []
    with open('logisticDataset.txt', 'r') as f:
        reader = csv.reader(f, delimiter='t')
        for row in reader:
            data.append([1.0, float(row[0]), float(row[1])])
            labels.append(int(row[2]))
    return data, labels


def plotBestFit(W):
    # 把训练集数据用坐标的形式画出来
    dataMat, labelMat = loadDataset()
    dataArr = np.array(dataMat)
    n = np.shape(dataArr)[0]
    xcord1 = []
    ycord1 = []
    xcord2 = []
    ycord2 = []
    for i in range(n):
        if int(labelMat[i]) == 1:
            xcord1.append(dataArr[i, 1])
            ycord1.append(dataArr[i, 2])
        else:
            xcord2.append(dataArr[i, 1])
            ycord2.append(dataArr[i, 2])
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
    ax.scatter(xcord2, ycord2, s=30, c='green')

    # 把分类边界画出来
    x = np.arange(-3.0, 3.0, 0.1)
    y = (-W[0] - W[1] * x) / W[2]
    ax.plot(x, y)
    plt.show()


def plotloss(loss_list):
    x = np.arange(0, 30, 0.01)
    plt.plot(x, np.array(loss_list), label='linear')

    plt.xlabel('time')  # 梯度下降的次数
    plt.ylabel('loss')  # 损失值
    plt.title('loss trend')  # 损失值随着W不断更新,不断变化的趋势
    plt.legend()  # 图形图例
    plt.show()


def main():
    # 读取训练集(txt文件)中的数据,
    data, labels = loadDataset()
    # 将数据转换成矩阵的形式,便于后面进行计算
    # 构建特征矩阵X
    X = np.array(data)
    # 构建标签矩阵y
    y = np.array(labels).reshape(-1, 1)
    # 随机生成一个w参数(权重)矩阵    .reshape((-1,1))的作用是,不知道有多少行,只想变成一列
    W = 0.001 * np.random.randn(3, 1).reshape((-1, 1))
    # m表示一共有多少组训练数据
    m = len(X)
    # 定义梯度下降的学习率 0.03
    learn_rate = 0.03

    loss_list = []
    # 实现梯度下降算法,不断更新W,获得最优解,使损失函数的损失值最小
    for i in range(3000):
        # 最重要的就是这里用numpy 矩阵计算,完成假设函数计算,损失函数计算,梯度下降计算
        # 计算假设函数 h(w)x
        g_x = np.dot(X, W)
        h_x = 1 / (1 + np.exp(-g_x))

        # 计算损失函数 Cost Function 的损失值loss
        loss = np.log(h_x) * y + (1 - y) * np.log(1 - h_x)
        loss = -np.sum(loss) / m
        loss_list.append(loss)

        # 梯度下降函数更新W权重
        dW = X.T.dot(h_x - y) / m
        W += -learn_rate * dW

    # 得到更新后的W,可视化
    print('W最优解:')
    print(W)
    print('最终得到的分类边界:')
    plotBestFit(W)
    print('损失值随着W不断更新,不断变化的趋势:')
    plotloss(loss_list)

    # 定义一个测试数据,计算他属于那一类别
    test_x = np.array([1, -1.395634, 4.662541])
    test_y = 1 / (1 + np.exp(-np.dot(test_x, W)))
    print(test_y)


#     print(data_arr)
if __name__ == '__main__':
    main()


数据集
-0.017612   14.053064  0
-1.395634  4.662541   1
-0.752157  6.538620   0
-1.322371  7.152853   0
0.423363   11.054677  0
0.406704   7.067335   1
0.667394   12.741452  0
-2.460150  6.866805   1
0.569411   9.548755   0
-0.026632  10.427743  0
0.850433   6.920334   1
1.347183   13.175500  0
1.176813   3.167020   1
-1.781871  9.097953   0
-0.566606  5.749003   1
0.931635   1.589505   1
-0.024205  6.151823   1
-0.036453  2.690988   1
-0.196949  0.444165   1
1.014459   5.754399   1
1.985298   3.230619   1
-1.693453  -0.557540  1
-0.576525  11.778922  0
-0.346811  -1.678730  1
-2.124484  2.672471   1
1.217916   9.597015   0
-0.733928  9.098687   0
-3.642001  -1.618087  1
0.315985   3.523953   1
1.416614   9.619232   0
-0.386323  3.989286   1
0.556921   8.294984   1
1.224863   11.587360  0
-1.347803  -2.406051  1
1.196604   4.951851   1
0.275221   9.543647   0
0.470575   9.332488   0
-1.889567  9.542662   0
-1.527893  12.150579  0
-1.185247  11.309318  0
-0.445678  3.297303   1
1.042222   6.105155   1
-0.618787  10.320986  0
1.152083   0.548467   1
0.828534   2.676045   1
-1.237728  10.549033  0
-0.683565  -2.166125  1
0.229456   5.921938   1
-0.959885  11.555336  0
0.492911   10.993324  0
0.184992   8.721488   0
-0.355715  10.325976  0
-0.397822  8.058397   0
0.824839   13.730343  0
1.507278   5.027866   1
0.099671   6.835839   1
-0.344008  10.717485  0
1.785928   7.718645   1
-0.918801  11.560217  0
-0.364009  4.747300   1
-0.841722  4.119083   1
0.490426   1.960539   1
-0.007194  9.075792   0
0.356107   12.447863  0
0.342578   12.281162  0
-0.810823  -1.466018  1
2.530777   6.476801   1
1.296683   11.607559  0
0.475487   12.040035  0
-0.783277  11.009725  0
0.074798   11.023650  0
-1.337472  0.468339   1
-0.102781  13.763651  0
-0.147324  2.874846   1
0.518389   9.887035   0
1.015399   7.571882   0
-1.658086  -0.027255  1
1.319944   2.171228   1
2.056216   5.019981   1
-0.851633  4.375691   1
-1.510047  6.061992   0
-1.076637  -3.181888  1
1.821096   10.283990  0
3.010150   8.401766   1
-1.099458  1.688274   1
-0.834872  -1.733869  1
-0.846637  3.849075   1
1.400102   12.628781  0
1.752842   5.468166   1
0.078557   0.059736   1
0.089392   -0.715300  1
1.825662   12.693808  0
0.197445   9.744638   0
0.126117   0.922311   1
-0.679797  1.220530   1
0.677983   2.556666   1
0.761349   10.693862  0
-2.168791  0.143632   1
1.388610   9.341997   0
0.317029   14.739025  0
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/767792.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号