栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

点击率预测

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

点击率预测

运行环境说明

Equipment environment:
system: Win10 64
python version: 3.7.10
matplotlib version: 3.4.2
numpy version: 1.20.3
sklearn version: 0.21.3
pandas version: 1.2.4
seaborn version: 0.11.1
sklearn version: 0.24.2
imblearn version: 0.8.0

-------------------点击率预测-------------------

数据探索与分析 数据信息概览

从info() 查看,各特征均没有数据缺失,有9个object数据,可能需要进行one-hot处理,需要进一步查看(该数据较多,不输出查看,可用 df.head() 查看前几行)
从df.describe().T输出的数据查看,只有非object类型的数据是数值, id的偏差不大,'click’列分布很不均衡(从mean和std可了解),C1,C14,C17数据较分散
根据df[‘click’].value_counts() 输出的信息(0:6948, 1:1502)可知,点击的样本只有不点击样本的约21%
可以通过groupby(某特征列)[‘click’].value_counts() 查看具体各特征下各类中点击和不点击的样本数量, 但有的达上千种类别,因此不适合贴出所有数据和画图显示,可通过将特征中种类少的可视化出来,特征中种类多的查看数据后将一些类别较少或者种类对应的点击率较小的给一个新的类,具体看one-hot部分

df = pd.read_csv('7_Regreaaion/task/data/train_sample_ctr.csv')
print(df.info())
# RangeIndex: 8450 entries, 0 to 8449
# Data columns (total 24 columns):
#  #   Column            Non-Null Count  Dtype  
# ---  ------            --------------  -----  
#  0   id                8450 non-null   float64
#  1   click             8450 non-null   int64  
#  2   hour              8450 non-null   int64
#  3   C1                8450 non-null   int64
#  4   banner_pos        8450 non-null   int64
#  5   site_id           8450 non-null   object
#  6   site_domain       8450 non-null   object
#  7   site_category     8450 non-null   object
#  8   app_id            8450 non-null   object
#  9   app_domain        8450 non-null   object
#  10  app_category      8450 non-null   object
#  11  device_id         8450 non-null   object
#  12  device_ip         8450 non-null   object
#  13  device_model      8450 non-null   object
#  14  device_type       8450 non-null   int64
#  15  device_conn_type  8450 non-null   int64
#  16  C14               8450 non-null   int64
#  17  C15               8450 non-null   int64
#  18  C16               8450 non-null   int64
#  19  C17               8450 non-null   int64
#  20  C18               8450 non-null   int64
#  21  C19               8450 non-null   int64
#  22  C20               8450 non-null   int64
#  23  C21               8450 non-null   int64
# dtypes: float64(1), int64(14), object(9)

print(df.describe().T)
#                    count          mean           std           min           25%           50%           75%           max
# id                8450.0  9.179895e+18  5.358954e+18  1.275880e+15  4.536770e+18  9.077210e+18  1.377865e+19  1.844660e+19
# click             8450.0  1.777515e-01  3.823260e-01  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  1.000000e+00
# hour              8450.0  1.410255e+07  2.981426e+02  1.410210e+07  1.410230e+07  1.410252e+07  1.410281e+07  1.410302e+07
# C1                8450.0  1.004960e+03  1.085978e+00  1.001000e+03  1.005000e+03  1.005000e+03  1.005000e+03  1.012000e+03
# banner_pos        8450.0  2.872189e-01  5.109662e-01  0.000000e+00  0.000000e+00  0.000000e+00  1.000000e+00  7.000000e+00
# device_type       8450.0  1.010178e+00  5.145676e-01  0.000000e+00  1.000000e+00  1.000000e+00  1.000000e+00  5.000000e+00
# device_conn_type  8450.0  3.469822e-01  8.695646e-01  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  5.000000e+00
# C14               8450.0  1.883368e+04  5.012526e+03  3.750000e+02  1.692000e+04  2.035200e+04  2.189400e+04  2.404100e+04
# C15               8450.0  3.190272e+02  2.367281e+01  2.160000e+02  3.200000e+02  3.200000e+02  3.200000e+02  7.680000e+02
# C16               8450.0  6.137112e+01  5.269093e+01  3.600000e+01  5.000000e+01  5.000000e+01  5.000000e+01  1.024000e+03
# C17               8450.0  2.110165e+03  6.155403e+02  1.120000e+02  1.863000e+03  2.325000e+03  2.526000e+03  2.756000e+03
# C18               8450.0  1.446154e+00  1.327442e+00  0.000000e+00  0.000000e+00  2.000000e+00  3.000000e+00  3.000000e+00
# C19               8450.0  2.225257e+02  3.430777e+02  3.300000e+01  3.500000e+01  3.900000e+01  1.710000e+02  1.839000e+03
# C20               8450.0  5.289988e+04  4.997953e+04 -1.000000e+00 -1.000000e+00  1.000355e+05  1.001030e+05  1.002480e+05
# C21               8450.0  8.392923e+01  7.076504e+01  1.300000e+01  2.300000e+01  6.100000e+01  1.100000e+02  2.530000e+02

