68. redis计数与限流中incr+expire的坑以及解决办法(Lua+TTL)-二、代码演进

时间:2024-02-01 22:09:24

第一版代码(存在bug隐患)

// 执行ocr调用
func (o *ocrSvc)doOcr(ctx context.Context,uid int)(interface,err){
	// 如果调用次数超过了指定限制,就直接拒绝此次请求
	ok,err := o.checkMinute(uid)
	if err != nil {
		return nil,err
	}
	if !ok {
		return nil,errors.News("frequently called")
	}
	// 执行第三方ocr调用(伪代码:模拟一个rpc接口)
	ocrRes,err := doOcrByThird()
	if err != nil {
		return nil,err
	}
	// 调用成功则执行 incr操作
	if err := o.redis.Incr(ctx,buildUserOcrCountKey(uid));err!=nil{
	   return nil,err
	}
	return ocrRes,nil
}
 
// 校验每分钟调用次数是否超过限制
func (o *ocrSvc)checkMinute (ctx context.Context,uid int) (bool, error) {
	minuteCount, err := o.redis.Get(ctx, buildUserOcrCountKey(uid))
	if err != nil && !errors.Is(err, eredis.Nil) {
		log.Error("checkMinute: redis.Get failed", zap.Error(err))
		return false, constx.ErrServer
	}
	if errors.Is(err, eredis.Nil) {
		// 过期了,或者没有该用户的调用次数记录(设置初始值为0,过期时间为1分钟)
		o.redis.Set(ctx, buildUserOcrCountKey(uid),0,time.Minute)
		return true, nil
	}
	// 已经超过每分钟的调用次数
	if cast.ToInt(minuteCount) >= config.UserOcrMinuteCount() {
		log.Warn("checkMinute: user FrequentlyCalled", zap.Int64("uid", uid), zap.String("minuteCount", minuteCount))
		return false, nil
	}
	return true, nil
}

详解
这一版代码存在什么问题呢?问题出在了 开始判断没有超出限制,然后执行第三方rpc接口调用也成功了,接下来直接进行计数加一(incr)操作有问题,如下图所示

在这里插入图片描述

说明:

  1. 假设当前用户在进行ocr识别时,未超过调用次数。但是在redis中的ttl还剩1秒钟
  2. 然后调用第三方ocr进行识别,加入耗时超过了1
  3. 识别成功后,调用次数+1。这里就很有可能出问题,比如:在incr的时候刚好该key1s前过期了,那么redis是怎么做的呢,它会将该key的值设置为1ttl设置为-1ttl设置为-1ttl设置为-1重要的事情说三遍),-1表示没有过期时间
  4. 这时候bug就出现了,用户的调用次数一直在涨,并且也不会过期,达到临界值时用户的请求就会被拒掉,相当于该用户之后都不能访问这个接口了,并且这种key变多后,由于没有过期时间,还会一直占用redis的内存。

总结
以上代码说明了一个问题,也就是increxpire必须具备原子性。而我们第一版代码显然在边界条件下是不满足要求的,极有可能造成bug,影响用户体验,强烈不推荐使用,接下来一步一步引入修正后的代码

第二版代码(几乎无隐患)

从对第一版代码的分析可知,是由于查询次数还没有达到限制后,又进行了一些rpc调用,或者处理了一些其他业务逻辑,这个时间内,可能key过期了,然后我们直接使用incr进行计数加一,导致了永不过期的key产生。那么我们是不是可以在incr前先保证key还没有过期就行呢?答案是可以的,代码如下:

// 执行ocr调用
func (o *ocrSvc)doOcr(ctx context.Context,uid int)(interface,err){
	// 如果调用次数超过了指定限制,就直接拒绝此次请求
	ok,err := o.checkMinute(uid)
	if err != nil {
		return nil,err
	}
	if !ok {
		return nil,errors.News("frequently called")
	}
	// 执行第三方ocr调用(伪代码:模拟一个rpc接口)
	ocrRes,err := doOcrByThird()
	if err != nil {
		return nil,err
	}
	
	// 调用成功则执行 incr操作
	exists, err := o.redis.Exists(ctx, buildUserOcrCountKey(uid)).Result()
	if err != nil {
		log.Error("doOcr: redis.Exists failed", zap.Error(err))
		return nil, err
	}
	if exists == 1 { // key存在,计数加1
		if err := o.redis.Incr(ctx,buildUserOcrCountKey(uid));err!=nil{
	    return nil,err
	}
	} else { // key不存在,设置key与过期时间
		if err := o.redis.Set(ctx,buildUserOcrCountKey(uid),1,expireTime);err!=nil{
		   return nil,err
		}
	}
	return ocrRes,nil
}
 
// 校验每分钟调用次数是否超过限制
func (o *ocrSvc)checkMinute (ctx context.Context,uid int) (bool, error) {
	minuteCount, err := o.redis.Get(ctx, buildUserOcrCountKey(uid))
	if err != nil && !errors.Is(err, eredis.Nil) {
		log.Error("checkMinute: redis.Get failed", zap.Error(err))
		return false, constx.ErrServer
	}
	if errors.Is(err, eredis.Nil) {
		// 过期了,或者没有该用户的调用次数记录(设置初始值为0,过期时间为1分钟)
		o.redis.Set(ctx, buildUserOcrCountKey(uid),0,time.Minute)
		return true, nil
	}
	// 已经超过每分钟的调用次数
	if cast.ToInt(minuteCount) >= config.UserOcrMinuteCount() {
		log.Warn("checkMinute: user FrequentlyCalled", zap.Int64("uid", uid), zap.String("minuteCount", minuteCount))
		return false, nil
	}
	return true, nil
}

