栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch学习之Dataset类

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch学习之Dataset类

一、

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):  # 提供一个全局变量
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self,idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label
    def __len__(self):
        return len(self.img_path)

root_dir = r"Dataset_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)

 

例子1:
from PIL import Image
path = "d:a"                   #获取图片的地址
img = Image.open(path)
img.size                         #尺寸
Img.show()                #显示数据,此处Image.show(Img)错误

例子2:import os
想获取图片的地址
1.获取所有图片的地址的列表list
2.通过相应的索引获取图片的地址
dir_path = "dataset/train/ants"
import os
img_path_list = os.listdir(dir_path)

想获取所有图片的地址
import os

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/738585.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号