print(df['click'].value_counts())
# 0    6948
# 1    1502
特征相关性分析

从图中可看出,C1与device_type, C17与C14相关性较搞,C18与C21有一定相关性(Note:可从drawingHeatMap(df, click_data, nonclick_data)查看图像显示)

时间的处理

根据数据可知,信息中的时间包含年月日的信息,把其逐个提取出来分析:把只有一个类的去除, 比如所有数据中年份和月份都是一致的 (Note: 查看draw_time_click())
时间处理后把存在多累的时间特征可视化。 从图中可看出,周末与非周末还是有区别,还有点击在凌晨点击率较高,,对于’day’, 感觉相差不大,可考虑去除。 其中’is_weekday’需要二值化或者one-hot

确定要one-hot的特征 把特种中种类少的和种类多的拆分开

根据上面的信息和当前的信息,下面多分类中出了’id’, 'hour’不需要one_hot, 但’hour’提取出来的’is_weekday’需要one-hot

# 从输出结果看, id的类别数基本与样本数量接近,可以考虑直接舍弃
feature_less_kinds, feature_many_kinds = split_feature(df)  # 移到main中
print(feature_less_kinds, feature_many_kinds)
# 类别比较少的特征
# {'click': 2, 'C1': 7, 'banner_pos': 7, 'site_category': 16, 'app_category': 18, 'device_type': 4, 'device_conn_type': 4, 'C15': 6, 'C16': 7, 'C18': 4}
# 类别较多的特征
# {'id': 8358, 'hour': 240, 'site_id': 541, 'site_domain': 442, 'app_id': 415, 'app_domain': 39, 'device_id': 1491, 'device_ip': 7713,
#  'device_model': 1244, 'C14': 891, 'C17': 322, 'C19': 61, 'C20': 126, 'C21': 56}
特征中存在较多类的处理

可以考虑把特征中类别较少的直接归为一类,但这样可能会造成较少类有强可分性的信息丢失, 因此此处采用观察数据,把特种中类别数量介于某个数量范围,且对应的点击率小于某个值时的类全部用一个新的类替换,仅在该范围内。(该项目中用20作为分界判断特征中类的多少)
(Note: 对应该函数feature_class_process())

特征中类别较少的可视化

数据清洗

对于特征中包含较多类的特征作归并

df, use_cols_filter = reorganize_data(df, feature_many_kinds, remove_cols=['id' , 'hour'], group_setting=group_setting_dict)

对需要做one-hot的数据做one-hot

one_heat_cols = data.columns.values
not_include_cols = ['C14', 'C21', 'device_type', 'click', 'day', 'int_hour'] 
# not_include_cols = ['click', 'day', 'int_hour']  # 不去除相关性高的特征
one_heat_cols = [col for col in one_heat_cols if col not in not_include_cols]
use_filter_one_heat_cols = ''.join([col+'_.*|' for col in one_heat_cols])
print(use_filter_one_heat_cols, len(use_filter_one_heat_cols))
data = one_hot_hander(one_heat_cols, data)
data = data.filter(regex=use_filter_one_heat_cols+'day|int_hour|click')
模型建立与训练

此处采用多模型对比,采用模型如下

clfs = []
lr = LogisticRegression(solver ='saga', penalty='l1', C=1.0, n_jobs=-1)
clfs.append(lr)
rf = RandomForestClassifier(n_estimators=61, max_depth=50, n_jobs=-1)
clfs.append(rf)
dt = DecisionTreeClassifier(max_depth=50, max_features=100, random_state=11)
clfs.append(dt)
gdbt = GradientBoostingClassifier(n_estimators=50, learning_rate=0.1, max_depth=50, max_features=50)
clfs.append(gdbt)
样本均衡

