栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

Java 手写一个可重入锁(带注释)

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Java 手写一个可重入锁(带注释)

手写一个可重入锁
package com.example.test.juc;

import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;


public class MyReentrantLock implements Lock {
    
    private static final int UNLOCK = -1;
    
    private static final int FIRST = 1;


    
    private boolean fair;
    
    private Thread current;
    
    private AtomicInteger state = new AtomicInteger(UNLOCK);




    
    private Condition condition ;

    abstract class MyAbstractCondition implements Condition{
        
        protected Object lockObj = new Object();
    }

    
    class MyNonFairCondition extends MyAbstractCondition{
        @Override
        public void await() throws InterruptedException {
            synchronized (lockObj){
                //挂起当前线程,等待其他线程调用该对象的notify唤醒
                lockObj.wait();
            }
        }

        @Override
        public void awaitUninterruptibly() {

        }

        @Override
        public long awaitNanos(long nanosTimeout) throws InterruptedException {
            return 0;
        }

        @Override
        public boolean await(long time, TimeUnit unit) throws InterruptedException {
            return false;
        }

        @Override
        public boolean awaitUntil(Date deadline) throws InterruptedException {
            return false;
        }

        @Override
        public void signal() {
            synchronized (lockObj){
                //唤醒一个等待该对象的线程
                lockObj.notify();
            }
        }

        @Override
        public void signalAll() {
            synchronized (lockObj){
                //唤醒所有等待该对象的线程
                lockObj.notifyAll();
            }
        }
    }
    
    class MyFairCondition extends MyAbstractCondition{

        
        private HashSet threadHashSet = new HashSet<>();
        
        private Queue threadQueue = new linkedList<>();

        
        @Override
        public void await() throws InterruptedException {
            Thread thread = Thread.currentThread();
            //没有正在排队就放进队列排队
            if (!threadHashSet.contains(thread)){
                threadQueue.add(thread);
                threadHashSet.add(thread);
            }
            //把当前线程挂起
            LockSupport.park();
        }

        
        @Override
        public void awaitUninterruptibly() {

        }


        
        @Override
        public long awaitNanos(long nanosTimeout) throws InterruptedException {
            Thread thread = Thread.currentThread();
            //没有正在排队就放进队列排队
            if (!threadHashSet.contains(thread)){
                threadQueue.add(thread);
                threadHashSet.add(thread);
            }
            //挂起,最多持续一段时间
            LockSupport.parkNanos(nanosTimeout);
            return 0;
        }

        
        @Override
        public boolean await(long time, TimeUnit unit) throws InterruptedException {
            awaitNanos(TimeUnit.NANOSECONDS.convert(time,unit));
            return true;
        }

        
        @Override
        public boolean awaitUntil(Date deadline) throws InterruptedException {
            awaitNanos(TimeUnit.NANOSECONDS.convert(deadline.getTime() - System.currentTimeMillis(),TimeUnit.MILLISECONDS));
            return false;
        }

        
        @Override
        public void signal() {
            synchronized (lockObj){
                if (threadQueue.size() > 0){
                    //让线程起来
                    Thread thread = threadQueue.poll();
                    threadHashSet.remove(thread);
                    LockSupport.unpark(thread);
                }
            }
        }
        
        @Override
        public void signalAll() {
            synchronized (lockObj){
                //唤醒所有的等待线程
                while (!threadQueue.isEmpty()){
                    Thread thread = threadQueue.poll();
                    LockSupport.unpark(thread);
                }
                threadHashSet.clear();
            }
        }
    }

    public MyReentrantLock(){
        this(false);
    }

