要开始搞data science了, 最近在考虑一些AI结合security的事情, 还比较迷茫, 不过应该总能找到出路的, 所以先不管学这个有什么意义了, 直接开干技术.
pytorch同TensorFlow一样, 是data science标配, 所以必须好好学习
blog主要参考b站up 我是土堆 的pytorch教学视频
以及pytorch官方文档和源码
查看torch包的成员
dir(torch)
['AVG', 'AggregationType', 'AnyType', 'Argument', 'ArgumentSpec', 'BFloat16Storage', 'BFloat16Tensor', 'BenchmarkConfig', 'BenchmarkExecutionStats', 'Block', 'BoolStorage', 'BoolTensor', 'BoolType', 'BufferDict', 'ByteStorage', 'ByteTensor', 'CONV_BN_FUSION', 'CallStack', 'Capsule', 'CharStorage', 'CharTensor', 'ClassType', 'Code', 'CompilationUnit', 'CompleteArgumentSpec', 'ComplexDoubleStorage', 'ComplexFloatStorage', 'ComplexType', 'ConcreteModuleType', 'ConcreteModuleTypeBuilder', 'CudaBFloat16Storagebase', 'CudaBoolStoragebase', 'CudaByteStoragebase', 'CudaCharStoragebase', 'CudaComplexDoubleStoragebase', 'CudaComplexFloatStoragebase', 'CudaDoubleStoragebase', 'CudaFloatStoragebase', 'CudaHalfStoragebase', 'CudaIntStoragebase', 'CudaLongStoragebase', 'CudaShortStoragebase', ...]
查看cuda的成员
dir(torch.cuda)
['Any', 'BFloat16Storage', 'BFloat16Tensor', 'BoolStorage', 'BoolTensor', 'ByteStorage', 'ByteTensor', 'CharStorage', 'CharTensor', 'ComplexDoubleStorage', 'ComplexFloatStorage', 'CudaError', 'DeferredCudaCallError', 'Device', 'Dict', 'DoubleStorage', 'DoubleTensor', 'Event', 'FloatStorage', 'FloatTensor', 'HalfStorage', 'HalfTensor', 'IntStorage', 'IntTensor', 'List', 'LongStorage', 'LongTensor', 'Optional', 'ShortStorage', 'ShortTensor', 'Stream', 'Tuple', 'Union', '_Cudabase', '_CudaDeviceProperties', '_Graph', '_Storagebase', '__annotations__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_check_capability', '_check_cubins', '_cudart', '_device', '_device_t', '_dummy_type', '_get_device_index', '_initialization_lock', '_initialized', '_is_in_bad_fork', '_lazy_call', '_lazy_init', '_lazy_new', '_queued_calls', '_sleep', '_tls', '_utils', 'amp', 'caching_allocator_alloc', 'caching_allocator_delete', 'can_device_access_peer', 'check_error', 'collections', 'contextlib', 'cudaStatus', 'cudart', 'current_blas_handle', 'current_device', 'current_stream', 'default_generators', 'default_stream', 'device', 'device_count', 'device_of', 'empty_cache', 'get_arch_list', 'get_device_capability', 'get_device_name', 'get_device_properties', 'get_gencode_flags', 'get_rng_state', 'get_rng_state_all', 'has_half', 'has_magma', 'init', 'initial_seed', 'ipc_collect', 'is_available', 'is_initialized', 'list_gpu_processes', 'manual_seed', 'manual_seed_all', 'max_memory_allocated', 'max_memory_cached', 'max_memory_reserved', 'memory', 'memory_allocated', 'memory_cached', 'memory_reserved', 'memory_snapshot', 'memory_stats', 'memory_stats_as_nested_dict', 'memory_summary', 'nccl', 'nvtx', 'os', 'profiler', 'random', 'reset_accumulated_memory_stats', 'reset_max_memory_allocated', 'reset_max_memory_cached', 'reset_peak_memory_stats', 'seed', 'seed_all', 'set_device', 'set_per_process_memory_fraction', 'set_rng_state', 'set_rng_state_all', 'sparse', 'stream', 'streams', 'synchronize', 'threading', 'torch', 'traceback', 'warnings']
查看成员帮助说明
help(torch.cuda.is_available)
Help on function is_available in module torch.cuda:
is_available() -> bool
Returns a bool indicating if CUDA is currently available.
这样就知道torch.cuda.is_available()是用来判断CUDA是否可用的函数
以上, 可以经常使用dir()和help()快捷方便的查看和使用包中的成员与函数说明等.
数据集读取数据集用up给的bee和ant图片包
Dataset继承torch.utils.data的Dataset类
import os
from PIL import Image
from torch.utils.data import Dataset
class DatasetUsage(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.path, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/hymenoptera_data/train"
ant_label = "ants"
bee_label = "bees"
ant_dataset = DatasetUsage(root_dir, ant_label)
bee_dataset = DatasetUsage(root_dir, bee_label)
dataset_sum = ant_dataset + bee_dataset
print(dataset_sum)
(1) 路径采用相对路径, root_dir, 用os.path.join拼接路径
(2) 关于数据集读取需要继承父类Dataset, 并实现__init__, __getitem__, __len__3个私有方法, 得到子类
(3) __len__统计在某一文件夹下的数据文件个数
(4) __get__需要返回数据和label
关于add方法重载, 在父类Dataset实现
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
ConcatDataset通过将传入的dataset对象转换成list类型存储, 实现数据集对象拼接
class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
self.datasets = list(datasets)
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
DataLoader
先看看官方文档https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
说明常用参数
dataset 目标数据集
batch_size 是单次取数据的个数, 比如batch_size = 4则是4个数据为一个batch
shuffle 表示是否随机取数据, shuffle=True打乱顺序来取数据
num_workers 执行读取数据集的进程数
drop_last 是否丢弃最后不符合batch_size的batch
这里用CIFAR10数据集作为dataloader的演示
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("DataLoader_example")
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("test data", imgs, step)
step += 1
writer.close()
效果
https://www.bilibili.com/video/BV1hE411t7RN?p=1



