栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

SpringBoot+Redis结合自定义注解实现接口防盗刷

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

SpringBoot+Redis结合自定义注解实现接口防盗刷

一为什么会有人刷接口?
我们知道,总有些人吃饱了没事干就喜欢到处搞破坏,就比如之前我得redis服务器没有设置密码就被某些孤儿攻击力了,因此现在接口安全一直是一个热门话题,比如有些黄牛党在12306网上抢票进行倒卖,还有些企业之间进行竞争去恶意攻击对方服务器,举个例子,比如某个短信接口被请求一次,会触发几分钱的运营商费用,可想而知,当某些懂点技术的狗写点脚本去疯狂冲击这个接口,那你的短信扣费就非常客观了。。。。

还有一些人去疯狂请求你的服务器,导致服务器不断生成JessionId等从而导致服务器内存溢出,因此宕机,所以就需要对一些接口做防止某一时间段内大量请求的操作,这个就是所谓的接口防盗刷

二。接口防盗刷思路
限制同一个ip的用户在限定的时间内,只能访问固定的次数
实现思路:使用redis在缓存中搞一个计数器,将该用户的ip+其它拼接组成redis的key,同一个用户访问的次数为Value,第一次将这个计数器置1后存入缓存,并给其设定有效期。每次点击后,取出这个值,计数器加一,如果超过限定次数,就抛出业务异常。

首先展示一下项目的代码结构,可自行根据该结构构建代码

三。代码实现
1.首先自己定义一个注解类,该注解作用在接口上,代表访问这个接口会有次数访问限制

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.METHOD}) // 标注该注解在方法上面有效
@Retention(RetentionPolicy.RUNTIME) // 在运行时有效
public @interface AccessLimit {
    
    int maxCount();// 最大访问次数
    int seconds();// 固定时间, 单位: s

}

2.ResponseCode枚举类,用于返回接口的一些提示信息

public enum ResponseCode {
    // 系统模块
    SUCCESS(0, "操作成功"),
    ERROR(1, "操作失败"),
    SERVER_ERROR(500, "服务器异常"),

    // 通用模块 1xxxx
    ILLEGAL_ARGUMENT(10000, "参数不合法"),
    ACCESS_LIMIT(10002, "请求太频繁, 请稍后再试"),
    REPETITIVE_OPERATION(10001, "请勿重复操作");

    ResponseCode(Integer code, String msg) {
        this.code = code;
        this.msg = msg;
    }

    private Integer code;
    private String msg;

    public Integer getCode() {
        return code;
    }
    public void setCode(Integer code) {
        this.code = code;
    }
    public String getMsg() {
        return msg;
    }
    public void setMsg(String msg) {
        this.msg = msg;
    }
}

ServerResponse类,同样用于返回一些接口调用返回提示信息

public class ServerResponse implements Serializable {
    private static final long serialVersionUID = 7498483649536881777L;

    private Integer status;

    private String msg;

    private Object data;

    public ServerResponse() {
    }

    public ServerResponse(Integer status, String msg, Object data) {
        this.status = status;
        this.msg = msg;
        this.data = data;
    }

    @JsonIgnore
    public boolean isSuccess() {
        return this.status == ResponseCode.SUCCESS.getCode();
    }
    public static ServerResponse success() {
        return new ServerResponse(ResponseCode.SUCCESS.getCode(), null, null);
    }
    public static ServerResponse success(String msg) {
        return new ServerResponse(ResponseCode.SUCCESS.getCode(), msg, null);
    }
    public static ServerResponse success(Object data) {
        return new ServerResponse(ResponseCode.SUCCESS.getCode(), null, data);
    }
    public static ServerResponse success(String msg, Object data) {
        return new ServerResponse(ResponseCode.SUCCESS.getCode(), msg, data);
    }
    public static ServerResponse error(String msg) {
        return new ServerResponse(ResponseCode.ERROR.getCode(), msg, null);
    }
    public static ServerResponse error(Object data) {
        return new ServerResponse(ResponseCode.ERROR.getCode(), null, data);
    }
    public static ServerResponse error(String msg, Object data) {
        return new ServerResponse(ResponseCode.ERROR.getCode(), msg, data);
    }
    public Integer getStatus() {
        return status;
    }
    public void setStatus(Integer status) {
        this.status = status;
    }
    public String getMsg() {
        return msg;
    }
    public void setMsg(String msg) {
        this.msg = msg;
    }
    public Object getData() {
        return data;
    }
    public void setData(Object data) {
        this.data = data;
    }
}

3.Jedis连接池配置

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;

@Configuration
public class JedisConfig {

