栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

使用注解实现限流

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

使用注解实现限流

使用注解实现限流 一、需求概述
  1. 使用注解添加到类上,则此类的所有方法都实现限流。
  2. 需要能精细化到一定时间窗口内,限流多少次,超过了次数则给予提示。
  3. 此限流针对相同用户且相同ip地址情况下,访问某个接口是达到限流的效果,因为不同用户,不在一个限流维度里。
  4. 相同用户,不同ip地址也不再一个相同限流维度里,因为可能出现一号多地登录,那么这个俩个人都对各自的请求限流,互不相干。
  5. 也可切换不针对用户,只针对某一个接口方法进行限流。
二、实现思路

利用Spring Aop 做前置增强,每次请求前,通过获取 Redis 中已经关于请求次数的数据进行对比,从而判断是否请求频繁。

三、实现代码
  1. 创建spring 项目,添加依赖:


    4.0.0
    
        org.springframework.boot
        spring-boot-starter-parent
        2.6.7
         
    
    org.javaboy
    rate_limiter
    0.0.1-SNAPSHOT
    rate_limiter
    Demo project for Spring Boot
    
       1.8
    
    
        
            org.springframework.boot
            spring-boot-starter-data-redis
        
        
            org.springframework.boot
            spring-boot-starter-web
        

        
            org.springframework.boot
            spring-boot-starter-aop
        

        
            org.springframework.boot
            spring-boot-starter-test
            test
        
    

    
        
            
                org.springframework.boot
                spring-boot-maven-plugin
            
        
    

  1. 添加 Redis 配置信息
spring.redis.host=localhost
spring.redis.port=6379
  1. 自定义限流注解
package org.javaboy.rate_limiter.annotation;

import org.javaboy.rate_limiter.enums.LimitType;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;


@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
    
    String key() default "rate_limit:";

    
    int time() default 60;

    
    int count() default 100;

    
    LimitType limitType() default LimitType.DEFAULT;

}
package org.javaboy.rate_limiter.enums;



public enum LimitType {

    
    DEFAULT,
    
    IP
}
  1. 对于redis 的相关操作,我们利用 lua脚本
local key = KEYS[1]    --redis key
local time = tonumber(ARGV[1])		-- 时间窗
local count = tonumber(ARGV[2])		-- 时间窗内限定的次数
local current = redis.call('get', key)		
if current and tonumber(current) > count then
    return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
    redis.call('expire', key, time)
end
return tonumber(current)

这里为什么要用 lua 脚本进行redis 操作呢?

虽然redis服务是单线程的服务,单步的redis操作是线程安全的,但是当我们在高并发的情况下,需要一系列的redis逻辑操作,而这些操作需要保证线程安全和原子性。这时候就需要Lua登场。
Lua 为静态语言提供更多的灵活性,Lua体积小、启动速度快。 Redis Lua 脚本出现之前 Redis 是没有服务器端运算能力的,主要是用来存储,用做缓存,运算是在客户端进行。有了 Lua 的支持,客户端可以定义对键值的运算,减少编译的次数,总之。可以让 Redis 更为灵活。redis 甚至在源代码中加入了Lua脚本的解释器,eval。

该脚本里面的遇到的 Lua 语法介绍:

  • call() 的参数就是发给Redis的命令:首先set key value ,然后 get key,这两个命令将依次执行,当这个脚本执行时,Redis服务不会做任何操作(单线程),它将非常快速运行。
  • 我们将会访问两个Lua表:KEYSARGV。表单是关联性数组和结构化数据的Lua唯一机制。对于我们的意图,你可以把它们看做是一个你所熟悉的任意语言对等的数组,但是提醒两个很容易困扰到新手的两个Lua定则:
    表是基于1的,也就是说索引以数值1开始。所以在表中的第一个元素就是KEYS[1],第二个就是KEY[2]等等。
    表中不能有nil值。如果一个操作表中有[1, nil, 3, 4],那么结果将会是[1]——表将会在第一个nil截断。
    当调用这个脚本时,我们还需要传递KEYS和ARGV表的值,为Redis编写Lua脚本时,每个KEY都是通过KEYS表指定。ARGV表用来传递参数。
  1. 添加 Redis 序列化配置以及 redis 读取该lua脚本配置
