Переглянути джерело

refactor(sessds): Factor out stream scheduler into its own module

ieQu1 2 роки тому
батько
коміт
4f4831fe7f

+ 182 - 350
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -115,17 +115,6 @@
     extra := map()
 }.
 
-%%%%% Session sequence numbers:
--define(next(QOS), {0, QOS}).
-%% Note: we consider the sequence number _committed_ once the full
-%% packet MQTT flow is completed for the sequence number. That is,
-%% when we receive PUBACK for the QoS1 message, or PUBCOMP, or PUBREC
-%% with Reason code > 0x80 for QoS2 message.
--define(committed(QOS), {1, QOS}).
-%% For QoS2 messages we also need to store the sequence number of the
-%% last PUBREL message:
--define(pubrec, 2).
-
 -define(TIMER_PULL, timer_pull).
 -define(TIMER_GET_STREAMS, timer_get_streams).
 -define(TIMER_BUMP_LAST_ALIVE_AT, timer_bump_last_alive_at).
@@ -156,7 +145,9 @@
     subscriptions_cnt,
     subscriptions_max,
     inflight_cnt,
-    inflight_max
+    inflight_max,
+    mqueue_len,
+    mqueue_dropped
 ]).
 
 %%
@@ -226,7 +217,7 @@ info(retry_interval, #{props := Conf}) ->
 % info(mqueue, #sessmem{mqueue = MQueue}) ->
 %     MQueue;
 info(mqueue_len, #{inflight := Inflight}) ->
-    emqx_persistent_session_ds_inflight:n_buffered(Inflight);
+    emqx_persistent_session_ds_inflight:n_buffered(all, Inflight);
 % info(mqueue_max, #sessmem{mqueue = MQueue}) ->
 %     emqx_mqueue:max_len(MQueue);
 info(mqueue_dropped, _Session) ->
@@ -250,7 +241,16 @@ stats(Session) ->
 %% Debug/troubleshooting
 -spec print_session(emqx_types:clientid()) -> map() | undefined.
 print_session(ClientId) ->
-    emqx_persistent_session_ds_state:print_session(ClientId).
+    case emqx_cm:lookup_channels(ClientId) of
+        [Pid] ->
+            #{channel := ChanState} = emqx_connection:get_state(Pid),
+            SessionState = emqx_channel:info(session_state, ChanState),
+            maps:update_with(s, fun emqx_persistent_session_ds_state:format/1, SessionState#{
+                '_alive' => {true, Pid}
+            });
+        [] ->
+            emqx_persistent_session_ds_state:print_session(ClientId)
+    end.
 
 %%--------------------------------------------------------------------
 %% Client -> Broker: SUBSCRIBE / UNSUBSCRIBE
@@ -420,7 +420,7 @@ handle_timeout(
     ?TIMER_PULL,
     Session0
 ) ->
-    {Publishes, Session1} = drain_buffer(fill_buffer(Session0, ClientInfo)),
+    {Publishes, Session1} = drain_buffer(fetch_new_messages(Session0, ClientInfo)),
     Timeout =
         case Publishes of
             [] ->
@@ -431,7 +431,7 @@ handle_timeout(
     Session = emqx_session:ensure_timer(?TIMER_PULL, Timeout, Session1),
     {ok, Publishes, Session};
 handle_timeout(_ClientInfo, ?TIMER_GET_STREAMS, Session0 = #{s := S0}) ->
-    S = renew_streams(S0),
+    S = emqx_persistent_session_ds_stream_scheduler:renew_streams(S0),
     Interval = emqx_config:get([session_persistence, renew_streams_interval]),
     Session = emqx_session:ensure_timer(
         ?TIMER_GET_STREAMS,
@@ -461,11 +461,11 @@ bump_last_alive(S0) ->
 
 -spec replay(clientinfo(), [], session()) ->
     {ok, replies(), session()}.
-replay(ClientInfo, [], Session0) ->
-    Streams = find_replay_streams(Session0),
+replay(ClientInfo, [], Session0 = #{s := S0}) ->
+    Streams = emqx_persistent_session_ds_stream_scheduler:find_replay_streams(S0),
     Session = lists:foldl(
-        fun({StreamKey, Stream}, SessionAcc) ->
-            replay_batch(StreamKey, Stream, SessionAcc, ClientInfo)
+        fun({_StreamKey, Stream}, SessionAcc) ->
+            replay_batch(Stream, SessionAcc, ClientInfo)
         end,
         Session0,
         Streams
@@ -474,6 +474,27 @@ replay(ClientInfo, [], Session0) ->
     %% from now on we'll rely on the normal inflight/flow control
     %% mechanisms to replay them:
     {ok, [], pull_now(Session)}.
+
+-spec replay_batch(stream_state(), session(), clientinfo()) -> session().
+replay_batch(Ifs0, Session, ClientInfo) ->
+    #ifs{
+        batch_begin_key = BatchBeginMsgKey,
+        batch_size = BatchSize,
+        it_end = ItEnd
+    } = Ifs0,
+    %% TODO: retry
+    {ok, ItBegin} = emqx_ds:update_iterator(?PERSISTENT_MESSAGE_DB, ItEnd, BatchBeginMsgKey),
+    Ifs1 = Ifs0#ifs{it_end = ItBegin},
+    {Ifs, Inflight} = enqueue_batch(true, BatchSize, Ifs1, Session, ClientInfo),
+    %% Assert:
+    Ifs =:= Ifs1 orelse
+        ?SLOG(warning, #{
+            msg => "replay_inconsistency",
+            expected => Ifs1,
+            got => Ifs
+        }),
+    Session#{inflight => Inflight}.
+
 %%--------------------------------------------------------------------
 
 -spec disconnect(session(), emqx_types:conninfo()) -> {shutdown, session()}.
@@ -544,12 +565,21 @@ session_ensure_new(Id, ConnInfo, Conf) ->
     S1 = emqx_persistent_session_ds_state:set_conninfo(ConnInfo, S0),
     S2 = bump_last_alive(S1),
     S3 = emqx_persistent_session_ds_state:set_created_at(Now, S2),
-    S4 = emqx_persistent_session_ds_state:put_seqno(?next(?QOS_1), 0, S3),
-    S5 = emqx_persistent_session_ds_state:put_seqno(?committed(?QOS_1), 0, S4),
-    S6 = emqx_persistent_session_ds_state:put_seqno(?next(?QOS_2), 0, S5),
-    S7 = emqx_persistent_session_ds_state:put_seqno(?committed(?QOS_2), 0, S6),
-    S8 = emqx_persistent_session_ds_state:put_seqno(?pubrec, 0, S7),
-    S = emqx_persistent_session_ds_state:commit(S8),
+    S4 = lists:foldl(
+        fun(Track, Acc) ->
+            emqx_persistent_session_ds_state:put_seqno(Track, 0, Acc)
+        end,
+        S3,
+        [
+            ?next(?QOS_1),
+            ?dup(?QOS_1),
+            ?committed(?QOS_1),
+            ?next(?QOS_2),
+            ?dup(?QOS_2),
+            ?committed(?QOS_2)
+        ]
+    ),
+    S = emqx_persistent_session_ds_state:commit(S4),
     #{
         id => Id,
         props => Conf,
@@ -587,105 +617,88 @@ do_ensure_all_iterators_closed(_DSSessionID) ->
     ok.
 
 %%--------------------------------------------------------------------
-%% Buffer filling
+%% Normal replay:
 %%--------------------------------------------------------------------
 
-fill_buffer(Session = #{s := S}, ClientInfo) ->
-    Streams = shuffle(find_new_streams(S)),
-    ?SLOG(error, #{msg => "fill_buffer", streams => Streams}),
-    fill_buffer(Streams, Session, ClientInfo).
-
--spec shuffle([A]) -> [A].
-shuffle(L0) ->
-    L1 = lists:map(
-        fun(A) ->
-            %% maybe topic/stream prioritization could be introduced here?
-            {rand:uniform(), A}
-        end,
-        L0
-    ),
-    L2 = lists:sort(L1),
-    {_, L} = lists:unzip(L2),
-    L.
+fetch_new_messages(Session = #{s := S}, ClientInfo) ->
+    Streams = emqx_persistent_session_ds_stream_scheduler:find_new_streams(S),
+    ?SLOG(debug, #{msg => "fill_buffer", streams => Streams}),
+    fetch_new_messages(Streams, Session, ClientInfo).
 
-fill_buffer([], Session, _ClientInfo) ->
+fetch_new_messages([], Session, _ClientInfo) ->
     Session;
-fill_buffer(
-    [{StreamKey, Stream0 = #ifs{it_end = It0}} | Streams],
-    Session0 = #{s := S0, inflight := Inflight0},
-    ClientInfo
-) ->
+fetch_new_messages([I | Streams], Session0 = #{inflight := Inflight}, ClientInfo) ->
     BatchSize = emqx_config:get([session_persistence, max_batch_size]),
-    MaxBufferSize = BatchSize * 2,
-    case emqx_persistent_session_ds_inflight:n_buffered(Inflight0) < MaxBufferSize of
+    case emqx_persistent_session_ds_inflight:n_buffered(all, Inflight) >= BatchSize of
         true ->
-            case emqx_ds:next(?PERSISTENT_MESSAGE_DB, It0, BatchSize) of
-                {ok, It, []} ->
-                    S = emqx_persistent_session_ds_state:put_stream(
-                        StreamKey, Stream0#ifs{it_end = It}, S0
-                    ),
-                    fill_buffer(Streams, Session0#{s := S}, ClientInfo);
-                {ok, It, Messages} ->
-                    Session = new_batch(StreamKey, Stream0, It, Messages, Session0, ClientInfo),
-                    fill_buffer(Streams, Session, ClientInfo);
-                {ok, end_of_stream} ->
-                    S = emqx_persistent_session_ds_state:put_stream(
-                        StreamKey, Stream0#ifs{it_end = end_of_stream}, S0
-                    ),
-                    fill_buffer(Streams, Session0#{s := S}, ClientInfo)
-            end;
+            %% Buffer is full:
+            Session0;
         false ->
-            Session0
+            Session = new_batch(I, BatchSize, Session0, ClientInfo),
+            fetch_new_messages(Streams, Session, ClientInfo)
     end.
 
-new_batch(
-    StreamKey, Stream0, Iterator, [{BatchBeginMsgKey, _} | _] = Messages0, Session0, ClientInfo
-) ->
-    #{inflight := Inflight0, s := S0} = Session0,
-    FirstSeqnoQos1 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_1), S0),
-    FirstSeqnoQos2 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_2), S0),
-    NBefore = emqx_persistent_session_ds_inflight:n_buffered(Inflight0),
-    {LastSeqnoQos1, LastSeqnoQos2, Session} = do_process_batch(
-        false, FirstSeqnoQos1, FirstSeqnoQos2, Messages0, Session0, ClientInfo
-    ),
-    NAfter = emqx_persistent_session_ds_inflight:n_buffered(maps:get(inflight, Session)),
-    Stream = Stream0#ifs{
-        batch_size = NAfter - NBefore,
-        batch_begin_key = BatchBeginMsgKey,
-        first_seqno_qos1 = FirstSeqnoQos1,
-        first_seqno_qos2 = FirstSeqnoQos2,
-        last_seqno_qos1 = LastSeqnoQos1,
-        last_seqno_qos2 = LastSeqnoQos2,
-        it_end = Iterator
+new_batch({StreamKey, Ifs0}, BatchSize, Session = #{s := S0}, ClientInfo) ->
+    SN1 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_1), S0),
+    SN2 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_2), S0),
+    Ifs1 = Ifs0#ifs{
+        first_seqno_qos1 = SN1,
+        first_seqno_qos2 = SN2,
+        batch_size = 0,
+        batch_begin_key = undefined,
+        last_seqno_qos1 = SN1,
+        last_seqno_qos2 = SN2
     },
-    S1 = emqx_persistent_session_ds_state:put_seqno(?next(?QOS_1), LastSeqnoQos1, S0),
-    S2 = emqx_persistent_session_ds_state:put_seqno(?next(?QOS_2), LastSeqnoQos2, S1),
-    S = emqx_persistent_session_ds_state:put_stream(StreamKey, Stream, S2),
-    Session#{s => S}.
+    {Ifs, Inflight} = enqueue_batch(false, BatchSize, Ifs1, Session, ClientInfo),
+    S1 = emqx_persistent_session_ds_state:put_seqno(?next(?QOS_1), Ifs#ifs.last_seqno_qos1, S0),
+    S2 = emqx_persistent_session_ds_state:put_seqno(?next(?QOS_2), Ifs#ifs.last_seqno_qos2, S1),
+    S = emqx_persistent_session_ds_state:put_stream(StreamKey, Ifs, S2),
+    Session#{s => S, inflight => Inflight}.
 
-replay_batch(_StreamKey, Stream, Session0, ClientInfo) ->
+enqueue_batch(IsReplay, BatchSize, Ifs0, Session = #{inflight := Inflight0}, ClientInfo) ->
     #ifs{
-        batch_begin_key = BatchBeginMsgKey,
-        batch_size = BatchSize,
+        it_end = It0,
         first_seqno_qos1 = FirstSeqnoQos1,
-        first_seqno_qos2 = FirstSeqnoQos2,
-        it_end = ItEnd
-    } = Stream,
-    {ok, ItBegin} = emqx_ds:update_iterator(?PERSISTENT_MESSAGE_DB, ItEnd, BatchBeginMsgKey),
-    case emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, BatchSize) of
-        {ok, _ItEnd, Messages} ->
-            {_LastSeqnoQo1, _LastSeqnoQos2, Session} = do_process_batch(
-                true, FirstSeqnoQos1, FirstSeqnoQos2, Messages, Session0, ClientInfo
+        first_seqno_qos2 = FirstSeqnoQos2
+    } = Ifs0,
+    case emqx_ds:next(?PERSISTENT_MESSAGE_DB, It0, BatchSize) of
+        {ok, It, []} ->
+            %% No new messages; just update the end iterator:
+            {Ifs0#ifs{it_end = It}, Inflight0};
+        {ok, end_of_stream} ->
+            %% No new messages; just update the end iterator:
+            {Ifs0#ifs{it_end = end_of_stream}, Inflight0};
+        {ok, It, [{BatchBeginMsgKey, _} | _] = Messages} ->
+            {Inflight, LastSeqnoQos1, LastSeqnoQos2} = process_batch(
+                IsReplay, Session, ClientInfo, FirstSeqnoQos1, FirstSeqnoQos2, Messages, Inflight0
             ),
-            %% TODO: check consistency of the sequence numbers
-            Session
+            Ifs = Ifs0#ifs{
+                it_end = It,
+                batch_begin_key = BatchBeginMsgKey,
+                %% TODO: it should be possible to avoid calling
+                %% length here by diffing size of inflight before
+                %% and after inserting messages:
+                batch_size = length(Messages),
+                last_seqno_qos1 = LastSeqnoQos1,
+                last_seqno_qos2 = LastSeqnoQos2
+            },
+            {Ifs, Inflight};
+        {error, _} when not IsReplay ->
+            ?SLOG(debug, #{msg => "failed_to_fetch_batch", iterator => It0}),
+            {Ifs0, Inflight0}
     end.
 
-do_process_batch(_IsReplay, LastSeqnoQos1, LastSeqnoQos2, [], Session, _ClientInfo) ->
-    {LastSeqnoQos1, LastSeqnoQos2, Session};
-do_process_batch(IsReplay, FirstSeqnoQos1, FirstSeqnoQos2, [KV | Messages], Session, ClientInfo) ->
-    #{s := S, props := #{upgrade_qos := UpgradeQoS}, inflight := Inflight0} = Session,
+process_batch(_IsReplay, _Session, _ClientInfo, LastSeqNoQos1, LastSeqNoQos2, [], Inflight) ->
+    {Inflight, LastSeqNoQos1, LastSeqNoQos2};
+process_batch(
+    IsReplay, Session, ClientInfo, FirstSeqNoQos1, FirstSeqNoQos2, [KV | Messages], Inflight0
+) ->
+    #{s := S, props := #{upgrade_qos := UpgradeQoS}} = Session,
     {_DsMsgKey, Msg0 = #message{topic = Topic}} = KV,
+    Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
+    Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
+    Dup1 = emqx_persistent_session_ds_state:get_seqno(?dup(?QOS_1), S),
+    Dup2 = emqx_persistent_session_ds_state:get_seqno(?dup(?QOS_2), S),
     Subs = emqx_persistent_session_ds_state:get_subscriptions(S),
     Msgs = [
         Msg
@@ -695,266 +708,85 @@ do_process_batch(IsReplay, FirstSeqnoQos1, FirstSeqnoQos2, [KV | Messages], Sess
             emqx_session:enrich_message(ClientInfo, Msg0, SubOpts, UpgradeQoS)
         end
     ],
-    CommittedQos1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
-    CommittedQos2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
-    {Inflight, LastSeqnoQos1, LastSeqnoQos2} = lists:foldl(
-        fun(Msg = #message{qos = Qos}, {Inflight1, SeqnoQos10, SeqnoQos20}) ->
+    {Inflight, LastSeqNoQos1, LastSeqNoQos2} = lists:foldl(
+        fun(Msg = #message{qos = Qos}, {Acc, SeqNoQos10, SeqNoQos20}) ->
             case Qos of
                 ?QOS_0 ->
-                    SeqnoQos1 = SeqnoQos10,
-                    SeqnoQos2 = SeqnoQos20,
-                    PacketId = undefined;
+                    SeqNoQos1 = SeqNoQos10,
+                    SeqNoQos2 = SeqNoQos20;
                 ?QOS_1 ->
-                    SeqnoQos1 = inc_seqno(?QOS_1, SeqnoQos10),
-                    SeqnoQos2 = SeqnoQos20,
-                    PacketId = seqno_to_packet_id(?QOS_1, SeqnoQos1);
+                    SeqNoQos1 = inc_seqno(?QOS_1, SeqNoQos10),
+                    SeqNoQos2 = SeqNoQos20;
                 ?QOS_2 ->
-                    SeqnoQos1 = SeqnoQos10,
-                    SeqnoQos2 = inc_seqno(?QOS_2, SeqnoQos20),
-                    PacketId = seqno_to_packet_id(?QOS_2, SeqnoQos2)
+                    SeqNoQos1 = SeqNoQos10,
+                    SeqNoQos2 = inc_seqno(?QOS_2, SeqNoQos20)
             end,
-            %% ?SLOG(debug, #{
-            %%     msg => "out packet",
-            %%     qos => Qos,
-            %%     packet_id => PacketId,
-            %%     enriched => emqx_message:to_map(Msg),
-            %%     original => emqx_message:to_map(Msg0),
-            %%     upgrade_qos => UpgradeQoS
-            %% }),
-
-            %% Handle various situations where we want to ignore the packet:
-            Inflight2 =
-                case IsReplay of
-                    true when Qos =:= ?QOS_0 ->
-                        Inflight1;
-                    true when Qos =:= ?QOS_1, SeqnoQos1 < CommittedQos1 ->
-                        Inflight1;
-                    true when Qos =:= ?QOS_2, SeqnoQos2 < CommittedQos2 ->
-                        Inflight1;
-                    _ ->
-                        emqx_persistent_session_ds_inflight:push({PacketId, Msg}, Inflight1)
-                end,
             {
-                Inflight2,
-                SeqnoQos1,
-                SeqnoQos2
+                case Msg#message.qos of
+                    ?QOS_0 when IsReplay ->
+                        %% We ignore QoS 0 messages during replay:
+                        Acc;
+                    ?QOS_0 ->
+                        emqx_persistent_session_ds_inflight:push({undefined, Msg}, Acc);
+                    ?QOS_1 when SeqNoQos1 =< Comm1 ->
+                        %% QoS1 message has been acked by the client, ignore:
+                        Acc;
+                    ?QOS_1 when SeqNoQos1 =< Dup1 ->
+                        %% QoS1 message has been sent but not
+                        %% acked. Retransmit:
+                        Msg1 = emqx_message:set_flag(dup, true, Msg),
+                        emqx_persistent_session_ds_inflight:push({SeqNoQos1, Msg1}, Acc);
+                    ?QOS_1 ->
+                        emqx_persistent_session_ds_inflight:push({SeqNoQos1, Msg}, Acc);
+                    ?QOS_2 when SeqNoQos2 =< Comm2 ->
+                        %% QoS2 message has been PUBCOMP'ed by the client, ignore:
+                        Acc;
+                    ?QOS_2 when SeqNoQos2 =< Dup2 ->
+                        %% QoS2 message has been PUBREC'ed by the client, resend PUBREL:
+                        emqx_persistent_session_ds_inflight:push({pubrel, SeqNoQos2}, Acc);
+                    ?QOS_2 ->
+                        %% MQTT standard 4.3.3: DUP flag is never set for QoS2 messages:
+                        emqx_persistent_session_ds_inflight:push({SeqNoQos2, Msg}, Acc)
+                end,
+                SeqNoQos1,
+                SeqNoQos2
             }
         end,
-        {Inflight0, FirstSeqnoQos1, FirstSeqnoQos2},
+        {Inflight0, FirstSeqNoQos1, FirstSeqNoQos2},
         Msgs
     ),
-    do_process_batch(
-        IsReplay, LastSeqnoQos1, LastSeqnoQos2, Messages, Session#{inflight => Inflight}, ClientInfo
+    process_batch(
+        IsReplay, Session, ClientInfo, LastSeqNoQos1, LastSeqNoQos2, Messages, Inflight
     ).
 
 %%--------------------------------------------------------------------
 %% Buffer drain
 %%--------------------------------------------------------------------
 
-drain_buffer(Session = #{inflight := Inflight0}) ->
-    {Messages, Inflight} = emqx_persistent_session_ds_inflight:pop(Inflight0),
-    {Messages, Session#{inflight => Inflight}}.
+drain_buffer(Session = #{inflight := Inflight0, s := S0}) ->
+    {Publishes, Inflight, S} = do_drain_buffer(Inflight0, S0, []),
+    {Publishes, Session#{inflight => Inflight, s := S}}.
 
-%%--------------------------------------------------------------------
-%% Stream renew
-%%--------------------------------------------------------------------
-
-%% erlfmt-ignore
--define(fully_replayed(STREAM, COMMITTEDQOS1, COMMITTEDQOS2),
-    ((STREAM#ifs.last_seqno_qos1 =< COMMITTEDQOS1 orelse STREAM#ifs.last_seqno_qos1 =:= undefined) andalso
-     (STREAM#ifs.last_seqno_qos2 =< COMMITTEDQOS2 orelse STREAM#ifs.last_seqno_qos2 =:= undefined))).
-
-%% erlfmt-ignore
--define(last_replayed(STREAM, NEXTQOS1, NEXTQOS2),
-    ((STREAM#ifs.last_seqno_qos1 == NEXTQOS1 orelse STREAM#ifs.last_seqno_qos1 =:= undefined) andalso
-     (STREAM#ifs.last_seqno_qos2 == NEXTQOS2 orelse STREAM#ifs.last_seqno_qos2 =:= undefined))).
-
--spec find_replay_streams(session()) ->
-    [{emqx_persistent_session_ds_state:stream_key(), stream_state()}].
-find_replay_streams(#{s := S}) ->
-    CommQos1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
-    CommQos2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
-    Streams = emqx_persistent_session_ds_state:fold_streams(
-        fun(Key, Stream, Acc) ->
-            case Stream of
-                #ifs{
-                    first_seqno_qos1 = F1,
-                    first_seqno_qos2 = F2,
-                    last_seqno_qos1 = L1,
-                    last_seqno_qos2 = L2
-                } when F1 >= CommQos1, L1 =< CommQos1, F2 >= CommQos2, L2 =< CommQos2 ->
-                    [{Key, Stream} | Acc];
-                _ ->
-                    Acc
-            end
-        end,
-        [],
-        S
-    ),
-    lists:sort(
-        fun(
-            #ifs{first_seqno_qos1 = A1, first_seqno_qos2 = A2},
-            #ifs{first_seqno_qos1 = B1, first_seqno_qos2 = B2}
-        ) ->
-            case A1 =:= A2 of
-                true -> B1 =< B2;
-                false -> A1 < A2
-            end
-        end,
-        Streams
-    ).
-
--spec find_new_streams(emqx_persistent_session_ds_state:t()) ->
-    [{emqx_persistent_session_ds_state:stream_key(), stream_state()}].
-find_new_streams(S) ->
-    %% FIXME: this function is currently very sensitive to the
-    %% consistency of the packet IDs on both broker and client side.
-    %%
-    %% If the client fails to properly ack packets due to a bug, or a
-    %% network issue, or if the state of streams and seqno tables ever
-    %% become de-synced, then this function will return an empty list,
-    %% and the replay cannot progress.
-    %%
-    %% In other words, this function is not robust, and we should find
-    %% some way to get the replays un-stuck at the cost of potentially
-    %% losing messages during replay (or just kill the stuck channel
-    %% after timeout?)
-    CommQos1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
-    CommQos2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
-    emqx_persistent_session_ds_state:fold_streams(
-        fun
-            (Key, Stream, Acc) when ?fully_replayed(Stream, CommQos1, CommQos2) ->
-                %% This stream has been full acked by the client. It
-                %% means we can get more messages from it:
-                [{Key, Stream} | Acc];
-            (_Key, _Stream, Acc) ->
-                Acc
-        end,
-        [],
-        S
-    ).
-
--spec renew_streams(emqx_persistent_session_ds_state:t()) -> emqx_persistent_session_ds_state:t().
-renew_streams(S0) ->
-    S1 = remove_old_streams(S0),
-    subs_fold(
-        fun(TopicFilterBin, _Subscription = #{start_time := StartTime, id := SubId}, S2) ->
-            TopicFilter = emqx_topic:words(TopicFilterBin),
-            Streams = select_streams(
-                SubId,
-                emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
-                S2
-            ),
-            lists:foldl(
-                fun(I, Acc) ->
-                    ensure_iterator(TopicFilter, StartTime, SubId, I, Acc)
-                end,
-                S2,
-                Streams
-            )
-        end,
-        S1,
-        S1
-    ).
-
-ensure_iterator(TopicFilter, StartTime, SubId, {{RankX, RankY}, Stream}, S) ->
-    Key = {SubId, Stream},
-    case emqx_persistent_session_ds_state:get_stream(Key, S) of
+do_drain_buffer(Inflight0, S0, Acc) ->
+    case emqx_persistent_session_ds_inflight:pop(Inflight0) of
         undefined ->
-            {ok, Iterator} = emqx_ds:make_iterator(
-                ?PERSISTENT_MESSAGE_DB, Stream, TopicFilter, StartTime
-            ),
-            NewStreamState = #ifs{
-                rank_x = RankX,
-                rank_y = RankY,
-                it_end = Iterator
-            },
-            emqx_persistent_session_ds_state:put_stream(Key, NewStreamState, S);
-        #ifs{} ->
-            S
-    end.
-
-select_streams(SubId, Streams0, S) ->
-    TopicStreamGroups = maps:groups_from_list(fun({{X, _}, _}) -> X end, Streams0),
-    maps:fold(
-        fun(RankX, Streams, Acc) ->
-            select_streams(SubId, RankX, Streams, S) ++ Acc
-        end,
-        [],
-        TopicStreamGroups
-    ).
-
-select_streams(SubId, RankX, Streams0, S) ->
-    %% 1. Find the streams with the rank Y greater than the recorded one:
-    Streams1 =
-        case emqx_persistent_session_ds_state:get_rank({SubId, RankX}, S) of
-            undefined ->
-                Streams0;
-            ReplayedY ->
-                [I || I = {{_, Y}, _} <- Streams0, Y > ReplayedY]
-        end,
-    %% 2. Sort streams by rank Y:
-    Streams = lists:sort(
-        fun({{_, Y1}, _}, {{_, Y2}, _}) ->
-            Y1 =< Y2
-        end,
-        Streams1
-    ),
-    %% 3. Select streams with the least rank Y:
-    case Streams of
-        [] ->
-            [];
-        [{{_, MinRankY}, _} | _] ->
-            lists:takewhile(fun({{_, Y}, _}) -> Y =:= MinRankY end, Streams)
-    end.
-
--spec remove_old_streams(emqx_persistent_session_ds_state:t()) ->
-    emqx_persistent_session_ds_state:t().
-remove_old_streams(S0) ->
-    CommQos1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S0),
-    CommQos2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S0),
-    %% 1. For each subscription, find the X ranks that were fully replayed:
-    Groups = emqx_persistent_session_ds_state:fold_streams(
-        fun({SubId, _Stream}, StreamState = #ifs{rank_x = RankX, rank_y = RankY, it_end = It}, Acc) ->
-            Key = {SubId, RankX},
-            IsComplete =
-                It =:= end_of_stream andalso ?fully_replayed(StreamState, CommQos1, CommQos2),
-            case {maps:get(Key, Acc, undefined), IsComplete} of
-                {undefined, true} ->
-                    Acc#{Key => {true, RankY}};
-                {_, false} ->
-                    Acc#{Key => false};
-                _ ->
-                    Acc
-            end
-        end,
-        #{},
-        S0
-    ),
-    %% 2. Advance rank y for each fully replayed set of streams:
-    S1 = maps:fold(
-        fun
-            (Key, {true, RankY}, Acc) ->
-                emqx_persistent_session_ds_state:put_rank(Key, RankY, Acc);
-            (_, _, Acc) ->
-                Acc
-        end,
-        S0,
-        Groups
-    ),
-    %% 3. Remove the fully replayed streams:
-    emqx_persistent_session_ds_state:fold_streams(
-        fun(Key = {SubId, _Stream}, #ifs{rank_x = RankX, rank_y = RankY}, Acc) ->
-            case emqx_persistent_session_ds_state:get_rank({SubId, RankX}, Acc) of
-                MinRankY when RankY < MinRankY ->
-                    emqx_persistent_session_ds_state:del_stream(Key, Acc);
-                _ ->
-                    Acc
+            {lists:reverse(Acc), Inflight0, S0};
+        {{pubrel, SeqNo}, Inflight} ->
+            Publish = {pubrel, seqno_to_packet_id(?QOS_2, SeqNo)},
+            do_drain_buffer(Inflight, S0, [Publish | Acc]);
+        {{SeqNo, Msg}, Inflight} ->
+            case Msg#message.qos of
+                ?QOS_0 ->
+                    do_drain_buffer(Inflight, S0, [{undefined, Msg} | Acc]);
+                ?QOS_1 ->
+                    S = emqx_persistent_session_ds_state:put_seqno(?dup(?QOS_1), SeqNo, S0),
+                    Publish = {seqno_to_packet_id(?QOS_1, SeqNo), Msg},
+                    do_drain_buffer(Inflight, S, [Publish | Acc]);
+                ?QOS_2 ->
+                    Publish = {seqno_to_packet_id(?QOS_2, SeqNo), Msg},
+                    do_drain_buffer(Inflight, S0, [Publish | Acc])
             end
-        end,
-        S1,
-        S1
-    ).
+    end.
 
 %%--------------------------------------------------------------------------------
 
@@ -1023,7 +855,7 @@ commit_seqno(Track, PacketId, Session = #{id := SessionId, s := S}) ->
             Old = ?committed(?QOS_1),
             Next = ?next(?QOS_1);
         pubrec ->
-            Old = ?pubrec,
+            Old = ?dup(?QOS_2),
             Next = ?next(?QOS_2);
         pubcomp ->
             Old = ?committed(?QOS_2),

+ 27 - 12
apps/emqx/src/emqx_persistent_session_ds.hrl

@@ -25,25 +25,40 @@
 -define(SESSION_COMMITTED_OFFSET_TAB, emqx_ds_committed_offset_tab).
 -define(DS_MRIA_SHARD, emqx_ds_session_shard).
 
-%% State of the stream:
+%%%%% Session sequence numbers:
+
+%%
+%%   -----|----------|----------|------> seqno
+%%        |          |          |
+%%   committed      dup       next
+
+%% Seqno becomes committed after receiving PUBACK for QoS1 or PUBCOMP
+%% for QoS2.
+-define(committed(QOS), {0, QOS}).
+%% Seqno becomes dup:
+%%
+%% 1. After broker sends QoS1 message to the client
+%% 2. After it receives PUBREC from the client for the QoS2 message
+-define(dup(QOS), {1, QOS}).
+%% Last seqno assigned to some message (that may reside in the
+%% mqueue):
+-define(next(QOS), {0, QOS}).
+
+%%%%% State of the stream:
 -record(ifs, {
     rank_x :: emqx_ds:rank_x(),
     rank_y :: emqx_ds:rank_y(),
     %% Iterator at the end of the last batch:
-    it_end :: emqx_ds:iterator() | undefined | end_of_stream,
-    %% Size of the last batch:
-    batch_size :: pos_integer() | undefined,
+    it_end :: emqx_ds:iterator() | end_of_stream,
     %% Key that points at the beginning of the batch:
     batch_begin_key :: binary() | undefined,
-    %% Number of messages collected in the last batch:
-    batch_n_messages :: pos_integer() | undefined,
+    batch_size = 0 :: non_neg_integer(),
     %% Session sequence number at the time when the batch was fetched:
-    first_seqno_qos1 :: emqx_persistent_session_ds:seqno() | undefined,
-    first_seqno_qos2 :: emqx_persistent_session_ds:seqno() | undefined,
-    %% Sequence numbers that the client must PUBACK or PUBREL
-    %% before we can consider the batch to be fully replayed:
-    last_seqno_qos1 :: emqx_persistent_session_ds:seqno() | undefined,
-    last_seqno_qos2 :: emqx_persistent_session_ds:seqno() | undefined
+    first_seqno_qos1 = 0 :: emqx_persistent_session_ds:seqno(),
+    first_seqno_qos2 = 0 :: emqx_persistent_session_ds:seqno(),
+    %% Number of messages collected in the last batch:
+    last_seqno_qos1 = 0 :: emqx_persistent_session_ds:seqno(),
+    last_seqno_qos2 = 0 :: emqx_persistent_session_ds:seqno()
 }).
 
 %% TODO: remove

+ 52 - 28
apps/emqx/src/emqx_persistent_session_ds_inflight.erl

@@ -16,7 +16,7 @@
 -module(emqx_persistent_session_ds_inflight).
 
 %% API:
--export([new/1, push/2, pop/1, n_buffered/1, n_inflight/1, inc_send_quota/1, receive_maximum/1]).
+-export([new/1, push/2, pop/1, n_buffered/2, n_inflight/1, inc_send_quota/1, receive_maximum/1]).
 
 %% behavior callbacks:
 -export([]).
@@ -44,6 +44,10 @@
 
 -type t() :: #inflight{}.
 
+-type payload() ::
+    {emqx_persistent_session_ds:seqno() | undefined, emqx_types:message()}
+    | {pubrel, emqx_persistent_session_ds:seqno()}.
+
 %%================================================================================
 %% API funcions
 %%================================================================================
@@ -56,10 +60,12 @@ new(ReceiveMaximum) when ReceiveMaximum > 0 ->
 receive_maximum(#inflight{receive_maximum = ReceiveMaximum}) ->
     ReceiveMaximum.
 
--spec push({emqx_types:packet_id() | undefined, emqx_types:message()}, t()) -> t().
-push(Val = {_PacketId, Msg}, Rec) ->
+-spec push(payload(), t()) -> t().
+push(Payload = {pubrel, _SeqNo}, Rec = #inflight{queue = Q}) ->
+    Rec#inflight{queue = queue:in(Payload, Q)};
+push(Payload = {_, Msg}, Rec) ->
     #inflight{queue = Q0, n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2} = Rec,
-    Q = queue:in(Val, Q0),
+    Q = queue:in(Payload, Q0),
     case Msg#message.qos of
         ?QOS_0 ->
             Rec#inflight{queue = Q, n_qos0 = NQos0 + 1};
@@ -69,12 +75,49 @@ push(Val = {_PacketId, Msg}, Rec) ->
             Rec#inflight{queue = Q, n_qos2 = NQos2 + 1}
     end.
 
--spec pop(t()) -> {[{emqx_types:packet_id() | undefined, emqx_types:message()}], t()}.
-pop(Inflight = #inflight{receive_maximum = ReceiveMaximum}) ->
-    do_pop(ReceiveMaximum, Inflight, []).
+-spec pop(t()) -> {payload(), t()} | undefined.
+pop(Rec0) ->
+    #inflight{
+        receive_maximum = ReceiveMaximum,
+        n_inflight = NInflight,
+        queue = Q0,
+        n_qos0 = NQos0,
+        n_qos1 = NQos1,
+        n_qos2 = NQos2
+    } = Rec0,
+    case NInflight < ReceiveMaximum andalso queue:out(Q0) of
+        {{value, Payload}, Q} ->
+            Rec =
+                case Payload of
+                    {pubrel, _} ->
+                        Rec0#inflight{queue = Q};
+                    {_, #message{qos = Qos}} ->
+                        case Qos of
+                            ?QOS_0 ->
+                                Rec0#inflight{queue = Q, n_qos0 = NQos0 - 1};
+                            ?QOS_1 ->
+                                Rec0#inflight{
+                                    queue = Q, n_qos1 = NQos1 - 1, n_inflight = NInflight + 1
+                                };
+                            ?QOS_2 ->
+                                Rec0#inflight{
+                                    queue = Q, n_qos2 = NQos2 - 1, n_inflight = NInflight + 1
+                                }
+                        end
+                end,
+            {Payload, Rec};
+        _ ->
+            undefined
+    end.
 
--spec n_buffered(t()) -> non_neg_integer().
-n_buffered(#inflight{n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2}) ->
+-spec n_buffered(0..2 | all, t()) -> non_neg_integer().
+n_buffered(?QOS_0, #inflight{n_qos0 = NQos0}) ->
+    NQos0;
+n_buffered(?QOS_1, #inflight{n_qos1 = NQos1}) ->
+    NQos1;
+n_buffered(?QOS_2, #inflight{n_qos2 = NQos2}) ->
+    NQos2;
+n_buffered(all, #inflight{n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2}) ->
     NQos0 + NQos1 + NQos2.
 
 -spec n_inflight(t()) -> non_neg_integer().
@@ -90,22 +133,3 @@ inc_send_quota(Rec = #inflight{n_inflight = NInflight0}) ->
 %%================================================================================
 %% Internal functions
 %%================================================================================
-
-do_pop(ReceiveMaximum, Rec0 = #inflight{n_inflight = NInflight, queue = Q0}, Acc) ->
-    case NInflight < ReceiveMaximum andalso queue:out(Q0) of
-        {{value, Val}, Q} ->
-            #inflight{n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2} = Rec0,
-            {_PacketId, #message{qos = Qos}} = Val,
-            Rec =
-                case Qos of
-                    ?QOS_0 ->
-                        Rec0#inflight{queue = Q, n_qos0 = NQos0 - 1};
-                    ?QOS_1 ->
-                        Rec0#inflight{queue = Q, n_qos1 = NQos1 - 1, n_inflight = NInflight + 1};
-                    ?QOS_2 ->
-                        Rec0#inflight{queue = Q, n_qos2 = NQos2 - 1, n_inflight = NInflight + 1}
-                end,
-            do_pop(ReceiveMaximum, Rec, [Val | Acc]);
-        _ ->
-            {lists:reverse(Acc), Rec0}
-    end.

+ 8 - 1
apps/emqx/src/emqx_persistent_session_ds_state.erl

@@ -41,6 +41,7 @@
 
 -export_type([t/0, subscriptions/0, seqno_type/0, stream_key/0, rank_key/0]).
 
+-include("emqx_mqtt.hrl").
 -include("emqx_persistent_session_ds.hrl").
 
 %%================================================================================
@@ -89,7 +90,13 @@
         ?last_subid => integer()
     }.
 
--type seqno_type() :: term().
+-type seqno_type() ::
+    ?next(?QOS_1)
+    | ?dup(?QOS_1)
+    | ?committed(?QOS_1)
+    | ?next(?QOS_2)
+    | ?dup(?QOS_2)
+    | ?committed(?QOS_2).
 
 -opaque t() :: #{
     id := emqx_persistent_session_ds:id(),

+ 247 - 0
apps/emqx/src/emqx_persistent_session_ds_stream_scheduler.erl

@@ -0,0 +1,247 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+-module(emqx_persistent_session_ds_stream_scheduler).
+
+%% API:
+-export([find_new_streams/1, find_replay_streams/1]).
+-export([renew_streams/1]).
+
+%% behavior callbacks:
+-export([]).
+
+%% internal exports:
+-export([]).
+
+-export_type([]).
+
+-include("emqx_mqtt.hrl").
+-include("emqx_persistent_session_ds.hrl").
+
+%%================================================================================
+%% Type declarations
+%%================================================================================
+
+%%================================================================================
+%% API functions
+%%================================================================================
+
+-spec find_replay_streams(emqx_persistent_session_ds_state:t()) ->
+    [{emqx_persistent_session_ds_state:stream_key(), emqx_persistent_session_ds:stream_state()}].
+find_replay_streams(S) ->
+    Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
+    Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
+    %% 1. Find the streams that aren't fully acked
+    Streams = emqx_persistent_session_ds_state:fold_streams(
+        fun(Key, Stream, Acc) ->
+            case is_fully_acked(Comm1, Comm2, Stream) of
+                false ->
+                    [{Key, Stream} | Acc];
+                true ->
+                    Acc
+            end
+        end,
+        [],
+        S
+    ),
+    lists:sort(fun compare_streams/2, Streams).
+
+-spec find_new_streams(emqx_persistent_session_ds_state:t()) ->
+    [{emqx_persistent_session_ds_state:stream_key(), emqx_persistent_session_ds:stream_state()}].
+find_new_streams(S) ->
+    %% FIXME: this function is currently very sensitive to the
+    %% consistency of the packet IDs on both broker and client side.
+    %%
+    %% If the client fails to properly ack packets due to a bug, or a
+    %% network issue, or if the state of streams and seqno tables ever
+    %% become de-synced, then this function will return an empty list,
+    %% and the replay cannot progress.
+    %%
+    %% In other words, this function is not robust, and we should find
+    %% some way to get the replays un-stuck at the cost of potentially
+    %% losing messages during replay (or just kill the stuck channel
+    %% after timeout?)
+    Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
+    Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
+    shuffle(
+        emqx_persistent_session_ds_state:fold_streams(
+            fun(Key, Stream, Acc) ->
+                case is_fully_acked(Comm1, Comm2, Stream) of
+                    true ->
+                        [{Key, Stream} | Acc];
+                    false ->
+                        Acc
+                end
+            end,
+            [],
+            S
+        )
+    ).
+
+-spec renew_streams(emqx_persistent_session_ds_state:t()) -> emqx_persistent_session_ds_state:t().
+renew_streams(S0) ->
+    S1 = remove_fully_replayed_streams(S0),
+    emqx_topic_gbt:fold(
+        fun(Key, _Subscription = #{start_time := StartTime, id := SubId}, S2) ->
+            TopicFilter = emqx_topic:words(emqx_trie_search:get_topic(Key)),
+            Streams = select_streams(
+                SubId,
+                emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
+                S2
+            ),
+            lists:foldl(
+                fun(I, Acc) ->
+                    ensure_iterator(TopicFilter, StartTime, SubId, I, Acc)
+                end,
+                S2,
+                Streams
+            )
+        end,
+        S1,
+        emqx_persistent_session_ds_state:get_subscriptions(S1)
+    ).
+
+%%================================================================================
+%% Internal functions
+%%================================================================================
+
+ensure_iterator(TopicFilter, StartTime, SubId, {{RankX, RankY}, Stream}, S) ->
+    Key = {SubId, Stream},
+    case emqx_persistent_session_ds_state:get_stream(Key, S) of
+        undefined ->
+            {ok, Iterator} = emqx_ds:make_iterator(
+                ?PERSISTENT_MESSAGE_DB, Stream, TopicFilter, StartTime
+            ),
+            NewStreamState = #ifs{
+                rank_x = RankX,
+                rank_y = RankY,
+                it_end = Iterator
+            },
+            emqx_persistent_session_ds_state:put_stream(Key, NewStreamState, S);
+        #ifs{} ->
+            S
+    end.
+
+select_streams(SubId, Streams0, S) ->
+    TopicStreamGroups = maps:groups_from_list(fun({{X, _}, _}) -> X end, Streams0),
+    maps:fold(
+        fun(RankX, Streams, Acc) ->
+            select_streams(SubId, RankX, Streams, S) ++ Acc
+        end,
+        [],
+        TopicStreamGroups
+    ).
+
+select_streams(SubId, RankX, Streams0, S) ->
+    %% 1. Find the streams with the rank Y greater than the recorded one:
+    Streams1 =
+        case emqx_persistent_session_ds_state:get_rank({SubId, RankX}, S) of
+            undefined ->
+                Streams0;
+            ReplayedY ->
+                [I || I = {{_, Y}, _} <- Streams0, Y > ReplayedY]
+        end,
+    %% 2. Sort streams by rank Y:
+    Streams = lists:sort(
+        fun({{_, Y1}, _}, {{_, Y2}, _}) ->
+            Y1 =< Y2
+        end,
+        Streams1
+    ),
+    %% 3. Select streams with the least rank Y:
+    case Streams of
+        [] ->
+            [];
+        [{{_, MinRankY}, _} | _] ->
+            lists:takewhile(fun({{_, Y}, _}) -> Y =:= MinRankY end, Streams)
+    end.
+
+-spec remove_fully_replayed_streams(emqx_persistent_session_ds_state:t()) ->
+    emqx_persistent_session_ds_state:t().
+remove_fully_replayed_streams(S0) ->
+    CommQos1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S0),
+    CommQos2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S0),
+    %% 1. For each subscription, find the X ranks that were fully replayed:
+    Groups = emqx_persistent_session_ds_state:fold_streams(
+        fun({SubId, _Stream}, StreamState = #ifs{rank_x = RankX, rank_y = RankY}, Acc) ->
+            Key = {SubId, RankX},
+            case
+                {maps:get(Key, Acc, undefined), is_fully_replayed(CommQos1, CommQos2, StreamState)}
+            of
+                {undefined, true} ->
+                    Acc#{Key => {true, RankY}};
+                {_, false} ->
+                    Acc#{Key => false};
+                _ ->
+                    Acc
+            end
+        end,
+        #{},
+        S0
+    ),
+    %% 2. Advance rank y for each fully replayed set of streams:
+    S1 = maps:fold(
+        fun
+            (Key, {true, RankY}, Acc) ->
+                emqx_persistent_session_ds_state:put_rank(Key, RankY, Acc);
+            (_, _, Acc) ->
+                Acc
+        end,
+        S0,
+        Groups
+    ),
+    %% 3. Remove the fully replayed streams:
+    emqx_persistent_session_ds_state:fold_streams(
+        fun(Key = {SubId, _Stream}, #ifs{rank_x = RankX, rank_y = RankY}, Acc) ->
+            case emqx_persistent_session_ds_state:get_rank({SubId, RankX}, Acc) of
+                MinRankY when RankY < MinRankY ->
+                    emqx_persistent_session_ds_state:del_stream(Key, Acc);
+                _ ->
+                    Acc
+            end
+        end,
+        S1,
+        S1
+    ).
+
+compare_streams(
+    #ifs{first_seqno_qos1 = A1, first_seqno_qos2 = A2},
+    #ifs{first_seqno_qos1 = B1, first_seqno_qos2 = B2}
+) ->
+    case A1 =:= B1 of
+        true ->
+            A2 =< B2;
+        false ->
+            A1 < B1
+    end.
+
+is_fully_replayed(Comm1, Comm2, S = #ifs{it_end = It}) ->
+    It =:= end_of_stream andalso is_fully_acked(Comm1, Comm2, S).
+
+is_fully_acked(Comm1, Comm2, #ifs{last_seqno_qos1 = S1, last_seqno_qos2 = S2}) ->
+    (Comm1 >= S1) andalso (Comm2 >= S2).
+
+-spec shuffle([A]) -> [A].
+shuffle(L0) ->
+    L1 = lists:map(
+        fun(A) ->
+            %% maybe topic/stream prioritization could be introduced here?
+            {rand:uniform(), A}
+        end,
+        L0
+    ),
+    L2 = lists:sort(L1),
+    {_, L} = lists:unzip(L2),
+    L.

+ 8 - 17
apps/emqx/test/emqx_persistent_session_SUITE.erl

@@ -713,8 +713,8 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
 
     ct:pal("Msgs2 = ~p", [Msgs2]),
 
-    ?assert(NMsgs2 < NPubs, Msgs2),
-    ?assert(NMsgs2 > NPubs2, Msgs2),
+    ?assert(NMsgs2 =< NPubs, {NMsgs2, '=<', NPubs}),
+    ?assert(NMsgs2 > NPubs2, {NMsgs2, '>', NPubs2}),
     ?assert(NMsgs2 >= NPubs - NAcked, Msgs2),
     NSame = NMsgs2 - NPubs2,
     ?assert(
@@ -782,9 +782,8 @@ t_publish_many_while_client_is_gone(Config) ->
     ClientOpts = [
         {proto_ver, v5},
         {clientid, ClientId},
-        %,
-        {properties, #{'Session-Expiry-Interval' => 30}}
-        %{auto_ack, never}
+        {properties, #{'Session-Expiry-Interval' => 30}},
+        {auto_ack, never}
         | Config
     ],
 
@@ -811,12 +810,12 @@ t_publish_many_while_client_is_gone(Config) ->
     Msgs1 = receive_messages(NPubs1),
     ct:pal("Msgs1 = ~p", [Msgs1]),
     NMsgs1 = length(Msgs1),
-    NPubs1 =:= NMsgs1 orelse
-        throw_with_debug_info({NPubs1, '==', NMsgs1}, ClientId),
+    ?assertEqual(NPubs1, NMsgs1, emqx_persistent_session_ds:print_session(ClientId)),
 
     ?assertEqual(
         get_topicwise_order(Pubs1),
-        get_topicwise_order(Msgs1)
+        get_topicwise_order(Msgs1),
+        emqx_persistent_session_ds:print_session(ClientId)
     ),
 
     %% PUBACK every QoS 1 message.
@@ -1088,14 +1087,6 @@ skip_ds_tc(Config) ->
     end.
 
 throw_with_debug_info(Error, ClientId) ->
-    Info =
-        case emqx_cm:lookup_channels(ClientId) of
-            [Pid] ->
-                #{channel := ChanState} = emqx_connection:get_state(Pid),
-                SessionState = emqx_channel:info(session_state, ChanState),
-                maps:update_with(s, fun emqx_persistent_session_ds_state:format/1, SessionState);
-            [] ->
-                no_channel
-        end,
+    Info = emqx_persistent_session_ds:print_session(ClientId),
     ct:pal("!!! Assertion failed: ~p~nState:~n~p", [Error, Info]),
     exit(Error).