栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

李宏毅(2020)作业2-hw2

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

李宏毅(2020)作业2-hw2


数据集:https://wwr.lanzoui.com/ibnAxud394j
密码:bvf8

文章目录
  • Logistic Regression
      • Preparing Data
      • Some Useful Functions
      • Functions about gradient and loss
      • Training
      • Plotting Loss and accuracy curve
      • Predicting testing labels
  • Porbabilistic generative model
      • Preparing Data
      • Mean and Covariance
      • Computing weights and bias
      • Predicting testing labels




!tar -zxvf data.tar.gz
!ls
data/
data/sample_submission.csv
data/test_no_label.csv
data/train.csv
data/X_test
data/X_train
data/Y_train
acc.png  data  data.tar.gz  loss.png  output_logistic.csv  hw2.ipynb
Logistic Regression Preparing Data

下载数据,并且对每个属性做正则化,处理过后再将其切分为训练集与验证集。

import numpy as np

np.random.seed(0)
X_train_fpath = './data/X_train'
Y_train_fpath = './data/Y_train'
X_test_fpath = './data/X_test'
output_fpath = './output_{}.csv'

# Parse csv files to numpy array
with open(X_train_fpath) as f:
    next(f)
    X_train = np.array([line.strip('n').split(',')[1:] for line in f], dtype = float)
with open(Y_train_fpath) as f:
    next(f)
    Y_train = np.array([line.strip('n').split(',')[1] for line in f], dtype = float)
with open(X_test_fpath) as f:
    next(f)
    X_test = np.array([line.strip('n').split(',')[1:] for line in f], dtype = float)

def _normalize(X, train = True, specified_column = None, X_mean = None, X_std = None):
    # 此函数用于规范化X的特定列。
    # 训练数据的均值和标准方差将在处理测试数据时重复使用。
    #
    # Arguments:
    #     X: 待处理的数据
    #     train: 'True' when processing training data, 'False' for testing data
    #     specific_column: 将被规范化的列的索引。如果参数为“None”,则所有列都将标准化。
    #     X_mean: mean value of training data, used when train = 'False'
    #     X_std: standard deviation of training data, used when train = 'False'
    # Outputs:
    #     X: normalized data
    #     X_mean: computed mean value of training data
    #     X_std: computed standard deviation of training data

    if specified_column == None:
        specified_column = np.arange(X.shape[1])
    if train:
        X_mean = np.mean(X[:, specified_column] ,0).reshape(1, -1)
        X_std  = np.std(X[:, specified_column], 0).reshape(1, -1)

    X[:,specified_column] = (X[:, specified_column] - X_mean) / (X_std + 1e-8)
     
    return X, X_mean, X_std

def _train_dev_split(X, Y, dev_ratio = 0.25):
    # This function spilts data into training set and development set.
    train_size = int(len(X) * (1 - dev_ratio))
    return X[:train_size], Y[:train_size], X[train_size:], Y[train_size:]

# Normalize training and testing data
X_train, X_mean, X_std = _normalize(X_train, train = True)
X_test, _, _= _normalize(X_test, train = False, specified_column = None, X_mean = X_mean, X_std = X_std)
    
# Split data into training set and development set
dev_ratio = 0.1
X_train, Y_train, X_dev, Y_dev = _train_dev_split(X_train, Y_train, dev_ratio = dev_ratio)

train_size = X_train.shape[0]
dev_size = X_dev.shape[0]
test_size = X_test.shape[0]
data_dim = X_train.shape[1]
print('Size of training set: {}'.format(train_size))
print('Size of development set: {}'.format(dev_size))
print('Size of testing set: {}'.format(test_size))
print('Dimension of data: {}'.format(data_dim))
Size of training set: 48830
Size of development set: 5426
Size of testing set: 27622
Dimension of data: 510
X_train.shape
(48830, 510)
Some Useful Functions

这几个函数可能会在训练循环中被重复使用到。

def _shuffle(X, Y):
    # 此函数将两个等长的列表/数组X和Y混合在一起。
    randomize = np.arange(len(X))
    np.random.shuffle(randomize)
    return (X[randomize], Y[randomize])

def _sigmoid(z):
    # Sigmoid function can be used to calculate probability.
    return np.clip(1 / (1.0 + np.exp(-z)), 1e-8, 1 - (1e-8))  # 为避免溢出,设置了最小/最大输出值。

