第一版代码(存在bug隐患)
// 执行ocr调用
func (o *ocrSvc)doOcr(ctx context.Context,uid int)(interface,err){
// 如果调用次数超过了指定限制,就直接拒绝此次请求
ok,err := o.checkMinute(uid)
if err != nil {
return nil,err
}
if !ok {
return nil,errors.News("frequently called")
}
// 执行第三方ocr调用(伪代码:模拟一个rpc接口)
ocrRes,err := doOcrByThird()
if err != nil {
return nil,err
}
// 调用成功则执行 incr操作
if err := o.redis.Incr(ctx,buildUserOcrCountKey(uid));err!=nil{
return nil,err
}
return ocrRes,nil
}
// 校验每分钟调用次数是否超过限制
func (o *ocrSvc)checkMinute (ctx context.Context,uid int) (bool, error) {
minuteCount, err := o.redis.Get(ctx, buildUserOcrCountKey(uid))
if err != nil && !errors.Is(err, eredis.Nil) {
log.Error("checkMinute: redis.Get failed", zap.Error(err))
return false, constx.ErrServer
}
if errors.Is(err, eredis.Nil) {
// 过期了,或者没有该用户的调用次数记录(设置初始值为0,过期时间为1分钟)
o.redis.Set(ctx, buildUserOcrCountKey(uid),0,time.Minute)
return true, nil
}
// 已经超过每分钟的调用次数
if cast.ToInt(minuteCount) >= config.UserOcrMinuteCount() {
log.Warn("checkMinute: user FrequentlyCalled", zap.Int64("uid", uid), zap.String("minuteCount", minuteCount))
return false, nil
}
return true, nil
}
详解
这一版代码存在什么问题呢?问题出在了 开始判断没有超出限制,然后执行第三方rpc接口调用也成功了,接下来直接进行计数加一(incr)操作有问题,如下图所示
说明:
- 假设当前用户在进行
ocr
识别时,未超过调用次数。但是在redis
中的ttl
还剩1
秒钟 - 然后调用第三方
ocr
进行识别,加入耗时超过了1
秒 - 识别成功后,调用次数
+1
。这里就很有可能出问题,比如:在incr
的时候刚好该key
在1s
前过期了,那么redis
是怎么做的呢,它会将该key
的值设置为1
,ttl
设置为-1
,ttl
设置为-1
,ttl
设置为-1
(重要的事情说三遍),-1表示没有过期时间 - 这时候
bug
就出现了,用户的调用次数一直在涨,并且也不会过期,达到临界值时用户的请求就会被拒掉,相当于该用户之后都不能访问这个接口了,并且这种key
变多后,由于没有过期时间,还会一直占用redis
的内存。
总结
以上代码说明了一个问题,也就是incr
和expire
必须具备原子性。而我们第一版代码显然在边界条件下是不满足要求的,极有可能造成bug
,影响用户体验,强烈不推荐使用,接下来一步一步引入修正后的代码
第二版代码(几乎无隐患)
从对第一版代码的分析可知,是由于查询次数还没有达到限制后,又进行了一些rpc
调用,或者处理了一些其他业务逻辑,这个时间内,可能key
过期了,然后我们直接使用incr
进行计数加一,导致了永不过期的key
产生。那么我们是不是可以在incr
前先保证key
还没有过期就行呢?答案是可以的,代码如下:
// 执行ocr调用
func (o *ocrSvc)doOcr(ctx context.Context,uid int)(interface,err){
// 如果调用次数超过了指定限制,就直接拒绝此次请求
ok,err := o.checkMinute(uid)
if err != nil {
return nil,err
}
if !ok {
return nil,errors.News("frequently called")
}
// 执行第三方ocr调用(伪代码:模拟一个rpc接口)
ocrRes,err := doOcrByThird()
if err != nil {
return nil,err
}
// 调用成功则执行 incr操作
exists, err := o.redis.Exists(ctx, buildUserOcrCountKey(uid)).Result()
if err != nil {
log.Error("doOcr: redis.Exists failed", zap.Error(err))
return nil, err
}
if exists == 1 { // key存在,计数加1
if err := o.redis.Incr(ctx,buildUserOcrCountKey(uid));err!=nil{
return nil,err
}
} else { // key不存在,设置key与过期时间
if err := o.redis.Set(ctx,buildUserOcrCountKey(uid),1,expireTime);err!=nil{
return nil,err
}
}
return ocrRes,nil
}
// 校验每分钟调用次数是否超过限制
func (o *ocrSvc)checkMinute (ctx context.Context,uid int) (bool, error) {
minuteCount, err := o.redis.Get(ctx, buildUserOcrCountKey(uid))
if err != nil && !errors.Is(err, eredis.Nil) {
log.Error("checkMinute: redis.Get failed", zap.Error(err))
return false, constx.ErrServer
}
if errors.Is(err, eredis.Nil) {
// 过期了,或者没有该用户的调用次数记录(设置初始值为0,过期时间为1分钟)
o.redis.Set(ctx, buildUserOcrCountKey(uid),0,time.Minute)
return true, nil
}
// 已经超过每分钟的调用次数
if cast.ToInt(minuteCount) >= config.UserOcrMinuteCount() {
log.Warn("checkMinute: user FrequentlyCalled", zap.Int64("uid", uid), zap.String("minuteCount", minuteCount))
return false, nil
}
return true, nil
}
与第一版的差异主要在于如下代码:
- 在需要
incr
操作前,我们先查看key
是否存在(且没有过期)- 确保存在后,立即
incr
(这两步间隔几乎可以忽略,所以几乎可以避免第一版中的问题) - 如果不存在则设置
key
并设置过期时间。
- 确保存在后,立即
注:redis中的incr命令是不会改变key的过期时间的
// 调用成功则执行 incr操作
exists, err := o.redis.Exists(ctx, buildUserOcrCountKey(uid)).Result()
if err != nil {
log.Error("doOcr: redis.Exists failed", zap.Error(err))
return nil, err
}
if exists == 1 { // key存在,计数加1
if err := o.redis.Incr(ctx,buildUserOcrCountKey(uid));err!=nil{
return nil,err
}
} else { // key不存在,设置key与过期时间
if err := o.redis.Set(ctx,buildUserOcrCountKey(uid),1,expireTime);err!=nil{
return nil,err
}
}
还有一种方式是查看key
的过期时间,使用ttl
,这样即使在极端情况下通过incr
设置出了没有过期时间的key
,也会在第二次访问的时候通过Set
设置过期时间了。
注:ttl命令返回值是键的剩余时间(单位是秒)。当键不存在时,ttl命令会返回-2。没有为键设置过期时间(即永久存在,这是建立一个键后的默认情况)返回-1。
// 调用成功则执行 incr操作
cnt, err := o.redis.Ttl(ctx, buildUserOcrCountKey(uid)).Result()
if err != nil {
log.Error("doOcr: redis.Ttl failed", zap.Error(err))
return nil, err
}
if cnt >= 1 { // key存在,且还没有过期
if err := o.redis.Incr(ctx,buildUserOcrCountKey(uid));err!=nil{
return nil,err
}
} else { // key马上过期0,key没有过期时间-1, key不存在-2
if err := o.redis.Set(ctx,buildUserOcrCountKey(uid),1,expireTime);err!=nil{
return nil,err
}
}
第三版代码(完美无瑕)
第二版代码中的两种方式其实已经可以在工作中使用了,但如果追求完美无瑕的话,ttl
版本的代码在极端情况下还是有点瑕疵,比如极端情况下,key
过期时间还有1s
过期,然后我们用incr
去累加,但是网络延迟了,导致命令到达redis
服务器的时候,key
已经过期了,尽管第二次访问会用set
重置key
并设置过期时间,但是万一该用户再也不来访问了呢?这时候这个key
就会永远占据着内存了。
将incr+expire
放在lua
脚本中执行保证原子性是最完美的。废话不多说了,直接上代码
// 执行ocr调用
func (o *ocrSvc)doOcr(ctx context.Context,uid int)(interface,err){
// 如果调用次数超过了指定限制,就直接拒绝此次请求
ok,err := o.checkMinute(uid)
if err != nil {
return nil,err
}
if !ok {
return nil,errors.News("frequently called")
}
// 执行第三方ocr调用((伪代码:模拟一个rpc接口))
ocrRes,err := doOcrByThird()
if err != nil {
return nil,err
}
// 调用成功则执行 incr操作
if err := o.incrCount(ctx,buildUserOcrCountKey(uid));err!=nil{
return nil,err
}
return ocrRes,nil
}
func (o *ocrSvc) incrCount(ctx context.Context, uid int64) error {
/*
此段lua脚本的作用:
第一步,先执行incr操作
local current = redis.call('incr',KEYS[1])
第二步,看下该key的ttl
local t = redis.call('ttl',KEYS[1]);
第三步,如果ttl为-1(永不过期)
if t == -1 then
则重新设置过期时间为 「一分钟」
redis.call('expire',KEYS[1],ARGV[1])
end;
*/
script := redis.NewScript(
`local current = redis.call('incr',KEYS[1]);
local t = redis.call('ttl',KEYS[1]);
if t == -1 then
redis.call('expire',KEYS[1],ARGV[1])
end;
return current
`)
var (
expireTime = 60 // 60 秒
)
_, err := script.Run(ctx, b.redis.Client(), []string{buildUserOcrCountKey(uid)}, expireTime).Result()
if err != nil {
return err
}
return nil
}
// 校验每分钟调用次数是否超过
func (o *ocrSvc)checkMinute (ctx context.Context,uid int) (bool, error) {
minuteCount, err := o.redis.Get(ctx, buildUserOcrCountKey(uid))
if err != nil && !errors.Is(err, eredis.Nil) {
elog.Error("checkMinute: redis.Get failed", zap.Error(err))
return false, constx.ErrServer
}
if errors.Is(err, eredis.Nil) {
// 第二版代码中在check时不进行初始化操作
// 过期了,或者没有该用户的调用次数记录(设置初始值为0,过期时间为1分钟)
// o.redis.Set(ctx, buildUserOcrCountKey(uid),0,time.Minute)
return true, nil
}
// 已经超过每分钟的调用次数
if cast.ToInt(minuteCount) >= config.UserOcrMinuteCount() {
elog.Warn("checkMinute: user FrequentlyCalled", zap.Int64("uid", uid), zap.String("minuteCount", minuteCount))
return false, nil
}
return true, nil
}