栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

实现Websocket集群及通信的第二种方式(含拦截器)

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

实现Websocket集群及通信的第二种方式(含拦截器)

一、第一种方式的缺点

        为了防止恶意占用网络连接资源,需要在websockt连接加入拦截器,但是在查找了大量网络资源后,根据注解@ServerEndpoint进行websocket连接的方式进行拦截我没有找到,其中有一篇博文是在@ServerEndPoint中加入自定义的配置器。

附:文章出处

去实现ServerEndpointConfig.Configurator内部类中的modifyHandShake方法进行拦截,我尝试了一下后,没有第二种方式简单,而且第二种方式具有通用性,较第一种方式要好一点,可以根据自己的情况进行选择。

二、第二种方式实现Websocket集群及通信         集群只需要加入SpringCloud依赖加入注册中心,再使用网关进行同一转发、负载均衡即可搭建集群,同上一篇博文一致,在此篇不做展示。        

效果展示:

        如果说第一种方式是一个websocket连接一个解决方案,那么第二种方式就是websocket集体注册,共享解决方案。具体代码如下

1、目录结构

  

2、pom依赖
    
        spring-boot-starter-parent
        org.springframework.boot
        2.1.6.RELEASE
    

    
        8
        8
    
    
        
            org.springframework
            spring-context
        
        
            org.springframework
            spring-beans
        
        
            org.springframework.boot
            spring-boot-starter-web
        
        
            org.springframework
            spring-aspects
        
        
            org.springframework.boot
            spring-boot-starter-redis
            1.4.1.RELEASE
        
        
            org.springframework.cloud
            spring-cloud-starter
            2.1.0.RELEASE
        
        
            org.springframework.boot
            spring-boot-starter-websocket
        
        
            org.projectlombok
            lombok
        
        
            com.alibaba
            fastjson
            1.2.9
        
        
            org.springframework.boot
            spring-boot-starter-amqp
        
    
3、代码根据结构从上至下 ①SpringWebSocketConfig
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.TextWebSocketHandler;


@Configuration
@EnableWebMvc
@EnableWebSocket
public class SpringWebSocketConfig extends WebMvcConfigurerAdapter implements WebSocketConfigurer {
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        //.setAllowedOrigins("*") 允许跨域访问
        registry.addHandler(webSocketHandler(),"/webSocket").addInterceptors(new SpringWebSocketHandlerInterceptor()).setAllowedOrigins("*");
        registry.addHandler(webSocketHandler(), "/sockjs/socketServer.do").addInterceptors(new SpringWebSocketHandlerInterceptor()).setAllowedOrigins("*");
    }

    @Bean
    public TextWebSocketHandler webSocketHandler(){
        return new SpringWebSocketHandler();
    }

}
②Controller层没有实际意义,不写也可以 ③SpringWebSocketHandler
import com.alibaba.fastjson.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;


@Component
public class SpringWebSocketHandler extends TextWebSocketHandler {
    @Autowired
    FanoutSender fanoutSender;

    private static final AtomicInteger ati = new AtomicInteger();

    public static final ConcurrentHashMap map = new ConcurrentHashMap<>();

    private static Logger logger = LoggerFactory.getLogger(SpringWebSocketHandler.class);