def _f(X, w, b):
    # 逻辑回归函数, parameterized by w and b
    #
    # Arguements:
    #     X: input data, shape = [batch_size, data_dimension]
    #     w: weight vector, shape = [data_dimension, ]
    #     b: bias, scalar
    # Output:
    #     X的每一行被正标记的预测概率, shape = [batch_size, ]
    return _sigmoid(np.matmul(X, w) + b)

def _predict(X, w, b):
    # 此函数通过对逻辑回归函数的结果进行四舍五入,为X的每一行返回真值预测。
    return np.round(_f(X, w, b)).astype(np.int)
    
def _accuracy(Y_pred, Y_label):
    # This function calculates prediction accuracy
    acc = 1 - np.mean(np.abs(Y_pred - Y_label))
    return acc
Functions about gradient and loss

Cross entropy:

C ( f ( x n ) , y ^ n ) = − [ y ^ n ln ⁡ f ( x n ) + ( 1 − y ^ n ) ln ⁡ ( 1 − f ( x n ) ) ] Cleft(fleft(x^{n}right), hat{y}^{n}right)=-left[hat{y}^{n} ln fleft(x^{n}right)+left(1-hat{y}^{n}right) ln left(1-fleft(x^{n}right)right)right] C(f(xn),y^​n)=−[y^​nlnf(xn)+(1−y^​n)ln(1−f(xn))]

def _cross_entropy_loss(y_pred, Y_label):
    # This function computes the cross entropy.
    #
    # Arguements:
    #     y_pred: probabilistic predictions, float vector
    #     Y_label: ground truth labels, bool vector
    # Output:
    #     cross entropy, scalar
    cross_entropy = -np.dot(Y_label, np.log(y_pred)) - np.dot((1 - Y_label), np.log(1 - y_pred))
    return cross_entropy

def _gradient(X, Y_label, w, b):
    # This function computes the gradient of cross entropy loss with respect to weight w and bias b.
    y_pred = _f(X, w, b)
    pred_error = Y_label - y_pred
    w_grad = -np.sum(pred_error * X.T, 1)
    b_grad = -np.sum(pred_error)
    return w_grad, b_grad
Training

我们使用小批次梯度下降法来训练。训练数据被分为许多小批次,针对每一个小批次,我们分别计算其梯度以及损失,并根据该批次来更新模型的参数。当一次循环完成,也就是整个训练集的所有小批次都被使用过一次以后,我们将所有训练数据打散并且重新分成新的小批次,进行下一个循环,直到事先设定的循环数量达成为止。

# 权重和偏差的初始化为0
w = np.zeros((data_dim,)) 
b = np.zeros((1,))

# Some parameters for training    
max_iter = 10
batch_size = 8
learning_rate = 0.2

# 在每次迭代时保存损失和精度,以便绘图
train_loss = []
dev_loss = []
train_acc = []
dev_acc = []

# Calcuate the number of parameter updates
step = 1

# Iterative training
for epoch in range(max_iter):
    # 在每个epoch开始时进行随机混合
    X_train, Y_train = _shuffle(X_train, Y_train)
        
    # Mini-batch training
    for idx in range(int(np.floor(train_size / batch_size))):
        X = X_train[idx*batch_size:(idx+1)*batch_size]
        Y = Y_train[idx*batch_size:(idx+1)*batch_size]

        # Compute the gradient
        w_grad, b_grad = _gradient(X, Y, w, b)
            
        # gradient descent update
        # learning rate decay with time
        w = w - learning_rate/np.sqrt(step) * w_grad
        b = b - learning_rate/np.sqrt(step) * b_grad

        step = step + 1
            
    # Compute loss and accuracy of training set and development set
    y_train_pred = _f(X_train, w, b)
    Y_train_pred = np.round(y_train_pred)#四舍五入
    train_acc.append(_accuracy(Y_train_pred, Y_train))
    train_loss.append(_cross_entropy_loss(y_train_pred, Y_train) / train_size)

    y_dev_pred = _f(X_dev, w, b)
    Y_dev_pred = np.round(y_dev_pred)
    dev_acc.append(_accuracy(Y_dev_pred, Y_dev))
    dev_loss.append(_cross_entropy_loss(y_dev_pred, Y_dev) / dev_size)

