1. sklearn的train_test_split2. flatten()函数3. np.argwhere()函数
1. sklearn的train_test_splitX_train,X_test, y_train, y_test =sklearn.model_selection.train_test_split(train_data,train_target,test_size=0.4,
因此,之后的fit要用第一和三的调用结果,即,fit(X_train,y_train)
2. flatten()函数flatten是numpy.ndarray.flatten的一个函数,即返回一个一维数组。
flatten只能适用于numpy对象,即array或者mat,不适用于普通的list列表。
from numpy import *
a=array([[1,2],[3,4],[5,6]])
print(a)
# 输出
array([[1, 2],
[3, 4],
[5, 6]])
# flatten()
# 默认按行的方向降维
a.flatten()
# 结果
array([1, 2, 3, 4, 5, 6])
# 按列降维
a.flatten('F')
# 结果
array([1, 3, 5, 2, 4, 6])
#按行降维
a.flatten('A')
# 结果
array([1, 2, 3, 4, 5, 6])
3. np.argwhere()函数
np.argwhere( np.array > num )
返回大于num的数组元素的索引。
x = np.arange(6).reshape(2,3) print(x) # 结果 [[0 1 2] [3 4 5]] a = np.argwhere(x>1) print(a) # 结果 [[0 2] [1 0] [1 1] [1 2]]



