为了防止恶意占用网络连接资源,需要在websockt连接加入拦截器,但是在查找了大量网络资源后,根据注解@ServerEndpoint进行websocket连接的方式进行拦截我没有找到,其中有一篇博文是在@ServerEndPoint中加入自定义的配置器。
附:文章出处
去实现ServerEndpointConfig.Configurator内部类中的modifyHandShake方法进行拦截,我尝试了一下后,没有第二种方式简单,而且第二种方式具有通用性,较第一种方式要好一点,可以根据自己的情况进行选择。
二、第二种方式实现Websocket集群及通信 集群只需要加入SpringCloud依赖加入注册中心,再使用网关进行同一转发、负载均衡即可搭建集群,同上一篇博文一致,在此篇不做展示。效果展示:
如果说第一种方式是一个websocket连接一个解决方案,那么第二种方式就是websocket集体注册,共享解决方案。具体代码如下
1、目录结构2、pom依赖
3、代码根据结构从上至下 ①SpringWebSocketConfigspring-boot-starter-parent org.springframework.boot 2.1.6.RELEASE 8 8 org.springframework spring-contextorg.springframework spring-beansorg.springframework.boot spring-boot-starter-weborg.springframework spring-aspectsorg.springframework.boot spring-boot-starter-redis1.4.1.RELEASE org.springframework.cloud spring-cloud-starter2.1.0.RELEASE org.springframework.boot spring-boot-starter-websocketorg.projectlombok lombokcom.alibaba fastjson1.2.9 org.springframework.boot spring-boot-starter-amqp
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.TextWebSocketHandler;
@Configuration
@EnableWebMvc
@EnableWebSocket
public class SpringWebSocketConfig extends WebMvcConfigurerAdapter implements WebSocketConfigurer {
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
//.setAllowedOrigins("*") 允许跨域访问
registry.addHandler(webSocketHandler(),"/webSocket").addInterceptors(new SpringWebSocketHandlerInterceptor()).setAllowedOrigins("*");
registry.addHandler(webSocketHandler(), "/sockjs/socketServer.do").addInterceptors(new SpringWebSocketHandlerInterceptor()).setAllowedOrigins("*");
}
@Bean
public TextWebSocketHandler webSocketHandler(){
return new SpringWebSocketHandler();
}
}
②Controller层没有实际意义,不写也可以
③SpringWebSocketHandler
import com.alibaba.fastjson.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
@Component
public class SpringWebSocketHandler extends TextWebSocketHandler {
@Autowired
FanoutSender fanoutSender;
private static final AtomicInteger ati = new AtomicInteger();
public static final ConcurrentHashMap map = new ConcurrentHashMap<>();
private static Logger logger = LoggerFactory.getLogger(SpringWebSocketHandler.class);
public SpringWebSocketHandler() {
// TODO Auto-generated constructor stub
}
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// TODO Auto-generated method stub
Object sid = session.getAttributes().get("tdt_sid");
map.put(sid.toString(),session);
int num = ati.incrementAndGet();
logger.info("connect to the websocket success......当前数量:{}",num);
//这块会实现自己业务,比如,当用户登录后,会把离线消息推送给用户
TextMessage returnMessage = new TextMessage("连接成功");
session.sendMessage(returnMessage);
}
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
logger.debug("websocket connection closed......");
String sid= (String) session.getAttributes().get("tdt_sid");
logger.info("用户"+sid+"已退出!");
if(map.containsKey(sid)){
map.remove(sid);
}
int num = ati.decrementAndGet();
logger.info("剩余在线用户:{}"+num);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
WebSocketMessage message2self = new TextMessage("该用户没有上线".getBytes());
String payload = message.getPayload();
WebSocketVO webSocketVO = JSONObject.parseObject(payload, WebSocketVO.class);
String toSid = webSocketVO.getToUserId();
if(map.containsKey(toSid)){
map.get(toSid).sendMessage(message);
}else {
ConcurrentHashMap sendMap = new ConcurrentHashMap();
sendMap.put(toSid,message.getPayload());
logger.info("getPayLoad():{}",message.getPayload());
logger.info("messge:{}",message.toString());
logger.info("map数据:{}",sendMap.toString());
fanoutSender.sendMessage(sendMap.toString());
}
}
public void sendMessage(TextMessage message) throws Exception {
WebSocketMessage message2self = new TextMessage("该用户没有上线".getBytes());
String payload = message.getPayload();
WebSocketVO webSocketVO = JSONObject.parseObject(payload, WebSocketVO.class);
String toSid = webSocketVO.getToUserId();
if(map.containsKey(toSid)){
map.get(toSid).sendMessage(message);
}
}
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
if(session.isOpen()){session.close();}
logger.debug("websocket connection closed......");
Map attributes = session.getAttributes();
Object sid = attributes.get("sid");
map.remove(sid);
WebSocketMessage webSocketMessage = new TextMessage("发生异常错误".getBytes());
session.sendMessage(webSocketMessage);
}
public boolean supportsPartialMessages() {
return false;
}
public void sendMessageToUser(String sid, TextMessage message) {
Set> entries = map.entrySet();
for (Map.Entry entry : entries) {
if (entry.getValue().getAttributes().get("tdt_sid").equals(sid)) {
try {
if (entry.getValue().isOpen()) {
entry.getValue().sendMessage(message);
}
} catch (IOException e) {
e.printStackTrace();
}
break;
}
}
}
public void sendMessageToUser(String sid, String message) {
TextMessage messages = new TextMessage(message.getBytes());
this.sendMessageToUser(sid,messages);
}
public void sendMessageToUsers(TextMessage message) throws IOException {
Set> entries = map.entrySet();
for (Map.Entry entry : entries) {
WebSocketSession session = entry.getValue();
if(session.isOpen()){
session.sendMessage(message);
}
}
}
}
④SpringWebSocketHandlerInterceptor
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import javax.servlet.http.HttpSession;
import java.util.Map;
@Component
public class SpringWebSocketHandlerInterceptor extends HttpSessionHandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Map attributes) throws Exception {
//TODO 获取url传递的参数,通过attributes在Interceptor处理结束后传递给WebSocketHandler
//TODO WebSocketHandler可以通过WebSocketSession的getAttributes()方法获取参数
//设置session值这里根据自己需求设置即可
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
HttpSession session = servletRequest.getServletRequest().getSession(true);
String sid = servletRequest.getServletRequest().getParameter("userId");
String token = servletRequest.getServletRequest().getParameter("token");
if(!token.equals("EXAM_PERMISSION")){
return false;
}
if (session != null) {
String userName = (String) session.getAttribute("exam_sid");
if (userName == null) {
userName = sid;
}
attributes.put("exam_sid",userName);
return true;
}
}
return false;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Exception ex) {
// TODO Auto-generated method stub
super.afterHandshake(request, response, wsHandler, ex);
}
}
⑤RetryCache该类是用来缓存队列消息和消息重发的,属于消息安全
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
@Component
public class RetryCache {
private SendMessage sendMessage;
private boolean stop = false;
private Map map = new ConcurrentHashMap<>();
private AtomicInteger id = new AtomicInteger();
@Data
@AllArgsConstructor
@NoArgsConstructor
private static class MessageWithTime{
long time;
Object message;
}
public void sender(SendMessage sendMessage){
this.sendMessage = sendMessage;
startRetry();
}
public String generaterId(){
return ""+id.incrementAndGet();
}
public void add(String id,Object message){
map.put(id,new MessageWithTime(System.currentTimeMillis(),message));
}
public void del(String id){
map.remove(id);
}
//多线程发送消息
private void startRetry(){
new Thread(()->{
while(!stop){
try {
Thread.sleep(System.currentTimeMillis());
} catch (InterruptedException e) {
e.printStackTrace();
}
long now = System.currentTimeMillis();
for (String key : map.keySet()) {
MessageWithTime messageWithTime = map.get(key);
if(null != messageWithTime){
if(messageWithTime.getTime()+ 3 * Constant.VALID_TIME < now){
log.info("send message failed after 3 min " + messageWithTime);
del(key);
}else if (messageWithTime.getTime() + Constant.VALID_TIME< now) {
DetailResult detailRes = sendMessage.send(messageWithTime.getMessage());
if (detailRes.isSuccess()) {
del(key);
}
}
}
}
}
}).start();
}
}
⑥RabbitMQConfig
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.core.*;
import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.connection.CorrelationData;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
@Slf4j
public class RabbitMQConfig {
@Value("${spring.rabbitmq.addresses}")
public String addresses;
@Value("${spring.rabbitmq.port}")
public String port;
@Value("${spring.rabbitmq.username}")
private String username;
@Value("${spring.rabbitmq.password}")
private String password;
@Value("${spring.rabbitmq.virtual-host}")
private String virtualHost;
@Value("${spring.rabbitmq.publisher-/confirm/is}")
private boolean publisher/confirm/is;
@Value("${tdt.queue}")
public String queue;
@Value("${tdt.exchange}")
public String exchange;
@Autowired
RetryCache retryCache;
@Bean
public ConnectionFactory connectionFactory(){
CachingConnectionFactory connectionFactory = new CachingConnectionFactory();
connectionFactory.setHost("127.0.0.1");
connectionFactory.setPort(Integer.valueOf(port));
connectionFactory.setUsername(username);
connectionFactory.setPassword(password);
connectionFactory.setVirtualHost(virtualHost);
connectionFactory.setPublisherConfirms(publisher/confirm/is);
return connectionFactory;
}
@Bean
public Queue queueTdt(){
log.info("创建队列成功:{}",queue);
return new Queue(queue);
}
@Bean
public FanoutExchange fanoutExchangeTdt(){
log.info("创建交换机成功:{}",exchange);
return new FanoutExchange(exchange);
}
@Bean
public Binding bindingTdt(){
Binding bind = BindingBuilder.bind(queueTdt()).to(fanoutExchangeTdt());
log.info("交换机队列绑定成功");
return bind;
}
@Bean
public RabbitTemplate rabbitTemplate(){
RabbitTemplate rabbitTemplate = new RabbitTemplate(connectionFactory());
//TODO 失败通知
rabbitTemplate.setMandatory(true);
//TODO 失败回调
rabbitTemplate.setReturnCallback(returnCallback());
//TODO 发送方确认
rabbitTemplate.setConfirmCallback(confirmCallback());
return rabbitTemplate;
}
//===============发送方确认===============
public RabbitTemplate.ConfirmCallback confirmCallback(){
return new RabbitTemplate.ConfirmCallback(){
@Override
public void confirm(CorrelationData correlationData,
boolean ack, String cause) {
if (ack) {
log.info("发送者确认发送给mq成功");
} else {
//处理失败的消息
log.info("发送者发送给mq失败,考虑重发:"+cause);
}
}
};
}
//===============失败通知===============
public RabbitTemplate.ReturnCallback returnCallback(){
return new RabbitTemplate.ReturnCallback(){
@Override
public void returnedMessage(Message message,
int replyCode,
String replyText,
String exchange,
String routingKey) {
log.info("无法路由的消息,需要考虑另外处理。");
log.info("Returned replyText:"+replyText);
log.info("Returned exchange:"+exchange);
log.info("Returned routingKey:"+routingKey);
String msgJson = new String(message.getBody());
log.info("Returned Message:"+msgJson);
}
};
}
}
⑦Constant
public class Constant {
public static final long VALID_TIME = 3600l;
}
⑧ReceiveMessage和SendMessage
public interface ReceiveMessage {
DetailResult receive(Object obj);
}
public interface SendMessage {
DetailResult send(Object obj);
}
⑨FanoutReceiver
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.rabbitmq.client.Channel;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.amqp.rabbit.annotation.RabbitHandler;
import org.springframework.amqp.rabbit.annotation.RabbitListener;
import org.springframework.amqp.support.AmqpHeaders;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.Header;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
@Component
@Slf4j
public class FanoutReceiver {
private static Logger logger = LoggerFactory.getLogger(FanoutReceiver.class);
@Autowired
SpringWebSocketHandler handler;
@RabbitHandler
@RabbitListener(queues = "queue_mqsocket")//动态绑定
public void receiveMessage(String jsonObject, Channel channel, @Header(AmqpHeaders.DELIVERY_TAG) long tag) {
//返回字符串
try{
log.info("队列接收到消息:{}",jsonObject);
jsonObject = jsonObject.replace("=",":");
Map mapstr = JSONObject.parseObject(jsonObject, Map.class);
Set> entries = mapstr.entrySet();
for (Map.Entry entry : entries) {
Object obj = entry.getKey();
String sid = obj.toString();
Object objs = entry.getValue();
String message = objs.toString();
if(SpringWebSocketHandler.map.containsKey(sid)){
handler.sendMessageToUser(sid,message);
}
}
}catch (JSonException e){
e.printStackTrace();
return;
}
try {
channel.basicAck(tag,false);
} catch (IOException e) {
e.printStackTrace();
}
}
}
⑩DetailResult
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class DetailResult {
private boolean flag;
public Object message;
public boolean isSuccess() {
return flag == true;
}
}
⑩①FanoutSender
import cn.tdt.rabbitmq.cache.RetryCache;
import cn.tdt.rabbitmq.function.SendMessage;
import cn.tdt.rabbitmq.result.DetailResult;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class FanoutSender implements SendMessage {
RetryCache retryCache = new RetryCache();
@Autowired
RabbitTemplate rabbitTemplate;
//发送消息
public void sendMessage(Object obj) {
rabbitTemplate.setExchange("exchange_mqsocket");
log.info("【消息发送者】发送消息到fanout交换机"+ JSONObject.toJSonString(obj));
try{
send(obj);
}catch (RuntimeException ex){
ex.printStackTrace();
log.info("send failed"+ex);
try{
send(obj);
}catch (RuntimeException e){
e.printStackTrace();
log.info("retry send failed"+e);
}
}
}
//客户端发送消息前,先在本地进行缓存
@Override
public DetailResult send(Object message) {
try{
String id = retryCache.generaterId();
retryCache.add(id,message);
rabbitTemplate.convertAndSend("exchange_mqsocket","",message);
// rabbitTemplate.correlationConvertAndSend(message,new CorrelationData(id));
// rabbitTemplate.correlationConvertAndSend(message,new CorrelationData(id));
}catch (Exception e){
return new DetailResult(false,"");
}
return new DetailResult(true,"");
}
}
⑩②WebSocketVO
import lombok.Data;
@Data
public class WebSocketVO {
private String toUserId;
private String msgType;
private String msgInfo;
}
⑩③启动类
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class SSOApplicationRun {
public static void main(String[] args) {
SpringApplication.run(SSOApplicationRun.class,args);
}
}
前端代码
WebSocket
hello socket
【userId】:
【toUserId】:
【msgType】:
【msgInfo】:
【操作】:
【操作】:
遇到的问题:
在往消息队列中发送消息时,如果发送的消息是对象,会在接收消息时自动添加Properties字段属性,影响json转换,解决方案为把发送的消息对象转换为字符串对象或者json串。
待优化的问题:
在单点登录系统中,拦截websocket请求后,未登录用户进行重定向。