    public SpringWebSocketHandler() {
        // TODO Auto-generated constructor stub
    }

    
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        // TODO Auto-generated method stub
        Object sid = session.getAttributes().get("tdt_sid");
        map.put(sid.toString(),session);
        int num = ati.incrementAndGet();
        logger.info("connect to the websocket success......当前数量:{}",num);
        //这块会实现自己业务,比如,当用户登录后,会把离线消息推送给用户
        TextMessage returnMessage = new TextMessage("连接成功");
        session.sendMessage(returnMessage);
    }

    
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        logger.debug("websocket connection closed......");
        String sid= (String) session.getAttributes().get("tdt_sid");
        logger.info("用户"+sid+"已退出!");
        if(map.containsKey(sid)){
            map.remove(sid);
        }
        int num = ati.decrementAndGet();
        logger.info("剩余在线用户:{}"+num);
    }

    
    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        WebSocketMessage message2self = new TextMessage("该用户没有上线".getBytes());
        String payload = message.getPayload();
        WebSocketVO webSocketVO = JSONObject.parseObject(payload, WebSocketVO.class);
        String toSid = webSocketVO.getToUserId();
        if(map.containsKey(toSid)){
            map.get(toSid).sendMessage(message);
        }else {
            ConcurrentHashMap sendMap = new ConcurrentHashMap();
            sendMap.put(toSid,message.getPayload());
            logger.info("getPayLoad():{}",message.getPayload());
            logger.info("messge:{}",message.toString());
            logger.info("map数据:{}",sendMap.toString());
            fanoutSender.sendMessage(sendMap.toString());
        }
    }

    public void sendMessage(TextMessage message) throws Exception {
        WebSocketMessage message2self = new TextMessage("该用户没有上线".getBytes());
        String payload = message.getPayload();
        WebSocketVO webSocketVO = JSONObject.parseObject(payload, WebSocketVO.class);
        String toSid = webSocketVO.getToUserId();
        if(map.containsKey(toSid)){
            map.get(toSid).sendMessage(message);
        }
    }

    
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        if(session.isOpen()){session.close();}
        logger.debug("websocket connection closed......");
        Map attributes = session.getAttributes();
        Object sid = attributes.get("sid");
        map.remove(sid);
        WebSocketMessage webSocketMessage = new TextMessage("发生异常错误".getBytes());
        session.sendMessage(webSocketMessage);
    }

    public boolean supportsPartialMessages() {
        return false;
    }

    
    public void sendMessageToUser(String sid, TextMessage message) {
        Set> entries = map.entrySet();
        for (Map.Entry entry : entries) {
            if (entry.getValue().getAttributes().get("tdt_sid").equals(sid)) {
                try {
                    if (entry.getValue().isOpen()) {
                        entry.getValue().sendMessage(message);
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }
                break;
            }
        }
    }

    public void sendMessageToUser(String sid, String message) {
        TextMessage messages = new TextMessage(message.getBytes());
        this.sendMessageToUser(sid,messages);
    }

    
    public void sendMessageToUsers(TextMessage message) throws IOException {
        Set> entries = map.entrySet();
        for (Map.Entry entry : entries) {
            WebSocketSession session = entry.getValue();
            if(session.isOpen()){
                session.sendMessage(message);
            }
        }
    }
}
④SpringWebSocketHandlerInterceptor
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;

import javax.servlet.http.HttpSession;
import java.util.Map;


@Component
public class SpringWebSocketHandlerInterceptor extends HttpSessionHandshakeInterceptor {

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
                                   Map attributes) throws Exception {
        //TODO 获取url传递的参数,通过attributes在Interceptor处理结束后传递给WebSocketHandler
        //TODO WebSocketHandler可以通过WebSocketSession的getAttributes()方法获取参数
        //设置session值这里根据自己需求设置即可
        if (request instanceof ServletServerHttpRequest) {
            ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
            HttpSession session = servletRequest.getServletRequest().getSession(true);
            String sid = servletRequest.getServletRequest().getParameter("userId");
            String token = servletRequest.getServletRequest().getParameter("token");
            if(!token.equals("EXAM_PERMISSION")){
                return false;
            }
            if (session != null) {
                String userName = (String) session.getAttribute("exam_sid");
                if (userName == null) {
                    userName = sid;
                }
                attributes.put("exam_sid",userName);
                return true;
            }
        }
        return false;
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
                               Exception ex) {
        // TODO Auto-generated method stub
        super.afterHandshake(request, response, wsHandler, ex);
    }
}
⑤RetryCache该类是用来缓存队列消息和消息重发的,属于消息安全
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;


@Slf4j
@Component
public class RetryCache {
    private SendMessage sendMessage;

    private boolean stop = false;

    private Map map = new ConcurrentHashMap<>();

    private AtomicInteger id = new AtomicInteger();

    @Data
    @AllArgsConstructor
    @NoArgsConstructor
    private static class MessageWithTime{
        long time;
        Object message;
    }

    public void sender(SendMessage sendMessage){
        this.sendMessage = sendMessage;
        startRetry();
    }

    public String generaterId(){
        return ""+id.incrementAndGet();
    }

    public void add(String id,Object message){
        map.put(id,new MessageWithTime(System.currentTimeMillis(),message));
    }

