栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

Java时间轮算法的实现代码示例

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Java时间轮算法的实现代码示例

考虑这样一个场景,现在有5000个任务,要让这5000个任务每隔5分中触发某个操作,怎么去实现这个需求。大部分人首先想到的是使用定时器,但是5000个任务,你就要用5000个定时器,一个定时器就是一个线程,你懂了吧,这种方法肯定是不行的。

针对这个场景,催生了时间轮算法,时间轮到底是什么?我一贯的风格,自行谷歌去。大发慈悲,发个时间轮介绍你们看看,看文字和图就好了,代码不要看了,那个文章里的代码运行不起来,时间轮介绍。

看好了介绍,我们就开始动手吧。

开发环境:idea + jdk1.8 + maven

新建一个maven工程

 

创建如下的目录结构

 

不要忘了pom.xml中添加netty库


    
      io.netty
      netty-all
      4.1.5.Final
    
  

代码如下

Timeout.Java

package com.tanghuachun.timer;
public interface Timeout {
  Timer timer();
  TimerTask task();
  boolean isExpired();
  boolean isCancelled();
  boolean cancel();
}

Timer.java

package com.tanghuachun.timer;
import java.util.Set;
import java.util.concurrent.TimeUnit;

public interface Timer {
  Timeout newTimeout(TimerTask task, long delay, TimeUnit unit, String argv);
  Set stop();
}

TimerTask.java

package com.tanghuachun.timer;
public interface TimerTask {
  void run(Timeout timeout, String argv) throws Exception;
}

TimerWheel.java


package com.tanghuachun.timer;
import io.netty.util.*;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.Collections;
import java.util.HashSet;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

public class TimerWheel implements Timer {

  static final InternalLogger logger =
      InternalLoggerFactory.getInstance(TimerWheel.class);

  private static final ResourceLeakDetector leakDetector = ResourceLeakDetectorFactory.instance()
      .newResourceLeakDetector(TimerWheel.class, 1, Runtime.getRuntime().availableProcessors() * 4L);

  private static final AtomicIntegerFieldUpdater WORKER_STATE_UPDATER;
  static {
    AtomicIntegerFieldUpdater workerStateUpdater =
 PlatformDependent.newAtomicIntegerFieldUpdater(TimerWheel.class, "workerState");
    if (workerStateUpdater == null) {
      workerStateUpdater = AtomicIntegerFieldUpdater.newUpdater(TimerWheel.class, "workerState");
    }
    WORKER_STATE_UPDATER = workerStateUpdater;
  }

  private final ResourceLeak leak;
  private final Worker worker = new Worker();
  private final Thread workerThread;

  public static final int WORKER_STATE_INIT = 0;
  public static final int WORKER_STATE_STARTED = 1;
  public static final int WORKER_STATE_SHUTDOWN = 2;
  @SuppressWarnings({ "unused", "FieldMayBeFinal", "RedundantFieldInitialization" })
  private volatile int workerState = WORKER_STATE_INIT; // 0 - init, 1 - started, 2 - shut down

  private final long tickDuration;
  private final HashedWheelBucket[] wheel;
  private final int mask;
  private final CountDownLatch startTimeInitialized = new CountDownLatch(1);
  private final Queue timeouts = PlatformDependent.newMpscQueue();
  private final Queue cancelledTimeouts = PlatformDependent.newMpscQueue();