print('Training loss: {}'.format(train_loss[-1]))
print('Development loss: {}'.format(dev_loss[-1]))
print('Training accuracy: {}'.format(train_acc[-1]))
print('Development accuracy: {}'.format(dev_acc[-1]))
Training loss: 0.271355435246406
Development loss: 0.2896359675026286
Training accuracy: 0.8836166291214418
Development accuracy: 0.8733873940287504
print(w)
[ 5.92915192e-01  1.24074857e-01  1.83630863e-01 -7.70052933e-02
  1.01394159e-01  3.20603386e-02 -4.03196028e+00  3.65554831e-02
  8.69717534e-02 -3.20047799e-02 -3.03096235e-01  1.94064135e-02
 -3.85584310e-02  7.00523961e-02  2.73524545e-02  2.04627379e-02
  1.09724785e-02 -9.36514029e-03 -2.00279580e-03 -6.09131082e-02
  2.43449989e-02  3.14315744e-02 -3.11362963e-02  2.62313688e-02
  6.41679870e-02 -2.59866508e-02 -4.58250689e-02  1.53066552e-02
  5.03107399e-03  3.15076844e-02 -3.99684936e-02 -3.84093140e-02
 -3.88026206e-02  5.78755912e-02 -2.99634188e-02  1.43790145e-02
  7.34378526e-02  1.90405432e-03  5.09819666e-03 -1.75756429e-03
 -1.93598360e-02  5.19897751e-02  2.66014738e-02  2.21600062e-02
 -7.12755653e-02 -4.38964656e-02 -1.00193950e-02 -5.54453739e-03
  3.73557951e-02 -1.35260617e-02  3.79730351e-02  4.06755762e-02
 -3.52166616e-02 -3.73820684e-02 -3.08406773e-02 -9.06102046e-03
 -4.22869883e-02  5.45105197e-02 -1.84837645e-02 -9.73816178e-04
  2.35096290e-02  4.36160353e-03 -5.22678316e-03  6.56595626e-02
 -3.96729095e-02  2.34134376e-02 -2.33975169e-02  5.87461431e-02
  7.36327176e-02  8.38750350e-02 -1.30694593e-02 -2.91644238e-02
  1.42799260e-02 -1.40928226e-01 -7.22504366e-02 -3.07968494e-02
  6.18218957e-03 -1.93561165e-02 -1.77971010e-02 -1.53469573e-02
  8.26096958e-02 -1.04483810e-02 -7.93838732e-02 -2.90085757e-02
 -3.62440449e-02 -3.47261797e-02 -1.25102172e-01  2.46400252e-02
 -3.49167611e-03 -8.78659553e-02  3.86936104e-02  1.49035447e-03
  2.32374232e-02 -5.74477020e-02 -1.36891742e-01  1.36123979e-02
 -4.88133865e-02 -5.74810409e-02  1.77515908e-01 -5.23769256e-03
  4.06755762e-02 -9.04404090e-03 -6.34098078e-03  1.54080695e-01
  8.85118474e-02  3.36738814e-03  8.23197488e-02 -4.97083063e-03
  2.35096290e-02 -4.80790235e-02 -1.16362870e-02 -1.74476376e-01
  8.02275522e-02  3.67902632e-01 -1.70801496e-01 -1.06224464e-01
  2.83182527e-01  2.57067335e-01  2.95422914e-02 -1.55890413e-01
  2.86382860e-01  3.78936970e-02 -1.01663872e-01 -4.88765495e-01
 -1.53292684e-01 -1.45762205e-01 -3.74187655e-02  7.27300522e-03
 -7.57582232e-02  6.80128902e-02 -9.93414482e-03 -5.80535623e-02
 -1.23798193e-02 -3.13483910e-03  2.27974344e-01 -2.69272900e-02
 -2.03951860e-01  3.79730351e-02  2.66014738e-02 -3.88026206e-02
 -6.09131082e-02  2.44176324e-02 -9.06102046e-03 -4.58250689e-02
  5.21858719e-03 -3.11362963e-02  2.43449989e-02  3.14315744e-02
  4.06755762e-02 -9.73816178e-04  2.21600062e-02 -4.38964656e-02
 -2.00279580e-03 -3.05408583e-02  3.73557951e-02  5.22790699e-02
 -5.54453739e-03  1.78015408e-02 -7.12755653e-02  2.70737039e-03
  2.35096290e-02 -1.59047546e-01  9.15192919e-03 -9.31798486e-02
 -3.47261797e-02  1.05507312e-01  1.36123979e-02 -3.24444325e-02
  8.94897505e-02  4.06755762e-02 -1.37699304e-02 -8.69028761e-02
 -1.31838624e-02 -1.44602158e-01  2.35096290e-02  1.49838911e-01
  2.42577307e-02  3.43031737e-02  2.50531873e-02 -2.78316196e-02
 -6.15983620e-02  3.30344036e-02  2.31681547e-03 -2.51053434e-02
 -4.40556137e-02  5.57076208e-02 -3.63960679e-02 -8.69624995e-02
  2.00569516e-02 -8.57990684e-04  5.35089606e-02  2.87088021e-01
 -2.87088021e-01  8.95978968e-03  3.45749147e-02 -2.56327505e-02
 -2.59273404e-02 -3.16117225e-02 -2.03838303e-02  2.57507332e-02
  8.69717534e-02 -1.74139681e-02 -2.49845413e-02 -2.34973847e-01
  1.18355195e-01  1.49699481e-01  1.72561102e-02  3.95507939e-02
 -1.45467260e-02  4.65258485e-02  1.17125583e+00  2.66533224e-01
  6.63525864e-01 -7.27827630e-01 -5.07667170e-02 -4.98233485e-02
  3.39861659e-01  1.97761640e-01  2.48696844e-01 -5.82757947e-03
  6.48555493e-02  8.79573597e-03 -1.30352193e-01 -5.50644196e-03
  4.56574410e-02  3.52357383e-02  1.26392843e-02 -1.45606838e-02
  2.26005503e-02 -3.70668567e-02  6.17429857e-03  6.23984329e-03
 -6.37739438e-03 -1.26999656e-03  2.70265911e-02  4.07870314e-02
  1.75550803e-02  4.88231677e-02  2.47230063e-02  3.41322946e-02
 -8.00533358e-03 -2.33559829e-02  4.48247183e-02 -1.59850017e-02
  1.04536247e-03  1.47047168e-02 -4.07252943e-02  6.14464901e-02
 -1.56271527e-02 -5.20450618e-02  1.72157416e-02 -2.34973759e-01
  5.45197343e-04 -4.45609679e-03  4.06560403e-02  8.79573597e-03
 -4.57926861e-03  2.48132054e-02  3.39024037e-02  6.35446089e-04
  7.53157026e-04  2.90646964e-02  1.85728920e-02 -2.99027909e-02
  8.22948962e-04 -6.51020747e-01 -1.99112025e-02  6.21244989e-02
  1.65282739e-02  5.70440732e-02 -1.34075287e-02  2.84760273e-02
  4.99995700e-02 -1.43149285e-02 -1.55062812e-02  5.77329914e-02
 -3.69292506e-01  1.85185705e-01 -1.62540396e+00 -6.10308780e-02
  1.90180669e-03 -4.04372675e-02  5.61635933e-02  1.09464402e-01
  1.95775552e-02 -1.41957598e+00  0.00000000e+00  1.52638296e-01
  1.51874586e-01 -3.35707151e-01  2.10206942e-02 -8.88917872e-03
  1.49249708e-01  2.08298466e-02  1.81110355e-01  6.04814282e-02
  3.95429998e-02 -2.45003137e-02  1.37817961e-01 -4.82220155e-02
 -5.67961117e-01 -1.83616676e-02  0.00000000e+00  8.16848461e-02
 -1.63960438e-01  2.52193554e-02 -1.29585721e+00  1.20071731e-02
 -2.34232528e-02 -1.92393043e-02  2.97280736e-02  1.32905137e-01
  1.50310263e-01  1.69551084e-02 -1.16779181e+00 -5.78850347e-01
  1.82971161e-01 -1.89815872e-02  2.57564229e-02  4.14126041e-02
  2.77022159e-01 -2.39123563e-01 -1.10876152e-01 -7.35060996e-01
  1.38831593e-02  4.28625899e-02 -1.19294741e-02  6.30333266e-02
  1.03028112e-03 -3.50181242e-02 -7.23435726e-03 -2.48613475e-03
  1.05123678e-02 -1.72779241e-02 -7.35060996e-01  3.85574950e-02
 -6.53372808e-03  6.30333266e-02 -1.19474000e-02 -5.50644196e-03
 -2.48613475e-03  4.19392113e-02 -1.72779241e-02 -7.35060996e-01
 -1.65351730e-02  7.62270772e-03  6.30333266e-02 -1.19474000e-02
  8.68768714e-03 -5.50644196e-03 -2.48613475e-03 -5.67461262e-02
  6.30333266e-02 -8.79573597e-03  7.83747806e-03 -2.48613475e-03
  2.01469636e-02 -4.07393974e-02  3.66672040e-01  9.38992277e-01
 -5.11948249e-01 -1.17099133e-01 -7.98523467e-01  3.11488704e-02
  3.18128110e-02 -2.66613114e-02  1.11889429e-01 -1.79139718e-01
 -8.02663810e-02 -5.60761108e-03 -1.40482455e-02  3.03537826e-02
 -1.19358442e-02  8.01655241e-02 -1.07307290e-01 -9.81331256e-03
  2.74842494e-04  7.57205724e-02  1.57883850e-02  2.44127053e-02
  1.60225360e-02  1.84400447e-02  2.20798953e-02  2.07603753e-02
 -1.06303656e+00 -2.68607668e-02  9.87097940e-02  3.61727162e-03
  9.17383247e-02 -9.48509683e-02 -5.39205127e-03  7.72450470e-02
  1.05271766e-02 -1.13628793e-02  6.68173269e-02 -1.85428803e-02
  3.04033958e-02  2.83028233e-02 -5.33805724e-01 -2.60507241e-02
 -1.71841724e-02 -1.70909119e-02 -3.64454160e-02  5.74614270e-02
  1.61836314e-01 -1.17644747e-02  4.01201859e-02  4.25922863e-02
  2.24045782e-02 -9.01743750e-02  6.99669105e-04 -2.71684652e-03
 -3.86347376e-02  4.01499594e-03  1.86945941e-02  6.05694714e-02
 -3.15348183e-02  3.99132207e-02  2.91947040e-02  2.49551591e-02
 -5.86622605e-02  3.60658476e-02  7.84888374e-04  5.02475053e-02
 -9.46789071e-02  3.33914878e-02  2.62680076e-02  3.20750525e-02
 -5.94337201e-01  1.08208228e-01 -1.53157252e-02  2.14145748e-02
 -6.50727785e-02  8.32744213e-02 -6.04723054e-03 -3.97474836e-02
  5.62623026e-02 -2.11583861e-03  5.94678753e-03  3.60987927e-02
 -3.67392667e-02 -1.53115476e-02 -1.09345814e+00  5.78387804e-03
 -3.49596972e-04  8.32295952e-02 -3.78016197e-02  1.66456045e-01
 -2.75717539e-02 -1.22802640e-02 -1.86061772e-01 -2.41814902e-02
 -9.50519553e-03  7.43952479e-02  6.96787657e-02 -2.90360728e-02
  2.80817919e-02  6.00474195e-02 -6.38743060e-02 -7.49832513e-02
  3.36176183e-02  9.16755527e-03 -4.55444360e-02  8.08578673e-02
  2.25488387e-02  2.17729754e-02  2.75659487e-02  7.55106467e-02
 -1.23013676e-01 -1.13212007e-01 -9.00521884e-02 -7.22166756e-01
  2.76286610e-03 -1.46307816e-02  7.45215701e-02  1.65937294e-02
 -1.64732342e-02 -7.48207181e-03 -1.57432618e-02  7.50349338e-03
 -3.54804751e-02  6.72622789e-02 -1.58077957e-03 -2.00818661e-02
  7.05419732e-02  5.56183347e-02  4.35683709e-02 -1.51234264e-02
 -5.16263909e-02  4.48685948e-02 -3.25668703e-01  7.63067239e-03
  1.02488270e-02 -1.67080754e-02 -1.15820085e-02 -2.86029040e-02
  5.34047963e-02 -3.72715378e-02 -3.42018332e-02 -2.63880779e-02
  1.32741848e-01 -7.14347815e-02  3.02172395e-02  6.45711463e-02
  4.12247244e-01 -4.90592028e-01  7.14347815e-02  8.22661492e-01
  2.48613475e-03 -2.48613475e-03]