从下面两个样本均衡处理来看, 简单copy得到的训练,测试,召回率均比SMO的高,因此就选用简单复制使样本均衡的方式

使用简单复制达到均衡

使用该方式处理得到的结果如下:

        train_score  test_score  recall_score  precision_score  predict_time
Logist     0.655750    0.650888      0.650888         0.633333      0.048962
Random     0.995148    0.898379      0.898379         0.835407      0.203685
Decisi     0.995148    0.879856      0.879856         0.804487      0.038521
Gradie     0.995148    0.911500      0.911500         0.859476      0.089314
使用SMO算法达到均衡

使用该方式处理得到的结果如下:

        train_score  test_score  recall_score  precision_score  predict_time
Logist     0.882245    0.542553      0.542553         0.812500      0.012487
Random     0.996007    0.562766      0.562766         0.712230      0.155159
Decisi     0.996007    0.554255      0.554255         0.623188      0.009837
Gradie     0.996007    0.560638      0.560638         0.702128      0.018247
模型评估

从下面的实验结果看, GDBT的精度较高,且预测用时也短

未去除相关性高的特征

        train_score  test_score  recall_score  precision_score  predict_time
Logist     0.655750    0.650888      0.650888         0.633333      0.048962
Random     0.995148    0.898379      0.898379         0.835407      0.203685
Decisi     0.995148    0.879856      0.879856         0.804487      0.038521
Gradie     0.995148    0.911500      0.911500         0.859476      0.089314
去除后相关性高的特征

        train_score  test_score  recall_score  precision_score  predict_time
Logist     0.650568    0.650373      0.650373         0.636025      0.044393
Random     0.992281    0.897093      0.897093         0.833095      0.188017
Decisi     0.992502    0.870852      0.870852         0.794184      0.031313
Gradie     0.992502    0.915102      0.915102         0.860442      0.067582
附录:代码
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.utils import shuffle
import xgboost as xgb
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.metrics import precision_score, recall_score
from time import time, perf_counter
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn import ensemble
from matplotlib import colors, legend, pyplot as plt

plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

df = pd.read_csv('7_Regreaaion/task/data/train_sample_ctr.csv')

# print(df.info())
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
# print(df.describe().T)
# print(df['click'].value_counts())

click_data, nonclick_data = df.loc[df['click']==1], df.loc[df['click']==0]
# print(len(click_data), len(nonclick_data))

def drawingHeatMap(heatmapData, pos_data, neg_data):
    fig, axes = plt.subplots(figsize=(12,8), nrows=1, ncols=3)
    fig.subplots_adjust(left=0.1, right=0.95, bottom=0.15, top=0.9, wspace=0.1, hspace=0.15)
    axes = axes.flatten()

    corrHeatmapData, corr_pos_data, corr_neg_data = heatmapData.corr(), pos_data.corr(), neg_data.corr()

    # 用于只显示下三角部分
    # mask = np.zeros_like(corrHeatmapData, dtype=np.bool)
    # mask[np.triu_indices_from(mask)] = True

    # 小于某个值的不显示
    mask1 = np.triu(np.ones_like(corrHeatmapData, dtype=np.bool))
    mask2 = np.abs(corrHeatmapData) <= 0.5
    mask = mask1 | mask2

    # cmap = sns.diverging_palette(150, 275, s=80, l=40, n=9, center="light", as_cmap=True)
    ax1 = sns.heatmap(corrHeatmapData, ax=axes[0], linewidths=0.1, cbar=False, vmin=-1, vmax=1, annot=True, fmt='.1f', annot_kws={'size':8}, mask=mask, cmap='YlGnBu')
    ax1.set_title('all data corr', color='b')
    ax2 = sns.heatmap(corr_pos_data, ax=axes[1], yticklabels=False, cbar=False, linewidths=0.1, vmin=-1, vmax=1, annot=True, fmt='.1f', annot_kws={'size':8}, mask=mask, cmap='YlGnBu')
    ax2.set_title('click data corr', color='b')
    ax3 = sns.heatmap(corr_neg_data, ax=axes[2], yticklabels=False, linewidths=0.1, vmin=-1, vmax=1, annot=True, fmt='.1f', annot_kws={'size':8}, mask=mask, cmap='YlGnBu')
    ax3.set_title('nonclick data corr', color='b')

    plt.suptitle('Correlation between variables')
    # plt.savefig('7_Regreaaion/task/picture/heatmap.jpg', bbox_inches='tight')
    plt.show()

