栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

行为识别(二):创建Dataset类用于加载视频数据

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

行为识别(二):创建Dataset类用于加载视频数据

行为识别(二):创建Dataset类用于加载视频数据

文章目录
  • 行为识别(二):创建Dataset类用于加载视频数据
    • 1 设计思路
    • 2 python代码

1 设计思路

经过上一节的预处理,现在数据集的组织形式为

`-- dataset
    |-- test
    |   |-- class1
    |   |   |-- clip1
    |   |   |   |-- clip1_001.jpg
    |   |   |   `-- clip1_002.jpg
    |   |   `-- clip2
    |   |       |-- clip2_001.jpg
    |   |       `-- clip2_002.jpg
    |   `-- class2
    |       |-- clip1
    |       |   |-- clip1_001.jpg
    |       |   `-- clip1_002.jpg
    |       `-- clip2
    |           |-- clip2_001.jpg
    |           `-- clip2_002.jpg
    `-- train
        |-- class1
        |   |-- clip1
        |   |   |-- clip1_001.jpg
        |   |   `-- clip1_002.jpg
        |   `-- clip2
        |       |-- clip2_001.jpg
        |       `-- clip2_002.jpg
        `-- class2
            |-- clip1
            |   |-- clip1_001.jpg
            |   `-- clip1_002.jpg
            `-- clip2
                |-- clip2_001.jpg
                `-- clip2_002.jpg

本文将通过对Pytorch自带的Dataset类进行继承来创建处理UCF101的类,并配合DataLoader类来加载数据。

重载了self._init_()函数:

1、用于获取clip的路径列表,并存放在列表clip_paths中

2、用于获取clip对应的label,并存放在labels中

jpg_num = len(os.listdir(clip_path))
                if jpg_num >= self.sample_fps:
                    self.clip_paths.append(clip_path)
                    self.labels.append(self.idx_to_class[action_class])

sample_fps表示读取这个clip的多少张图片,上面一段代码用于判断这个clip分解成图片后,图片的数量是否满足要求,如果不满足则舍弃这个clip,不加入训练或测试列表

重载了self._getitem_()函数:

1、等间隔抽取sample_fps张图片,通过torch.cat()函数存入张量clip_fold中,由于dataloader返回的对象限制的原因,这里先把所有图片堆叠在一起,clip_flod对应的格式为(n_channels, height, width),这里把frames堆叠到channel中去了,后面我们通过unfold_batch函数和unfold函数把他拆解开

2、dataloader获取到的clips的格式为(n_batch, n_channels, height, width)

重载了self._len_()函数:

1、通过labels的长度来判断数据集中clip的数目

自定义了onehot()函数:

1、将label转换为onehot形式,传入的labels为一个batch的label,维度dim=0时为batch的维度,所以在dim=1的维度上进行onehot

自定义了unfold_batch()函数:

1、传入的参数的格式为(n_batch, n_channels, height, width),返回的clip_batch格式为(n_batch, n_channel, n_frame, height, width),这也是conv3d卷积输入所需要的格式。由于rgb图像有三个通道,所以这里n_channels=3,n_frame对应的就是将R通道的sample_fps张图像堆叠,G通道的sample_fps张图像堆叠,B通道的sample张图像堆叠。

自定义了unfold()函数:

1、将每个clip堆叠在一起的48张不容通道和不同frame的图像重新组织为训练需要的格式

2、训练需要的格式为(n_batch, n_channel, n_frame, height, width)

2 python代码
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision import transforms


