import torchvision # root 数据集位置,train 训练集或测试集,download 是否下载 train_set = torchvision.datasets.CIFAR10( root ="./dataset",train=True,download=True) test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True) print(test_set[0]) # (2. dataset + transform 运用, 3) print(test_set.classes) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] img, target = test_set[0] print(test_set.classes[target]) # cat img.show() # 显示img
transform 把图片转为tensor ,并用tensorboard 显示
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
#使用transforms 转为tensor工具
tensor_compose = transforms.Compose([transforms.ToTensor()])
writer = SummaryWriter("logs")
# root 数据集位置,train 训练集或测试集,download 是否下载, transform: img -> tensor
train_set = torchvision.datasets.CIFAR10( root ="./dataset",train=True, transform = tensor_compose, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform = tensor_compose, download=True)
for i in range(10):
img, target = train_set[i]
writer.add_image("dataset",img,i)
writer.close()
3. Coco 数据集
coco_train_set = torchvision.datasets.CocoDetection(root="/home/wtj/Data/Coco2017/train2017",
annFile="/home/wtj/Data/Coco2017/annotations/instances_train2017.json")
coco_test_set = torchvision.datasets.CocoDetection(root="/home/wtj/Data/Coco2017/val2017",
annFile="/home/wtj/Data/Coco2017/annotations/instances_val2017.json")
print(coco_test_set[0])