# 从图中可看出,C1与device_type, C17与C14相关性较搞,C18与C21有一定相关性
# drawingHeatMap(df, click_data, nonclick_data)
# print(df.head())

def split_feature(df):
    cols = df.columns
    # print(cols)
    feature_many_kinds = {}
    feature_less_kinds = {}
    for col in cols:
        if col:
            col_counts = df[col].value_counts()
            if any(col_counts.values > 1000) and len(col_counts) < 20:
                feature_less_kinds[col] = len(col_counts)
            else:
                feature_many_kinds[col] = len(col_counts)
    return feature_less_kinds, feature_many_kinds

# 从输出结果看, id的类别数基本与样本数量接近,可以考虑直接舍弃
# feature_less_kinds, feature_many_kinds = split_feature(df)  # 移到main中
# print(feature_less_kinds, feature_many_kinds)
# 类别比较少的特征
# {'click': 2, 'C1': 7, 'banner_pos': 7, 'site_category': 16, 'app_category': 18, 'device_type': 4, 'device_conn_type': 4, 'C15': 6, 'C16': 7, 'C18': 4}
# 类别较多的特征
# {'id': 8358, 'hour': 240, 'site_id': 541, 'site_domain': 442, 'app_id': 415, 'app_domain': 39, 'device_id': 1491, 'device_ip': 7713,
#  'device_model': 1244, 'C14': 891, 'C17': 322, 'C19': 61, 'C20': 126, 'C21': 56}

# 查看分类少的各点击率的情况
def draw_lesscategory_clickrate(feature_less_kinds):
    index_category_click = list(feature_less_kinds.keys())
    index_category_click.remove('click')

    fig = plt.figure(figsize=(9, 9))
    fig_pos = 331
    for col in index_category_click:
        plt.subplot(fig_pos)
        # 为啥用groupby 不能画图, 后面需了解
        # ax = df.groupby(df[col]).plot(kind='bar', label=u'各类的数量', legend=True)
        ax = df[col].value_counts().plot(kind='bar')
        ax2 = ax.twinx()
        ax2.tick_params(axis='y', colors='r')
        click_rates = df.groupby(df[col])['click'].sum().sort_index()/df[col].value_counts().sort_index()
        ax2.plot(click_rates.values, color='r') #, label=u'各类下点击的概率'len(click_data)
        plt.title(col)
        fig_pos += 1
    plt.figlegend([u'各类的数量', u'各类下点击的概率'])
    fig.supxlabel('category name in each feature')
    fig.supylabel('number of category')
    plt.tight_layout()
    # plt.savefig('7_Regreaaion/task/picture/feature_less_kinds_click_rate.jpg')
    plt.show()
# 从途中可知,在特征上对某个类别,点击率明显高, 所以不能舍弃
# draw_lesscategory_clickrate(feature_less_kinds)

########################## Hour 时间处理 #################################
# 把日期转换为 年、月,日,时,是否周末的信息,并返回
def time_hander(df):
    time = pd.Dataframe()
    time['date_time'] = pd.to_datetime(df['hour'], format='%y%m%d%H')
    time['year'] = time['date_time'].dt.year
    time['month'] = time['date_time'].dt.month
    time['day'] = time['date_time'].dt.day
    time['int_hour'] = time['date_time'].dt.hour
    time['is_weekday'] = time['date_time'].dt.dayofweek
    time['is_weekday'] = time.apply(lambda x: x['is_weekday'] in [5, 6], axis=1)
    time = time.drop('date_time', axis=1)

    return time

