背景
有一个定时任务通过MapReduce的形式执行,父任务中进行分页计算,子任务根据分页信息查询数据库,然后处理数据,因为一些原因,需要控制并发查询数据库的线程数量,简单来说,就是做个限流。
解决方案
如果一次只允许一个线程查询数据库,通过加锁就可以解决,但如果多个线程,加锁再额外维护一个计数器也可以解决,但加锁的时候为了防止抢锁失败自旋占用CPU,一般会设置休眠,这就可能导致休眠时有线程释放锁,但是线程不能及时被唤醒,白白浪费一些时间,如下图所示。
思考片刻,突然想到操作系统用来进程同步的信号量机制,一个进程释放资源,紧接着会唤醒另外一个,不存在等待,效率很高。
复习一下 —— 信号量
信号量是操作系统中用来进行进程同步控制的一种机制,分为整型信号量和记录性信号量,本质就是一个变量,可以用来表示系统中某种资源的数量。
整形信号量只有一个变量表示资源数量,没有抢到资源的进程会自旋,进入忙等,占用CPU。
记录型信号量有一个变量表示资源数量,还有一个等待队列,没有抢到资源的进程进入等待队列,让出CPU,切换到阻塞状态等待资源,符合让权等待原则。
typedef struct {
int value; // 某种资源的数量
struct process *L; // 等待队列
} semaphore;
void wait(semaphore S) { // wait原语
S.value--;
if(S.value < 0) {
bloack(S.L); // 阻塞进程
}
}
void signal(semaphore S) { // signal原语
S.value++;
if(S.value <= 0) {
wakeup(S.L); // 唤醒进程
}
}
典型的信号量使用
信号量的使用有很多经典模型,比如:
- 生产消费模型
- 读写者模型
- 哲学家就餐模型
- 理发师睡觉模型
以最简单的生产消费模型为例,来看下信号量的使用:
/*
* 盘子只能放一个水果
* 爸爸往盘子里放苹果,妈妈往盘子里放橘子
* 女儿只从盘子里拿苹果吃,儿子只从盘子里拿橘子吃
*/
semaphore apple = 0; // 盘子中苹果的数量
semaphore orange = 0; // 盘子中橘子数量
semaphore plate = 1; // 盘子中可以放水果的数量
dad() {
while(1) {
准备一个苹果;
P(plate); // 盘子有水果则被阻塞
把苹果放入盘子;
V(apple);
}
}
mom() {
while(1) {
准备一个橘子;
P(plate);
把橘子放入盘子;
V(orange);
}
}
doughter() {
while(1) {
P(apple); // 没有苹果则被阻塞
从盘子中取出苹果;
V(plate);
吃掉苹果;
}
}
son() {
while(1) {
P(orange); // 没有橘子则被阻塞
从盘子中取出橘子;
V(plate);
吃掉橘子;
}
}
基于Redis实现分布式信号量
实现思路
信号量实现需要考虑信号量获取、计数、释放等问题,如果要实现让权等待,还需要考虑线程阻塞问题。同时,要有信号量过期机制,防止机器获取信号量后突然挂掉,无法打释放信号量,导致其他线程无法获取信号量。
综上,信号量实现可以基于Redis的list或者zset数据结构,使用list实现比较简单,类似生产消费模型,最初生产n个信号量放到list,通过RPOP/LPOP命令或者BRPOP/BLPOP命令获取信号量,使用完成后将信号量再push回list即可。但是使用list实现存在一个初始化过程,为了确保数据一致性,需要额外加锁,如果使用zset可以规避掉初始化过程,只需要考虑信号量的申请和释放,使用起来更方便一些。
基于ZSET实现
信号量获取:使用zset的ZADD命令将机器加入到集合中,确保score递增,然后ZRANK判断当前所处排名,如果未超过信号量数量则获取成功,如果要实现未获取到信号量线程阻塞,需要单独维护阻塞队列。
信号量释放:直接ZREM删除集合中表示该机器的标识即可。
信号量计数:zset本身有序,维护好score即可。
信号量过期机制:如果使用时间戳作为score,只要在每次获取信号量时删除一下过期的数据即可。如果使用单独序号,需要额外维护一个记录获取时间的zset,用来做超时检测。
代码实现
使用zset实现需要维护一个信号量集合,一个时间戳集合(检测信号量过期),一个阻塞队列,一个信号量序列号自增值。
申请信号量时,先通过时间戳集合清除过期信号量,然后将时间戳集合与信号量集合取交集,即可清除过期的信号量,然后获取一个信号量序列号,再将机器唯一标识作为member,分别以获取信号量时的时间戳和信号量序列号作为score加入到时间戳集合和信号量集合,因为信号量序列号严格递增,所以判断机器在信号量集合中的位次,再与信号量总量比较即可得知信号量申请结果。
释放信号量只需要将两个集合中的数据删除,再唤醒一个线程即可。
为了确保获取信号量操作的原子性,使用lua脚本操作Redis,lua脚本内容如下。
-- KEYS[1] 时间戳集合key
-- KEYS[2] 信号量集合key
-- ARGV[1] 获取信号量时的时间戳
-- ARGV[2] 机器唯一标识
-- ARGV[3] 信号量总数
-- ARGV[4] 过期信号量时间戳(小于此时间戳认为信号量过期)
-- 移除过期的信号量
redis.call("ZREMRANGEBYSCORE", KEYS[1], 0, ARGV[4])
redis.call("ZINTERSTORE", KEYS[2], 2, KEYS[1], KEYS[2], "AGGREGATE", "MIN")
-- 获取信号量
local cnt = redis.call("INCR", KEYS[3])
redis.call("ZADD", KEYS[1], ARGV[1], ARGV[2])
redis.call("ZADD", KEYS[2], cnt, ARGV[2])
-- 当前排名不超过信号量总量,则获取成功
local res = redis.call("ZRANK", KEYS[2], ARGV[2])
if res < tonumber(ARGV[3]) then
return 'T'
else
return 'F'
end
-- KEYS[1] 信号量集合key
-- KEYS[2] 时间戳集合key
-- KEYS[3] 阻塞队列key
-- ARGV[1] 机器唯一标识
-- 释放信号量
redis.call("ZREM", KEYS[1], ARGV[1])
redis.call("ZREM", KEYS[2], ARGV[1])
-- 唤醒一个等待者
redis.call("LPUSH", KEYS[3], "n")
Java代码如下:
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@Slf4j
@Component
public class TairSemaphoreManager {
@Resource
private RedisManager redisManager;
// 阻塞重试次数
public static int SEMAPHORE_BLOCK_RETRY = 3;
private String identify;
private String acquireScript;
private String releaseScript;
@PostConstruct
public void init() {
try (InputStream acquireInput = Thread.currentThread().getContextClassLoader().getResourceAsStream("lua/semaphore.lua");
InputStream releaseInput = Thread.currentThread().getContextClassLoader().getResourceAsStream("lua/release.lua");) {
// 加载脚本
byte[] by = new byte[acquireInput.available()];
acquireInput.read(by);
acquireScript = new String(by, StandardCharsets.UTF_8);
byte[] bytes = new byte[releaseInput.available()];
releaseInput.read(bytes);
releaseScript = new String(bytes, StandardCharsets.UTF_8);
// 获取唯一标识
identify = Runtime.getRuntime().exec("hostname").toString();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* * 获取信号量(非阻塞)
*
* @param semaphoreKey 信号量标志
* @param timerKey 计时器标志
* @param sequenceKey 序列号标志
* @param maxPermit 信号量最大数量
* @param expire 信号量有效时间(毫秒)
* @return true:成功 false:失败
*/
public boolean tryAcquire(String semaphoreKey, String timerKey, String sequenceKey, int maxPermit, int expire) {
long timestamp = System.currentTimeMillis();
String invalidTime = String.valueOf(timestamp - expire);
List<String> keys = Arrays.asList(timerKey, semaphoreKey, sequenceKey);
List<String> args = Arrays.asList(String.valueOf(timestamp), identify, String.valueOf(maxPermit), invalidTime);
String res = (String) redisManager.runLuaScript(acquireScript, keys, args);
return "T".equals(res);
}
/**
* 获取信号量(阻塞)
*
* @param semaphoreKey 信号量标志
* @param timerKey 计时器标志
* @param sequenceKey 序列号标志
* @param waitKey 等待标志
* @param timeout 等待超时时间(秒)
* @param maxPermit 信号量最大数量
* @param expire 信号量有效时间(毫秒)
* @return true:成功 false:失败
*/
public boolean acquire(String semaphoreKey, String timerKey, String sequenceKey, String waitKey,
int timeout, int maxPermit, int expire) {
boolean res = tryAcquire(semaphoreKey, timerKey, sequenceKey, maxPermit, expire);
if (!res) {
int i = 0;
while (i < SEMAPHORE_BLOCK_RETRY) {
List<String> list = redisManager.brpop(waitKey, timeout);
if (list != null) {
return true;
}
i++;
}
return false;
}
return true;
}
/**
* 释放信号量
*
* @param semaphoreKey 信号量标志
* @param timerKey 计时器标志
* @param waitKey 等待标志
*/
public void release(String semaphoreKey, String timerKey, String waitKey) {
List<String> keys = Arrays.asList(semaphoreKey, timerKey, waitKey);
redisManager.runLuaScript(releaseScript, keys, Collections.singletonList(identify));
}
}
jiyouzhan
这篇文章写得深入浅出,让我这个小白也看懂了!
迟於
@jiyouzhan :