    public MyReentrantLock(boolean fair){
        this.fair = fair;
        condition = newCondition();
    }


    
    @Override
    public void lock() {
        Thread thread = Thread.currentThread();
        for (;;){
            if (current != null){
                if (current.equals(thread)){
                    //重入
                    //自增完成,函数返回
                    state.incrementAndGet();
                    return;
                }else {
                    //不是自己的锁,让这个线程等
                    try {
                        condition.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }else {
                //抢占锁
                //尝试锁定
                if (state.compareAndSet(UNLOCK,FIRST)){
                    //锁定成功
                    current = Thread.currentThread();
                    return;
                }else {
                    //锁定失败,被别人抢走了,那继续等吧
                    try {
                        condition.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    @Override
    public void lockInterruptibly() throws InterruptedException {

    }

    
    @Override
    public boolean tryLock() {
        Thread thread = Thread.currentThread();
        //有线程占有了
        if (current != null){
            if (current.equals(thread)){
                //重入
                //自增完成,函数返回
                state.incrementAndGet();
                return true;
            }
        }else {
            //抢占锁
            //尝试锁定
            if (state.compareAndSet(UNLOCK,FIRST)){
                //锁定成功
                current = Thread.currentThread();
                return true;
            }
        }
        return false;
    }

    
    @Override
    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        long expire =System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(time,unit);
        while (System.currentTimeMillis() < expire){
            if (tryLock()){
                return true;
            }
        }
        return false;
    }

    
    @Override
    public void unlock() {
        Thread thread = Thread.currentThread();
        //有线程占有并且属于当前线程
        if (current != null && current.equals(thread)){
            //开始解锁
            int state = this.state.decrementAndGet();
            //全部解锁,释放掉锁
            if (state == 0){
                current = null;
                this.state.compareAndSet(state,UNLOCK);
                //唤醒一个线程
                condition.signal();
            }
        }
    }

    
    @Override
    public Condition newCondition() {
        if (this.fair){
            return new MyFairCondition();
        }
        return new MyNonFairCondition();
    }
}

测试类
package com.example.test;

import com.example.test.juc.MyReentrantLock;
import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;


public class MyReentrantLockTest {
    int ticket = 1000;

    
    @Test
    public void testLock() {
        MyReentrantLock lock = new MyReentrantLock(true);
        AtomicInteger sell = new AtomicInteger(0);

        Thread t1 = new Thread(() -> {
            for (; ; ) {
                try {
                    System.out.println(Thread.currentThread().getName() + " 尝试获得锁");
                    lock.lock();
                    System.out.println(Thread.currentThread().getName() + " 获得了锁");
                    if (ticket > 0) {
                        ticket--;
                        System.out.println(Thread.currentThread().getName() + " 卖了一张票,剩下" + ticket + "张票");
                        sell.incrementAndGet();
                    } else {
                        return;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    System.out.println(Thread.currentThread().getName() + " 释放了锁");
                    lock.unlock();
                    try {
                        Thread.sleep(1);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }, "【t1线程】");
        t1.start();

        Thread t2 = new Thread(() -> {
            for (; ; ) {
                try {
                    System.out.println(Thread.currentThread().getName() + " 尝试获得锁");
                    lock.lock();
                    System.out.println(Thread.currentThread().getName() + " 获得了锁");
                    if (ticket > 0) {
                        ticket--;
                        System.out.println(Thread.currentThread().getName() + " 卖了一张票,剩下" + ticket + "张票");
                        sell.incrementAndGet();
                    } else {
                        return;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    System.out.println(Thread.currentThread().getName() + " 释放了锁");
                    lock.unlock();
                    try {
                        Thread.sleep(1);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }, "【t2线程】");
        t2.start();

        try {
            //等待两个线程运行结束才结束主线程
            t1.join();
            t2.join();
            System.out.println("最终两个线程总共卖出了" + sell + "张票");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    
    @Test
    public void testReentrant() {
        MyReentrantLock lock = new MyReentrantLock();
        AtomicInteger sell = new AtomicInteger(0);
        Thread t1 = new Thread(() -> {
            for (; ; ) {
                try {
                    lock.lock();
                    if (ticket > 0) {
                        //卖出一张
                        sellTicket(sell);
                        //假设线程再次锁
                        try {
                            lock.lock();
                            if (ticket > 0) {
                                //再次卖出
                                sellTicket(sell);
                            }
                        } catch (Exception e) {
                            e.printStackTrace();
                        } finally {
                            lock.unlock();
                        }
                    } else {
                        return;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    lock.unlock();
                }

            }
        });
        t1.start();


        Thread t2 = new Thread(() -> {
            for (; ; ) {
                try {
                    lock.lock();
                    if (ticket > 0) {
                        //卖出一张
                        sellTicket(sell);
                        //假设线程再次锁
                        try {
                            lock.lock();
                            if (ticket > 0) {
                                //再次卖出
                                sellTicket(sell);
                            }
                        } catch (Exception e) {
                            e.printStackTrace();
                        } finally {
                            lock.unlock();
                        }
                    } else {
                        return;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    lock.unlock();
                }

            }
        });
        t2.start();
        try {
            //等待两个线程运行结束才结束主线程
            t1.join();
            t2.join();
            System.out.println("两个线程总共卖出" + sell.intValue() + " 张票");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void sellTicket(AtomicInteger sell) {
        ticket--;
        sell.incrementAndGet();
        System.out.println(Thread.currentThread().getName() + " 售出了一张票,剩余:" + ticket + "张票");
    }
}

公平锁

非公平锁

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/357821.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号