JAVA 中常用 webSocket 进行前后端的通信,比如在适当的时候通知前端进行刷新操作。通常情况下没有问题(消息传递可靠性不做考虑),但是一旦后端突然接收到大量请求,需要向前端发送大量 socket 通知其刷新。这种情况下会给前端带去很大的压力,很有可能刷新不过来,造成前端页面卡死。

       本文通过对“向同一客户端发送的大量同类socket消息“进行过滤来进行限流操作。

核心代码

package com.ysu.ems.web;

import java.io.IOException;
import javax.annotation.Resource;
import org.springframework.stereotype.Component;
import com.alibaba.fastjson.JSONObject;
import com.ysu.ems.pojo.RedisUtil;
/**
 * 2019/08/21
 *
 */
@Component
public class HighConcurrencyWebSocketServer {
	
	private static WebSocketServer websocketServer;//普通发送socket消息的方法类,代码附后
	@Resource
 	public void setWebSocketServer(WebSocketServer websocketServer) {
 		this.websocketServer = websocketServer;
 	}
	private static RedisUtil redisUtil;//连接redis的工具类,代码附后
	@Resource
 	public void setRedisUtil(RedisUtil redisUtil) {
 		this.redisUtil = redisUtil;
 	}
	
	public static void sendInfo(Object ob,String message,String userId) throws IOException  {
		StringBuffer keyBuffer=new StringBuffer();//redis缓存的key值
		JSONObject jSONObject=(JSONObject)ob;
		keyBuffer.append("highConWebSocketKey-");//key值前缀
		keyBuffer.append(jSONObject.get("type").toString());//socket的type值
		keyBuffer.append("-"+userId);//socket将要送达的客户端的userId
		if(!redisUtil.exists(keyBuffer.toString())){//如果缓存中不存在
			//存入缓存
			redisUtil.set(keyBuffer.toString(),"1",1L);//1秒后失效,value值无意义
			//开定时器,1秒之后再发出去
			Thread socketSender=new Thread(()->{
				try {
					Thread.sleep(1000);//先休眠1秒
					websocketServer.sendInfo(ob, message, userId);//发送socket					
				} catch (Exception e) {
					e.printStackTrace();
				}
			});
			socketSender.start();
		}
	}
}

上述代码中的 Object ob对象的值类似这样:

JSONObject jsonObject = new JSONObject();
jsonObject.put("type", 6);//socket类型,和前端约定好的,不同值代表不同意义
jsonObject.put("info", "这是要发送的socket解释信息");

核心思路就是:把短时间内(比如1秒内),向同一用户发送的同类socket消息过滤掉不发送,并确保该段时间内的最新消息(因为是同类的,所以最先到达的消息也等同于最新的)能够发送到前端,避免最后到达的那条信息被过滤而没被发送,休眠时间可以比缓存失效时间长一点点。

比如1秒内向 A 用户客户端发送了 100 条同样的信息要求前端刷新,这样过滤后前端就只会收到第一条消息(它和最后一条最新的消息是等价的),后面的99条冗余消息就被后端过滤掉了,前端在1秒内顶多只刷新一次(而且一定是最新的消息通知的),不会导致页面卡死。

 

连接redis

package com.ysu.ems.pojo;
import java.io.Serializable;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
 
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.stereotype.Component;
 
/**
 * 
 * @Description: spring boot 的redis工具类  
 * 创建日期:2019/05/26
 * @function redisTemplateInit 配置Redis防止乱码   2019/08/21
 */
@SuppressWarnings("unchecked")
@Component
public class RedisUtil {
	 @SuppressWarnings("rawtypes")
	    @Autowired
	    private RedisTemplate redisTemplate;
	 
	 	@Bean
	    public RedisTemplate redisTemplateInit() {
	        //设置序列化Key的实例化对象
	        redisTemplate.setKeySerializer(new StringRedisSerializer());
	        //设置序列化Value的实例化对象
	        redisTemplate.setValueSerializer(new GenericJackson2JsonRedisSerializer());
	        return redisTemplate;
	    }
	    /**
	     * 批量删除对应的value
	     * 
	     * @param keys
	     */
	    public void remove(final String... keys) {
	        for (String key : keys) {
	            remove(key);
	        }
	    }
	 
	    /**
	     * 批量删除key
	     * 
	     * @param pattern
	     */
	    public void removePattern(final String pattern) {
	        Set<Serializable> keys = redisTemplate.keys(pattern);
	        if (keys.size() > 0)
	            redisTemplate.delete(keys);
	    }
	 
	    /**
	     * 删除对应的value
	     * 
	     * @param key
	     */
	    public void remove(final String key) {
	        if (exists(key)) {
	            redisTemplate.delete(key);
	        }
	    }
	 
	    /**
	     * 判断缓存中是否有对应的value
	     * 
	     * @param key
	     * @return
	     */
	    public boolean exists(final String key) {
	        return redisTemplate.hasKey(key);
	    }
	 