  private volatile long startTime;

  
  public TimerWheel() {
    this(Executors.defaultThreadFactory());
  }

  
  public TimerWheel(long tickDuration, TimeUnit unit) {
    this(Executors.defaultThreadFactory(), tickDuration, unit);
  }

  
  public TimerWheel(long tickDuration, TimeUnit unit, int ticksPerWheel) {
    this(Executors.defaultThreadFactory(), tickDuration, unit, ticksPerWheel);
  }

  
  public TimerWheel(ThreadFactory threadFactory) {
    this(threadFactory, 100, TimeUnit.MILLISECONDS);
  }

  
  public TimerWheel(
      ThreadFactory threadFactory, long tickDuration, TimeUnit unit) {
    this(threadFactory, tickDuration, unit, 512);
  }

  
  public TimerWheel(
      ThreadFactory threadFactory,
      long tickDuration, TimeUnit unit, int ticksPerWheel) {
    this(threadFactory, tickDuration, unit, ticksPerWheel, true);
  }

  
  public TimerWheel(
      ThreadFactory threadFactory,
      long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection) {

    if (threadFactory == null) {
      throw new NullPointerException("threadFactory");
    }
    if (unit == null) {
      throw new NullPointerException("unit");
    }
    if (tickDuration <= 0) {
      throw new IllegalArgumentException("tickDuration must be greater than 0: " + tickDuration);
    }
    if (ticksPerWheel <= 0) {
      throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel);
    }

    // Normalize ticksPerWheel to power of two and initialize the wheel.
    wheel = createWheel(ticksPerWheel);
    mask = wheel.length - 1;

    // Convert tickDuration to nanos.
    this.tickDuration = unit.tonanos(tickDuration);

    // Prevent overflow.
    if (this.tickDuration >= Long.MAX_VALUE / wheel.length) {
      throw new IllegalArgumentException(String.format(
   "tickDuration: %d (expected: 0 < tickDuration in nanos < %d",
   tickDuration, Long.MAX_VALUE / wheel.length));
    }
    workerThread = threadFactory.newThread(worker);