class UCF101(Dataset):
    def __init__(self, dataset_root):
        self.dataset_root = dataset_root
        self.sample_fps = 16
        # self.class_to_idx = ['ApplyEyeMakeup']
        self.class_to_idx = ['ApplyEyeMakeup', 'ApplyLipstick', 'Archery', 'BabyCrawling', 'BalanceBeam',
                             'BandMarching', 'BaseballPitch', 'Basketball', 'BasketballDunk', 'BenchPress', 'Biking',
                             'Billiards', 'BlowDryHair', 'BlowingCandles', 'BodyWeightSquats', 'Bowling',
                             'BoxingPunchingBag', 'BoxingSpeedBag', 'BreastStroke', 'BrushingTeeth', 'CleanAndJerk',
                             'CliffDiving', 'CricketBowling', 'CricketShot', 'CuttingInKitchen', 'Diving', 'Drumming',
                             'Fencing', 'FieldHockeyPenalty', 'FloorGymnastics', 'FrisbeeCatch', 'FrontCrawl',
                             'GolfSwing', 'Haircut', 'Hammering', 'HammerThrow', 'HandstandPushups',
                             'HandstandWalking', 'HeadMassage', 'HighJump', 'HorseRace', 'HorseRiding', 'HulaHoop',
                             'IceDancing', 'JavelinThrow', 'JugglingBalls', 'JumpingJack', 'JumpRope', 'Kayaking',
                             'Knitting', 'LongJump', 'Lunges', 'MilitaryParade', 'Mixing', 'MoppingFloor', 'Nunchucks',
                             'ParallelBars', 'PizzaTossing', 'PlayingCello', 'PlayingDaf', 'PlayingDhol',
                             'PlayingFlute', 'PlayingGuitar', 'PlayingPiano', 'PlayingSitar', 'PlayingTabla',
                             'PlayingViolin', 'PoleVault', 'PommelHorse', 'PullUps', 'Punch', 'PushUps', 'Rafting',
                             'RockClimbingIndoor', 'RopeClimbing', 'Rowing', 'SalsaSpin', 'ShavingBeard', 'Shotput',
                             'SkateBoarding', 'Skiing', 'Skijet', 'SkyDiving', 'SoccerJuggling', 'SoccerPenalty',
                             'StillRings', 'SumoWrestling', 'Surfing', 'Swing', 'TableTennisShot', 'TaiChi',
                             'TennisSwing', 'ThrowDiscus', 'TrampolineJumping', 'Typing', 'UnevenBars',
                             'VolleyballSpiking', 'WalkingWithDog', 'WallPushups', 'WritingOnBoard', 'YoYo']

        # self.idx_to_class = {'ApplyEyeMakeup': 0}
        self.idx_to_class = {'ApplyEyeMakeup': 0, 'ApplyLipstick': 1, 'Archery': 2, 'BabyCrawling': 3, 'BalanceBeam': 4,
                             'BandMarching': 5, 'BaseballPitch': 6, 'Basketball': 7, 'BasketballDunk': 8,
                             'BenchPress': 9, 'Biking': 10, 'Billiards': 11, 'BlowDryHair': 12, 'BlowingCandles': 13,
                             'BodyWeightSquats': 14, 'Bowling': 15, 'BoxingPunchingBag': 16, 'BoxingSpeedBag': 17,
                             'BreastStroke': 18, 'BrushingTeeth': 19, 'CleanAndJerk': 20, 'CliffDiving': 21,
                             'CricketBowling': 22, 'CricketShot': 23, 'CuttingInKitchen': 24, 'Diving': 25,
                             'Drumming': 26, 'Fencing': 27, 'FieldHockeyPenalty': 28, 'FloorGymnastics': 29,
                             'FrisbeeCatch': 30, 'FrontCrawl': 31, 'GolfSwing': 32, 'Haircut': 33, 'Hammering': 34,
                             'HammerThrow': 35, 'HandstandPushups': 36, 'HandstandWalking': 37, 'HeadMassage': 38,
                             'HighJump': 39, 'HorseRace': 40, 'HorseRiding': 41, 'HulaHoop': 42, 'IceDancing': 43,
                             'JavelinThrow': 44, 'JugglingBalls': 45, 'JumpingJack': 46, 'JumpRope': 47, 'Kayaking': 48,
                             'Knitting': 49, 'LongJump': 50, 'Lunges': 51, 'MilitaryParade': 52, 'Mixing': 53,
                             'MoppingFloor': 54, 'Nunchucks': 55, 'ParallelBars': 56, 'PizzaTossing': 57,
                             'PlayingCello': 58, 'PlayingDaf': 59, 'PlayingDhol': 60, 'PlayingFlute': 61,
                             'PlayingGuitar': 62, 'PlayingPiano': 63, 'PlayingSitar': 64, 'PlayingTabla': 65,
                             'PlayingViolin': 66, 'PoleVault': 67, 'PommelHorse': 68, 'PullUps': 69, 'Punch': 70,
                             'PushUps': 71, 'Rafting': 72, 'RockClimbingIndoor': 73, 'RopeClimbing': 74, 'Rowing': 75,
                             'SalsaSpin': 76, 'ShavingBeard': 77, 'Shotput': 78, 'SkateBoarding': 79, 'Skiing': 80,
                             'Skijet': 81, 'SkyDiving': 82, 'SoccerJuggling': 83, 'SoccerPenalty': 84, 'StillRings': 85,
                             'SumoWrestling': 86, 'Surfing': 87, 'Swing': 88, 'TableTennisShot': 89, 'TaiChi': 90,
                             'TennisSwing': 91, 'ThrowDiscus': 92, 'TrampolineJumping': 93, 'Typing': 94,
                             'UnevenBars': 95, 'VolleyballSpiking': 96, 'WalkingWithDog': 97, 'WallPushups': 98,
                             'WritingOnBoard': 99, 'YoYo': 100}

        self.clip_paths = []
        self.labels = []

        for action_class in self.class_to_idx:
            for action_clip in os.listdir(os.path.join(self.dataset_root, action_class)):
                clip_path = os.path.join(self.dataset_root, os.path.join(action_class, action_clip))
                jpg_num = len(os.listdir(clip_path))
                if jpg_num >= self.sample_fps:
                    self.clip_paths.append(clip_path)
                    self.labels.append(self.idx_to_class[action_class])

    def __getitem__(self, item):
        clip_path = self.clip_paths[item]
        label = self.labels[item]
        load_transforms = transforms.Compose([
            # transforms.Grayscale()
        ])
        transform2 = transforms.Compose([
            transforms.ToTensor()
        ])
        clip = []
        total_jpgs = len(os.listdir(clip_path))
        index = int(total_jpgs/self.sample_fps)
        for i in range(0,self.sample_fps):
            jpg_name = "{}_{:06d}.jpg".format(os.path.basename(clip_path), i*index+1)
            img = Image.open(os.path.join(clip_path,jpg_name))
            clip.append(transform2(img))
        clip_fold = torch.cat(clip, dim=0)
        return clip_fold, label

    def __len__(self):
        return len(self.labels)

    def unfold_batch(self,clip_batch_flod):
        # unfold to batch (n_batch, n_channel, n_frames, h, w)
        L_batch = []
        for i in range(0,clip_batch_flod.shape[0]):
            L_batch.append(self.unfold(clip_batch_flod[i]))
        clip_batch = torch.stack(L_batch,dim=0)
        return clip_batch

    def unfold(self, clip_fold):
        # a clip (n_channel, n_frames, h, w)
        L_clip = []
        n_channel = int(clip_fold.shape[0]/self.sample_fps)
        for i in range(0,n_channel):
            L_channel = []
            for j in range(0,self.sample_fps):
                L_channel.append(clip_fold[n_channel * j + i])
            clip_channel = torch.stack(L_channel, dim=0)
            L_clip.append(clip_channel)
        clip = torch.stack(L_clip,dim=0)
        return clip

    def onehot(self,labels):
        labels = torch.LongTensor(labels)
        labels = torch.reshape(labels,[len(labels),1])
        onehot = torch.zeros([len(labels),101])
        onehot = onehot.scatter_(dim=1,index=labels,value=1)
        return onehot

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/876772.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号