    public void del(String id){
        map.remove(id);
    }

    //多线程发送消息
    private void startRetry(){
        new Thread(()->{
            while(!stop){
                try {
                    Thread.sleep(System.currentTimeMillis());
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                long now = System.currentTimeMillis();
                for (String key : map.keySet()) {
                    MessageWithTime messageWithTime = map.get(key);
                    if(null != messageWithTime){
                        if(messageWithTime.getTime()+ 3 * Constant.VALID_TIME < now){
                            log.info("send message failed after 3 min " + messageWithTime);
                            del(key);
                        }else if (messageWithTime.getTime() + Constant.VALID_TIME< now) {
                            DetailResult detailRes = sendMessage.send(messageWithTime.getMessage());
                            if (detailRes.isSuccess()) {
                                del(key);
                            }
                        }
                    }
                }
            }
        }).start();
    }
}
⑥RabbitMQConfig
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.core.*;
import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.connection.CorrelationData;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;


@Configuration
@Slf4j
public class RabbitMQConfig {
    @Value("${spring.rabbitmq.addresses}")
    public String addresses;

    @Value("${spring.rabbitmq.port}")
    public String port;

    @Value("${spring.rabbitmq.username}")
    private String username;

    @Value("${spring.rabbitmq.password}")
    private String password;

    @Value("${spring.rabbitmq.virtual-host}")
    private String virtualHost;

    @Value("${spring.rabbitmq.publisher-/confirm/is}")
    private boolean publisher/confirm/is;

    @Value("${tdt.queue}")
    public String queue;

    @Value("${tdt.exchange}")
    public String exchange;

    @Autowired
    RetryCache retryCache;

    
    @Bean
    public ConnectionFactory connectionFactory(){
        CachingConnectionFactory connectionFactory = new CachingConnectionFactory();
        connectionFactory.setHost("127.0.0.1");
        connectionFactory.setPort(Integer.valueOf(port));
        connectionFactory.setUsername(username);
        connectionFactory.setPassword(password);
        connectionFactory.setVirtualHost(virtualHost);
        connectionFactory.setPublisherConfirms(publisher/confirm/is);
        return connectionFactory;
    }

    @Bean
    public Queue queueTdt(){
        log.info("创建队列成功:{}",queue);
        return new Queue(queue);
    }

    @Bean
    public FanoutExchange fanoutExchangeTdt(){
        log.info("创建交换机成功:{}",exchange);
        return new FanoutExchange(exchange);
    }

    @Bean
    public Binding bindingTdt(){
        Binding bind = BindingBuilder.bind(queueTdt()).to(fanoutExchangeTdt());
        log.info("交换机队列绑定成功");
        return bind;
    }

    @Bean
    public RabbitTemplate rabbitTemplate(){
        RabbitTemplate rabbitTemplate = new RabbitTemplate(connectionFactory());
        //TODO 失败通知
        rabbitTemplate.setMandatory(true);
        //TODO 失败回调
        rabbitTemplate.setReturnCallback(returnCallback());
        //TODO 发送方确认
        rabbitTemplate.setConfirmCallback(confirmCallback());
        return rabbitTemplate;
    }

    //===============发送方确认===============
    public RabbitTemplate.ConfirmCallback confirmCallback(){
        return new RabbitTemplate.ConfirmCallback(){
            @Override
            public void confirm(CorrelationData correlationData,
                                boolean ack, String cause) {
                if (ack) {
                    log.info("发送者确认发送给mq成功");
                } else {
                    //处理失败的消息
                    log.info("发送者发送给mq失败,考虑重发:"+cause);
                }
            }
        };
    }