Plotting Loss and accuracy curve
import matplotlib.pyplot as plt

# Loss curve
plt.plot(train_loss)
plt.plot(dev_loss)
plt.title('Loss')
plt.legend(['train', 'dev'])
plt.savefig('loss.png')
plt.show()

# Accuracy curve
plt.plot(train_acc)
plt.plot(dev_acc)
plt.title('Accuracy')
plt.legend(['train', 'dev'])
plt.savefig('acc.png')
plt.show()


Predicting testing labels

预测测试集的数据标签并且存在output_logistic.csv中。

# Predict testing labels
predictions = _predict(X_test, w, b)
with open(output_fpath.format('logistic'), 'w') as f:
    f.write('id,labeln')
    for i, label in  enumerate(predictions):
        f.write('{},{}n'.format(i, label))

# Print out the most significant weights
ind = np.argsort(np.abs(w))[::-1]
with open(X_test_fpath) as f:
    content = f.readline().strip('n').split(',')
features = np.array(content)
for i in ind[0:10]:
    print(features[i], w[i])
 Not in universe -4.031960278019252
 Spouse of householder -1.6254039587051405
 Other Rel <18 never married RP of subfamily -1.4195759775765422
 Child 18+ ever marr Not in a subfamily -1.2958572076664725
 Unemployed full-time 1.1712558285885906
 Other Rel <18 ever marr RP of subfamily -1.1677918072962385
 Italy -1.0934581438006181
 Vietnam -1.0630365633146415
