package com.example;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.text.MessageFormat;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractOwnableSynchronizer;
import java.util.concurrent.locks.LockSupport;
/**
* 通过DB(PostgreSQL)实现的分布式锁。它具有以下功能:<br/>
* <pre>
* lock
* 有阻塞获取锁
* 可重入
* tryLock
* 无阻塞获取锁
* 可重入
* tryLock(time)
* 无阻塞获取锁
* 可重入
* 阻塞一段时间获取锁
* unlock
* *** 无死锁,当获取到分布式锁的进程死掉,需要提供释放锁的机制,否则锁将成死锁。
* </pre>
* 它还未有以下功能:<br/>
* <pre>
* lockInterruptibly
* tryLock(time)
* 响应中断
* newCondition
* </pre>
*/
public class DBLock extends AbstractOwnableSynchronizer {
private static final String insertSQL = "insert into lock_noblock(lock_name,state,expire_time) values(?,?) on conflict do nothing";
private static final String compareAndSwapSetExpireSQL = "update lock_noblock set state = ? ,expire_time=date+? where lock_name = ? and ( state = ? or expire_time < date)";
private static final String updateExpireSQL = "update lock_noblock set expire_time=expire_time+? where lock_name = ?";
/**
* 参考 {@link java.util.concurrent.locks.AbstractQueuedSynchronizer#spinForTimeoutThreshold spinForTimeoutThreshold}
*/
static final long spinForTimeoutThreshold = 1000L;
static final long expire = 30000L;
/**
* 仅仅在可重入使用,表示重入的次数,不可以通过state=0判断锁已经释放
*/
private int state;
private DataSource ds;
private String lockName;
private LockDBSupport lockDBSupport;
public DBLock(String lockName, DataSource ds) {
this.lockName = lockName;
this.ds = ds;
initData();
this.lockDBSupport = new LockDBSupport(lockName,ds);
}
public void lock() {
lock(-1,TimeUnit.SECONDS);
}
public void lock(long leaseTime, TimeUnit unit){
if (!tryGetLock()) {
while (true) {
if (compareAndSetStateAndSetExpire(0, 1,unit.toMillis(leaseTime))) {
setExclusiveOwnerThread(Thread.currentThread());
// 定时续期
return;
} else {
lockDBSupport.park();
}
}
}
}
boolean tryLock() {
return tryGetLock();
}
boolean tryLock(long time, TimeUnit unit) {
return tryLock(time,-1,unit);
}
boolean tryLock(long time,long leaseTime, TimeUnit unit) {
long nanosTimeout = unit.toNanos(time);
if (nanosTimeout <= 0L) {
return false;
}
final long deadline = System.nanoTime() + nanosTimeout;
if (!tryGetLock()) {
while (true) {
if (compareAndSetStateAndSetExpire(0, 1,unit.toMillis(leaseTime))) {
setExclusiveOwnerThread(Thread.currentThread());
// 定时续期
return true;
} else {
nanosTimeout = deadline - System.nanoTime();
if (nanosTimeout <= 0L) {
return false;
}
if (nanosTimeout > spinForTimeoutThreshold) {
lockDBSupport.parkNanos(nanosTimeout);
}
}
}
}
return false;
}
void unlock() {
if (!isHeldByCurrentThread()) {
throw new IllegalMonitorStateException();
} else {
int c = getState() - 1;
if (c == 0) {
setExclusiveOwnerThread(null);
}
// 这是设置可重入数,还没有释放锁
setState(c);
// 这儿原本可以直接更新state=0,这样就等于释放锁了,为了少些代码直接使用cas更新
compareAndSetStateAndSetExpire(1, 0,0);
lockDBSupport.unpark();
}
}
private void initData() {
PreparedStatement ps = null;
Connection conn = null;
try {
conn = ds.getConnection();
conn.setAutoCommit(true);
ps = conn.prepareStatement(insertSQL);
ps.setString(1, lockName);
ps.setInt(2, 0);
ps.execute();
} catch (SQLException e) {
throw new RuntimeException(MessageFormat.format("无法创建[{0}]锁", lockName), e);
} finally {
close(ps);
close(conn);
}
}
private int getState() {
return state;
}
private void setState(int state) {
this.state = state;
}
private boolean tryGetLock() {
if (isHeldByCurrentThread()) {
int nextc = getState() + 1;
if (nextc < 0) {// overflow
throw new Error("Maximum lock count exceeded");
}
setState(nextc);
return true;
} else {
if (compareAndSetStateAndSetExpire(0, 1,expire)) {
setState(1);
setExclusiveOwnerThread(Thread.currentThread());
// 定时续期
return true;
}
}
return false;
}
/**
* CAS获取锁,并且附带设置过期时间
* @param expect
* @param update
* @param expire
* @return
*/
private boolean compareAndSetStateAndSetExpire(int expect, int update,long expire) {
PreparedStatement ps = null;
Connection conn = null;
try {
conn = ds.getConnection();
ps = conn.prepareStatement(compareAndSwapSetExpireSQL);
ps.setInt(1, update);
ps.setLong(2,expire);
ps.setString(3, lockName);
ps.setInt(4, expect);
return ps.executeUpdate() == 1 ? true : false;
} catch (SQLException e) {
throw new RuntimeException(MessageFormat.format("找到[{0}]锁", lockName), e);
} finally {
close(ps);
close(conn);
}
}
private void renewal(long nanos) {
Thread thread = new Thread(new Runnable() {
@Override
public void run() {
// get expire
long expire = 0;
long gapTime = expire - System.nanoTime();
if(gapTime <= 0){
throw new RuntimeException("锁已过期");
}
try {
Thread.sleep(gapTime * 2/3);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
PreparedStatement ps = null;
Connection conn = null;
try {
conn = ds.getConnection();
ps = conn.prepareStatement(updateExpireSQL);
ps.setLong(1, nanos);
ps.setString(2, lockName);
} catch (SQLException e) {
throw new RuntimeException(MessageFormat.format("找到[{0}]锁", lockName), e);
} finally {
close(ps);
close(conn);
}
}
});
thread.start();
}
/**
* 锁是否由当前线程持有 </br>
* 不需要设置 exclusiveOwnerThread 为 volatile。因为 A线程设置exclusiveOwnerThread,B线程不管是否能取到这个设置的值都是false。
* @return true锁由当前线程持有;false锁不为当前线程持有。
*/
private boolean isHeldByCurrentThread() {
return Thread.currentThread() == getExclusiveOwnerThread();
}
private void close(AutoCloseable ac) {
try {
if (ac != null) {
ac.close();
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}