可重入锁实现

在做项目时遇到需要可重入的缓存锁场景,缓存锁用的是 Redis 分布式锁,于是考虑用 ThreadLocal 来实现本机上的可重入锁

ReentrantCacheLockManager:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@Component
public class ReentrantCacheLockManager implements ICacheLockManager {

@Resource
private ICacheService cacheService;

@Override
public boolean tryLock(String key, int expireSecond) {
return tryLock(Collections.singletonList(key), expireSecond);
}

@Override
public boolean tryLock(List<String> keyList, int expireSecond) {
Assert.isTrue(CollectionUtils.isNotEmpty(keyList), "获取锁失败:key 不能为空");

// 1.从未获取过锁的 key 先去获取锁
List<String> unCacheKeyList = keyList.stream()
.filter(k -> !LockKeyCountHolder.exist(k))
.collect(Collectors.toList());

// 批量获取锁
if (CollectionUtils.isNotEmpty(unCacheKeyList)) {
Map<String, String> lockMap = unCacheKeyList.stream().collect(Collectors.toMap(Function.identity(), k -> "1", (a, b) -> b));
long msetnx = cacheService.msetnx(lockMap, expireSecond);
// 未取到锁
if (Objects.equals(msetnx, 0L)) {
return false;
}
}

// 2.重入计数 +1
for (String key : keyList) {
LockKeyCountHolder.increment(key);
}

return true;
}

@Override
public void unlock(String key) {
unlock(Collections.singletonList(key));
}

@Override
public void unlock(List<String> keyList) {
List<String> releaseKeyList = Lists.newArrayList();

// 重入 -1
for (String key : keyList) {
if (LockKeyCountHolder.decrement(key) == 0) {
releaseKeyList.add(key);
}
}

// 释放锁
cacheService.del(releaseKeyList.toArray(new String[0]));
}

}

LockKeyCountHolder:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
/**
* 已锁的 key 重入计数
*
* @author mianxian
* 2023/8/2 14:56
*/
public class LockKeyCountHolder {

private LockKeyCountHolder() {
}

/**
* lockKey -> 重入次数
*/
private static final ThreadLocal<Map<String, Integer>> THREAD_LOCAL = ThreadLocal.withInitial(Maps::newConcurrentMap);

/**
* 计数 +1
*
* @param key key
* @return 操作后的计数
* @author ccomma
*/
public static int increment(String key) {
int count = get(key);
THREAD_LOCAL.get().put(key, ++count);
return count;
}

/**
* 获取计数
*
* @param key key
* @return 计数
* @author ccomma
*/
public static int get(String key) {
Map<String, Integer> keyMap = THREAD_LOCAL.get();
return keyMap.getOrDefault(key, 0);
}

/**
* 计数是否 > 0
*
* @param key key
* @return 是否 > 0
* @author ccomma
*/
public static boolean exist(String key) {
return get(key) > 0;
}

/**
* 计数 -1
*
* @param key key
* @return 操作后的计数
* @author ccomma
*/
public static int decrement(String key) {
if (!exist(key)) {
return 0;
}

int count = get(key) - 1;
if (count == 0) {
THREAD_LOCAL.get().remove(key);
} else {
THREAD_LOCAL.get().put(key, count);
}

return count;
}

}