num persons worked for employer 0.93899227735665
 1 0.8226614922117186
Porbabilistic generative model

基于generative model的二元分类器

Preparing Data

训练集与测试集的处理方法跟logistic regression一模一样,然而因为generative model有可解析的最佳解,因此不必使用到development set。

# Parse csv files to numpy array
with open(X_train_fpath) as f:
    next(f)
    X_train = np.array([line.strip('n').split(',')[1:] for line in f], dtype = float)
with open(Y_train_fpath) as f:
    next(f)
    Y_train = np.array([line.strip('n').split(',')[1] for line in f], dtype = float)
with open(X_test_fpath) as f:
    next(f)
    X_test = np.array([line.strip('n').split(',')[1:] for line in f], dtype = float)

# Normalize training and testing data
X_train, X_mean, X_std = _normalize(X_train, train = True)
X_test, _, _= _normalize(X_test, train = False, specified_column = None, X_mean = X_mean, X_std = X_std)
Mean and Covariance

在generative model中,我们需要分别计算两个类别内的数据平均与协方差

# Compute in-class mean
X_train_0 = np.array([x for x, y in zip(X_train, Y_train) if y == 0])
X_train_1 = np.array([x for x, y in zip(X_train, Y_train) if y == 1])

mean_0 = np.mean(X_train_0, axis = 0)
mean_1 = np.mean(X_train_1, axis = 0)  