    @Bean(name = "jedisPoolConfig")
    public JedisPoolConfig jedisPoolConfig(){
        JedisPoolConfig jedisPoolConfig = new JedisPoolConfig();
        jedisPoolConfig.setMaxTotal(500);
        jedisPoolConfig.setMaxIdle(200);
        jedisPoolConfig.setNumTestsPerEvictionRun(1024);
        jedisPoolConfig.setTimeBetweenEvictionRunsMillis(30000);
        jedisPoolConfig.setMinEvictableIdleTimeMillis(-1);
        jedisPoolConfig.setSoftMinEvictableIdleTimeMillis(10000);
        jedisPoolConfig.setMaxWaitMillis(1500);
        jedisPoolConfig.setTestOnBorrow(true);
        jedisPoolConfig.setTestWhileIdle(true);
        jedisPoolConfig.setTestOnReturn(false);
        jedisPoolConfig.setJmxEnabled(true);
        jedisPoolConfig.setBlockWhenExhausted(false);
        return jedisPoolConfig;
    }
    @Bean
    public JedisPool redisPool() {
        String host = "127.0.0.1";
        int port = 6379;
        return new JedisPool(jedisPoolConfig(), host, port);
    }
}

4.接口防刷拦截器配置

@Component
public class AccessLimitInterceptor implements HandlerInterceptor {

    private static final String ACCESS_LIMIT_PREFIX = "accessLimit:";

    @Autowired
    private JedisUtil jedisUtil;

    // 在Controller之前执行
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
        // 判断拦截的是否是方法类型
        if (!(handler instanceof HandlerMethod)) {//如果是HandlerMethod 类,强转,拿到注解
            return true;
        }
        // 将handler转换为  HandlerMethod 类型 ,为方便后续操作
        HandlerMethod handlerMethod = (HandlerMethod) handler;
        // 获取到拦截的方法的method对象
        Method method = handlerMethod.getMethod();
        // 获取到方法上的注解
        AccessLimit annotation = method.getAnnotation(AccessLimit.class);
        if (annotation != null) {
            check(annotation, request);
        }
        return true;
    }

    private void check(AccessLimit annotation, HttpServletRequest request) {
        获取方法上注解的参数
        int maxCount = annotation.maxCount();
        int seconds = annotation.seconds();

        StringBuilder sb = new StringBuilder();
        sb.append(ACCESS_LIMIT_PREFIX).append(IpUtil.getIpAddress(request)).append(request.getRequestURI());
        String key = sb.toString();

        Boolean exists = jedisUtil.exists(key);
        if (!exists) {//如果没有,说明没访问过,置1
            jedisUtil.set(key, String.valueOf(1), seconds);
        } else {
            // 获取到redis中对应key的值(说白了这里的value的值就是同一个第三方客户端访问了几次这个接口)
            int count = Integer.parseInt(jedisUtil.get(key));
            if (count < maxCount) {//设置 如果小于我们的防刷次数
                Long ttl = jedisUtil.ttl(key);
                if (ttl <= 0) {
                    // 说明key已经过期了
                    jedisUtil.set(key, String.valueOf(1), seconds);
                } else {//小于5 就+1
                    jedisUtil.set(key, String.valueOf(++count), ttl.intValue());
                }
            } else {//说明大于最大次数
                throw new ServiceException(ResponseCode.ACCESS_LIMIT.getMsg());
            }
        }
    }

    // 在controller之后执行
    @Override
    public void postHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Object o, ModelAndView modelAndView) throws Exception {
    }

    // 在模板引擎之后执行
    @Override
    public void afterCompletion(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Object o, Exception e) throws Exception {
    }
}

将该拦截器进行注册

import com.cd.interceptor.AccessLimitInterceptor;
import com.cd.interceptor.ApiIdempotentInterceptor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.UrlbasedCorsConfigurationSource;
import org.springframework.web.filter.CorsFilter;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;


@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
    
    @Bean
    public CorsFilter corsFilter() {
        final UrlbasedCorsConfigurationSource urlbasedCorsConfigurationSource = new UrlbasedCorsConfigurationSource();
        final CorsConfiguration corsConfiguration = new CorsConfiguration();
        corsConfiguration.setAllowCredentials(true);
        corsConfiguration.addAllowedOrigin("*");
        corsConfiguration.addAllowedHeader("*");
        corsConfiguration.addAllowedMethod("*");
        urlbasedCorsConfigurationSource.registerCorsConfiguration("
public class ServiceException extends RuntimeException{

    private String code;
    private String msg;
    public ServiceException() {
    }
    public ServiceException(String msg) {
        this.msg = msg;
    }
    public ServiceException(String code, String msg) {
        this.code = code;
        this.msg = msg;
    }
    public String getCode() {
        return code;
    }
    public void setCode(String code) {
        this.code = code;
    }
    public String getMsg() {
        return msg;
    }
    public void setMsg(String msg) {
        this.msg = msg;
    }
}

6.编写service层业务代码

import com.cd.common.ServerResponse;

import javax.servlet.http.HttpServletRequest;

public interface TokenService {
    // 接口盗刷测试
    ServerResponse accessLimit();
}

对应实现类

@Override
@Service
public class TokenServiceImpl implements TokenService {
     // 接口盗刷实现
    public ServerResponse accessLimit() {
        return ServerResponse.success("accessLimit: success");
    }
}

7.工具类

ip工具类

import javax.servlet.http.HttpServletRequest;

public class IpUtil {

    
    public static String getIpAddress(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");

        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }

        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }

        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }

        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }

        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }
}

