今天在加载数据的时候报错:
之前在使用Keras提供的to_categorical标签编码方法的时候一直没有仔细留意,因为直接传入标签数据即可得到one-hot编码的标签数据,今天在使用该方法的时候多加入了一个参数如下所示:
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test = np_utils.to_categorical(y_test, nb_classes)
加入了nb_classes,也就是说人为指定了多少个类别,这时候就报错了,其实从报错意思上面来解读不难理解就是索引超标了,这个就要回到Keras提供的to_categorical方法上,这个默认是从0开始编码的,比如你的标签类别是1,2 ,3,那么直接去编码的话就会得到4个类别,one-hot向量也是四维的,结合这个实现原理就好解决这个报错了,解决方法如下:
y_train = np_utils.to_categorical(y_train - 1, nb_classes) y_test = np_utils.to_categorical(y_test - 1, nb_classes)
之后就可以了,记录一下!



