分布式websocket推送
場景
項目中用到websocket推送消息,后台是分布式部署的,需要通過websocket講預警消息推送給前台。直接添加websocket后出現了一個問題,假設兩台服務S1、S2,客戶端C和后端服務建立鏈接的時候經過負載均衡給了S1,如果S1后台收到了預警消息此時可以直接推送給客戶端C,但是加入服務端S2后台收到了預警消息也要推送給客戶端,但是此時S2並沒有和客戶端C建立連接,此時該消息就會丟失而無法推送給客戶端。
解決方案
使用MQ解耦消息和websocket服務端,假設收到了預警消息不是直接推送到客戶端,而是發送到MQ,然后再websocket服務端通過監聽/拉去MQ中的消息進行判斷和推送。當然消息體的格式需要設計符合你的業務的結構。
實現
既然要使用MQ,我們該如何選型呢,其實市面上常見的MQ都是夠用了,比如RocketMQ、ActiveMQ、RabbitMQ等,Kafka(不過有點兒大才小用了)。因為我們這個業務的關系,不希望引入新的組件,項目中剛好用到了Redis,決定用Redis的訂閱發布功能解決。
代碼
websocket
配置類
EndpointConfig
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import javax.websocket.server.ServerEndpointConfig;
public class EndpointConfig extends ServerEndpointConfig.Configurator implements ApplicationContextAware {
private static volatile BeanFactory context;
@Override
public <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
return context.getBean(clazz);
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
EndpointConfig.context = applicationContext;
}
}
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
`WebSocketConfig`類
//@EnableWebSocket // 可以不用該注解
@Configuration
public class WebSocketConfig02 {
@Bean
public ServerEndpointExporter serverEndpointConfig() {
return new ServerEndpointExporter();
}
@Bean
public EndpointConfig newConfig() {
return new EndpointConfig();
}
}
websocket請求類
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.bart.websocket.configuration.EndpointConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 1) value = "/ws/{userId}"
* onOpen(@PathParam("userId") String userId, Session session){ // ... }
* 這種方式必須在前端在/后面拼接參數 ws://localhost:7889/productWebSocket/123 ,否則404
*
* 2) value = "/ws"
* onOpen(Session session){ // ... }
* Map<String, List<String>> requestParameterMap = session.getRequestParameterMap();
* // 獲得 ?userId=123 這樣的參數
* @author bart
*/
@Component
@ServerEndpoint(
value = "/ws/{userId}",
configurator = EndpointConfig.class
,encoders = { ProductWebSocket.MessageEncoder.class } // 添加消息編碼器
)
public class ProductWebSocket {
final static Logger log = LoggerFactory.getLogger(ProductWebSocket.class);
//當前在線用戶
private static final AtomicInteger onlineCount = new AtomicInteger(0);
// 當前登錄用戶的id和websocket session的map
private static ConcurrentHashMap<Session, String> userIdSessionMap = new ConcurrentHashMap<>();
private Session session;
private String userId;
/**
* 連接開啟時調用
*
* @param userId
* @param session
*/
@OnOpen
public void onOpen(@PathParam("userId") String userId, Session session) {
if (userId != null) {
log.info("websocket 新客戶端連入,用戶id:" + userId);
userIdSessionMap.put(session, userId);
addOnlineCount();
// 發送消息返回當前用戶
JSONObject jsonObject = new JSONObject();
jsonObject.put("code", 200);
jsonObject.put("message", "OK");
send(userId, JSON.toJSONString(jsonObject));
} else {
log.error("websocket連接 缺少參數 id");
throw new IllegalArgumentException("websocket連接 缺少參數 id");
}
}
/**
* 連接關閉時調用
*/
@OnClose
public void onClose(Session session) {
log.info("一個客戶端關閉連接");
subOnlineCount();
userIdSessionMap.remove(session);
}
/**
* 服務端接收到信息后調用
*
* @param message
* @param session
*/
@OnMessage
public void onMessage(String message, Session session) {
log.info("用戶發送過來的消息為:" + message);
}
/**
* 服務端websocket出錯時調用
*
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
log.error("websocket出現錯誤");
error.printStackTrace();
}
/**
* 服務端發送信息給客戶端
* @param id 用戶ID
* @param message 發送的消息
*/
public void send(String id, String message) {
log.info("#### 點對點消息,userId={}", id);
if(userIdSessionMap.size() > 0) {
List<Session> sessionList = new ArrayList<>();
for (Map.Entry<Session, String> entry : userIdSessionMap.entrySet()) {
if(id.equalsIgnoreCase(entry.getValue())) {
sessionList.add(entry.getKey());
}
}
if(sessionList.size() > 0) {
for (Session session : sessionList) {
try {
session.getBasicRemote().sendText(message);//發送string
log.info("推送用戶【{}】消息成功,消息為:【{}】", id , message);
} catch (Exception e) {
log.info("推送用戶【{}】消息失敗,消息為:【{}】,原因是:【{}】", id , message, e.getMessage());
}
}
} else {
log.error("未找到當前id對應的session, id = {}", id);
}
} else {
log.warn("當前無websocket連接");
}
}
/**
* 廣播消息
* @param message
*/
public void broadcast(String message) {
log.info("#### 廣播消息");
if(userIdSessionMap.size() > 0) {
for (Map.Entry<Session, String> entry : userIdSessionMap.entrySet()) {
try {
entry.getKey().getBasicRemote().sendText(message);//發送string
} catch (Exception e) {
log.error("websocket 發送【{}】消息出錯:{}",entry.getKey(), e.getMessage());
}
}
} else {
log.warn("當前無websocket連接");
}
}
public static synchronized int getOnlineCount() {
return onlineCount.get();
}
public static synchronized void addOnlineCount() {
onlineCount.incrementAndGet();
}
public static synchronized void subOnlineCount() {
onlineCount.decrementAndGet();
}
/**
* 自定義消息編碼器
*/
public static class MessageEncoder implements Encoder.Text<JSONObject> {
@Override
public void init(javax.websocket.EndpointConfig endpointConfig) {
}
@Override
public void destroy () {
}
@Override
public String encode(JSONObject object) throws EncodeException {
return object == null ? "" : object.toJSONString();
}
}
}
redis
常量類
public class RedisKeyConstants {
/**
* redis topic
*/
public final static String REDIS_TOPIC_MSG = "redis_topic_msg";
}
配置類
import java.util.Arrays;
import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.configuration.redis.listener.RedisTopicListener;
import com.bart.websocket.service.WarnMsgService;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
/**
* @author bart
*/
@Configuration
public class RedisConfig {
/**
* 添加spring提供的RedisMessageListenerContainer到容器
* @param connectionFactory
* @return
*/
@Bean
RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
return container;
}
/**
* 添加自己的監聽器到容器中(監聽指定topic)
* @param container
* @param stringRedisTemplate
* @return
*/
@Bean
RedisTopicListener redisTopicListener(
RedisMessageListenerContainer container,
StringRedisTemplate stringRedisTemplate,
WarnMsgService warnMsgService) {
// 指定監聽的 topic
RedisTopicListener redisTopicListener = new RedisTopicListener(container,
Arrays.asList(new ChannelTopic(RedisKeyConstants.REDIS_TOPIC_MSG)),
warnMsgService);
redisTopicListener.setStringRedisSerializer(new StringRedisSerializer());
redisTopicListener.setStringRedisTemplate(stringRedisTemplate);
return redisTopicListener;
}
}
redis消息體
import com.bart.websocket.entity.WarnMsg;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* redis發送消息的封裝
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
public class TopicMsg {
private String userId;
private WarnMsg msg;
}
監聽器
import java.util.List;
import com.alibaba.fastjson.JSON;
import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.Topic;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.util.StringUtils;
/**
* 自定義的topic的監聽器
* @author bart
*
*/
public class RedisTopicListener implements MessageListener {
private final static Logger log = LoggerFactory.getLogger(RedisTopicListener.class);
private StringRedisSerializer stringRedisSerializer;
private StringRedisTemplate stringRedisTemplate;
private WarnMsgService warnMsgService;
public RedisTopicListener(RedisMessageListenerContainer listenerContainer, List< ? extends Topic> topics, WarnMsgService warnMsgService) {
this(listenerContainer, topics);
this.warnMsgService = warnMsgService;
}
public RedisTopicListener(RedisMessageListenerContainer listenerContainer, List< ? extends Topic> topics) {
listenerContainer.addMessageListener(this, topics);
}
@Override
public void onMessage(Message message, byte[] pattern) {
String patternStr = stringRedisSerializer.deserialize(pattern);
String channel = stringRedisSerializer.deserialize(message.getChannel());
String body = stringRedisSerializer.deserialize(message.getBody());
log.info("event = {}, message.channel = {}, message.body = {}", patternStr, channel, body);
if(RedisKeyConstants.REDIS_TOPIC_MSG.equals(channel)) {
TopicMsg topicMsg = JSON.parseObject(body, TopicMsg.class);
String userId = topicMsg.getUserId();
WarnMsg msg = topicMsg.getMsg();
// log.debug("receive from topic=[{}] , userId=[{}], msg=[{}]", RedisKeyConstants.REDIS_TOPIC_MSG, userId, msg);
// 發送消息 id 為空就是群發消息
if(StringUtils.isEmpty(userId)) {
warnMsgService.push(msg);
} else {
warnMsgService.push(userId, msg);
}
}
}
public StringRedisSerializer getStringRedisSerializer() {
return stringRedisSerializer;
}
public void setStringRedisSerializer(StringRedisSerializer stringRedisSerializer) {
this.stringRedisSerializer = stringRedisSerializer;
}
public StringRedisTemplate getStringRedisTemplate() {
return stringRedisTemplate;
}
public void setStringRedisTemplate(StringRedisTemplate stringRedisTemplate) {
this.stringRedisTemplate = stringRedisTemplate;
}
}
重點方法在這里:
com.bart.websocket.configuration.redis.listener.RedisTopicListener#onMessage
測試接口
import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
@RestController
public class IndexController {
@Autowired
WarnMsgService warnMsgService;
/**
* 推送消息測試
*/
@GetMapping("/push")
public void initMsg(String id) {
WarnMsg warnMsg = new WarnMsg();
String format = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
warnMsg.setTitle(format);
warnMsg.setBody("吃了沒?");
warnMsgService.push(id, warnMsg);
}
}
消息處理器類
WarnMsgService
接口
public interface WarnMsgService {
/**
* 推送消息
* @param msg
*/
void push(WarnMsg msg);
/**
* 推送消息
* @param userId 用戶id
* @param msg
*/
void push(String userId, WarnMsg msg);
/**
* 通過 redis topic 發送消息(群發)
* @param msg
*/
void pushWithTopic(WarnMsg msg);
/**
* 通過 redis topic 發送消息
* @param userId
* @param msg
*/
void pushWithTopic(String userId, WarnMsg msg);
}
WarnMsgServiceImpl
實現類
package com.bart.websocket.service.impl;
import com.alibaba.fastjson.JSON;
import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.configuration.redis.listener.TopicMsg;
import com.bart.websocket.controller._02_spring_annotation.ProductWebSocket;
import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import java.util.Collections;
@Service
public class WarnMsgServiceImpl implements WarnMsgService, ApplicationContextAware {
private final static Logger log = LoggerFactory.getLogger(WarnMsgServiceImpl.class);
ProductWebSocket webSocketHandler;
@Autowired
StringRedisTemplate stringRedisTemplate;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
webSocketHandler = (ProductWebSocket)applicationContext.getBean("webSocketHandler", WebSocketHandler.class);
Assert.notNull(webSocketHandler, "初始化webSocketHandler成功!");
}
@Override
public void push(WarnMsg msg) {
// RyGzry user = CommonUtils.getUser();
// push(String.valueOf(user.getId()), msg);
push("", msg);
}
@Override
public void push(String userId, WarnMsg msg) {
Assert.notNull(msg, "消息對象不能為空!");
if(msg.getBody() == null) {
msg.setBody(Collections.emptyMap());
}
if(StringUtils.isEmpty(userId)) {
webSocketHandler.broadcast(JSON.toJSONString(msg));
} else {
webSocketHandler.send(userId, JSON.toJSONString(msg));
}
}
/*
* 向 redis 的 topic 發消息
* 測試指定的topic的監聽器(命令行)
* 發布訂閱
* SUBSCRIBE redisChat // 訂閱主題
* PSUBSCRIBE it* big* //訂閱給定模式的主題
*
* PUBLISH redisChat "Redis is a great caching technique" // 發布消息主題
*
* PUNSUBSCRIBE it* big* // 取消訂閱通配符的頻道
* UNSUBSCRIBE channel it_info big_data // 取消訂閱具體的頻道
*/
@Override
public void pushWithTopic(String userId, WarnMsg msg) {
if(null == userId) {
userId = "";
}
if(msg == null) {
log.debug("send to userId = [{}] msg is empty, just ignore!", userId);
return;
}
String body = JSON.toJSONString(new TopicMsg(userId, msg));
log.debug("send topic=[], msg=[]", RedisKeyConstants.REDIS_TOPIC_MSG, body);
stringRedisTemplate.convertAndSend(RedisKeyConstants.REDIS_TOPIC_MSG, body);
}
@Override
public void pushWithTopic(WarnMsg msg) {
pushWithTopic("", msg);
}
}
前端
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>websocket</title>
<script src="js/sockjs.js"></script>
<script src="js/jquery.min.js"></script>
</head>
<body>
<fieldset>
<legend>User01</legend>
<button onclick="online('bart')">上線</button>
session:<input type="text" id="session-bart"/>
host:<input type="text" id="host-bart" value="localhost"/>
port:<input type="text" id="port-bart" value="8089"/>
<div>發送消息:</div>
<input type="text" id="msgContent-bart"/>
<input type="button" value="點我發送" onclick="chat('bart')"/>
<div>接受消息:</div>
<div id="receiveMsg-bart" style="background-color: gainsboro;"></div>
</fieldset>
<script>
var map = {};
function online(name) {
var host = $("#host-"+name).val();
var port = $("#port-"+name).val();
var session = $("#session-"+name).val();
var chat = new CHAT(name, "ws://"+host+":"+port+"/ws/"+session);
chat.init();
map[name] = chat
}
function chat(name) {
console.log(name)
return false;
}
function CHAT(name, url) {
this.name = name;
this.socket = null,
this.init = function() {
if ('WebSocket' in window) {
console.log("WebSocket -> "+ url);
//this.socket = new WebSocket("ws://localhost:8088/ws/"+ this.name);
this.socket = new WebSocket(url);
} else {
console.log("your broswer not support websocket!");
alert("your broswer not support websocket!")
return;
}
if(this.socket === null) {
return
}
this.socket.onopen = function() {
console.log("連接建立成功...");
},
this.socket.onclose = function() {
console.log("連接關閉...");
},
this.socket.onerror = function() {
console.log("發生錯誤...");
},
this.socket.onmessage = function(e) {
var id = "receiveMsg-"+ name;
var res = JSON.parse(e.data);
console.log(name , res);
// 業務邏輯
}
},
this.chat = function() {
var id = "msgContent-"+ name;
var value = document.getElementById(id).value;
console.log("發送消息", id, value)
var msg = {
"type": 1, // 1 就是發給所有人
"msg": value
}
this.socket.send(JSON.stringify(msg));
}
};
</script>
</body>
</html>
測試
啟動兩個后端項目,端口分別為8080
,8081
1、瀏覽器中鏈接8080端口的websocket
2、然后訪問8081
的接口http://localhost:8081/push
,發現鏈接8080的客戶端也受到了消息;