栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

TensorData和Dataloader

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

TensorData和Dataloader

'''
Description: torch--TensorData
Autor: 365JHWZGo
Date: 2021-11-15 21:42:12
LastEditors: 365JHWZGo
LastEditTime: 2021-11-15 21:53:45
'''

import torch
import torch.utils.data as Data
import numpy as np

BATCH_SIZE = 2

#numpy
#方法一
x = np.linspace(1,10,10)[:,np.newaxis]
#方法二
x = np.expand_dims(np.linspace(1,10,10),1)
y = np.square(x)

#tensor
x = torch.unsqueeze(torch.linspace(1,10,10),dim=1)
y = torch.square(x)

dataset = Data.TensorDataset(torch.from_numpy(x),torch.from_numpy(y))
train_loader = torch.utils.data.DataLoader(
    batch_size=BATCH_SIZE,
    dataset=dataset,
    shuffle=True,
    num_workers=2
)
if __name__ == '__main__':
    for step,(batch_x,batch_y) in enumerate(train_loader):
        print('step:{0}nbatch_x:{1}nbatch_y:{2}n'.format(step,batch_x,batch_y))


运行结果:

step:0
batch_x:tensor([[5.],[3.]], dtype=torch.float64)
batch_y:tensor([[25.],[ 9.]], dtype=torch.float64)

step:1
batch_x:tensor([[9.],[1.]], dtype=torch.float64)
batch_y:tensor([[81.],[ 1.]], dtype=torch.float64)

step:2
batch_x:tensor([[6.],[4.]], dtype=torch.float64)
batch_y:tensor([[36.],[16.]], dtype=torch.float64)

step:3
batch_x:tensor([[2.],[7.]], dtype=torch.float64)
batch_y:tensor([[ 4.],[49.]], dtype=torch.float64)

step:4
batch_x:tensor([[10.],[ 8.]], dtype=torch.float64)
batch_y:tensor([[100.],[ 64.]], dtype=torch.float64)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/503788.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号