与第一版的差异主要在于如下代码:

  • 在需要incr操作前,我们先查看key是否存在(且没有过期)
    • 确保存在后,立即incr(这两步间隔几乎可以忽略,所以几乎可以避免第一版中的问题)
    • 如果不存在则设置key并设置过期时间。

注:redis中的incr命令是不会改变key的过期时间的

// 调用成功则执行 incr操作
	exists, err := o.redis.Exists(ctx, buildUserOcrCountKey(uid)).Result()
	if err != nil {
		log.Error("doOcr: redis.Exists failed", zap.Error(err))
		return nil, err
	}
	if exists == 1 { // key存在,计数加1
		if err := o.redis.Incr(ctx,buildUserOcrCountKey(uid));err!=nil{
	    return nil,err
	}
	} else { // key不存在,设置key与过期时间
		if err := o.redis.Set(ctx,buildUserOcrCountKey(uid),1,expireTime);err!=nil{
		   return nil,err
		}
	}

还有一种方式是查看key的过期时间,使用ttl,这样即使在极端情况下通过incr设置出了没有过期时间的key,也会在第二次访问的时候通过Set设置过期时间了。

注:ttl命令返回值是键的剩余时间(单位是秒)。当键不存在时,ttl命令会返回-2。没有为键设置过期时间(即永久存在,这是建立一个键后的默认情况)返回-1。

// 调用成功则执行 incr操作
	cnt, err := o.redis.Ttl(ctx, buildUserOcrCountKey(uid)).Result()
	if err != nil {
		log.Error("doOcr: redis.Ttl failed", zap.Error(err))
		return nil, err
	}
	if cnt >= 1 { // key存在,且还没有过期
		if err := o.redis.Incr(ctx,buildUserOcrCountKey(uid));err!=nil{
	    return nil,err
	}
	} else { // key马上过期0,key没有过期时间-1, key不存在-2
		if err := o.redis.Set(ctx,buildUserOcrCountKey(uid),1,expireTime);err!=nil{
		   return nil,err
		}
	}

第三版代码(完美无瑕)

第二版代码中的两种方式其实已经可以在工作中使用了,但如果追求完美无瑕的话,ttl版本的代码在极端情况下还是有点瑕疵,比如极端情况下,key过期时间还有1s过期,然后我们用incr去累加,但是网络延迟了,导致命令到达redis服务器的时候,key已经过期了,尽管第二次访问会用set重置key并设置过期时间,但是万一该用户再也不来访问了呢?这时候这个key就会永远占据着内存了。

incr+expire放在lua脚本中执行保证原子性是最完美的。废话不多说了,直接上代码

// 执行ocr调用
func (o *ocrSvc)doOcr(ctx context.Context,uid int)(interface,err){
	// 如果调用次数超过了指定限制,就直接拒绝此次请求
	ok,err := o.checkMinute(uid)
	if err != nil {
		return nil,err
	}
	if !ok {
		return nil,errors.News("frequently called")
	}
	// 执行第三方ocr调用((伪代码:模拟一个rpc接口))
	ocrRes,err := doOcrByThird()
	if err != nil {
		return nil,err
	}
	// 调用成功则执行 incr操作
	if err := o.incrCount(ctx,buildUserOcrCountKey(uid));err!=nil{
	   return nil,err
	}
	return ocrRes,nil
}
 
func (o *ocrSvc) incrCount(ctx context.Context, uid int64) error {
   /*
   此段lua脚本的作用:
     第一步,先执行incr操作
     local current = redis.call('incr',KEYS[1])
     第二步,看下该key的ttl
     local t = redis.call('ttl',KEYS[1]); 
     第三步,如果ttl为-1(永不过期)
     if t == -1 then
         则重新设置过期时间为 「一分钟」
		 redis.call('expire',KEYS[1],ARGV[1])
	 end;
   */
	script := redis.NewScript(
		`local current = redis.call('incr',KEYS[1]);
		 local t = redis.call('ttl',KEYS[1]); 
		 if t == -1 then
		 	redis.call('expire',KEYS[1],ARGV[1])
		 end;
		 return current
	`)
	var (
		expireTime = 60 // 60 秒
	)
	_, err := script.Run(ctx, b.redis.Client(), []string{buildUserOcrCountKey(uid)}, expireTime).Result()
	if err != nil {
		return err
	}
	return nil
}
 
// 校验每分钟调用次数是否超过
func (o *ocrSvc)checkMinute (ctx context.Context,uid int) (bool, error) {
	minuteCount, err := o.redis.Get(ctx, buildUserOcrCountKey(uid))
	if err != nil && !errors.Is(err, eredis.Nil) {
		elog.Error("checkMinute: redis.Get failed", zap.Error(err))
		return false, constx.ErrServer
	}
	if errors.Is(err, eredis.Nil) {
	    // 第二版代码中在check时不进行初始化操作
		// 过期了,或者没有该用户的调用次数记录(设置初始值为0,过期时间为1分钟)
		// o.redis.Set(ctx, buildUserOcrCountKey(uid),0,time.Minute)
		return true, nil
	}
	// 已经超过每分钟的调用次数
	if cast.ToInt(minuteCount) >= config.UserOcrMinuteCount() {
		elog.Warn("checkMinute: user FrequentlyCalled", zap.Int64("uid", uid), zap.String("minuteCount", minuteCount))
		return false, nil
	}
	return true, nil
}