数据增强
def augment(l, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip: img = img[:, ::-1, :]
if vflip: img = img[::-1, :, :]
if rot90: img = img.transpose(1, 0, 2)
return img
return [_augment(_l) for _l in l]
通道变换
def set_channel(l, n_channel):
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channel == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channel == 3 and c == 1:
img = np.concatenate([img] * n_channel, 2)
return img
return [_set_channel(_l) for _l in l]