package com.mzl.xx.web;

import com.mzl.xx.base.cache.StringCacheClient;
import com.mzl.xx.config.security.SpringSecurityUtils;
import com.mzl.xx.service.mqtt.MqttService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArraySet;

@ServerEndpoint(value = "/socket/{type}")
@Component("websocket")
@Slf4j
public class WebSocketServer {
    //静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
    private static int onlineCount = 0;

    //concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。使用Map,对websocket连接分类。
    private static final Map<String, CopyOnWriteArraySet<WebSocketServer>> WEB_SOCKET_SETS = new HashMap<>();

    //与某个客户端的连接会话,需要通过它来给客户端发送数据
    private Session session;

//    private StringCacheClient stringCacheClient = SpringSecurityUtils.getBean(StringCacheClient.class);

//    private MqttService mqttService = SpringSecurityUtils.getBean(MqttService.class);

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(@PathParam("type") String type, Session session) {
        this.session = session;
        this.session.setMaxIdleTimeout(30 * 1000);
        Map<String, List<String>> map = session.getRequestParameterMap();
        String sessionId = session.getId();
        log.info("==> sessionId: " + sessionId + "; param: " + map);
//        List<String> sLs = map.get("token");
//        if (sLs == null || sLs.size() == 0) {
//            try {
//                sendMessage("参数无效");
//                session.close();
//            } catch (IOException e) {
//            }
//            return;
//        }
//        String token = sLs.get(0);
//        String id = stringCacheClient.get(token);
//        if (StringUtils.isEmpty(id)) {
//            try {
//                sendMessage("身份无效");
//                session.close();
//            } catch (IOException e) {
//            }
//            return;
//        }

        // 加入set中
        if (WEB_SOCKET_SETS.containsKey(type)) {
            WEB_SOCKET_SETS.get(type).add(this);
        } else {
            CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<WebSocketServer>();
            webSocketSet.add(this);
            WEB_SOCKET_SETS.put(type, webSocketSet);
        }
        addOnlineCount();           //在线数加1
//        mqttService.start();
        log.info("=======WebSocket======= 有新连接加入!当前在线人数为 {}", getOnlineCount());
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose(@PathParam("type") String type) {
        CopyOnWriteArraySet<WebSocketServer> webSocketSet = WEB_SOCKET_SETS.get(type);

        log.info("==> connect count before close: " + webSocketSet.size());
        String key = session.getId();
        // 从set中删除
        webSocketSet.remove(this);
        subOnlineCount();           //在线数减1
        log.info("==> connect count after close: " + webSocketSet.size());
        log.info("==> " + key + " 已关闭");
        if (onlineCount <= 0) {
//            mqttService.stop();
        }
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(@PathParam("type") String type, String message) {
        String key = session.getId();
        log.info("==> 心跳检测:" + key + " >> " + message);
    }

    /**
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        String key = session.getId();
        log.error("==> " + key + " ==> " + error.getMessage(), error);
    }


    public void sendMessage(String message) throws IOException {
        if (this != null && this.session != null
                && this.session.getBasicRemote() != null) {
            this.session.getBasicRemote().sendText(message);
        }

    }

    @Override
    public int hashCode() {
        return Objects.hash(session);
    }

    /**
     * 群发自定义消息
     */
    public void sendData(String message, String type) {
        if (!org.springframework.util.StringUtils.hasText(type)) {
            return;
        }
        // 根据type区分要推送的是车间视图还是车间生产循环的数据
        CopyOnWriteArraySet<WebSocketServer> webSockTests = WEB_SOCKET_SETS.get(type);
        if (!CollectionUtils.isEmpty(webSockTests)) {
            webSockTests.forEach(item -> {
                try {
                    item.sendMessage(message);
                    log.info("=======WebSocketServer======= 实时推送消息成功 : {}", message);
                } catch (IOException e) {
                    log.error("=======WebSocket======= 发送消息错误", e);
                }
            });
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof WebSocketServer)) {
            return false;
        }

        WebSocketServer that = (WebSocketServer) o;

        return session.equals(that.session);
    }

    // 该方***导致部署到服务器上找不到webSocket的Bean
//    @Override
//    public int hashCode() {
//        return session.hashCode();
//    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocketServer.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocketServer.onlineCount--;
    }

}