	    /**
	     * 读取缓存
	     * 
	     * @param key
	     * @return
	     */
	    public String get(final String key) {
	        Object result = null;
	        redisTemplate.setValueSerializer(new StringRedisSerializer());
	        ValueOperations<Serializable, Object> operations = redisTemplate.opsForValue();
	        result = operations.get(key);
	        if (result == null) {
	            return null;
	        }
	        return result.toString();
	    }
	 
	    /**
	     * 写入缓存
	     * 
	     * @param key
	     * @param value
	     * @return
	     */
	    public boolean set(final String key, Object value) {
	        boolean result = false;
	        try {
	            ValueOperations<Serializable, Object> operations = redisTemplate.opsForValue();
	            operations.set(key, value);
	            result = true;
	        } catch (Exception e) {
	            e.printStackTrace();
	        }
	        return result;
	    }
	 
	    /**
	     * 写入缓存
	     * 
	     * @param key
	     * @param value
	     * @return
	     */
	    public boolean set(final String key, Object value, Long expireTime) {
	        boolean result = false;
	        try {
	            ValueOperations<Serializable, Object> operations = redisTemplate.opsForValue();
	            operations.set(key, value);
	            redisTemplate.expire(key, expireTime, TimeUnit.SECONDS);
	            result = true;
	        } catch (Exception e) {
	            e.printStackTrace();
	        }
	        return result;
	    }
	 
	    public boolean hmset(String key, Map<String, String> value) {
	        boolean result = false;
	        try {
	            redisTemplate.opsForHash().putAll(key, value);
	            result = true;
	        } catch (Exception e) {
	            e.printStackTrace();
	        }
	        return result;
	    }
	 
	    public Map<String, String> hmget(String key) {
	        Map<String, String> result = null;
	        try {
	            result = redisTemplate.opsForHash().entries(key);
	        } catch (Exception e) {
	            e.printStackTrace();
	        }
	        return result;
	    }
}

我用maven管理的,服务器那边redis需要你们自己配置一下,redis连接池我用spring-boot管理的。在配置文件中写一下就行

#redis缓存配置

spring.redis.host=服务器ip
spring.redis.port=6379
#spring.redis.password=
#spring.redis.database=1
#spring.redis.pool.max-active=80
#spring.redis.pool.max-wait=-1
#spring.redis.pool.max-idle=500
#spring.redis.pool.min-idle=0
#spring.redis.timeout=100000

webSocket实现

文中的 websocketServer.sendInfo(ob, message, userId);就是普通的发送socket消息的实现方法

package com.ysu.ems.web;

import java.io.IOException;
import java.io.OutputStream;
import java.net.Socket;
import java.util.List;
import java.util.concurrent.CopyOnWriteArraySet;

import javax.annotation.Resource;
import javax.websocket.EncodeException;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;

import java.util.Date;
import java.text.SimpleDateFormat;

import org.springframework.stereotype.Component;

import com.alibaba.fastjson.JSONObject;





@ServerEndpoint("/websocket/{sid}")
@Component

public class WebSocketServer {

    private static int onlineCount = 0;
    private static CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<WebSocketServer>();

    private Session session;

    private String sid="";

  
    @OnOpen
    public void onOpen(Session session,@PathParam("sid") String sid) {
        this.session = session;
        webSocketSet.add(this);
        addOnlineCount();
        this.sid=sid;
    }


    @OnClose
    public void onClose() {
        webSocketSet.remove(this);
        subOnlineCount();
    }


    @OnMessage
    public void onMessage(String message, Session session) {
        //log.info("收到来自窗口"+sid+"的信息:"+message);

        for (WebSocketServer item : webSocketSet) {
            try {
                item.sendMessage(item.sid);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }


    @OnError
    public void onError(Session session, Throwable error) {
        //log.error("发生错误");
        error.printStackTrace();
    }

    public void sendMessage(String message) throws IOException {
        this.session.getAsyncRemote().sendText(message);
    }
  


    public static void sendInfo(Object ob,String message,@PathParam("sid") String sid) throws IOException {
        //log.info("推送消息到窗口"+sid+",推送内容:"+message);
    	
        for (WebSocketServer item : webSocketSet) {
   
            try {
                //这里可以设定只推送给这个sid的,为null则全部推送
                if(sid==null) {
                 if(message!=null)
                    item.sendMessage(message);
                  if(ob!=null)
                    item.sendMessage(JSONObject.toJSONString(ob));
                }else if(item.sid.equals(sid)){
                	if(message!=null)
                        item.sendMessage(message);
                    if(ob!=null)
                        item.sendMessage(JSONObject.toJSONString(ob));
                }
            } catch (IOException e) {
                continue;
            }
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocketServer.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocketServer.onlineCount--;
    }
}