前言

  • 在使用ChatGPT的时候,发现输入prompt后,是使用流式的效果返回数据,给用户的是一个打字机的效果。查看其网络请求,发现这个接口的通响应类型是text/event-stream,一种基于EventStream的事件流。
  • 那么为什么要这样传输呢?从使用场景上来说,ChatGPT是一个基于深度学习的大型语言模型,处理自然语言需要大量的计算资源和时间,那么响应速度肯定是比一般业务要慢的,那么接口等待时间过长,显然也不合适,那么对于这种对话场景,采用SSE技术边计算边返回,避免用户因为等待时间过长而关闭页面。

概述

SSE(Server Sent Event),直译为服务器发送事件,也就是服务器主动发送事件,客户端可以获取到服务器发送的事件。

  • 常见的HTTP交互方式主要是客户端发起请求,然后服务端响应,然后一次性请求完毕。但是在SSE的使用场景下,客户端发起请求,然后建立SSE连接一直保持,服务端就可以返回数据给客户端。
  • SSE简单来说就是服务器主动向前端推送数据的一种技术,它是单向的。SSE适用于消息推送、监控等只需要服务端推送数据的场景中。

特点

  • 服务端主动推送
    1. HTML5新标准,用于从服务端试试推送数据到浏览器端。
    2. 直接建立在当前HTTP连接上,本质上是一个HTTP长连接。

SSE与WebSocket的区别

  • SSE是单工的,只能由服务端想客户端发送消息,而WebSocket是双工的
SSE WebScoket
http 协议 独立的 websocket 协议
轻量,使用简单 相对复杂
默认支持断线重连 需要自己实现断线重连
文本传输 二进制传输
支持自定义发送的消息类型 -

SSE规范

  • 在HTML5中,服务端SSE一般要遵循以下要求
    1. 请求头:开启长连接 + 流式传递
      1
      2
      3
      Content-Type: text/event-stream;charset=UTF-8
      Cache-Control: no-cache
      Connection: keep-alive
    2. 数据格式:服务端发送的消息,由message组成,其格式如下
      1
      field:value

SSE实践

  • 这里简单做一个时钟效果,有服务端主动推送当前时间数据给前端,前端页面接收后展示。

SseEmitter类简介

  • SpringBoot使用SseEmitter来支持SSE,并对SSE规范做了一些封装,使用起来非常简单。我们在操作SseEmitter对象时,只需要关注发送的消息文本即可。
  • SseEmittter类的几个方法:
    1. send():发送数据,如果传入的是一个非SseEventBuilder对象,那么传递参数会被封装到data中。
    2. complete():表示执行完成,会断开连接(如果是一些轮询进度的任务,我们可以在任务进度完成时,主动断开连接)
    3. onTimeout():连接超时时回调触发。
    4. onCompletion():结束之后的回调触发。
    5. onError():报错时的回调触发。

示例Demo

