栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

210316-针对类别不平衡数据集PyTorch实现每个Batch中出现所有类别及数量近似(待整理)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

210316-针对类别不平衡数据集PyTorch实现每个Batch中出现所有类别及数量近似(待整理)

def prepare_dataloader(X, Y, P=None, dim='2D', batch_size=32, drop_last=True):
    from torch.utils.data import Dataset, DataLoader
    import torch
    import numpy as np
    from collections import Counter
    from torch.utils.data.sampler import WeightedRandomSampler

    class Self_Def_Dataset(Dataset):
        def __init__(self, X, Y):
            if dim=='2D':
                X = np.reshape(X, (-1,1,32,32))
            self.X = torch.Tensor(X)
            self.Y = torch.LongTensor(Y)
            if P==None:
                self.P = torch.LongTensor(Y)
                self.O = 'two_output'
            else:
                self.O = 'three_output'
                self.P = torch.LongTensor(P)

        def __getitem__(self, index):
            x = self.X[index]
            y = self.Y[index]
            p = self.P[index]
            if self.O=='two_output':
                return x, y
            if self.O=='three_output':
                return x, y, p

        def __len__(self):
            return len(self.X)

    # ------------------------- Weightedrandomsampler ------------------------ #
    # # ! [Batch for imbalanaced dataset] 
    # # ! https://stackoverflow.com/questions/60812032/using-weightedrandomsampler-in-pytorch
    # # ! https://towardsdatascience.com/pytorch-basics-sampling-samplers-2a0f29f0bf2a
    # # ! https://www.cnblogs.com/dxscode/p/14382444.html?ivk_sa=1024320u
    # Y_dict = dict(Counter(Y))
    # class_weights = [Y_dict[k]/len(Y) for (k,v) in enumerate(Y_dict)]
    # weights = [class_weights[Y[i]] for i in range(len(Y))]
    # sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(len(Y)))
    # dataset = Self_Def_Dataset(X, Y)
    # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, sampler=sampler)
    # ------------------------- Weightedrandomsampler ------------------------ #

    dataset = Self_Def_Dataset(X, Y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
    return dataloader
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/769133.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号