package com.mzl.xx.web;
import com.mzl.xx.base.cache.StringCacheClient;
import com.mzl.xx.config.security.SpringSecurityUtils;
import com.mzl.xx.service.mqtt.MqttService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArraySet;
@ServerEndpoint(value = "/socket/{type}")
@Component("websocket")
@Slf4j
public class WebSocketServer {
//静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
private static int onlineCount = 0;
//concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。使用Map,对websocket连接分类。
private static final Map<String, CopyOnWriteArraySet<WebSocketServer>> WEB_SOCKET_SETS = new HashMap<>();
//与某个客户端的连接会话,需要通过它来给客户端发送数据
private Session session;
// private StringCacheClient stringCacheClient = SpringSecurityUtils.getBean(StringCacheClient.class);
// private MqttService mqttService = SpringSecurityUtils.getBean(MqttService.class);
/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(@PathParam("type") String type, Session session) {
this.session = session;
this.session.setMaxIdleTimeout(30 * 1000);
Map<String, List<String>> map = session.getRequestParameterMap();
String sessionId = session.getId();
log.info("==> sessionId: " + sessionId + "; param: " + map);
// List<String> sLs = map.get("token");
// if (sLs == null || sLs.size() == 0) {
// try {
// sendMessage("参数无效");
// session.close();
// } catch (IOException e) {
// }
// return;
// }
// String token = sLs.get(0);
// String id = stringCacheClient.get(token);
// if (StringUtils.isEmpty(id)) {
// try {
// sendMessage("身份无效");
// session.close();
// } catch (IOException e) {
// }
// return;
// }
// 加入set中
if (WEB_SOCKET_SETS.containsKey(type)) {
WEB_SOCKET_SETS.get(type).add(this);
} else {
CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<WebSocketServer>();
webSocketSet.add(this);
WEB_SOCKET_SETS.put(type, webSocketSet);
}
addOnlineCount(); //在线数加1
// mqttService.start();
log.info("=======WebSocket======= 有新连接加入!当前在线人数为 {}", getOnlineCount());
}
/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose(@PathParam("type") String type) {
CopyOnWriteArraySet<WebSocketServer> webSocketSet = WEB_SOCKET_SETS.get(type);
log.info("==> connect count before close: " + webSocketSet.size());
String key = session.getId();
// 从set中删除
webSocketSet.remove(this);
subOnlineCount(); //在线数减1
log.info("==> connect count after close: " + webSocketSet.size());
log.info("==> " + key + " 已关闭");
if (onlineCount <= 0) {
// mqttService.stop();
}
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage
public void onMessage(@PathParam("type") String type, String message) {
String key = session.getId();
log.info("==> 心跳检测:" + key + " >> " + message);
}
/**
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
String key = session.getId();
log.error("==> " + key + " ==> " + error.getMessage(), error);
}
public void sendMessage(String message) throws IOException {
if (this != null && this.session != null
&& this.session.getBasicRemote() != null) {
this.session.getBasicRemote().sendText(message);
}
}
@Override
public int hashCode() {
return Objects.hash(session);
}
/**
* 群发自定义消息
*/
public void sendData(String message, String type) {
if (!org.springframework.util.StringUtils.hasText(type)) {
return;
}
// 根据type区分要推送的是车间视图还是车间生产循环的数据
CopyOnWriteArraySet<WebSocketServer> webSockTests = WEB_SOCKET_SETS.get(type);
if (!CollectionUtils.isEmpty(webSockTests)) {
webSockTests.forEach(item -> {
try {
item.sendMessage(message);
log.info("=======WebSocketServer======= 实时推送消息成功 : {}", message);
} catch (IOException e) {
log.error("=======WebSocket======= 发送消息错误", e);
}
});
}
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof WebSocketServer)) {
return false;
}
WebSocketServer that = (WebSocketServer) o;
return session.equals(that.session);
}
// 该方***导致部署到服务器上找不到webSocket的Bean
// @Override
// public int hashCode() {
// return session.hashCode();
// }
public static synchronized int getOnlineCount() {
return onlineCount;
}
public static synchronized void addOnlineCount() {
WebSocketServer.onlineCount++;
}
public static synchronized void subOnlineCount() {
WebSocketServer.onlineCount--;
}
}