Jedis工具类

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;

@Component
@Slf4j
public class JedisUtil {

    @Autowired
    private JedisPool jedisPool;

    // 获取到操作redis的客户端对象
    private Jedis getJedis() {
        return jedisPool.getResource();
    }

    
    public String set(String key, String value) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.set(key, value);
        } catch (Exception e) {
            log.error("set key:{} value:{} error", key, value, e);
            return null;
        } finally {
            close(jedis);
        }
    }

    
    public String set(String key, String value, int expireTime) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.setex(key, expireTime, value);
        } catch (Exception e) {
            log.error("set key:{} value:{} expireTime:{} error", key, value, expireTime, e);
            return null;
        } finally {
            close(jedis);
        }
    }

    
    public String get(String key) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.get(key);
        } catch (Exception e) {
            log.error("get key:{} error", key, e);
            return null;
        } finally {
            close(jedis);
        }
    }

    
    public Long del(String key) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.del(key.getBytes());
        } catch (Exception e) {
            log.error("del key:{} error", key, e);
            return null;
        } finally {
            close(jedis);
        }
    }

    
    public Boolean exists(String key) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.exists(key.getBytes());
        } catch (Exception e) {
            log.error("exists key:{} error", key, e);
            return null;
        } finally {
            close(jedis);
        }
    }

    
    public Long expire(String key, int expireTime) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.expire(key.getBytes(), expireTime);
        } catch (Exception e) {
            log.error("expire key:{} error", key, e);
            return null;
        } finally {
            close(jedis);
        }
    }

    
    public Long ttl(String key) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.ttl(key);
        } catch (Exception e) {
            log.error("ttl key:{} error", key, e);
            return null;
        } finally {
            close(jedis);
        }
    }

    private void close(Jedis jedis) {
        if (null != jedis) {
            jedis.close();
        }
    }

}

时间工具类

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;

import java.util.Date;

@Slf4j
public class JodaTimeUtil {

    private static final String STANDARD_FORMAT = "yyyy-MM-dd HH:mm:ss";

    
    public static String dateToStr(Date date) {
        return dateToStr(date, STANDARD_FORMAT);
    }

    
    public static String dateToStr(Date date, String format) {
        if (date == null) {
            return null;
        }
        format = StringUtils.isBlank(format) ? STANDARD_FORMAT : format;
        DateTime dateTime = new DateTime(date);
        return dateTime.toString(format);
    }
    
    public static Date strToDate(String timeStr) {
        return strToDate(timeStr, STANDARD_FORMAT);
    }
    
    public static Date strToDate(String timeStr, String format) {
        if (StringUtils.isBlank(timeStr)) {
            return null;
        }
        format = StringUtils.isBlank(format) ? STANDARD_FORMAT : format;
        org.joda.time.format.DateTimeFormatter dateTimeFormatter = DateTimeFormat.forPattern(format);
        DateTime dateTime;
        try {
            dateTime = dateTimeFormatter.parseDateTime(timeStr);
        } catch (Exception e) {
            log.error("strToDate error: timeStr: {}", timeStr, e);
            return null;
        }
        return dateTime.toDate();
    }

    
    public static Boolean isTimeExpired(Date date) {
        String timeStr = dateToStr(date);
        return isBeforeNow(timeStr);
    }
    
    public static Boolean isTimeExpired(String timeStr) {
        if (StringUtils.isBlank(timeStr)) {
            return true;
        }
        return isBeforeNow(timeStr);
    }
    
    private static Boolean isBeforeNow(String timeStr) {
        DateTimeFormatter format = DateTimeFormat.forPattern(STANDARD_FORMAT);
        DateTime dateTime;
        try {
            dateTime = DateTime.parse(timeStr, format);
        } catch (Exception e) {
            log.error("isBeforeNow error: timeStr: {}", timeStr, e);
            return null;
        }
        return dateTime.isBeforeNow();
    }
    
    public static Date plusDays(Date date, int days) {
        return plusOrMinusDays(date, days, 0);
    }
    
