package com.example.test.juc;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;
public class MyReentrantLock implements Lock {
private static final int UNLOCK = -1;
private static final int FIRST = 1;
private boolean fair;
private Thread current;
private AtomicInteger state = new AtomicInteger(UNLOCK);
private Condition condition ;
abstract class MyAbstractCondition implements Condition{
protected Object lockObj = new Object();
}
class MyNonFairCondition extends MyAbstractCondition{
@Override
public void await() throws InterruptedException {
synchronized (lockObj){
//挂起当前线程,等待其他线程调用该对象的notify唤醒
lockObj.wait();
}
}
@Override
public void awaitUninterruptibly() {
}
@Override
public long awaitNanos(long nanosTimeout) throws InterruptedException {
return 0;
}
@Override
public boolean await(long time, TimeUnit unit) throws InterruptedException {
return false;
}
@Override
public boolean awaitUntil(Date deadline) throws InterruptedException {
return false;
}
@Override
public void signal() {
synchronized (lockObj){
//唤醒一个等待该对象的线程
lockObj.notify();
}
}
@Override
public void signalAll() {
synchronized (lockObj){
//唤醒所有等待该对象的线程
lockObj.notifyAll();
}
}
}
class MyFairCondition extends MyAbstractCondition{
private HashSet threadHashSet = new HashSet<>();
private Queue threadQueue = new linkedList<>();
@Override
public void await() throws InterruptedException {
Thread thread = Thread.currentThread();
//没有正在排队就放进队列排队
if (!threadHashSet.contains(thread)){
threadQueue.add(thread);
threadHashSet.add(thread);
}
//把当前线程挂起
LockSupport.park();
}
@Override
public void awaitUninterruptibly() {
}
@Override
public long awaitNanos(long nanosTimeout) throws InterruptedException {
Thread thread = Thread.currentThread();
//没有正在排队就放进队列排队
if (!threadHashSet.contains(thread)){
threadQueue.add(thread);
threadHashSet.add(thread);
}
//挂起,最多持续一段时间
LockSupport.parkNanos(nanosTimeout);
return 0;
}
@Override
public boolean await(long time, TimeUnit unit) throws InterruptedException {
awaitNanos(TimeUnit.NANOSECONDS.convert(time,unit));
return true;
}
@Override
public boolean awaitUntil(Date deadline) throws InterruptedException {
awaitNanos(TimeUnit.NANOSECONDS.convert(deadline.getTime() - System.currentTimeMillis(),TimeUnit.MILLISECONDS));
return false;
}
@Override
public void signal() {
synchronized (lockObj){
if (threadQueue.size() > 0){
//让线程起来
Thread thread = threadQueue.poll();
threadHashSet.remove(thread);
LockSupport.unpark(thread);
}
}
}
@Override
public void signalAll() {
synchronized (lockObj){
//唤醒所有的等待线程
while (!threadQueue.isEmpty()){
Thread thread = threadQueue.poll();
LockSupport.unpark(thread);
}
threadHashSet.clear();
}
}
}
public MyReentrantLock(){
this(false);
}
public MyReentrantLock(boolean fair){
this.fair = fair;
condition = newCondition();
}
@Override
public void lock() {
Thread thread = Thread.currentThread();
for (;;){
if (current != null){
if (current.equals(thread)){
//重入
//自增完成,函数返回
state.incrementAndGet();
return;
}else {
//不是自己的锁,让这个线程等
try {
condition.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}else {
//抢占锁
//尝试锁定
if (state.compareAndSet(UNLOCK,FIRST)){
//锁定成功
current = Thread.currentThread();
return;
}else {
//锁定失败,被别人抢走了,那继续等吧
try {
condition.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
}
@Override
public void lockInterruptibly() throws InterruptedException {
}
@Override
public boolean tryLock() {
Thread thread = Thread.currentThread();
//有线程占有了
if (current != null){
if (current.equals(thread)){
//重入
//自增完成,函数返回
state.incrementAndGet();
return true;
}
}else {
//抢占锁
//尝试锁定
if (state.compareAndSet(UNLOCK,FIRST)){
//锁定成功
current = Thread.currentThread();
return true;
}
}
return false;
}
@Override
public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
long expire =System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(time,unit);
while (System.currentTimeMillis() < expire){
if (tryLock()){
return true;
}
}
return false;
}
@Override
public void unlock() {
Thread thread = Thread.currentThread();
//有线程占有并且属于当前线程
if (current != null && current.equals(thread)){
//开始解锁
int state = this.state.decrementAndGet();
//全部解锁,释放掉锁
if (state == 0){
current = null;
this.state.compareAndSet(state,UNLOCK);
//唤醒一个线程
condition.signal();
}
}
}
@Override
public Condition newCondition() {
if (this.fair){
return new MyFairCondition();
}
return new MyNonFairCondition();
}
}
测试类
package com.example.test;
import com.example.test.juc.MyReentrantLock;
import org.junit.jupiter.api.Test;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
public class MyReentrantLockTest {
int ticket = 1000;
@Test
public void testLock() {
MyReentrantLock lock = new MyReentrantLock(true);
AtomicInteger sell = new AtomicInteger(0);
Thread t1 = new Thread(() -> {
for (; ; ) {
try {
System.out.println(Thread.currentThread().getName() + " 尝试获得锁");
lock.lock();
System.out.println(Thread.currentThread().getName() + " 获得了锁");
if (ticket > 0) {
ticket--;
System.out.println(Thread.currentThread().getName() + " 卖了一张票,剩下" + ticket + "张票");
sell.incrementAndGet();
} else {
return;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
System.out.println(Thread.currentThread().getName() + " 释放了锁");
lock.unlock();
try {
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}, "【t1线程】");
t1.start();
Thread t2 = new Thread(() -> {
for (; ; ) {
try {
System.out.println(Thread.currentThread().getName() + " 尝试获得锁");
lock.lock();
System.out.println(Thread.currentThread().getName() + " 获得了锁");
if (ticket > 0) {
ticket--;
System.out.println(Thread.currentThread().getName() + " 卖了一张票,剩下" + ticket + "张票");
sell.incrementAndGet();
} else {
return;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
System.out.println(Thread.currentThread().getName() + " 释放了锁");
lock.unlock();
try {
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}, "【t2线程】");
t2.start();
try {
//等待两个线程运行结束才结束主线程
t1.join();
t2.join();
System.out.println("最终两个线程总共卖出了" + sell + "张票");
} catch (Exception e) {
e.printStackTrace();
}
}
@Test
public void testReentrant() {
MyReentrantLock lock = new MyReentrantLock();
AtomicInteger sell = new AtomicInteger(0);
Thread t1 = new Thread(() -> {
for (; ; ) {
try {
lock.lock();
if (ticket > 0) {
//卖出一张
sellTicket(sell);
//假设线程再次锁
try {
lock.lock();
if (ticket > 0) {
//再次卖出
sellTicket(sell);
}
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
} else {
return;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
});
t1.start();
Thread t2 = new Thread(() -> {
for (; ; ) {
try {
lock.lock();
if (ticket > 0) {
//卖出一张
sellTicket(sell);
//假设线程再次锁
try {
lock.lock();
if (ticket > 0) {
//再次卖出
sellTicket(sell);
}
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
} else {
return;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
});
t2.start();
try {
//等待两个线程运行结束才结束主线程
t1.join();
t2.join();
System.out.println("两个线程总共卖出" + sell.intValue() + " 张票");
} catch (Exception e) {
e.printStackTrace();
}
}
private void sellTicket(AtomicInteger sell) {
ticket--;
sell.incrementAndGet();
System.out.println(Thread.currentThread().getName() + " 售出了一张票,剩余:" + ticket + "张票");
}
}
公平锁
非公平锁