# 查看有用的时间信息中各类对应的点击率 (把只有一类的信息去除,因为对判断没用,函数内部已经处理了)
def draw_time_click(df):
    # 对于提取出的所有时间,查看有多类的情况, 因为只有一类就不需要分析,可直接去除(当然只是再小样本中,具体大所有样本要再具体看)
    time_df = time_hander(df)
    time_dict_set = {}
    for col in time_df.columns:
        time_kind_set = set(time_df[col])
        if len(time_kind_set) > 1:
            time_dict_set[col] = time_kind_set
    # 从输出可知,只有'day', 'int_hour', 'is_weekday' 这几个信息有多属性
    # print(time_dict_set)

    df = pd.concat([df, time_df[list(time_dict_set.keys())]], axis=1)
    fig, axes = plt.subplots(figsize=(11, 4), nrows=1, ncols=3)
    axes = axes.flatten()
    fig_pos = 0
    for col in time_dict_set.keys():
        ax = df[col].value_counts().sort_index().plot(kind='bar', alpha=0.5, ax=axes[fig_pos], legend=True, sharex=True)
        ax.set_title(col)
        ax2 = ax.twinx()
        ax2.tick_params(axis='y', colors='r')
        click_rates = df.groupby(df[col])['click'].sum().sort_index()/df[col].value_counts().sort_index().values
        ax2.plot(click_rates.values, color='r', label='click_rate')
        # 对于下面这一行,本可替代上面的,可'day'显示中,不能显示click rate, 其他两个可以正常显示, 不知道原因, 待查明
        # ax2.plot(click_rates, color='r', label='click_rate')
        ax2.legend(click_rates.index.name)
        fig_pos += 1
    plt.legend(ncol=2, fontsize=8, framealpha=0.8)
    # plt.figlegend(['number', 'click_rate'])
    plt.ylabel(u'特征下各类别数量')
    ax2.set_ylabel(u'特征下各类中的点击率')
    plt.tight_layout()
    plt.savefig('7_Regreaaion/task/picture/time_msg_click_rate.jpg')
    plt.show()

    return time_df
# 保留从'hour'提取出的且显示出的信息即可:'day', 'int_hour', 'is_weekday', 从途中看'day'的类别中, 各类别下的点击率差不多,可以考虑去除
# draw_time_click(df)

########################## 对特征有较多类的进行类别组合函数 ##################
def feature_class_process(name_feature, n_class_boundaries, rates, fill_class, name_pos='click', df=df):
    name_feature, df, n_class_boundaries, rates, fill_class= name_feature, df, n_class_boundaries, rates, fill_class
    flag_fill_new = True

    feature_value_cnt = df[name_feature].value_counts()
    feature_pos_rate = df.groupby(df[name_feature])[name_pos].sum().sort_index()/ df[name_feature].value_counts().sort_index()
    feature_pos_rate.name = 'rate'
    # n_feature_class_pos = df.groupby(df[name_feature])[name_pos].sum()
    df_feature = pd.concat([feature_value_cnt, feature_pos_rate], axis=1)

    def class_polymerization(index_condition, fill_str):
        indexs = df_feature[name_feature][index_condition].index.values
        for index_str in indexs:
            df[name_feature].replace(index_str, fill_str, inplace=True)

    for i in range(len(n_class_boundaries)):
        lower_limit = 1 if i==0 else n_class_boundaries[i-1]
        rate = 1 if i==0 else rates[i-1]

        condition_1 = np.array(df_feature[name_feature] >= lower_limit)
        condition_2 = np.array(df_feature[name_feature]  <  n_class_boundaries[i])
        condition_3 = np.array(df_feature['rate'] <= rate)
        index_condition = list(condition_1 * condition_2 * condition_3)
        if any(index_condition):
            fill_str = fill_class[i] if flag_fill_new==True else fill_str
            flag_fill_new = False if sum(index_condition) == len(index_condition) else True
            class_polymerization(index_condition, fill_str)
    
    return df[name_feature]

def reorganize_data(df, feature_many_kinds, remove_cols, group_setting):
    feature_many_kinds_cols = list(feature_many_kinds.keys())
    for col in remove_cols:
        feature_many_kinds_cols.remove(col)
    # feature_many_kinds_cols.remove('hour')
    use_cols_filter = ''
    for col in feature_many_kinds_cols:
        # fill_strs = [col+'_'+s for s in fill_class]
        use_cols_filter += '|'+col
        df[col] = feature_class_process(name_feature=col, n_class_boundaries=group_setting['n_boundaries'], rates=group_setting['rates'], fill_class=group_setting['fill_class'])
    # print(use_cols_filter)
    return df, use_cols_filter

def one_hot_hander(feature_list, data):
    data_one_heat = []
    for feature_name in feature_list:
        data_one_heat.append(pd.get_dummies(data[feature_name], prefix= feature_name))
    data_hoted = pd.concat([data]+data_one_heat, axis=1)
    return data_hoted

