- 行为识别(二):创建Dataset类用于加载视频数据
- 1 设计思路
- 2 python代码
经过上一节的预处理,现在数据集的组织形式为
`-- dataset
|-- test
| |-- class1
| | |-- clip1
| | | |-- clip1_001.jpg
| | | `-- clip1_002.jpg
| | `-- clip2
| | |-- clip2_001.jpg
| | `-- clip2_002.jpg
| `-- class2
| |-- clip1
| | |-- clip1_001.jpg
| | `-- clip1_002.jpg
| `-- clip2
| |-- clip2_001.jpg
| `-- clip2_002.jpg
`-- train
|-- class1
| |-- clip1
| | |-- clip1_001.jpg
| | `-- clip1_002.jpg
| `-- clip2
| |-- clip2_001.jpg
| `-- clip2_002.jpg
`-- class2
|-- clip1
| |-- clip1_001.jpg
| `-- clip1_002.jpg
`-- clip2
|-- clip2_001.jpg
`-- clip2_002.jpg
本文将通过对Pytorch自带的Dataset类进行继承来创建处理UCF101的类,并配合DataLoader类来加载数据。
重载了self._init_()函数:
1、用于获取clip的路径列表,并存放在列表clip_paths中
2、用于获取clip对应的label,并存放在labels中
jpg_num = len(os.listdir(clip_path))
if jpg_num >= self.sample_fps:
self.clip_paths.append(clip_path)
self.labels.append(self.idx_to_class[action_class])
sample_fps表示读取这个clip的多少张图片,上面一段代码用于判断这个clip分解成图片后,图片的数量是否满足要求,如果不满足则舍弃这个clip,不加入训练或测试列表
重载了self._getitem_()函数:
1、等间隔抽取sample_fps张图片,通过torch.cat()函数存入张量clip_fold中,由于dataloader返回的对象限制的原因,这里先把所有图片堆叠在一起,clip_flod对应的格式为(n_channels, height, width),这里把frames堆叠到channel中去了,后面我们通过unfold_batch函数和unfold函数把他拆解开
2、dataloader获取到的clips的格式为(n_batch, n_channels, height, width)
重载了self._len_()函数:
1、通过labels的长度来判断数据集中clip的数目
自定义了onehot()函数:
1、将label转换为onehot形式,传入的labels为一个batch的label,维度dim=0时为batch的维度,所以在dim=1的维度上进行onehot
自定义了unfold_batch()函数:
1、传入的参数的格式为(n_batch, n_channels, height, width),返回的clip_batch格式为(n_batch, n_channel, n_frame, height, width),这也是conv3d卷积输入所需要的格式。由于rgb图像有三个通道,所以这里n_channels=3,n_frame对应的就是将R通道的sample_fps张图像堆叠,G通道的sample_fps张图像堆叠,B通道的sample张图像堆叠。
自定义了unfold()函数:
1、将每个clip堆叠在一起的48张不容通道和不同frame的图像重新组织为训练需要的格式
2、训练需要的格式为(n_batch, n_channel, n_frame, height, width)
2 python代码import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision import transforms
class UCF101(Dataset):
def __init__(self, dataset_root):
self.dataset_root = dataset_root
self.sample_fps = 16
# self.class_to_idx = ['ApplyEyeMakeup']
self.class_to_idx = ['ApplyEyeMakeup', 'ApplyLipstick', 'Archery', 'BabyCrawling', 'BalanceBeam',
'BandMarching', 'BaseballPitch', 'Basketball', 'BasketballDunk', 'BenchPress', 'Biking',
'Billiards', 'BlowDryHair', 'BlowingCandles', 'BodyWeightSquats', 'Bowling',
'BoxingPunchingBag', 'BoxingSpeedBag', 'BreastStroke', 'BrushingTeeth', 'CleanAndJerk',
'CliffDiving', 'CricketBowling', 'CricketShot', 'CuttingInKitchen', 'Diving', 'Drumming',
'Fencing', 'FieldHockeyPenalty', 'FloorGymnastics', 'FrisbeeCatch', 'FrontCrawl',
'GolfSwing', 'Haircut', 'Hammering', 'HammerThrow', 'HandstandPushups',
'HandstandWalking', 'HeadMassage', 'HighJump', 'HorseRace', 'HorseRiding', 'HulaHoop',
'IceDancing', 'JavelinThrow', 'JugglingBalls', 'JumpingJack', 'JumpRope', 'Kayaking',
'Knitting', 'LongJump', 'Lunges', 'MilitaryParade', 'Mixing', 'MoppingFloor', 'Nunchucks',
'ParallelBars', 'PizzaTossing', 'PlayingCello', 'PlayingDaf', 'PlayingDhol',
'PlayingFlute', 'PlayingGuitar', 'PlayingPiano', 'PlayingSitar', 'PlayingTabla',
'PlayingViolin', 'PoleVault', 'PommelHorse', 'PullUps', 'Punch', 'PushUps', 'Rafting',
'RockClimbingIndoor', 'RopeClimbing', 'Rowing', 'SalsaSpin', 'ShavingBeard', 'Shotput',
'SkateBoarding', 'Skiing', 'Skijet', 'SkyDiving', 'SoccerJuggling', 'SoccerPenalty',
'StillRings', 'SumoWrestling', 'Surfing', 'Swing', 'TableTennisShot', 'TaiChi',
'TennisSwing', 'ThrowDiscus', 'TrampolineJumping', 'Typing', 'UnevenBars',
'VolleyballSpiking', 'WalkingWithDog', 'WallPushups', 'WritingOnBoard', 'YoYo']
# self.idx_to_class = {'ApplyEyeMakeup': 0}
self.idx_to_class = {'ApplyEyeMakeup': 0, 'ApplyLipstick': 1, 'Archery': 2, 'BabyCrawling': 3, 'BalanceBeam': 4,
'BandMarching': 5, 'BaseballPitch': 6, 'Basketball': 7, 'BasketballDunk': 8,
'BenchPress': 9, 'Biking': 10, 'Billiards': 11, 'BlowDryHair': 12, 'BlowingCandles': 13,
'BodyWeightSquats': 14, 'Bowling': 15, 'BoxingPunchingBag': 16, 'BoxingSpeedBag': 17,
'BreastStroke': 18, 'BrushingTeeth': 19, 'CleanAndJerk': 20, 'CliffDiving': 21,
'CricketBowling': 22, 'CricketShot': 23, 'CuttingInKitchen': 24, 'Diving': 25,
'Drumming': 26, 'Fencing': 27, 'FieldHockeyPenalty': 28, 'FloorGymnastics': 29,
'FrisbeeCatch': 30, 'FrontCrawl': 31, 'GolfSwing': 32, 'Haircut': 33, 'Hammering': 34,
'HammerThrow': 35, 'HandstandPushups': 36, 'HandstandWalking': 37, 'HeadMassage': 38,
'HighJump': 39, 'HorseRace': 40, 'HorseRiding': 41, 'HulaHoop': 42, 'IceDancing': 43,
'JavelinThrow': 44, 'JugglingBalls': 45, 'JumpingJack': 46, 'JumpRope': 47, 'Kayaking': 48,
'Knitting': 49, 'LongJump': 50, 'Lunges': 51, 'MilitaryParade': 52, 'Mixing': 53,
'MoppingFloor': 54, 'Nunchucks': 55, 'ParallelBars': 56, 'PizzaTossing': 57,
'PlayingCello': 58, 'PlayingDaf': 59, 'PlayingDhol': 60, 'PlayingFlute': 61,
'PlayingGuitar': 62, 'PlayingPiano': 63, 'PlayingSitar': 64, 'PlayingTabla': 65,
'PlayingViolin': 66, 'PoleVault': 67, 'PommelHorse': 68, 'PullUps': 69, 'Punch': 70,
'PushUps': 71, 'Rafting': 72, 'RockClimbingIndoor': 73, 'RopeClimbing': 74, 'Rowing': 75,
'SalsaSpin': 76, 'ShavingBeard': 77, 'Shotput': 78, 'SkateBoarding': 79, 'Skiing': 80,
'Skijet': 81, 'SkyDiving': 82, 'SoccerJuggling': 83, 'SoccerPenalty': 84, 'StillRings': 85,
'SumoWrestling': 86, 'Surfing': 87, 'Swing': 88, 'TableTennisShot': 89, 'TaiChi': 90,
'TennisSwing': 91, 'ThrowDiscus': 92, 'TrampolineJumping': 93, 'Typing': 94,
'UnevenBars': 95, 'VolleyballSpiking': 96, 'WalkingWithDog': 97, 'WallPushups': 98,
'WritingOnBoard': 99, 'YoYo': 100}
self.clip_paths = []
self.labels = []
for action_class in self.class_to_idx:
for action_clip in os.listdir(os.path.join(self.dataset_root, action_class)):
clip_path = os.path.join(self.dataset_root, os.path.join(action_class, action_clip))
jpg_num = len(os.listdir(clip_path))
if jpg_num >= self.sample_fps:
self.clip_paths.append(clip_path)
self.labels.append(self.idx_to_class[action_class])
def __getitem__(self, item):
clip_path = self.clip_paths[item]
label = self.labels[item]
load_transforms = transforms.Compose([
# transforms.Grayscale()
])
transform2 = transforms.Compose([
transforms.ToTensor()
])
clip = []
total_jpgs = len(os.listdir(clip_path))
index = int(total_jpgs/self.sample_fps)
for i in range(0,self.sample_fps):
jpg_name = "{}_{:06d}.jpg".format(os.path.basename(clip_path), i*index+1)
img = Image.open(os.path.join(clip_path,jpg_name))
clip.append(transform2(img))
clip_fold = torch.cat(clip, dim=0)
return clip_fold, label
def __len__(self):
return len(self.labels)
def unfold_batch(self,clip_batch_flod):
# unfold to batch (n_batch, n_channel, n_frames, h, w)
L_batch = []
for i in range(0,clip_batch_flod.shape[0]):
L_batch.append(self.unfold(clip_batch_flod[i]))
clip_batch = torch.stack(L_batch,dim=0)
return clip_batch
def unfold(self, clip_fold):
# a clip (n_channel, n_frames, h, w)
L_clip = []
n_channel = int(clip_fold.shape[0]/self.sample_fps)
for i in range(0,n_channel):
L_channel = []
for j in range(0,self.sample_fps):
L_channel.append(clip_fold[n_channel * j + i])
clip_channel = torch.stack(L_channel, dim=0)
L_clip.append(clip_channel)
clip = torch.stack(L_clip,dim=0)
return clip
def onehot(self,labels):
labels = torch.LongTensor(labels)
labels = torch.reshape(labels,[len(labels),1])
onehot = torch.zeros([len(labels),101])
onehot = onehot.scatter_(dim=1,index=labels,value=1)
return onehot