package org.javaboy.rate_limiter.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;


@Configuration
public class RedisConfig {

    @Bean
    RedisTemplate redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate template = new RedisTemplate<>();
        template.setConnectionFactory(redisConnectionFactory);
        Jackson2JsonRedisSerializer serializer = new Jackson2JsonRedisSerializer<>(Object.class);
        template.setKeySerializer(serializer);
        template.setHashKeySerializer(serializer);
        template.setValueSerializer(serializer);
        template.setHashValueSerializer(serializer);
        return template;
    }

    @Bean
    DefaultRedisScript limitScript() {
        DefaultRedisScript script = new DefaultRedisScript<>();
        script.setResultType(Long.class);
        script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
        return script;
    }
}

 
  1. 自定义限流异常
package org.javaboy.rate_limiter.exception;

public class RateLimitException extends Exception {
    public RateLimitException(String message) {
        super(message);
    }
}
package org.javaboy.rate_limiter.exception;

import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;

import java.util.HashMap;
import java.util.Map;


@RestControllerAdvice
public class GlobalException {
    @ExceptionHandler(RateLimitException.class)
    public Map rateLimitException(RateLimitException e) {
        Map map = new HashMap<>();
        map.put("status", 500);
        map.put("message", e.getMessage());
        return map;
    }
}

7.添加 ip 获取工具

package org.javaboy.rate_limiter.utils;

import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;


public class IpUtils {
    
    public static String getIpAddr(HttpServletRequest request) {
        if (request == null) {
            return "unknown";
        }
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("X-Forwarded-For");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("X-Real-IP");
        }

        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }

        return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : getMultistageReverseProxyIp(ip);
    }

    
    public static boolean internalIp(String ip) {
        byte[] addr = textToNumericFormatV4(ip);
        return internalIp(addr) || "127.0.0.1".equals(ip);
    }

    
    private static boolean internalIp(byte[] addr) {
        if (addr == null || addr.length < 2) {
            return true;
        }
        final byte b0 = addr[0];
        final byte b1 = addr[1];
        // 10.x.x.x/8
        final byte SECTION_1 = 0x0A;
        // 172.16.x.x/12
        final byte SECTION_2 = (byte) 0xAC;
        final byte SECTION_3 = (byte) 0x10;
        final byte SECTION_4 = (byte) 0x1F;
        // 192.168.x.x/16
        final byte SECTION_5 = (byte) 0xC0;
        final byte SECTION_6 = (byte) 0xA8;
        switch (b0) {
            case SECTION_1:
                return true;
            case SECTION_2:
                if (b1 >= SECTION_3 && b1 <= SECTION_4) {
                    return true;
                }
            case SECTION_5:
                switch (b1) {
                    case SECTION_6:
                        return true;
                }
            default:
                return false;
        }
    }

    
    public static byte[] textToNumericFormatV4(String text) {
        if (text.length() == 0) {
            return null;
        }

        byte[] bytes = new byte[4];
        String[] elements = text.split("\.", -1);
        try {
            long l;
            int i;
            switch (elements.length) {
                case 1:
                    l = Long.parseLong(elements[0]);
                    if ((l < 0L) || (l > 4294967295L)) {
                        return null;
                    }
                    bytes[0] = (byte) (int) (l >> 24 & 0xFF);
                    bytes[1] = (byte) (int) ((l & 0xFFFFFF) >> 16 & 0xFF);
                    bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
                    bytes[3] = (byte) (int) (l & 0xFF);
                    break;
                case 2:
                    l = Integer.parseInt(elements[0]);
                    if ((l < 0L) || (l > 255L)) {
                        return null;
                    }
                    bytes[0] = (byte) (int) (l & 0xFF);
                    l = Integer.parseInt(elements[1]);
                    if ((l < 0L) || (l > 16777215L)) {
                        return null;
                    }
                    bytes[1] = (byte) (int) (l >> 16 & 0xFF);
                    bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
                    bytes[3] = (byte) (int) (l & 0xFF);
                    break;
                case 3:
                    for (i = 0; i < 2; ++i) {
                        l = Integer.parseInt(elements[i]);
                        if ((l < 0L) || (l > 255L)) {
                            return null;
                        }
                        bytes[i] = (byte) (int) (l & 0xFF);
                    }
                    l = Integer.parseInt(elements[2]);
                    if ((l < 0L) || (l > 65535L)) {
                        return null;
                    }
                    bytes[2] = (byte) (int) (l >> 8 & 0xFF);
                    bytes[3] = (byte) (int) (l & 0xFF);
                    break;
                case 4:
                    for (i = 0; i < 4; ++i) {
                        l = Integer.parseInt(elements[i]);
                        if ((l < 0L) || (l > 255L)) {
                            return null;
                        }
                        bytes[i] = (byte) (int) (l & 0xFF);
                    }
                    break;
                default:
                    return null;
            }
        } catch (NumberFormatException e) {
            return null;
        }
        return bytes;
    }

    
    public static String getHostIp() {
        try {
            return InetAddress.getLocalHost().getHostAddress();
        } catch (UnknownHostException e) {
        }
        return "127.0.0.1";
    }

    
    public static String getHostName() {
        try {
            return InetAddress.getLocalHost().getHostName();
        } catch (UnknownHostException e) {
        }
        return "未知";
    }

    
    public static String getMultistageReverseProxyIp(String ip) {
        // 多级反向代理检测
        if (ip != null && ip.indexOf(",") > 0) {
            final String[] ips = ip.trim().split(",");
            for (String subIp : ips) {
                if (false == isUnknown(subIp)) {
                    ip = subIp;
                    break;
                }
            }
        }
        return ip;
    }

    
    public static boolean isUnknown(String checkString) {
        return checkString == null || checkString.length() == 0 || "unknown".equalsIgnoreCase(checkString);
    }
}