def train_test_split_and_sample_balanced(data, method=None, test_size=0.3):
    if method == 'copy':
        click_data = data.loc[data['click']==1]
        data = pd.concat([click_data, data, click_data, click_data], axis=0)
        data = shuffle(data, random_state=333)
    
    data_in, data_target = data.drop('click', axis=1), data['click']
    X_train, X_test, y_train, y_test = train_test_split(data_in, data_target, test_size=test_size, random_state=111)

    if method == 'smo':
        from imblearn.over_sampling import SMOTE

        test = pd.concat([X_test, y_test], axis=1)
        test = pd.concat([test[y_test==0].sample(n=np.sum(y_test==1), axis=0), test[y_test==1]], axis=0)
        X_test, y_test = test.drop('click', axis=1), test['click']
        # print(test['click'].value_counts())
 
        sm = SMOTE(random_state=333)
        X_train, y_train = sm.fit_resample(X_train, y_train)
        # print(y_train.value_counts())      

    return X_train, X_test, y_train, y_test

if __name__ == '__main__':
    # 从输出结果看, id的类别数基本与样本数量接近,可以考虑直接舍弃
    feature_less_kinds, feature_many_kinds = split_feature(df)
    # config to use feature of many kinds
    group_setting_dict = {'n_boundaries':[5, 10, 50, 100],'rates':[0.8, 0.7, 0.5],
        'fill_class':['place_1', 'place_2', 'place_3', 'place_4']}
    # 对类别较多的特征重新分组
    df, use_cols_filter = reorganize_data(df, feature_many_kinds, remove_cols=['id' , 'hour'], group_setting=group_setting_dict)

    df = pd.concat([df, time_hander(df)[['day', 'int_hour', 'is_weekday']]], axis=1)
    data = df.filter(regex='click|day|int_hour|is_weekday|C1|banner_pos|site_category|app_category|device_type|
        |device_conn_type|C15|C16|C18'+use_cols_filter)

    one_heat_cols = data.columns.values
    not_include_cols = ['C14', 'C21', 'device_type', 'click', 'day', 'int_hour'] # 去除相关性高的特征
    # not_include_cols = ['click', 'day', 'int_hour']
    one_heat_cols = [col for col in one_heat_cols if col not in not_include_cols]
    use_filter_one_heat_cols = ''.join([col+'_.*|' for col in one_heat_cols])
    print(use_filter_one_heat_cols, len(use_filter_one_heat_cols))
    data = one_hot_hander(one_heat_cols, data)
    data = data.filter(regex=use_filter_one_heat_cols+'day|int_hour|click')

    # 测试、训练集拆分
    X_train, X_test, y_train, y_test = train_test_split_and_sample_balanced(data, method='smo', test_size=0.3)

    clfs = []
    lr = LogisticRegression(solver ='saga', penalty='l1', C=1.0, n_jobs=-1)
    clfs.append(lr)
    rf = RandomForestClassifier(n_estimators=61, max_depth=50, n_jobs=-1)
    clfs.append(rf)
    dt = DecisionTreeClassifier(max_depth=50, max_features=100, random_state=11)
    clfs.append(dt)
    gdbt = GradientBoostingClassifier(n_estimators=50, learning_rate=0.1, max_depth=50, max_features=50)
    clfs.append(gdbt)
    # voting = ensemble.VotingClassifier(clfs, voting='hard')
    # clfs.append(voting)
    # xgboost = xgb.XGBClassifier()
    # clfs.append(xgboost)

    mode_metrics = pd.Dataframe()

    for clf in clfs:
        clf.fit(X_train, y_train)
        start = perf_counter()
        y_predict = clf.predict(X_test)
        end = perf_counter()

        mode_name = clf.__class__.__name__.replace('Classifier', '')[:6]
        mode_metrics.loc[mode_name, 'train_score'] = np.mean(clf.predict(X_train)==y_train)
        mode_metrics.loc[mode_name, 'test_score'] = np.mean(y_predict==y_test)
        mode_metrics.loc[mode_name, 'recall_score'] = recall_score(y_test, y_predict, average='micro')
        mode_metrics.loc[mode_name, 'precision_score'] = precision_score(y_test, y_predict)
        mode_metrics.loc[mode_name, 'predict_time'] = end - start
    print(mode_metrics)
    ax = mode_metrics.plot(kind='bar', secondary_y=['predict_time'], figsize=(8, 6))
    plt.xticks(rotation=90)
    ax.set_xlabel('mode name', color='r')
    ax.set_ylabel('accuracy', color='r')
    ax.right_ax.set_ylabel('predict time', color='r')
    # plt.savefig('7_Regreaaion/task/picture/result_variants_mode_reduce.jpg')
    # plt.savefig('7_Regreaaion/task/picture/result_variants_mode.jpg')
    plt.show()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/269632.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号