栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

tensorflow2.x-ValueError: Value tf.Tensor(***,)has insufficient rank for batching.

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

tensorflow2.x-ValueError: Value tf.Tensor(***,)has insufficient rank for batching.

一、问题背景

本人想要把【D:BaiduNetdiskDownloadpycv-learningdataspot_data_clssplash】目录下的所有图片文件,转成一个测试集数据。

# 图片文件的读取以及预处理
def _decode_and_resize(filename, label=None):
    img_string = tf.io.read_file(filename)
    img_decoded = tf.image.decode_image(img_string, channels=3, expand_animations = False)
    img_resized = tf.image.resize(img_decoded, resized_img_shape) / 255.
    if label == None:
        return img_resized
    return img_resized, label

# 根据训练集、测试集的路径集合、标签集合来生成可用来训练、预测的dataset
def processData(filenames, labels=None, trainmode=True):
    if trainmode:
        train_dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
        train_dataset = train_dataset.map(map_func=_decode_and_resize)
        # train_dataset = train_dataset.shuffle(buffer_size=25000) # 非常耗内存,不使用
        train_dataset = train_dataset.batch(batch_size)
        train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        return train_dataset
    else:
        test_dataset = tf.data.Dataset.from_tensor_slices(filenames)
        test_dataset = test_dataset.map(map_func=_decode_and_resize)
        test_dataset = test_dataset.batch(batch_size)
        return test_dataset
    
    
test_dataset = processData(r'D:BaiduNetdiskDownloadpycv-learningdataspot_data_clssplash', trainmode=False)

下面是该目录下的图片文件信息。 

二、解决思路

 后来我意识到,这个filenames参数,应该是包括有一批图片的路径集合,而不是一个文件夹路径。

所以改成下面的,就没出问题了。

import tensorflow as tf

import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
%matplotlib inline
import os
from pathlib import Path
import shutil
import time



# 数据集主目录
DATA_ROOT = 'D:BaiduNetdiskDownloadpycv-learningdataspot_data_cls'

# 子数据集划分目录去向
TRAIN_DIR = r'D:BaiduNetdiskDownloadpycv-learningdataspot_data_clstrainImageSet'
VAL_DIR = r'D:BaiduNetdiskDownloadpycv-learningdataspot_data_clsvalImageSet'
TEST_DIR = r'D:BaiduNetdiskDownloadpycv-learningdataspot_data_clstestImageSet'


# 训练集验证集测试集划分比例,(8:2):1
Ratio = {'trainval':0.9, 'val':2/9, 'test':0.1}
num_epochs = 10
learning_rate = 1e-3
resized_img_shape = (224, 224)
iuput_shape = (224, 224, 3)
batch_size = 8

 

from pathlib import Path

# 图片文件的读取以及预处理
def _decode_and_resize(filename, label=None):
    img_string = tf.io.read_file(filename)
    img_decoded = tf.image.decode_image(img_string, channels=3, expand_animations = False)
    img_resized = tf.image.resize(img_decoded, resized_img_shape) / 255.
    if label == None:
        return img_resized
    return img_resized, label

# 根据训练集、测试集的路径集合、标签集合来生成可用来训练、预测的dataset
def processData(filenames, labels=None, trainmode=True):
    if trainmode:
        train_dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
        train_dataset = train_dataset.map(map_func=_decode_and_resize)
        # train_dataset = train_dataset.shuffle(buffer_size=25000) # 非常耗内存,不使用
        train_dataset = train_dataset.batch(batch_size)
        train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        return train_dataset
    else:
        test_dataset = tf.data.Dataset.from_tensor_slices(filenames)
        test_dataset = test_dataset.map(map_func=_decode_and_resize)
        test_dataset = test_dataset.batch(batch_size)
        return test_dataset
    
    
    


filepaths = list(map(lambda x : str(x), Path(r'D:BaiduNetdiskDownloadpycv-learningdataspot_data_clssplash').iterdir()))

test_dataset = processData(filepaths, trainmode=False)

best_model.predict(test_dataset)

 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/829126.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号