8.编写限流切面

package org.javaboy.rate_limiter.aspectj;

import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.javaboy.rate_limiter.annotation.RateLimiter;
import org.javaboy.rate_limiter.enums.LimitType;
import org.javaboy.rate_limiter.exception.RateLimitException;
import org.javaboy.rate_limiter.utils.IpUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import java.lang.reflect.Method;
import java.util.Collections;

@Aspect
@Component
public class RateLimiterAspect {

    private static final Logger logger = LoggerFactory.getLogger(RateLimiterAspect.class);

    @Autowired
    RedisTemplate redisTemplate;

    @Autowired
    RedisScript redisScript;


    @Before("@annotation(rateLimiter)")
    public void before(JoinPoint jp, RateLimiter rateLimiter) throws RateLimitException {
        int time = rateLimiter.time();
        int count = rateLimiter.count();
        String combineKey = getCombineKey(rateLimiter, jp);
        try {
            Long number = redisTemplate.execute(redisScript, Collections.singletonList(combineKey), time, count);
            if (number == null || number.intValue() > count) {
                //超过限流阈值
                logger.info("当前接口以达到最大限流次数");
                throw new RateLimitException("访问过于频繁,请稍后访问");
            }
            logger.info("一个时间窗内请求次数:{},当前请求次数:{},缓存的 key 为 {}", count, number, combineKey);
        } catch (Exception e) {
            throw e;
        }
    }

    
    private String getCombineKey(RateLimiter rateLimiter, JoinPoint jp) {
        StringBuffer key = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {
            key.append(IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest()))
                    .append("-");
        }
        MethodSignature signature = (MethodSignature) jp.getSignature();
        Method method = signature.getMethod();
        key.append(method.getDeclaringClass().getName())
                .append("-")
                .append(method.getName());
        return key.toString();
    }
}
  1. 测试
package org.javaboy.rate_limiter.controller;

import org.javaboy.rate_limiter.annotation.RateLimiter;
import org.javaboy.rate_limiter.enums.LimitType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;


@RestController
public class HelloController {
    @GetMapping("/hello")
    
    @RateLimiter(time = 10, count = 3,limitType = LimitType.IP)
    public String hello() {
        return "hello";
    }
}

当 10 秒内访问超过 3 次之后,抛出异常:

转载请注明:文章转载自 www.mshxw.com
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号