import torch
import numpy as np
from torchvision.models import AlexNet
from torch.optim.lr_scheduler import CosineAnnealingLR,OneCycleLR, CosineAnnealingWarmRestarts
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LambdaLR
def circle_warmup_cosinedecay(warm_lr, max_lr, warm_step=500, epoch_step=5000, decay=0.01, total_step=None):
"""
Circle Warmup Cosine Decay Scheduler
"""
def gloable_decay(t):
return decay**(t/(total_step))
def fn1(t):
# print(t%epoch_step)
return t%epoch_step
def warm_func(step, t):
decay = gloable_decay(t)
max_lr_ = max_lr*decay
warm_lr_ = warm_lr*decay
# print(step)
# print(step, (max_lr-warm_lr)/warm_step*step+warm_lr)
return (max_lr_-warm_lr_)/warm_step*step+warm_lr_
def decay_func(step, t):
decay = gloable_decay(t)
max_lr_ = max_lr*decay
return max_lr_*0.5*(1+np.cos(np.pi*step/epoch_step))
return lambda t: np.where(fn1(t)
结果如下:



