基于信号量的分布式限流


背景

有一个定时任务通过MapReduce的形式执行,父任务中进行分页计算,子任务根据分页信息查询数据库,然后处理数据,因为一些原因,需要控制并发查询数据库的线程数量,简单来说,就是做个限流。

解决方案

如果一次只允许一个线程查询数据库,通过加锁就可以解决,但如果多个线程,加锁再额外维护一个计数器也可以解决,但加锁的时候为了防止抢锁失败自旋占用CPU,一般会设置休眠,这就可能导致休眠时有线程释放锁,但是线程不能及时被唤醒,白白浪费一些时间,如下图所示。

思考片刻,突然想到操作系统用来进程同步的信号量机制,一个进程释放资源,紧接着会唤醒另外一个,不存在等待,效率很高。

复习一下 —— 信号量

信号量是操作系统中用来进行进程同步控制的一种机制,分为整型信号量和记录性信号量,本质就是一个变量,可以用来表示系统中某种资源的数量。
整形信号量只有一个变量表示资源数量,没有抢到资源的进程会自旋,进入忙等,占用CPU。
记录型信号量有一个变量表示资源数量,还有一个等待队列,没有抢到资源的进程进入等待队列,让出CPU,切换到阻塞状态等待资源,符合让权等待原则。

typedef struct {
    int value; // 某种资源的数量
    struct process *L; // 等待队列
} semaphore;

void wait(semaphore S) { // wait原语
    S.value--;
    if(S.value < 0) {
        bloack(S.L); // 阻塞进程
    }
}

void signal(semaphore S) { // signal原语
    S.value++;
    if(S.value <= 0) {
        wakeup(S.L); // 唤醒进程
    }
}

典型的信号量使用

信号量的使用有很多经典模型,比如:

  • 生产消费模型
  • 读写者模型
  • 哲学家就餐模型
  • 理发师睡觉模型

以最简单的生产消费模型为例,来看下信号量的使用:

/*
 * 盘子只能放一个水果
 * 爸爸往盘子里放苹果,妈妈往盘子里放橘子
 * 女儿只从盘子里拿苹果吃,儿子只从盘子里拿橘子吃
 */

semaphore apple = 0; // 盘子中苹果的数量
semaphore orange = 0; // 盘子中橘子数量
semaphore plate = 1; // 盘子中可以放水果的数量

dad() {
    while(1) {
        准备一个苹果;
        P(plate);  // 盘子有水果则被阻塞
        把苹果放入盘子;
        V(apple);
    }
}

mom() {
    while(1) {
        准备一个橘子;
        P(plate);
        把橘子放入盘子;
        V(orange);
    }
}

doughter() {
    while(1) {
        P(apple);  // 没有苹果则被阻塞
        从盘子中取出苹果;
        V(plate);
        吃掉苹果;
    }
}

son() {
    while(1) {
        P(orange);  // 没有橘子则被阻塞
        从盘子中取出橘子;
        V(plate);
        吃掉橘子;
    }
}

基于Redis实现分布式信号量

实现思路

信号量实现需要考虑信号量获取、计数、释放等问题,如果要实现让权等待,还需要考虑线程阻塞问题。同时,要有信号量过期机制,防止机器获取信号量后突然挂掉,无法打释放信号量,导致其他线程无法获取信号量。

综上,信号量实现可以基于Redis的list或者zset数据结构,使用list实现比较简单,类似生产消费模型,最初生产n个信号量放到list,通过RPOP/LPOP命令或者BRPOP/BLPOP命令获取信号量,使用完成后将信号量再push回list即可。但是使用list实现存在一个初始化过程,为了确保数据一致性,需要额外加锁,如果使用zset可以规避掉初始化过程,只需要考虑信号量的申请和释放,使用起来更方便一些。

基于ZSET实现

信号量获取:使用zset的ZADD命令将机器加入到集合中,确保score递增,然后ZRANK判断当前所处排名,如果未超过信号量数量则获取成功,如果要实现未获取到信号量线程阻塞,需要单独维护阻塞队列。
信号量释放:直接ZREM删除集合中表示该机器的标识即可。
信号量计数:zset本身有序,维护好score即可。
信号量过期机制:如果使用时间戳作为score,只要在每次获取信号量时删除一下过期的数据即可。如果使用单独序号,需要额外维护一个记录获取时间的zset,用来做超时检测。

代码实现

使用zset实现需要维护一个信号量集合,一个时间戳集合(检测信号量过期),一个阻塞队列,一个信号量序列号自增值。

申请信号量时,先通过时间戳集合清除过期信号量,然后将时间戳集合与信号量集合取交集,即可清除过期的信号量,然后获取一个信号量序列号,再将机器唯一标识作为member,分别以获取信号量时的时间戳和信号量序列号作为score加入到时间戳集合和信号量集合,因为信号量序列号严格递增,所以判断机器在信号量集合中的位次,再与信号量总量比较即可得知信号量申请结果。

释放信号量只需要将两个集合中的数据删除,再唤醒一个线程即可。

为了确保获取信号量操作的原子性,使用lua脚本操作Redis,lua脚本内容如下。

-- KEYS[1] 时间戳集合key
-- KEYS[2] 信号量集合key
-- ARGV[1] 获取信号量时的时间戳
-- ARGV[2] 机器唯一标识
-- ARGV[3] 信号量总数
-- ARGV[4] 过期信号量时间戳(小于此时间戳认为信号量过期)

-- 移除过期的信号量
redis.call("ZREMRANGEBYSCORE", KEYS[1], 0, ARGV[4])
redis.call("ZINTERSTORE", KEYS[2], 2, KEYS[1], KEYS[2], "AGGREGATE", "MIN")

-- 获取信号量
local cnt = redis.call("INCR", KEYS[3])
redis.call("ZADD", KEYS[1], ARGV[1], ARGV[2])
redis.call("ZADD", KEYS[2], cnt, ARGV[2])

-- 当前排名不超过信号量总量,则获取成功
local res = redis.call("ZRANK", KEYS[2], ARGV[2])
if res < tonumber(ARGV[3]) then
    return 'T'
else
    return 'F'
end
-- KEYS[1] 信号量集合key
-- KEYS[2] 时间戳集合key
-- KEYS[3] 阻塞队列key
-- ARGV[1] 机器唯一标识

-- 释放信号量
redis.call("ZREM", KEYS[1], ARGV[1])
redis.call("ZREM", KEYS[2], ARGV[1])
-- 唤醒一个等待者
redis.call("LPUSH", KEYS[3], "n")

Java代码如下:

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

@Slf4j
@Component
public class TairSemaphoreManager {

    @Resource
    private RedisManager redisManager;
    
    // 阻塞重试次数
    public static int SEMAPHORE_BLOCK_RETRY = 3;
    
    private String identify;
    
    private String acquireScript;
    
    private String releaseScript;
    
    @PostConstruct
    public void init() {
        try (InputStream acquireInput = Thread.currentThread().getContextClassLoader().getResourceAsStream("lua/semaphore.lua");
             InputStream releaseInput = Thread.currentThread().getContextClassLoader().getResourceAsStream("lua/release.lua");) {
            // 加载脚本
            byte[] by = new byte[acquireInput.available()];
            acquireInput.read(by);
            acquireScript = new String(by, StandardCharsets.UTF_8);
            byte[] bytes = new byte[releaseInput.available()];
            releaseInput.read(bytes);
            releaseScript = new String(bytes, StandardCharsets.UTF_8);
            // 获取唯一标识
            identify = Runtime.getRuntime().exec("hostname").toString();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
    
    /**
     * * 获取信号量(非阻塞)
     *
     * @param semaphoreKey 信号量标志
     * @param timerKey     计时器标志
     * @param sequenceKey  序列号标志
     * @param maxPermit    信号量最大数量
     * @param expire       信号量有效时间(毫秒)
     * @return true:成功 false:失败
     */
    public boolean tryAcquire(String semaphoreKey, String timerKey, String sequenceKey, int maxPermit, int expire) {
        long timestamp = System.currentTimeMillis();
        String invalidTime = String.valueOf(timestamp - expire);
        List<String> keys = Arrays.asList(timerKey, semaphoreKey, sequenceKey);
        List<String> args = Arrays.asList(String.valueOf(timestamp), identify, String.valueOf(maxPermit), invalidTime);
        String res = (String) redisManager.runLuaScript(acquireScript, keys, args);
    
        return "T".equals(res);
    }
    
    /**
     * 获取信号量(阻塞)
     *
     * @param semaphoreKey 信号量标志
     * @param timerKey     计时器标志
     * @param sequenceKey  序列号标志
     * @param waitKey      等待标志
     * @param timeout      等待超时时间(秒)
     * @param maxPermit    信号量最大数量
     * @param expire       信号量有效时间(毫秒)
     * @return true:成功 false:失败
     */
    public boolean acquire(String semaphoreKey, String timerKey, String sequenceKey, String waitKey,
                              int timeout, int maxPermit, int expire) {
    
        boolean res = tryAcquire(semaphoreKey, timerKey, sequenceKey, maxPermit, expire);
    
        if (!res) {
            int i = 0;
            while (i < SEMAPHORE_BLOCK_RETRY) {
                List<String> list = redisManager.brpop(waitKey, timeout);
                if (list != null) {
                    return true;
                }
                i++;
            }
            return false;
        }
    
        return true;
    }
    
    /**
     * 释放信号量
     *
     * @param semaphoreKey 信号量标志
     * @param timerKey     计时器标志
     * @param waitKey      等待标志
     */
    public void release(String semaphoreKey, String timerKey, String waitKey) {
        List<String> keys = Arrays.asList(semaphoreKey, timerKey, waitKey);
        redisManager.runLuaScript(releaseScript, keys, Collections.singletonList(identify));
    }
}

声明:迟於|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - 基于信号量的分布式限流


栖迟於一丘