    //===============失败通知===============
    public RabbitTemplate.ReturnCallback returnCallback(){
        return new RabbitTemplate.ReturnCallback(){
            @Override
            public void returnedMessage(Message message,
                                        int replyCode,
                                        String replyText,
                                        String exchange,
                                        String routingKey) {
                log.info("无法路由的消息,需要考虑另外处理。");
                log.info("Returned replyText:"+replyText);
                log.info("Returned exchange:"+exchange);
                log.info("Returned routingKey:"+routingKey);
                String msgJson = new String(message.getBody());
                log.info("Returned Message:"+msgJson);
            }
        };
    }
}
⑦Constant
public class Constant {
    public static final long VALID_TIME = 3600l;
}
⑧ReceiveMessage和SendMessage
public interface ReceiveMessage {
    DetailResult receive(Object obj);
}
public interface SendMessage {
    DetailResult send(Object obj);
}
⑨FanoutReceiver
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.rabbitmq.client.Channel;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.amqp.rabbit.annotation.RabbitHandler;
import org.springframework.amqp.rabbit.annotation.RabbitListener;
import org.springframework.amqp.support.AmqpHeaders;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.Header;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.Map;
import java.util.Set;


@Component
@Slf4j
public class FanoutReceiver {
    private static Logger logger = LoggerFactory.getLogger(FanoutReceiver.class);
    @Autowired
    SpringWebSocketHandler handler;

    @RabbitHandler
    @RabbitListener(queues = "queue_mqsocket")//动态绑定
    public void receiveMessage(String jsonObject, Channel channel, @Header(AmqpHeaders.DELIVERY_TAG) long tag) {
        //返回字符串
        try{
            log.info("队列接收到消息:{}",jsonObject);
            jsonObject = jsonObject.replace("=",":");
            Map mapstr = JSONObject.parseObject(jsonObject, Map.class);
            Set> entries = mapstr.entrySet();
            for (Map.Entry entry : entries) {
                Object obj = entry.getKey();
                String sid = obj.toString();

                Object objs = entry.getValue();
                String message = objs.toString();

                if(SpringWebSocketHandler.map.containsKey(sid)){
                    handler.sendMessageToUser(sid,message);
                }
            }
        }catch (JSonException e){
            e.printStackTrace();
            return;
        }
        try {
            channel.basicAck(tag,false);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
⑩DetailResult
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;


@Data
@AllArgsConstructor
@NoArgsConstructor
public class DetailResult {
    private boolean flag;
    public Object message;

    public boolean isSuccess() {
        return flag == true;
    }
}
⑩①FanoutSender
import cn.tdt.rabbitmq.cache.RetryCache;
import cn.tdt.rabbitmq.function.SendMessage;
import cn.tdt.rabbitmq.result.DetailResult;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;


@Component
@Slf4j
public class FanoutSender implements SendMessage {
    RetryCache retryCache = new RetryCache();
    @Autowired
    RabbitTemplate rabbitTemplate;

    //发送消息
    public void sendMessage(Object obj) {
        rabbitTemplate.setExchange("exchange_mqsocket");
        log.info("【消息发送者】发送消息到fanout交换机"+ JSONObject.toJSonString(obj));
        try{
            send(obj);
        }catch (RuntimeException ex){
            ex.printStackTrace();
            log.info("send failed"+ex);
            try{
                send(obj);
            }catch (RuntimeException e){
                e.printStackTrace();
                log.info("retry send failed"+e);
            }
        }
    }

    //客户端发送消息前,先在本地进行缓存
    @Override
    public DetailResult send(Object message) {
        try{
            String id = retryCache.generaterId();
            retryCache.add(id,message);
            rabbitTemplate.convertAndSend("exchange_mqsocket","",message);
//            rabbitTemplate.correlationConvertAndSend(message,new CorrelationData(id));
//            rabbitTemplate.correlationConvertAndSend(message,new CorrelationData(id));
        }catch (Exception e){
            return new DetailResult(false,"");
        }
        return new DetailResult(true,"");
    }
}
⑩②WebSocketVO
import lombok.Data;


@Data
public class WebSocketVO {
    private String toUserId;
    private String msgType;
    private String msgInfo;
}
⑩③启动类
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;


@SpringBootApplication
public class SSOApplicationRun {
    public static void main(String[] args) {
        SpringApplication.run(SSOApplicationRun.class,args);
    }
}

前端代码




    
    WebSocket



hello socket

【userId】:

【toUserId】:

【msgType】:

【msgInfo】:

【操作】:

【操作】:

遇到的问题:

        在往消息队列中发送消息时,如果发送的消息是对象,会在接收消息时自动添加Properties字段属性,影响json转换,解决方案为把发送的消息对象转换为字符串对象或者json串。

待优化的问题:

        在单点登录系统中,拦截websocket请求后,未登录用户进行重定向。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/531630.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号