    public static Date minusDays(Date date, int days) {
        return plusOrMinusDays(date, days, 1);
    }
    
    private static Date plusOrMinusDays(Date date, int days, Integer type) {
        if (null == date) {
            return null;
        }
        DateTime dateTime = new DateTime(date);
        if (type == 0) {
            dateTime = dateTime.plusDays(days);
        } else {
            dateTime = dateTime.minusDays(days);
        }
        return dateTime.toDate();
    }
    
    public static Date plusMinutes(Date date, int minutes) {
        return plusOrMinusMinutes(date, minutes, 0);
    }
    
    public static Date minusMinutes(Date date, int minutes) {
        return plusOrMinusMinutes(date, minutes, 1);
    }
    
    private static Date plusOrMinusMinutes(Date date, int minutes, Integer type) {
        if (null == date) {
            return null;
        }
        DateTime dateTime = new DateTime(date);
        if (type == 0) {
            dateTime = dateTime.plusMinutes(minutes);
        } else {
            dateTime = dateTime.minusMinutes(minutes);
        }
        return dateTime.toDate();
    }
    
    public static Date plusMonths(Date date, int months) {
        return plusOrMinusMonths(date, months, 0);
    }
    
    public static Date minusMonths(Date date, int months) {
        return plusOrMinusMonths(date, months, 1);
    }
    
    private static Date plusOrMinusMonths(Date date, int months, Integer type) {
        if (null == date) {
            return null;
        }
        DateTime dateTime = new DateTime(date);
        if (type == 0) {
            dateTime = dateTime.plusMonths(months);
        } else {
            dateTime = dateTime.minusMonths(months);
        }
        return dateTime.toDate();
    }
    
    public static Boolean isBetweenStartAndEndTime(Date target, Date startTime, Date endTime) {
        if (null == target || null == startTime || null == endTime) {
            return false;
        }
        DateTime dateTime = new DateTime(target);
        return dateTime.isAfter(startTime.getTime()) && dateTime.isBefore(endTime.getTime());
    }
}

随机数工具类

import java.util.Random;
import java.util.UUID;

public class RandomUtil {

    public static final String allChar = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";

    public static final String letterChar = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";

    public static final String numberChar = "0123456789";
    public static String UUID32() {
        String str = UUID.randomUUID().toString();
        return str.replaceAll("-", "");
    }
    public static String UUID36() {
        return UUID.randomUUID().toString();
    }
    
    public static String generateStr(int length) {
        StringBuffer sb = new StringBuffer();
        Random random = new Random();
        for (int i = 0; i < length; i++) {
            sb.append(allChar.charAt(random.nextInt(allChar.length())));
        }
        return sb.toString();
    }
    
    public static String generateDigitalStr(int length) {
        StringBuffer sb = new StringBuffer();
        Random random = new Random();
        for (int i = 0; i < length; i++) {
            sb.append(numberChar.charAt(random.nextInt(numberChar.length())));
        }
        return sb.toString();
    }
    
    public static String generateLetterStr(int length) {
        StringBuffer sb = new StringBuffer();
        Random random = new Random();
        for (int i = 0; i < length; i++) {
            sb.append(letterChar.charAt(random.nextInt(letterChar.length())));
        }
        return sb.toString();
    }
    
    public static String generateLowerStr(int length) {
        return generateLetterStr(length).toLowerCase();
    }
    
    public static String generateUpperStr(int length) {
        return generateLetterStr(length).toUpperCase();
    }
    
    public static String generateZeroStr(int length) {
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < length; i++) {
            sb.append('0');
        }
        return sb.toString();
    }

    
    public static String generateStrWithZero(int num, int strLength) {
        StringBuffer sb = new StringBuffer();
        String strNum = String.valueOf(num);
        if (strLength - strNum.length() >= 0) {
            sb.append(generateZeroStr(strLength - strNum.length()));
        } else {
            throw new RuntimeException("将数字" + num + "转化为长度为" + strLength + "的字符串异常!");
        }
        sb.append(strNum);
        return sb.toString();
    }
}

7.最后controller层加上注解,代表访问这个接口对于同一个ip在规定时间内有访问次数限制

import com.cd.annotation.AccessLimit;
import com.cd.annotation.ApiIdempotent;
import com.cd.common.ServerResponse;
import com.cd.service.TokenService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/token")
public class TokenController {

  @Autowired
  private TokenService tokenService;

 
   
    @AccessLimit(maxCount = 5, seconds = 5)
    @PostMapping("accessLimit")
    public ServerResponse accessLimit() {
        return tokenService.accessLimit();
    }

至此,一个防盗刷功能就完成了,当然有些公司可能会用一些更加复杂的方案来实现,我这里可能算是比较简单的一种实现了

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/605667.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号