1
2
3
4
5
6
7
8
9
10
11
12
<html>
<body>
<div id="msg_from_server"></div>
</body>
<script>
const sse = new EventSource("http://localhost/sse/hello");
sse.onmessage = function (event) {
var eventVal = document.getElementById("msg_from_server");
eventVal.innerHTML = event.data;
};
</script>
</html>
  • 后端接口
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.web.bind.annotation.CrossOrigin;
    import org.springframework.web.bind.annotation.GetMapping;
    import org.springframework.web.bind.annotation.RequestMapping;
    import org.springframework.web.bind.annotation.RestController;
    import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

    import javax.servlet.http.HttpServletResponse;
    import java.time.LocalDateTime;
    import java.time.format.DateTimeFormatter;


    @Slf4j
    @CrossOrigin
    @RestController
    @RequestMapping("/sse")
    public class CommonController {
    @GetMapping("/hello")
    public SseEmitter helloworld(HttpServletResponse response) {
    response.setContentType("text/event-stream");
    response.setCharacterEncoding("utf-8");
    SseEmitter sseEmitter = new SseEmitter();
    new Thread(() -> {
    try {
    while (true) {
    Thread.sleep(1000L);
    sseEmitter.send(SseEmitter.event().data(LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"))));
    }
    } catch (Exception e) {
    log.error("Error in SSE: {}", e.getMessage());
    sseEmitter.completeWithError(e);
    }
    }).start();
    return sseEmitter;
    }
    }
  • 大功告成

注意事项

  • 这里的协议是http/1.1,仅支持6个连接数,而HTTP/2默认支持100个连接数,同时这里每30秒重新建立了一个新连接,这是SSE默认的连接超时时间,我们可以通过配置连接超时时间来达到不过期的目的,那么就需要我们在业务逻辑里手动关闭连接
  • 同时,每建立一个SSE连接的时候,都需要一个线程,那么这里就需要创建一个线程池来处理并发问题,同时也要根据自身的业务需求来做好压测。
  • 但是HTTP/2仅支持HTTPS,我这里就不演示了,感兴趣的小伙伴可以去了解一下使用OpenSSL生成一个自签名的SSL证书

工具类封装

  • 下面是我封装的一个简单的SseUtils
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.http.MediaType;
    import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

    import java.io.IOException;
    import java.util.Map;
    import java.util.concurrent.ConcurrentHashMap;

    @Slf4j
    public class SseUtils {
    // timeout -> 0表示不过期,默认是30秒,超过时间未完成(断开)会抛出异常
    private static final Long DEFAULT_TIME_OUT = 0L;
    // 会话map, 方便管理连接数
    private static Map<String, SseEmitter> conversationMap = new ConcurrentHashMap<>();

    /**
    * 建立连接
    *
    * @param conversationId - 会话Id
    * @return
    */
    public static SseEmitter getConnect(String conversationId) {
    // 创建SSE
    SseEmitter sseEmitter = new SseEmitter(DEFAULT_TIME_OUT);
    // 异常
    try {
    // 设置前端重试时5s
    sseEmitter.send(SseEmitter.event().reconnectTime(5_000L).data("SSE建立成功"));
    // 连接超时
    sseEmitter.onTimeout(() -> SseUtils.timeout(conversationId));
    // 连接断开
    sseEmitter.onCompletion(() -> SseUtils.completion(conversationId));
    // 错误
    sseEmitter.onError((e) -> SseUtils.error(conversationId, e.getMessage()));
    // 添加sse
    conversationMap.put(conversationId, sseEmitter);
    // 连接成功
    log.info("创建sse连接成功 ==> 当前连接总数={}, 会话Id={}", conversationMap.size(), conversationId);
    } catch (IOException e) {
    // 日志
    log.error("前端重连异常 ==> 会话Id={}, 异常信息={}", conversationId, e.getMessage());
    }
    // 返回
    return sseEmitter;
    }

    /***
    * 获取消息实例
    *
    * @param conversationId - 会话Id
    * @return
    */
    public static SseEmitter getInstance(String conversationId) {
    return conversationMap.get(conversationId);
    }

    /***
    * 断开连接
    *
    * @param conversationId - 会话Id
    * @return
    */
    public static void disconnect(String conversationId) {
    SseUtils.getInstance(conversationId).complete();
    }

    /**
    * 给指定会话发送消息,如果发送失败,返回false
    *
    * @param conversationId - 会话Id
    * @param jsonMsg - 消息
    */
    public static boolean sendMessage(String conversationId, String jsonMsg) {
    // 判断该会话是否已建立连接
    // 已建立连接
    if (SseUtils.getIsExistClientId(conversationId)) {
    try {
    // 发送消息
    SseUtils.getInstance(conversationId).send(jsonMsg, MediaType.APPLICATION_JSON);
    return true;
    } catch (IOException e) {
    // 日志
    SseUtils.removeClientId(conversationId);
    log.error("发送消息异常 ==> 会话Id={}, 异常信息={}", conversationId, e.getMessage());
    return false;
    }
    } else {
    // 未建立连接
    log.error("连接不存在或者超时 ==> 会话Id={}会话自动关闭", conversationId);
    SseUtils.removeClientId(conversationId);
    return false;
    }
    }

    /**
    * 移除会话Id
    *
    * @param conversationId - 会话Id
    */
    public static void removeClientId(String conversationId) {
    // 不存在存在会话
    if (!SseUtils.getIsExistClientId(conversationId)) {
    return;
    }
    // 删除该会话
    conversationMap.remove(conversationId);
    // 日志
    log.info("移除会话成功 ==> 会话Id={}", conversationId);
    }

    /**
    * 获取是否存在会话
    *
    * @param conversationId - 会话Id
    */
    public static boolean getIsExistClientId(String conversationId) {
    return conversationMap.containsKey(conversationId);
    }

    /**
    * 获取当前连接总数
    *
    * @return - 连接总数
    */
    public static int getConnectTotal() {
    log.error("当前连接数:{}", conversationMap.size());
    for (String s : conversationMap.keySet()) {
    log.error("输出SSE-Map:{}", conversationMap.get(s));
    }
    return conversationMap.size();
    }

    /**
    * 超时
    *
    * @param conversationId String 会话Id
    */
    public static void timeout(String conversationId) {
    // 日志
    log.error("sse连接超时 ==> 会话Id={}", conversationId);
    // 移除会话
    SseUtils.removeClientId(conversationId);
    }

    /**
    * 完成
    *
    * @param conversationId String 会话Id
    */
    public static void completion(String conversationId) {
    // 日志
    log.info("sse连接已断开 ==> 会话Id={}", conversationId);
    // 移除会话
    SseUtils.removeClientId(conversationId);
    }

    /**
    * 错误
    *
    * @param conversationId String 会话Id
    */
    public static void error(String conversationId, String message) {
    // 日志
    log.error("sse服务异常 ==> 会话Id={}, 异常信息={}", conversationId, message);
    // 移除会话
    SseUtils.removeClientId(conversationId);
    }
    }
  • 还是用刚刚推送当前时间的例子,这里我们做一下主动关闭连接,我这里简单的逻辑就是遍历到一个整分,就停止推送
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    @GetMapping("/demo")
    public SseEmitter timeStamp(HttpServletResponse response) {
    response.setContentType("text/event-stream");
    response.setCharacterEncoding("utf-8");
    // 生成会话ID
    String conversationId = "123456";
    // 建立连接
    SseEmitter sseEmitter = SseUtils.getConnect(conversationId);
    new Thread(() -> {
    try {
    while (true) {
    Thread.sleep(1000L);
    // 向会话发送消息
    String timeStamp = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
    SseUtils.sendMessage(conversationId, timeStamp);
    if (timeStamp.endsWith("00")) {
    SseUtils.removeClientId(conversationId);
    break;
    }
    }
    } catch (Exception e) {
    log.error("Error in SSE: {}", e.getMessage());
    sseEmitter.completeWithError(e);
    }
    }).start();
    return sseEmitter;
    }

SSE实战

  • 我这里也是在我项目里的轮询订单进度的时候尝试用了一下,因为我这个项目也是文本生成方向的,之前是前端定时轮询我这边的接口,现在换成我主动向前端推送数据,然后前端拿到数据自己解析内容就好了。这里用的工具类就是我刚刚封装的那个
    1
    2
    3
    4
    5
    6
    7
    @CrossOrigin
    @GetMapping("/getOrderDetail")
    public SseEmitter getOrderDetailById(String orderId, HttpServletResponse httpServletResponse) {
    httpServletResponse.setContentType("text/event-stream");
    httpServletResponse.setCharacterEncoding("utf-8");
    return orderService.getOrderDetailById(orderId, httpServletResponse);
    }
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    // 简单来个线程池
    ThreadPoolExecutor executor = new ThreadPoolExecutor(10, 10, 60, TimeUnit.SECONDS, new LinkedBlockingQueue<>());

    @Override
    public SseEmitter getOrderDetailById(String orderId, HttpServletResponse httpServletResponse) {
    // 建立连接
    SseEmitter emitter = SseUtils.getConnect(orderId);
    executor.execute(() -> {
    while (true) {
    log.error("=========SSE轮询中=========");
    try {
    // 每5秒推送一次数据
    Thread.sleep(5000L);
    } catch (InterruptedException e) {
    throw new RuntimeException(e);
    }
    // 查询订单数据
    Torder torder = orderMapper.selectOne(Wrappers.lambdaQuery(Torder.class).eq(Torder::getOrderId, orderId));
    if (torder == null) {
    // 如果订单不存在,返回错误,主动断开连接
    SseUtils.sendMessage(orderId, JSON.toJSONString(ErrorCodeEnum.ORDER_ID_NOT_EXIST));
    SseUtils.removeClientId(orderId);
    break;
    }
    OrderDetailVO detailVO = new OrderDetailVO();
    detailVO.setIsExpire(stringRedisTemplate.opsForValue().get(orderId) == null);
    detailVO.setOrderId(orderId);
    detailVO.setCreateTime(torder.getCreateTime());
    detailVO.setOrderType(torder.getPolishType());
    detailVO.setAmount(torder.getAmount().doubleValue());
    // 根据不同的订单类型来封装不同的参数(这里为了满足产品的需求,想用一个接口显示不同种类订单的信息,用了SQL反模式设计数据库,导致代码很不优雅)
    if (torder.getOrderType() == 0) {
    Wrapper<Object> statusByOrderId = getStatusByOrderId(orderId);
    if (statusByOrderId.getCode() != 0) {
    // 订单状态查询异常,返回错误,主动断开连接
    SseUtils.sendMessage(orderId, JSON.toJSONString(ErrorCodeEnum.ASYNC_SERVICE_ERROR));
    SseUtils.removeClientId(orderId);
    break;
    }
    if (torder.getPolishType() == Common.POLISH_TYPE_WITH_PAPER) {
    PaperStatusByOrderIdVO paperVO = (PaperStatusByOrderIdVO) statusByOrderId.getResult();
    BeanUtils.copyProperties(paperVO, detailVO);
    detailVO.setProgress(Double.valueOf(paperVO.getProgress()));
    detailVO.setTitle(paperVO.getPaperTitle());
    detailVO.setOrderStatus(paperVO.getStatus());
    } else {
    TextStatusByOrderIdVO textVO = (TextStatusByOrderIdVO) statusByOrderId.getResult();
    BeanUtils.copyProperties(textVO, detailVO);
    detailVO.setProgress(Double.valueOf(textVO.getProgress()));
    detailVO.setTitle(textVO.getPaperTitle());
    detailVO.setOrderStatus(textVO.getStatus());
    }
    } else if (torder.getOrderType() == 1) {
    CheckpassOrder checkpassOrder = checkpassOrderMapper.selectOne(Wrappers.lambdaQuery(CheckpassOrder.class).eq(CheckpassOrder::getOrderId, orderId));
    CheckpassReport checkpassReport = checkpassReportMapper.selectOne(Wrappers.lambdaQuery(CheckpassReport.class).eq(CheckpassReport::getPaperId, checkpassOrder.getPaperId()));
    detailVO.setOrderStatus(checkpassOrder.getStatus());
    detailVO.setAuthor(checkpassReport.getAuthor());
    detailVO.setTitle(checkpassReport.getTitle());
    detailVO.setProgress(checkpassReport.getCopyPercent() == null ? 0 : checkpassReport.getCopyPercent());
    detailVO.setCheckVersion(CommonUtil.getCheckVersion(checkpassOrder.getJaneName()));
    }
    boolean flag = SseUtils.sendMessage(orderId, JSON.toJSONString(detailVO));
    if (!flag) {
    break;
    }
    if (torder.getStatus() == Common.ORDER_FINISH_STATUS) {
    // 订单完成,主动关闭连接
    try {
    emitter.send(SseEmitter.event().reconnectTime(5000L).data("SSE关闭连接"));
    } catch (IOException e) {
    throw new RuntimeException(e);
    }
    SseUtils.removeClientId(orderId);
    break;
    }
    }
    });
    return emitter;

使用过程中的一些坑

  1. 在使用过程中,浏览器中查看接口一直显示待处理状态,但我的Java服务确确实实已经推送了数据。
    • 如果你等待了一会儿,发现请求响应成功,但是一次性推送了很多条消息,那么大概率是缓冲区的问题,因为SSE是流式输出,流式输出通常会涉及到缓冲区的使用。在Java Servlet中,HttpServletResponse对象的输出流会有一个缓冲区。当使用Servlet的输出流写入数据时,这些数据首先会被写入缓冲区,然后才会被发送到客户端。所以我们需要再代码中禁用掉。
      1
      httpServletResponse.setHeader("X-Accel-Buffering", "no");
    • 同时Nginx里也要加上同样的配置,如果你中间经过了多级Nginx,需要每一级Nginx都禁用此项。
      1
      proxy_buffering off;
  2. 如果你使用了阿里云的CDN服务,那么请设置为动态加速
  3. 服务端无法到客户端网络中断:客户端网络中断后,服务端无法感知到客户端断开连接,就会导致服务端的线程中的任务一直在运行,不断地给这个客户端推送消息。解决方案如下:
    1. 通过给不同的业务场景给服务端设置不同的最大连接时长,超过这个时长,服务端会主动地去断开这个连接。
    2. 客户端感知断开连接的通知之后,如果当前订单任务还未结束,那么客户端会重新建立连接,直到订单任务结束,这样做能避免一些无效会话一直在推送消息的问题。
  4. 客户端重连机制:如果客户端因为网络问题或者其他问题进行了断线,那么客户端会根据服务端发送的retry参数设置的时间间隔进行重连,而这个时候服务端是暂时无法感知到客户端已经断线了,所以还是会在持续地去给客户端推送消息。假如客户端重连成功之后,就会出现以下两种场景:
    1. 服务端未断开连接:复用之前的连接线路,客户端会一次性收到多条断线期间未收到的消息内容,这个时候客户端使用限流,只更新最后一条消息,减少DOM渲染。
    2. 服务端主动断开了连接(订单任务结束断开/达到最大连接时长):重新建立一条线路(之前的那条线路其实还是存在的),因为是一条新线路所以之前断线时,服务端发送的消息,是收不到的。
  5. 如何保证用户在同一个业务场景下只会建立一条连接?
    1. 这也就是上面标黄处提到的问题,之前的会话id都是服务端来生成,最后修改为客户端来生成会话id并且临时保存在本地策略就是(业务ID - 用户token后20位 - 页面RUNTIME_ID),这个样做的原因主要还是确保用户在同一个业务场景下或者在断线重连时 客户端每次向服务端建立连接的会话id都是相同的,从而方便后面 服务端断开之前的线路。
    2. 由于服务端采用的是HashMap来存储每个SSE对象,所以在插入id相同的会话的时候,会直接替换map中已经存在的会话,虽然之前的会话已经不存在了,但是其建立的连接并没有真正的断开,所以服务端在新的会话插入之前,先去显式地去将之前的会话执行一次断开连接的操作,然后再去执行创建连接操作。否则,当多余的线路达到一定的数量之后,客户端会出现线路阻塞的问题。
  6. 新的会话加入之后,如何中断旧会话占用的线程?
    • 一开始的逻辑是将会话id保留在线程之中,具体流程是:判断当前会话是否存在 -> 存在就推送消息 -> sleep n秒。这样的处理的话就会出现一个问题,虽然我们在这里判断了会话id是否存在,但是由于上面我们在替换旧会话的时候,又重新创建了一个相同id的新会话(在同一个业务场景下多次建立连接,每次的会话id是一样的),所以当前线程sleep结束之后,会发现这个会话是存在的,从而会继续给这个会话推送消息。这个时候客户端会收到多个不同线程发送来的消息的问题。解决方案如下:
      1. 在每次建立连接的时候将会话和该会话的所属线程关联在一起,也就是将管理会话的map由原来的 Map<String, SseEmitter>类型,修改为: Map<String, SseEmitterInfo> 类型,其中SseEmitterInfo是我们自己封装的一个类,其中包含SseEmitter对象和建立该连接时的线程名。
      2. 在发送消息之前,需要判断当前会话是否存在,并且判断该会话所属的线程是否是当前线程,如果满足上面两个条件的话,就推送消息;否则,中断线程;这样就可以保证每一个会话只会有一个线程在推送消息。

一些补充

  • 后续实际使用的时候,我又对SseUtils进行了改进,最终版本如下
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    import com.aimc.paperreduction.common.wrapper.RWrappers;
    import com.aimc.paperreduction.model.enums.ErrorCodeEnum;
    import com.alibaba.fastjson.JSONObject;
    import lombok.Data;
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.http.MediaType;
    import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

    import java.io.IOException;
    import java.util.Map;
    import java.util.Set;
    import java.util.concurrent.ConcurrentHashMap;

    @Slf4j
    public class SseUtil {
    // 为了避免内存泄露,这里最好设置一个超时时间
    private static final Long DEFAULT_TIME_OUT = 30L * 60 * 1000;
    // 会话map
    private static final Map<String, SseEmitterInfo> conversationMap = new ConcurrentHashMap<>();

    /***
    * 断开连接
    *
    * @param conversationId - 会话Id
    */
    public static void disconnect(String conversationId) {
    SseEmitterInfo instance = SseUtil.getInstance(conversationId);
    if (instance != null) {
    instance.getEmitter().complete();
    }
    }

    /**
    * 建立连接
    *
    * @param conversationId - 会话Id
    */
    public static SseEmitter getConnect(String conversationId) {
    // 创建 SseEmitterInfo
    SseEmitterInfo sseEmitterInfo = new SseEmitterInfo(conversationId);
    SseEmitter sseEmitter = new SseEmitter(DEFAULT_TIME_OUT);
    sseEmitterInfo.setEmitter(sseEmitter);
    // 异常
    try {
    // 设置前端重试时5s
    sseEmitter.send(SseEmitter.event().reconnectTime(5_000L).data(JSONObject.toJSONString(RWrappers.Fail(ErrorCodeEnum.SSE_CONNECT_SUCCESS))));
    // 连接超时
    sseEmitter.onTimeout(() -> SseUtil.timeout(conversationId));
    // 连接断开
    sseEmitter.onCompletion(() -> SseUtil.completion(conversationId));
    // 错误
    sseEmitter.onError((e) -> SseUtil.error(conversationId, e.getMessage()));
    // 添加sse
    conversationMap.put(conversationId, sseEmitterInfo);
    // 连接成功
    log.info("创建sse连接成功 ==> 当前连接总数={}, 会话Id={}", conversationMap.size(), conversationId);
    } catch (IOException e) {
    // 日志
    log.error("前端重连异常 ==> 会话Id={}, 异常信息={}", conversationId, e.getMessage());
    }
    // 返回
    return sseEmitter;
    }

    /***
    * 获取消息实例
    *
    * @param conversationId - 会话Id
    */
    public static SseEmitterInfo getInstance(String conversationId) {
    return conversationMap.get(conversationId);
    }

    /**
    * 给指定会话发送消息,如果发送失败,返回false
    *
    * @param conversationId - 会话Id
    * @param jsonMsg - 消息
    */
    public static boolean sendMessage(String conversationId, String jsonMsg) {
    // 判断该会话是否还在Map中,不存在则删除
    if (!conversationMap.containsKey(conversationId)) {
    return false;
    }
    // 已建立连接
    if (SseUtil.getIsExistClientId(conversationId)) {
    try {
    // 发送消息
    SseUtil.getInstance(conversationId).getEmitter().send(jsonMsg, MediaType.APPLICATION_JSON);
    return true;
    } catch (IOException e) {
    // 日志
    SseUtil.removeClientId(conversationId);
    log.error("发送消息异常 ==> 会话Id={}, 异常信息={}", conversationId, e.getMessage());
    return false;
    }
    } else {
    // 未建立连接
    log.error("连接不存在或者超时 ==> 会话Id={}会话自动关闭", conversationId);
    SseUtil.removeClientId(conversationId);
    return false;
    }
    }

    /**
    * 移除并断开会话Id
    *
    * @param conversationId - 会话Id
    */
    public static void removeClientId(String conversationId) {
    // 不存在存在会话
    if (!SseUtil.getIsExistClientId(conversationId)) {
    return;
    }
    // 删除该会话
    conversationMap.remove(conversationId);
    SseUtil.disconnect(conversationId);
    // 日志
    log.info("移除会话成功 ==> 会话Id={}", conversationId);
    }

    /**
    * 获取是否存在会话
    *
    * @param conversationId - 会话Id
    */
    public static boolean getIsExistClientId(String conversationId) {
    return conversationMap.containsKey(conversationId);
    }

    /**
    * 获取当前连接总数
    *
    * @return - 连接总数
    */
    public static int getConnectTotal() {
    log.error("当前连接数:{}", conversationMap.size());
    return conversationMap.size();
    }

    /**
    * 超时
    *
    * @param conversationId String 会话Id
    */
    public static void timeout(String conversationId) {
    // 日志
    log.error("sse连接超时 ==> 会话Id={}", conversationId);
    // 移除会话
    SseUtil.removeClientId(conversationId);
    }

    /**
    * 完成
    *
    * @param conversationId String 会话Id
    */
    public static void completion(String conversationId) {
    // 日志
    log.info("sse连接已断开 ==> 会话Id:{},当前剩余连接数:{}", conversationId, conversationMap.size());
    // 移除会话
    SseUtil.removeClientId(conversationId);
    }

    /**
    * 错误
    *
    * @param conversationId String 会话Id
    */
    public static void error(String conversationId, String message) {
    // 日志
    // log.error("sse服务异常 ==> 会话Id={}, 异常信息={}", conversationId, message);
    // 移除会话
    SseUtil.removeClientId(conversationId);
    }

    public static class SseEmitterInfo {
    private SseEmitter emitter;
    private String threadName;

    public SseEmitterInfo(String conversationId) {
    this.emitter = null;
    this.threadName = Thread.currentThread().getName();
    }

    public SseEmitter getEmitter() {
    return emitter;
    }

    public void setEmitter(SseEmitter emitter) {
    this.emitter = emitter;
    }

    public String getThreadName() {
    return threadName;
    }

    public void setThreadName(String threadName) {
    this.threadName = threadName;
    }
    }
    }
  • 实际使用如下
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    @PostMapping("/connSse")
    public SseEmitter connSse(String conversationId) {
    SseEmitter emitter = SseUtil.getConnect(conversationId);
    executor.execute(() -> {
    // 这里需要保证同一个会话ID只有一个线程处理
    SseUtil.getInstance(conversationId).setThreadName(Thread.currentThread().getName());
    while (!Thread.interrupted() && SseUtil.getInstance(conversationId) != null && SseUtil.getInstance(conversationId).getThreadName().equals(Thread.currentThread().getName())) {
    boolean sendSuccess = SseUtil.sendMessage(conversationId, JSONObject.toJSONString(new byte[1024 * 10]));
    log.info("向会话:{},推送", conversationId);
    if (!sendSuccess) {
    log.info("=========连接不存在,服务端主动关闭SSE连接=========");
    SseUtil.removeClientId(conversationId);
    break;
    }
    // 每一秒推送一次数据
    try {
    Thread.sleep(1000L);
    } catch (InterruptedException e) {
    throw new RuntimeException(e);
    }
    }
    });
    return emitter;
    }
  • 由于我这里的业务限制,只能这么用SSE。原有的业务逻辑是,我轮询算法接口,更新数据,然后前端轮询我的接口,更新页面状态。使用了SSE之后变成了,我轮询算法接口,更新数据,然后向前端推送数据。
  • 但是更好的处理方式是,我这边给算法提供一个回调的接口,当算法有进度更新时,调用我这个回调接口,然后我在这个回调逻辑里向前端推送数据,这样逻辑上其实是更顺的,后续有时间,打算和算法侧聊聊这块,进一步优化。