栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

在GridSearchCV中明确指定测试/训练集

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

在GridSearchCV中明确指定测试/训练集

正如@MaxU所说,最好让GridSearchCV处理拆分,但是如果您要按照问题中的设置执行拆分,则可以使用

PredefinedSplit
which来完成此任务。

因此,您需要对代码进行以下更改。

# Here X_test, y_test is the untouched data# Validation data (X_val, y_val) is currently inside X_train, which will be split using PredefinedSplit inside GridSearchCVX_train, X_test = np.array_split(X, [50])y_train, y_test = np.array_split(y, [50])# The indices which have the value -1 will be kept in train.train_indices = np.full((35,), -1, dtype=int)# The indices which have zero or positive values, will be kept in testtest_indices = np.full((15,), 0, dtype=int)test_fold = np.append(train_indices, test_indices)print(test_fold)# OUTPUT: array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,       -1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])from sklearn.model_selection import PredefinedSplitps = PredefinedSplit(test_fold)# Check how many splits will be done, based on test_foldps.get_n_splits()# OUTPUT: 1for train_index, test_index in ps.split():    print("TRAIN:", train_index, "TEST:", test_index)# OUTPUT: ('TRAIN:', array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,   17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,   34]),  'TEST:', array([35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]))# And now, send this `ps` to cv param in GridSearchCVfrom sklearn.model_selection import GridSearchCVgrid_search = GridSearchCV(Ridge(random_state=444), param_grid, cv=ps)# Here, send the X_train and y_traingrid_search.fit(X_train, y_train)

发送到X_train和y_train的代码

fit()
将使用我们定义的拆分方式分为训练和测试(在您的情况下为val),因此,将对Ridge进行索引[0:35]的原始数据训练,并在[35:50]进行测试。

希望这能清除工作。



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/668413.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号