    leak = leakDetection || !workerThread.isDaemon() ? leakDetector.open(this) : null;
  }

  private static HashedWheelBucket[] createWheel(int ticksPerWheel) {
    if (ticksPerWheel <= 0) {
      throw new IllegalArgumentException(
   "ticksPerWheel must be greater than 0: " + ticksPerWheel);
    }
    if (ticksPerWheel > 1073741824) {
      throw new IllegalArgumentException(
   "ticksPerWheel may not be greater than 2^30: " + ticksPerWheel);
    }

    ticksPerWheel = normalizeTicksPerWheel(ticksPerWheel);
    HashedWheelBucket[] wheel = new HashedWheelBucket[ticksPerWheel];
    for (int i = 0; i < wheel.length; i ++) {
      wheel[i] = new HashedWheelBucket();
    }
    return wheel;
  }

  private static int normalizeTicksPerWheel(int ticksPerWheel) {
    int normalizedTicksPerWheel = 1;
    while (normalizedTicksPerWheel < ticksPerWheel) {
      normalizedTicksPerWheel <<= 1;
    }
    return normalizedTicksPerWheel;
  }

  
  public void start() {
    switch (WORKER_STATE_UPDATER.get(this)) {
      case WORKER_STATE_INIT:
 if (WORKER_STATE_UPDATER.compareAndSet(this, WORKER_STATE_INIT, WORKER_STATE_STARTED)) {
   workerThread.start();
 }
 break;
      case WORKER_STATE_STARTED:
 break;
      case WORKER_STATE_SHUTDOWN:
 throw new IllegalStateException("cannot be started once stopped");
      default:
 throw new Error("Invalid WorkerState");
    }

    // Wait until the startTime is initialized by the worker.
    while (startTime == 0) {
      try {
 startTimeInitialized.await();
      } catch (InterruptedException ignore) {
 // Ignore - it will be ready very soon.
      }
    }
  }

  @Override
  public Set stop() {
    if (Thread.currentThread() == workerThread) {
      throw new IllegalStateException(
   TimerWheel.class.getSimpleName() +
".stop() cannot be called from " +
TimerTask.class.getSimpleName());
    }

    if (!WORKER_STATE_UPDATeR.compareAndSet(this, WORKER_STATE_STARTED, WORKER_STATE_SHUTDOWN)) {
      // workerState can be 0 or 2 at this moment - let it always be 2.
      WORKER_STATE_UPDATER.set(this, WORKER_STATE_SHUTDOWN);

      if (leak != null) {
 leak.close();
      }

      return Collections.emptySet();
    }

    boolean interrupted = false;
    while (workerThread.isAlive()) {
      workerThread.interrupt();
      try {
 workerThread.join(100);
      } catch (InterruptedException ignored) {
 interrupted = true;
      }
    }

    if (interrupted) {
      Thread.currentThread().interrupt();
    }

    if (leak != null) {
      leak.close();
    }
    return worker.unprocessedTimeouts();
  }

  @Override
  public Timeout newTimeout(TimerTask task, long delay, TimeUnit unit, String argv) {
    if (task == null) {
      throw new NullPointerException("task");
    }
    if (unit == null) {
      throw new NullPointerException("unit");
    }
    start();

    // Add the timeout to the timeout queue which will be processed on the next tick.
    // During processing all the queued HashedWheelTimeouts will be added to the correct HashedWheelBucket.
    long deadline = System.nanoTime() + unit.tonanos(delay) - startTime;
    HashedWheelTimeout timeout = new HashedWheelTimeout(this, task, deadline, argv);
    timeouts.add(timeout);
    return timeout;
  }

  private final class Worker implements Runnable {
    private final Set unprocessedTimeouts = new HashSet();

    private long tick;

    @Override
    public void run() {
      // Initialize the startTime.
      startTime = System.nanoTime();
      if (startTime == 0) {
 // We use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized.
 startTime = 1;
      }

      // Notify the other threads waiting for the initialization at start().
      startTimeInitialized.countDown();

      do {
 final long deadline = waitForNextTick();
 if (deadline > 0) {
   int idx = (int) (tick & mask);
   processCancelledTasks();
   HashedWheelBucket bucket =
wheel[idx];
   transferTimeoutsToBuckets();
   bucket.expireTimeouts(deadline);
   tick++;
 }
      } while (WORKER_STATE_UPDATER.get(TimerWheel.this) == WORKER_STATE_STARTED);

      // Fill the unprocessedTimeouts so we can return them from stop() method.
      for (HashedWheelBucket bucket: wheel) {
 bucket.clearTimeouts(unprocessedTimeouts);
      }
      for (;;) {
 HashedWheelTimeout timeout = timeouts.poll();
 if (timeout == null) {
   break;
 }
 if (!timeout.isCancelled()) {
   unprocessedTimeouts.add(timeout);
 }
      }
      processCancelledTasks();
    }

    private void transferTimeoutsToBuckets() {
      // transfer only max. 100000 timeouts per tick to prevent a thread to stale the workerThread when it just
      // adds new timeouts in a loop.
      for (int i = 0; i < 100000; i++) {
 HashedWheelTimeout timeout = timeouts.poll();
 if (timeout == null) {
   // all processed
   break;
 }
 if (timeout.state() == HashedWheelTimeout.ST_CANCELLED) {
   // Was cancelled in the meantime.
   continue;
 }

 long calculated = timeout.deadline / tickDuration;
 timeout.remainingRounds = (calculated - tick) / wheel.length;

 final long ticks = Math.max(calculated, tick); // Ensure we don't schedule for past.
 int stopIndex = (int) (ticks & mask);

 HashedWheelBucket bucket = wheel[stopIndex];
 bucket.addTimeout(timeout);
      }
    }

    private void processCancelledTasks() {
      for (;;) {
 HashedWheelTimeout timeout = cancelledTimeouts.poll();
 if (timeout == null) {
   // all processed
   break;
 }
 try {
   timeout.remove();
 } catch (Throwable t) {
   if (logger.isWarnEnabled()) {
     logger.warn("An exception was thrown while process a cancellation task", t);
   }
 }
      }
    }

    
    private long waitForNextTick() {
      long deadline = tickDuration * (tick + 1);

      for (;;) {
 final long currentTime = System.nanoTime() - startTime;
 long sleepTimeMs = (deadline - currentTime + 999999) / 1000000;

 if (sleepTimeMs <= 0) {
   if (currentTime == Long.MIN_VALUE) {
     return -Long.MAX_VALUE;
   } else {
     return currentTime;
   }
 }

 // Check if we run on windows, as if thats the case we will need
 // to round the sleepTime as workaround for a bug that only affect
 // the JVM if it runs on windows.
 //
 // See https://github.com/netty/netty/issues/356
 if (PlatformDependent.isWindows()) {
   sleepTimeMs = sleepTimeMs / 10 * 10;
 }

 try {
   Thread.sleep(sleepTimeMs);
 } catch (InterruptedException ignored) {
   if (WORKER_STATE_UPDATeR.get(TimerWheel.this) == WORKER_STATE_SHUTDOWN) {
     return Long.MIN_VALUE;
   }
 }
      }
    }

    public Set unprocessedTimeouts() {
      return Collections.unmodifiableSet(unprocessedTimeouts);
    }
  }

  private static final class HashedWheelTimeout implements Timeout {

    private static final int ST_INIT = 0;
    private static final int ST_CANCELLED = 1;
    private static final int ST_EXPIRED = 2;
    private static final AtomicIntegerFieldUpdater STATE_UPDATER;

    static {
      AtomicIntegerFieldUpdater updater =
   PlatformDependent.newAtomicIntegerFieldUpdater(HashedWheelTimeout.class, "state");
      if (updater == null) {
 updater = AtomicIntegerFieldUpdater.newUpdater(HashedWheelTimeout.class, "state");
      }
      STATE_UPDATER = updater;
    }

    private final TimerWheel timer;
    private final TimerTask task;
    private final long deadline;

    @SuppressWarnings({"unused", "FieldMayBeFinal", "RedundantFieldInitialization" })
    private volatile int state = ST_INIT;

    // remainingRounds will be calculated and set by Worker.transferTimeoutsToBuckets() before the
    // HashedWheelTimeout will be added to the correct HashedWheelBucket.
    long remainingRounds;
    String argv;

    // This will be used to chain timeouts in HashedWheelTimerBucket via a double-linked-list.
    // As only the workerThread will act on it there is no need for synchronization / volatile.
    HashedWheelTimeout next;
    HashedWheelTimeout prev;

    // The bucket to which the timeout was added
    HashedWheelBucket bucket;

    HashedWheelTimeout(TimerWheel timer, TimerTask task, long deadline, String argv) {
      this.timer = timer;
      this.task = task;
      this.deadline = deadline;
      this.argv = argv;

    }

    @Override
    public Timer timer() {
      return timer;
    }

    @Override
    public TimerTask task() {
      return task;
    }

    @Override
    public boolean cancel() {
      // only update the state it will be removed from HashedWheelBucket on next tick.
      if (!compareAndSetState(ST_INIT, ST_CANCELLED)) {
 return false;
      }
      // If a task should be canceled we put this to another queue which will be processed on each tick.
      // So this means that we will have a GC latency of max. 1 tick duration which is good enough. This way
      // we can make again use of our MpsclinkedQueue and so minimize the locking / overhead as much as possible.
      timer.cancelledTimeouts.add(this);
      return true;
    }

    void remove() {
      HashedWheelBucket bucket = this.bucket;
      if (bucket != null) {
 bucket.remove(this);
      }
    }

    public boolean compareAndSetState(int expected, int state) {
      return STATE_UPDATER.compareAndSet(this, expected, state);
    }

    public int state() {
      return state;
    }

    @Override
    public boolean isCancelled() {
      return state() == ST_CANCELLED;
    }

    @Override
    public boolean isExpired() {
      return state() == ST_EXPIRED;
    }

    public void expire() {
      if (!compareAndSetState(ST_INIT, ST_EXPIRED)) {
 return;
      }

      try {
 task.run(this, argv);
      } catch (Throwable t) {
 if (logger.isWarnEnabled()) {
   logger.warn("An exception was thrown by " + TimerTask.class.getSimpleName() + '.', t);
 }
      }
    }

    @Override
    public String toString() {
      final long currentTime = System.nanoTime();
      long remaining = deadline - currentTime + timer.startTime;

      StringBuilder buf = new StringBuilder(192)
   .append(StringUtil.simpleClassName(this))
   .append('(')
   .append("deadline: ");
      if (remaining > 0) {
 buf.append(remaining)
     .append(" ns later");
      } else if (remaining < 0) {
 buf.append(-remaining)
     .append(" ns ago");
      } else {
 buf.append("now");
      }

      if (isCancelled()) {
 buf.append(", cancelled");
      }

      return buf.append(", task: ")
   .append(task())
   .append(')')
   .toString();
    }
  }

  
  private static final class HashedWheelBucket {
    // Used for the linked-list datastructure
    private HashedWheelTimeout head;
    private HashedWheelTimeout tail;

    
    public void addTimeout(HashedWheelTimeout timeout) {
      assert timeout.bucket == null;
      timeout.bucket = this;
      if (head == null) {
 head = tail = timeout;
      } else {
 tail.next = timeout;
 timeout.prev = tail;
 tail = timeout;
      }
    }

    
    public void expireTimeouts(long deadline) {
      HashedWheelTimeout timeout = head;

      // process all timeouts
      while (timeout != null) {
 boolean remove = false;
 if (timeout.remainingRounds <= 0) {
   if (timeout.deadline <= deadline) {
     timeout.expire();
   } else {
     // The timeout was placed into a wrong slot. This should never happen.
     throw new IllegalStateException(String.format(
  "timeout.deadline (%d) > deadline (%d)", timeout.deadline, deadline));
   }
   remove = true;
 } else if (timeout.isCancelled()) {
   remove = true;
 } else {
   timeout.remainingRounds --;
 }
 // store reference to next as we may null out timeout.next in the remove block.
 HashedWheelTimeout next = timeout.next;
 if (remove) {
   remove(timeout);
 }
 timeout = next;
      }
    }

    public void remove(HashedWheelTimeout timeout) {
      HashedWheelTimeout next = timeout.next;
      // remove timeout that was either processed or cancelled by updating the linked-list
      if (timeout.prev != null) {
 timeout.prev.next = next;
      }
      if (timeout.next != null) {
 timeout.next.prev = timeout.prev;
      }

      if (timeout == head) {
 // if timeout is also the tail we need to adjust the entry too
 if (timeout == tail) {
   tail = null;
   head = null;
 } else {
   head = next;
 }
      } else if (timeout == tail) {
 // if the timeout is the tail modify the tail to be the prev node.
 tail = timeout.prev;
      }
      // null out prev, next and bucket to allow for GC.
      timeout.prev = null;
      timeout.next = null;
      timeout.bucket = null;
    }

    
    public void clearTimeouts(Set set) {
      for (;;) {
 HashedWheelTimeout timeout = pollTimeout();
 if (timeout == null) {
   return;
 }
 if (timeout.isExpired() || timeout.isCancelled()) {
   continue;
 }
 set.add(timeout);
      }
    }

    private HashedWheelTimeout pollTimeout() {
      HashedWheelTimeout head = this.head;
      if (head == null) {
 return null;
      }
      HashedWheelTimeout next = head.next;
      if (next == null) {
 tail = this.head = null;
      } else {
 this.head = next;
 next.prev = null;
      }

      // null out prev and next to allow for GC.
      head.next = null;
      head.prev = null;
      head.bucket = null;
      return head;
    }
  }
}

编写测试类Main.java

package com.tanghuachun.timer;
import java.util.concurrent.TimeUnit;


public class Main implements TimerTask{
  final static Timer timer = new TimerWheel();


  public static void main(String[] args) {
    TimerTask timerTask = new Main();
    for (int i = 0; i < 10; i++) {
      timer.newTimeout(timerTask, 5, TimeUnit.SECONDS, "" + i );
    }
  }
  @Override
  public void run(Timeout timeout, String argv) throws Exception {
    System.out.println("timeout, argv = " + argv );
  }
}

然后就可以看到运行结果啦。

工程代码下载(以maven的方式导入)。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持考高分网。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/144677.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号