# Compute in-class covariance
cov_0 = np.zeros((data_dim, data_dim))
cov_1 = np.zeros((data_dim, data_dim))

for x in X_train_0:
    cov_0 += np.dot(np.transpose([x - mean_0]), [x - mean_0]) / X_train_0.shape[0]
for x in X_train_1:
    cov_1 += np.dot(np.transpose([x - mean_1]), [x - mean_1]) / X_train_1.shape[0]

# Shared covariance is taken as a weighted average of individual in-class covariance.
cov = (cov_0 * X_train_0.shape[0] + cov_1 * X_train_1.shape[0]) / (X_train_0.shape[0] + X_train_1.shape[0])
Computing weights and bias

# 计算协方差矩阵的逆
# 由于协方差矩阵可能几乎是奇异的,因此np.linalg.inv()可能会给出较大的数值误差.
# 通过奇异值分解,可以高效、准确地得到矩阵逆。
u, s, v = np.linalg.svd(cov, full_matrices=False)
inv = np.matmul(v.T * 1 / s, u.T)

# Directly compute weights and bias
w = np.dot(inv, mean_0 - mean_1)
b =  (-0.5) * np.dot(mean_0, np.dot(inv, mean_0)) + 0.5 * np.dot(mean_1, np.dot(inv, mean_1))
    + np.log(float(X_train_0.shape[0]) / X_train_1.shape[0]) 

# Compute accuracy on training set
Y_train_pred = 1 - _predict(X_train, w, b)
print('Training accuracy: {}'.format(_accuracy(Y_train_pred, Y_train)))
Training accuracy: 0.871885137127691
Predicting testing labels

预测测试集的数据标签并且存在output_generative.csv中。

# Predict testing labels
predictions = 1 - _predict(X_test, w, b)
with open(output_fpath.format('generative'), 'w') as f:
    f.write('id,labeln')
    for i, label in  enumerate(predictions):
        f.write('{},{}n'.format(i, label))

# Print out the most significant weights
ind = np.argsort(np.abs(w))[::-1]
with open(X_test_fpath) as f:
    content = f.readline().strip('n').split(',')
features = np.array(content)
for i in ind[0:10]:
    print(features[i], w[i])
 Retail trade 8.1796875
 34 -6.21875
 37 -5.7890625
 Other service -5.328125
 Other professional services -4.515625
 44 4.234375
 29 -3.125
 Same county 3.078125
 Different state in West -3.0546875
 Forestry and fisheries 2.90625
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/269623.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号