Просмотр исходного кода

Merge pull request #12029 from keynslug/ft/EMQX-11049/qos2

feat(sessds): provide QoS2 message replay support
Andrew Mayorov 2 лет назад
Родитель
Сommit
ccef91437d

+ 2 - 2
apps/emqx/rebar.config

@@ -45,7 +45,7 @@
             {meck, "0.9.2"},
             {proper, "1.4.0"},
             {bbmustache, "1.10.0"},
-            {emqtt, {git, "https://github.com/emqx/emqtt", {tag, "1.9.1"}}}
+            {emqtt, {git, "https://github.com/emqx/emqtt", {tag, "1.9.6"}}}
         ]},
         {extra_src_dirs, [{"test", [recursive]},
                           {"integration_test", [recursive]}]}
@@ -55,7 +55,7 @@
             {meck, "0.9.2"},
             {proper, "1.4.0"},
             {bbmustache, "1.10.0"},
-            {emqtt, {git, "https://github.com/emqx/emqtt", {tag, "1.9.1"}}}
+            {emqtt, {git, "https://github.com/emqx/emqtt", {tag, "1.9.6"}}}
         ]},
         {extra_src_dirs, [{"test", [recursive]}]}
     ]}

+ 1 - 0
apps/emqx/src/emqx_channel.erl

@@ -423,6 +423,7 @@ handle_in(
             {ok, Channel}
     end;
 handle_in(
+    %% TODO: Why discard the Reason Code?
     ?PUBREC_PACKET(PacketId, _ReasonCode, Properties),
     Channel =
         #channel{clientinfo = ClientInfo, session = Session}

+ 339 - 150
apps/emqx/src/emqx_persistent_message_ds_replayer.erl

@@ -19,7 +19,13 @@
 -module(emqx_persistent_message_ds_replayer).
 
 %% API:
--export([new/0, open/1, next_packet_id/1, replay/1, commit_offset/3, poll/3, n_inflight/1]).
+-export([new/0, open/1, next_packet_id/1, n_inflight/1]).
+
+-export([poll/4, replay/2, commit_offset/4]).
+
+-export([seqno_to_packet_id/1, packet_id_to_seqno/2]).
+
+-export([committed_until/2]).
 
 %% internal exports:
 -export([]).
@@ -27,7 +33,6 @@
 -export_type([inflight/0, seqno/0]).
 
 -include_lib("emqx/include/logger.hrl").
--include_lib("emqx_utils/include/emqx_message.hrl").
 -include("emqx_persistent_session_ds.hrl").
 
 -ifdef(TEST).
@@ -35,6 +40,13 @@
 -include_lib("eunit/include/eunit.hrl").
 -endif.
 
+-define(EPOCH_SIZE, 16#10000).
+
+-define(ACK, 0).
+-define(COMP, 1).
+
+-define(TRACK_FLAG(WHICH), (1 bsl WHICH)).
+
 %%================================================================================
 %% Type declarations
 %%================================================================================
@@ -42,15 +54,23 @@
 %% Note: sequence numbers are monotonic; they don't wrap around:
 -type seqno() :: non_neg_integer().
 
+-type track() :: ack | comp.
+-type commit_type() :: rec.
+
 -record(inflight, {
     next_seqno = 1 :: seqno(),
-    acked_until = 1 :: seqno(),
+    commits = #{ack => 1, comp => 1, rec => 1} :: #{track() | commit_type() => seqno()},
     %% Ranges are sorted in ascending order of their sequence numbers.
     offset_ranges = [] :: [ds_pubrange()]
 }).
 
 -opaque inflight() :: #inflight{}.
 
+-type reply_fun() :: fun(
+    (seqno(), emqx_types:message()) ->
+        emqx_session:replies() | {_AdvanceSeqno :: false, emqx_session:replies()}
+).
+
 %%================================================================================
 %% API funcions
 %%================================================================================
@@ -61,10 +81,12 @@ new() ->
 
 -spec open(emqx_persistent_session_ds:id()) -> inflight().
 open(SessionId) ->
-    Ranges = ro_transaction(fun() -> get_ranges(SessionId) end),
-    {AckedUntil, NextSeqno} = compute_inflight_range(Ranges),
+    {Ranges, RecUntil} = ro_transaction(
+        fun() -> {get_ranges(SessionId), get_committed_offset(SessionId, rec)} end
+    ),
+    {Commits, NextSeqno} = compute_inflight_range(Ranges),
     #inflight{
-        acked_until = AckedUntil,
+        commits = Commits#{rec => RecUntil},
         next_seqno = NextSeqno,
         offset_ranges = Ranges
     }.
@@ -75,15 +97,30 @@ next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqno}) ->
     {seqno_to_packet_id(LastSeqno), Inflight}.
 
 -spec n_inflight(inflight()) -> non_neg_integer().
-n_inflight(#inflight{next_seqno = NextSeqno, acked_until = AckedUntil}) ->
-    range_size(AckedUntil, NextSeqno).
+n_inflight(#inflight{offset_ranges = Ranges}) ->
+    %% TODO
+    %% This is not very efficient. Instead, we can take the maximum of
+    %% `range_size(AckedUntil, NextSeqno)` and `range_size(CompUntil, NextSeqno)`.
+    %% This won't be exact number but a pessimistic estimate, but this way we
+    %% will penalize clients that PUBACK QoS 1 messages but don't PUBCOMP QoS 2
+    %% messages for some reason. For that to work, we need to additionally track
+    %% actual `AckedUntil` / `CompUntil` during `commit_offset/4`.
+    lists:foldl(
+        fun
+            (#ds_pubrange{type = ?T_CHECKPOINT}, N) ->
+                N;
+            (#ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until}, N) ->
+                N + range_size(First, Until)
+        end,
+        0,
+        Ranges
+    ).
 
--spec replay(inflight()) ->
-    {emqx_session:replies(), inflight()}.
-replay(Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}) ->
+-spec replay(reply_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
+replay(ReplyFun, Inflight0 = #inflight{offset_ranges = Ranges0}) ->
     {Ranges, Replies} = lists:mapfoldr(
         fun(Range, Acc) ->
-            replay_range(Range, AckedUntil, Acc)
+            replay_range(ReplyFun, Range, Acc)
         end,
         [],
         Ranges0
@@ -91,43 +128,49 @@ replay(Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0})
     Inflight = Inflight0#inflight{offset_ranges = Ranges},
     {Replies, Inflight}.
 
--spec commit_offset(emqx_persistent_session_ds:id(), emqx_types:packet_id(), inflight()) ->
-    {_IsValidOffset :: boolean(), inflight()}.
+-spec commit_offset(emqx_persistent_session_ds:id(), Offset, emqx_types:packet_id(), inflight()) ->
+    {_IsValidOffset :: boolean(), inflight()}
+when
+    Offset :: track() | commit_type().
 commit_offset(
     SessionId,
+    Track,
     PacketId,
-    Inflight0 = #inflight{
-        acked_until = AckedUntil, next_seqno = NextSeqno
-    }
-) ->
-    case packet_id_to_seqno(NextSeqno, PacketId) of
-        Seqno when Seqno >= AckedUntil andalso Seqno < NextSeqno ->
+    Inflight0 = #inflight{commits = Commits}
+) when Track == ack orelse Track == comp ->
+    case validate_commit(Track, PacketId, Inflight0) of
+        CommitUntil when is_integer(CommitUntil) ->
             %% TODO
-            %% We do not preserve `acked_until` in the database. Instead, we discard
+            %% We do not preserve `CommitUntil` in the database. Instead, we discard
             %% fully acked ranges from the database. In effect, this means that the
-            %% most recent `acked_until` the client has sent may be lost in case of a
+            %% most recent `CommitUntil` the client has sent may be lost in case of a
             %% crash or client loss.
-            Inflight1 = Inflight0#inflight{acked_until = next_seqno(Seqno)},
-            Inflight = discard_acked(SessionId, Inflight1),
+            Inflight1 = Inflight0#inflight{commits = Commits#{Track := CommitUntil}},
+            Inflight = discard_committed(SessionId, Inflight1),
             {true, Inflight};
-        OutOfRange ->
-            ?SLOG(warning, #{
-                msg => "out-of-order_ack",
-                acked_until => AckedUntil,
-                acked_seqno => OutOfRange,
-                next_seqno => NextSeqno,
-                packet_id => PacketId
-            }),
+        false ->
+            {false, Inflight0}
+    end;
+commit_offset(
+    SessionId,
+    CommitType = rec,
+    PacketId,
+    Inflight0 = #inflight{commits = Commits}
+) ->
+    case validate_commit(CommitType, PacketId, Inflight0) of
+        CommitUntil when is_integer(CommitUntil) ->
+            update_committed_offset(SessionId, CommitType, CommitUntil),
+            Inflight = Inflight0#inflight{commits = Commits#{CommitType := CommitUntil}},
+            {true, Inflight};
+        false ->
             {false, Inflight0}
     end.
 
--spec poll(emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
+-spec poll(reply_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
     {emqx_session:replies(), inflight()}.
-poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff ->
-    #inflight{next_seqno = NextSeqNo0, acked_until = AckedSeqno} =
-        Inflight0,
+poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
     FetchThreshold = max(1, WindowSize div 2),
-    FreeSpace = AckedSeqno + WindowSize - NextSeqNo0,
+    FreeSpace = WindowSize - n_inflight(Inflight0),
     case FreeSpace >= FetchThreshold of
         false ->
             %% TODO: this branch is meant to avoid fetching data from
@@ -138,9 +181,25 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
         true ->
             %% TODO: Wrap this in `mria:async_dirty/2`?
             Streams = shuffle(get_streams(SessionId)),
-            fetch(SessionId, Inflight0, Streams, FreeSpace, [])
+            fetch(ReplyFun, SessionId, Inflight0, Streams, FreeSpace, [])
     end.
 
+%% Which seqno this track is committed until.
+%% "Until" means this is first seqno that is _not yet committed_ for this track.
+-spec committed_until(track() | commit_type(), inflight()) -> seqno().
+committed_until(Track, #inflight{commits = Commits}) ->
+    maps:get(Track, Commits).
+
+-spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id() | 0.
+seqno_to_packet_id(Seqno) ->
+    Seqno rem ?EPOCH_SIZE.
+
+%% Reconstruct session counter by adding most significant bits from
+%% the current counter to the packet id.
+-spec packet_id_to_seqno(emqx_types:packet_id(), inflight()) -> seqno().
+packet_id_to_seqno(PacketId, #inflight{next_seqno = NextSeqno}) ->
+    packet_id_to_seqno_(NextSeqno, PacketId).
+
 %%================================================================================
 %% Internal exports
 %%================================================================================
@@ -150,18 +209,34 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
 %%================================================================================
 
 compute_inflight_range([]) ->
-    {1, 1};
+    {#{ack => 1, comp => 1}, 1};
 compute_inflight_range(Ranges) ->
     _RangeLast = #ds_pubrange{until = LastSeqno} = lists:last(Ranges),
-    RangesUnacked = lists:dropwhile(
-        fun(#ds_pubrange{type = T}) -> T == checkpoint end,
+    AckedUntil = find_committed_until(ack, Ranges),
+    CompUntil = find_committed_until(comp, Ranges),
+    Commits = #{
+        ack => emqx_maybe:define(AckedUntil, LastSeqno),
+        comp => emqx_maybe:define(CompUntil, LastSeqno)
+    },
+    {Commits, LastSeqno}.
+
+find_committed_until(Track, Ranges) ->
+    RangesUncommitted = lists:dropwhile(
+        fun(Range) ->
+            case Range of
+                #ds_pubrange{type = ?T_CHECKPOINT} ->
+                    true;
+                #ds_pubrange{type = ?T_INFLIGHT, tracks = Tracks} ->
+                    not has_track(Track, Tracks)
+            end
+        end,
         Ranges
     ),
-    case RangesUnacked of
-        [#ds_pubrange{id = {_, AckedUntil}} | _] ->
-            {AckedUntil, LastSeqno};
+    case RangesUncommitted of
+        [#ds_pubrange{id = {_, CommittedUntil}} | _] ->
+            CommittedUntil;
         [] ->
-            {LastSeqno, LastSeqno}
+            undefined
     end.
 
 -spec get_ranges(emqx_persistent_session_ds:id()) -> [ds_pubrange()].
@@ -173,21 +248,22 @@ get_ranges(SessionId) ->
     ),
     mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
 
-fetch(SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
+fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
     #inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
     ItBegin = get_last_iterator(DSStream, Ranges),
     {ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
     case Messages of
         [] ->
-            fetch(SessionId, Inflight0, Streams, N, Acc);
+            fetch(ReplyFun, SessionId, Inflight0, Streams, N, Acc);
         _ ->
-            {Publishes, UntilSeqno} = publish(FirstSeqno, Messages, _PreserveQoS0 = true),
-            Size = range_size(FirstSeqno, UntilSeqno),
             %% We need to preserve the iterator pointing to the beginning of the
             %% range, so that we can replay it if needed.
+            {Publishes, {UntilSeqno, Tracks}} = publish(ReplyFun, FirstSeqno, Messages),
+            Size = range_size(FirstSeqno, UntilSeqno),
             Range0 = #ds_pubrange{
                 id = {SessionId, FirstSeqno},
-                type = inflight,
+                type = ?T_INFLIGHT,
+                tracks = Tracks,
                 until = UntilSeqno,
                 stream = DSStream#ds_stream.ref,
                 iterator = ItBegin
@@ -196,25 +272,25 @@ fetch(SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
             %% ...Yet we need to keep the iterator pointing past the end of the
             %% range, so that we can pick up where we left off: it will become
             %% `ItBegin` of the next range for this stream.
-            Range = Range0#ds_pubrange{iterator = ItEnd},
+            Range = keep_next_iterator(ItEnd, Range0),
             Inflight = Inflight0#inflight{
                 next_seqno = UntilSeqno,
                 offset_ranges = Ranges ++ [Range]
             },
-            fetch(SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
+            fetch(ReplyFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
     end;
-fetch(_SessionId, Inflight, _Streams, _N, Acc) ->
+fetch(_ReplyFun, _SessionId, Inflight, _Streams, _N, Acc) ->
     Publishes = lists:append(lists:reverse(Acc)),
     {Publishes, Inflight}.
 
-discard_acked(
+discard_committed(
     SessionId,
-    Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}
+    Inflight0 = #inflight{commits = Commits, offset_ranges = Ranges0}
 ) ->
     %% TODO: This could be kept and incrementally updated in the inflight state.
     Checkpoints = find_checkpoints(Ranges0),
     %% TODO: Wrap this in `mria:async_dirty/2`?
-    Ranges = discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Ranges0),
+    Ranges = discard_committed_ranges(SessionId, Commits, Checkpoints, Ranges0),
     Inflight0#inflight{offset_ranges = Ranges}.
 
 find_checkpoints(Ranges) ->
@@ -227,84 +303,178 @@ find_checkpoints(Ranges) ->
         Ranges
     ).
 
-discard_acked_ranges(
+discard_committed_ranges(
     SessionId,
-    AckedUntil,
+    Commits,
     Checkpoints,
-    [Range = #ds_pubrange{until = Until, stream = StreamRef} | Rest]
-) when Until =< AckedUntil ->
-    %% This range has been fully acked.
-    %% Either discard it completely, or preserve the iterator for the next range
-    %% over this stream (i.e. a checkpoint).
-    RangeKept =
-        case maps:get(StreamRef, Checkpoints) of
-            CP when CP > Until ->
-                discard_range(Range),
-                [];
-            Until ->
-                [checkpoint_range(Range)]
+    Ranges = [Range = #ds_pubrange{until = Until, stream = StreamRef} | Rest]
+) ->
+    case discard_committed_range(Commits, Range) of
+        discard ->
+            %% This range has been fully committed.
+            %% Either discard it completely, or preserve the iterator for the next range
+            %% over this stream (i.e. a checkpoint).
+            RangeKept =
+                case maps:get(StreamRef, Checkpoints) of
+                    CP when CP > Until ->
+                        discard_range(Range),
+                        [];
+                    Until ->
+                        [checkpoint_range(Range)]
+                end,
+            %% Since we're (intentionally) not using transactions here, it's important to
+            %% issue database writes in the same order in which ranges are stored: from
+            %% the oldest to the newest. This is also why we need to compute which ranges
+            %% should become checkpoints before we start writing anything.
+            RangeKept ++ discard_committed_ranges(SessionId, Commits, Checkpoints, Rest);
+        keep ->
+            %% This range has not been fully committed.
+            [Range | discard_committed_ranges(SessionId, Commits, Checkpoints, Rest)];
+        keep_all ->
+            %% The rest of ranges (if any) still have uncommitted messages.
+            Ranges;
+        TracksLeft ->
+            %% Only some track has been committed.
+            %% Preserve the uncommitted tracks in the database.
+            RangeKept = Range#ds_pubrange{tracks = TracksLeft},
+            preserve_range(restore_first_iterator(RangeKept)),
+            [RangeKept | discard_committed_ranges(SessionId, Commits, Checkpoints, Rest)]
+    end;
+discard_committed_ranges(_SessionId, _Commits, _Checkpoints, []) ->
+    [].
+
+discard_committed_range(_Commits, #ds_pubrange{type = ?T_CHECKPOINT}) ->
+    discard;
+discard_committed_range(
+    #{ack := AckedUntil, comp := CompUntil},
+    #ds_pubrange{until = Until}
+) when Until > AckedUntil andalso Until > CompUntil ->
+    keep_all;
+discard_committed_range(Commits, #ds_pubrange{until = Until, tracks = Tracks}) ->
+    case discard_tracks(Commits, Until, Tracks) of
+        0 ->
+            discard;
+        Tracks ->
+            keep;
+        TracksLeft ->
+            TracksLeft
+    end.
+
+discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) ->
+    TAck =
+        case Until > AckedUntil of
+            true -> ?TRACK_FLAG(?ACK) band Tracks;
+            false -> 0
+        end,
+    TComp =
+        case Until > CompUntil of
+            true -> ?TRACK_FLAG(?COMP) band Tracks;
+            false -> 0
         end,
-    %% Since we're (intentionally) not using transactions here, it's important to
-    %% issue database writes in the same order in which ranges are stored: from
-    %% the oldest to the newest. This is also why we need to compute which ranges
-    %% should become checkpoints before we start writing anything.
-    RangeKept ++ discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Rest);
-discard_acked_ranges(_SessionId, _AckedUntil, _Checkpoints, Ranges) ->
-    %% The rest of ranges (if any) still have unacked messages.
-    Ranges.
+    TAck bor TComp.
 
 replay_range(
-    Range0 = #ds_pubrange{type = inflight, id = {_, First}, until = Until, iterator = It},
-    AckedUntil,
+    ReplyFun,
+    Range0 = #ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until, iterator = It},
     Acc
 ) ->
     Size = range_size(First, Until),
-    FirstUnacked = max(First, AckedUntil),
-    {ok, ItNext, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
-    MessagesUnacked =
-        case FirstUnacked of
-            First ->
-                Messages;
-            _ ->
-                lists:nthtail(range_size(First, FirstUnacked), Messages)
-        end,
-    MessagesReplay = [emqx_message:set_flag(dup, true, Msg) || Msg <- MessagesUnacked],
+    {ok, ItNext, MessagesUnacked} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
     %% Asserting that range is consistent with the message storage state.
-    {Replies, Until} = publish(FirstUnacked, MessagesReplay, _PreserveQoS0 = false),
+    {Replies, {Until, _TracksInitial}} = publish(ReplyFun, First, MessagesUnacked),
     %% Again, we need to keep the iterator pointing past the end of the
     %% range, so that we can pick up where we left off.
-    Range = Range0#ds_pubrange{iterator = ItNext},
+    Range = keep_next_iterator(ItNext, Range0),
     {Range, Replies ++ Acc};
-replay_range(Range0 = #ds_pubrange{type = checkpoint}, _AckedUntil, Acc) ->
+replay_range(_ReplyFun, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
     {Range0, Acc}.
 
-publish(FirstSeqNo, Messages, PreserveQos0) ->
-    do_publish(FirstSeqNo, Messages, PreserveQos0, []).
+validate_commit(
+    Track,
+    PacketId,
+    Inflight = #inflight{commits = Commits, next_seqno = NextSeqno}
+) ->
+    Seqno = packet_id_to_seqno_(NextSeqno, PacketId),
+    CommittedUntil = maps:get(Track, Commits),
+    CommitNext = get_commit_next(Track, Inflight),
+    case Seqno >= CommittedUntil andalso Seqno < CommitNext of
+        true ->
+            next_seqno(Seqno);
+        false ->
+            ?SLOG(warning, #{
+                msg => "out-of-order_commit",
+                track => Track,
+                packet_id => PacketId,
+                commit_seqno => Seqno,
+                committed_until => CommittedUntil,
+                commit_next => CommitNext
+            }),
+            false
+    end.
+
+get_commit_next(ack, #inflight{next_seqno = NextSeqno}) ->
+    NextSeqno;
+get_commit_next(rec, #inflight{next_seqno = NextSeqno}) ->
+    NextSeqno;
+get_commit_next(comp, #inflight{commits = Commits}) ->
+    maps:get(rec, Commits).
+
+publish(ReplyFun, FirstSeqno, Messages) ->
+    lists:mapfoldl(
+        fun(Message, {Seqno, TAcc}) ->
+            case ReplyFun(Seqno, Message) of
+                {_Advance = false, Reply} ->
+                    {Reply, {Seqno, TAcc}};
+                Reply ->
+                    NextSeqno = next_seqno(Seqno),
+                    NextTAcc = add_msg_track(Message, TAcc),
+                    {Reply, {NextSeqno, NextTAcc}}
+            end
+        end,
+        {FirstSeqno, 0},
+        Messages
+    ).
+
+add_msg_track(Message, Tracks) ->
+    case emqx_message:qos(Message) of
+        1 -> ?TRACK_FLAG(?ACK) bor Tracks;
+        2 -> ?TRACK_FLAG(?COMP) bor Tracks;
+        _ -> Tracks
+    end.
 
-do_publish(SeqNo, [], _, Acc) ->
-    {lists:reverse(Acc), SeqNo};
-do_publish(SeqNo, [#message{qos = 0} | Messages], false, Acc) ->
-    do_publish(SeqNo, Messages, false, Acc);
-do_publish(SeqNo, [#message{qos = 0} = Message | Messages], true, Acc) ->
-    do_publish(SeqNo, Messages, true, [{undefined, Message} | Acc]);
-do_publish(SeqNo, [Message | Messages], PreserveQos0, Acc) ->
-    PacketId = seqno_to_packet_id(SeqNo),
-    do_publish(next_seqno(SeqNo), Messages, PreserveQos0, [{PacketId, Message} | Acc]).
+keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) ->
+    Range#ds_pubrange{
+        iterator = ItNext,
+        %% We need to keep the first iterator around, in case we need to preserve
+        %% this range again, updating still uncommitted tracks it's part of.
+        misc = Misc#{iterator_first => ItFirst}
+    }.
+
+restore_first_iterator(Range = #ds_pubrange{misc = Misc = #{iterator_first := ItFirst}}) ->
+    Range#ds_pubrange{
+        iterator = ItFirst,
+        misc = maps:remove(iterator_first, Misc)
+    }.
 
 -spec preserve_range(ds_pubrange()) -> ok.
-preserve_range(Range = #ds_pubrange{type = inflight}) ->
+preserve_range(Range = #ds_pubrange{type = ?T_INFLIGHT}) ->
     mria:dirty_write(?SESSION_PUBRANGE_TAB, Range).
 
+has_track(ack, Tracks) ->
+    (?TRACK_FLAG(?ACK) band Tracks) > 0;
+has_track(comp, Tracks) ->
+    (?TRACK_FLAG(?COMP) band Tracks) > 0.
+
 -spec discard_range(ds_pubrange()) -> ok.
 discard_range(#ds_pubrange{id = RangeId}) ->
     mria:dirty_delete(?SESSION_PUBRANGE_TAB, RangeId).
 
 -spec checkpoint_range(ds_pubrange()) -> ds_pubrange().
-checkpoint_range(Range0 = #ds_pubrange{type = inflight}) ->
-    Range = Range0#ds_pubrange{type = checkpoint},
+checkpoint_range(Range0 = #ds_pubrange{type = ?T_INFLIGHT}) ->
+    Range = Range0#ds_pubrange{type = ?T_CHECKPOINT, misc = #{}},
     ok = mria:dirty_write(?SESSION_PUBRANGE_TAB, Range),
     Range;
-checkpoint_range(Range = #ds_pubrange{type = checkpoint}) ->
+checkpoint_range(Range = #ds_pubrange{type = ?T_CHECKPOINT}) ->
     %% This range should have been checkpointed already.
     Range.
 
@@ -320,6 +490,21 @@ get_last_iterator(DSStream = #ds_stream{ref = StreamRef}, Ranges) ->
 get_streams(SessionId) ->
     mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId).
 
+-spec get_committed_offset(emqx_persistent_session_ds:id(), _Name) -> seqno().
+get_committed_offset(SessionId, Name) ->
+    case mnesia:read(?SESSION_COMMITTED_OFFSET_TAB, {SessionId, Name}) of
+        [] ->
+            1;
+        [#ds_committed_offset{until = Seqno}] ->
+            Seqno
+    end.
+
+-spec update_committed_offset(emqx_persistent_session_ds:id(), _Name, seqno()) -> ok.
+update_committed_offset(SessionId, Name, Until) ->
+    mria:dirty_write(?SESSION_COMMITTED_OFFSET_TAB, #ds_committed_offset{
+        id = {SessionId, Name}, until = Until
+    }).
+
 next_seqno(Seqno) ->
     NextSeqno = Seqno + 1,
     case seqno_to_packet_id(NextSeqno) of
@@ -332,26 +517,15 @@ next_seqno(Seqno) ->
             NextSeqno
     end.
 
-%% Reconstruct session counter by adding most significant bits from
-%% the current counter to the packet id.
--spec packet_id_to_seqno(_Next :: seqno(), emqx_types:packet_id()) -> seqno().
-packet_id_to_seqno(NextSeqNo, PacketId) ->
-    Epoch = NextSeqNo bsr 16,
-    case packet_id_to_seqno_(Epoch, PacketId) of
-        N when N =< NextSeqNo ->
+packet_id_to_seqno_(NextSeqno, PacketId) ->
+    Epoch = NextSeqno bsr 16,
+    case (Epoch bsl 16) + PacketId of
+        N when N =< NextSeqno ->
             N;
-        _ ->
-            packet_id_to_seqno_(Epoch - 1, PacketId)
+        N ->
+            N - ?EPOCH_SIZE
     end.
 
--spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> seqno().
-packet_id_to_seqno_(Epoch, PacketId) ->
-    (Epoch bsl 16) + PacketId.
-
--spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id() | 0.
-seqno_to_packet_id(Seqno) ->
-    Seqno rem 16#10000.
-
 range_size(FirstSeqno, UntilSeqno) ->
     %% This function assumes that gaps in the sequence ID occur _only_ when the
     %% packet ID wraps.
@@ -379,19 +553,19 @@ ro_transaction(Fun) ->
 %% This test only tests boundary conditions (to make sure property-based test didn't skip them):
 packet_id_to_seqno_test() ->
     %% Packet ID = 1; first epoch:
-    ?assertEqual(1, packet_id_to_seqno(1, 1)),
-    ?assertEqual(1, packet_id_to_seqno(10, 1)),
-    ?assertEqual(1, packet_id_to_seqno(1 bsl 16 - 1, 1)),
-    ?assertEqual(1, packet_id_to_seqno(1 bsl 16, 1)),
+    ?assertEqual(1, packet_id_to_seqno_(1, 1)),
+    ?assertEqual(1, packet_id_to_seqno_(10, 1)),
+    ?assertEqual(1, packet_id_to_seqno_(1 bsl 16 - 1, 1)),
+    ?assertEqual(1, packet_id_to_seqno_(1 bsl 16, 1)),
     %% Packet ID = 1; second and 3rd epochs:
-    ?assertEqual(1 bsl 16 + 1, packet_id_to_seqno(1 bsl 16 + 1, 1)),
-    ?assertEqual(1 bsl 16 + 1, packet_id_to_seqno(2 bsl 16, 1)),
-    ?assertEqual(2 bsl 16 + 1, packet_id_to_seqno(2 bsl 16 + 1, 1)),
+    ?assertEqual(1 bsl 16 + 1, packet_id_to_seqno_(1 bsl 16 + 1, 1)),
+    ?assertEqual(1 bsl 16 + 1, packet_id_to_seqno_(2 bsl 16, 1)),
+    ?assertEqual(2 bsl 16 + 1, packet_id_to_seqno_(2 bsl 16 + 1, 1)),
     %% Packet ID = 16#ffff:
     PID = 1 bsl 16 - 1,
-    ?assertEqual(PID, packet_id_to_seqno(PID, PID)),
-    ?assertEqual(PID, packet_id_to_seqno(1 bsl 16, PID)),
-    ?assertEqual(1 bsl 16 + PID, packet_id_to_seqno(2 bsl 16, PID)),
+    ?assertEqual(PID, packet_id_to_seqno_(PID, PID)),
+    ?assertEqual(PID, packet_id_to_seqno_(1 bsl 16, PID)),
+    ?assertEqual(1 bsl 16 + PID, packet_id_to_seqno_(2 bsl 16, PID)),
     ok.
 
 packet_id_to_seqno_test_() ->
@@ -406,8 +580,8 @@ packet_id_to_seqno_prop() ->
             SeqNo,
             seqno_gen(NextSeqNo),
             begin
-                PacketId = SeqNo rem 16#10000,
-                ?assertEqual(SeqNo, packet_id_to_seqno(NextSeqNo, PacketId)),
+                PacketId = seqno_to_packet_id(SeqNo),
+                ?assertEqual(SeqNo, packet_id_to_seqno_(NextSeqNo, PacketId)),
                 true
             end
         )
@@ -437,27 +611,42 @@ range_size_test_() ->
 compute_inflight_range_test_() ->
     [
         ?_assertEqual(
-            {1, 1},
+            {#{ack => 1, comp => 1}, 1},
             compute_inflight_range([])
         ),
         ?_assertEqual(
-            {12, 42},
+            {#{ack => 12, comp => 13}, 42},
             compute_inflight_range([
-                #ds_pubrange{id = {<<>>, 1}, until = 2, type = checkpoint},
-                #ds_pubrange{id = {<<>>, 4}, until = 8, type = checkpoint},
-                #ds_pubrange{id = {<<>>, 11}, until = 12, type = checkpoint},
-                #ds_pubrange{id = {<<>>, 12}, until = 13, type = inflight},
-                #ds_pubrange{id = {<<>>, 13}, until = 20, type = inflight},
-                #ds_pubrange{id = {<<>>, 20}, until = 42, type = inflight}
+                #ds_pubrange{id = {<<>>, 1}, until = 2, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 4}, until = 8, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 11}, until = 12, type = ?T_CHECKPOINT},
+                #ds_pubrange{
+                    id = {<<>>, 12},
+                    until = 13,
+                    type = ?T_INFLIGHT,
+                    tracks = ?TRACK_FLAG(?ACK)
+                },
+                #ds_pubrange{
+                    id = {<<>>, 13},
+                    until = 20,
+                    type = ?T_INFLIGHT,
+                    tracks = ?TRACK_FLAG(?COMP)
+                },
+                #ds_pubrange{
+                    id = {<<>>, 20},
+                    until = 42,
+                    type = ?T_INFLIGHT,
+                    tracks = ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)
+                }
             ])
         ),
         ?_assertEqual(
-            {13, 13},
+            {#{ack => 13, comp => 13}, 13},
             compute_inflight_range([
-                #ds_pubrange{id = {<<>>, 1}, until = 2, type = checkpoint},
-                #ds_pubrange{id = {<<>>, 4}, until = 8, type = checkpoint},
-                #ds_pubrange{id = {<<>>, 11}, until = 12, type = checkpoint},
-                #ds_pubrange{id = {<<>>, 12}, until = 13, type = checkpoint}
+                #ds_pubrange{id = {<<>>, 1}, until = 2, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 4}, until = 8, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 11}, until = 12, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 12}, until = 13, type = ?T_CHECKPOINT}
             ])
         )
     ].

+ 90 - 12
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -247,6 +247,7 @@ print_session(ClientId) ->
                         session => Session,
                         streams => mnesia:read(?SESSION_STREAM_TAB, ClientId),
                         pubranges => session_read_pubranges(ClientId),
+                        offsets => session_read_offsets(ClientId),
                         subscriptions => session_read_subscriptions(ClientId)
                     };
                 [] ->
@@ -327,12 +328,13 @@ publish(_PacketId, Msg, Session) ->
     {ok, emqx_types:message(), replies(), session()}
     | {error, emqx_types:reason_code()}.
 puback(_ClientInfo, PacketId, Session = #{id := Id, inflight := Inflight0}) ->
-    case emqx_persistent_message_ds_replayer:commit_offset(Id, PacketId, Inflight0) of
+    case emqx_persistent_message_ds_replayer:commit_offset(Id, ack, PacketId, Inflight0) of
         {true, Inflight} ->
             %% TODO
-            Msg = #message{},
+            Msg = emqx_message:make(Id, <<>>, <<>>),
             {ok, Msg, [], Session#{inflight => Inflight}};
         {false, _} ->
+            %% Invalid Packet Id
             {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}
     end.
 
@@ -343,9 +345,16 @@ puback(_ClientInfo, PacketId, Session = #{id := Id, inflight := Inflight0}) ->
 -spec pubrec(emqx_types:packet_id(), session()) ->
     {ok, emqx_types:message(), session()}
     | {error, emqx_types:reason_code()}.
-pubrec(_PacketId, _Session = #{}) ->
-    % TODO: stub
-    {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}.
+pubrec(PacketId, Session = #{id := Id, inflight := Inflight0}) ->
+    case emqx_persistent_message_ds_replayer:commit_offset(Id, rec, PacketId, Inflight0) of
+        {true, Inflight} ->
+            %% TODO
+            Msg = emqx_message:make(Id, <<>>, <<>>),
+            {ok, Msg, Session#{inflight => Inflight}};
+        {false, _} ->
+            %% Invalid Packet Id
+            {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}
+    end.
 
 %%--------------------------------------------------------------------
 %% Client -> Broker: PUBREL
@@ -364,9 +373,16 @@ pubrel(_PacketId, Session = #{}) ->
 -spec pubcomp(clientinfo(), emqx_types:packet_id(), session()) ->
     {ok, emqx_types:message(), replies(), session()}
     | {error, emqx_types:reason_code()}.
-pubcomp(_ClientInfo, _PacketId, _Session = #{}) ->
-    % TODO: stub
-    {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}.
+pubcomp(_ClientInfo, PacketId, Session = #{id := Id, inflight := Inflight0}) ->
+    case emqx_persistent_message_ds_replayer:commit_offset(Id, comp, PacketId, Inflight0) of
+        {true, Inflight} ->
+            %% TODO
+            Msg = emqx_message:make(Id, <<>>, <<>>),
+            {ok, Msg, [], Session#{inflight => Inflight}};
+        {false, _} ->
+            %% Invalid Packet Id
+            {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}
+    end.
 
 %%--------------------------------------------------------------------
 
@@ -383,7 +399,18 @@ handle_timeout(
     pull,
     Session = #{id := Id, inflight := Inflight0, receive_maximum := ReceiveMaximum}
 ) ->
-    {Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll(Id, Inflight0, ReceiveMaximum),
+    {Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll(
+        fun
+            (_Seqno, Message = #message{qos = ?QOS_0}) ->
+                {false, {undefined, Message}};
+            (Seqno, Message) ->
+                PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
+                {PacketId, Message}
+        end,
+        Id,
+        Inflight0,
+        ReceiveMaximum
+    ),
     IdlePollInterval = emqx_config:get([session_persistence, idle_poll_interval]),
     Timeout =
         case Publishes of
@@ -393,7 +420,7 @@ handle_timeout(
                 0
         end,
     ensure_timer(pull, Timeout),
-    {ok, Publishes, Session#{inflight => Inflight}};
+    {ok, Publishes, Session#{inflight := Inflight}};
 handle_timeout(_ClientInfo, get_streams, Session) ->
     renew_streams(Session),
     ensure_timer(get_streams),
@@ -407,7 +434,24 @@ handle_timeout(_ClientInfo, bump_last_alive_at, Session0) ->
 -spec replay(clientinfo(), [], session()) ->
     {ok, replies(), session()}.
 replay(_ClientInfo, [], Session = #{inflight := Inflight0}) ->
-    {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(Inflight0),
+    AckedUntil = emqx_persistent_message_ds_replayer:committed_until(ack, Inflight0),
+    RecUntil = emqx_persistent_message_ds_replayer:committed_until(rec, Inflight0),
+    CompUntil = emqx_persistent_message_ds_replayer:committed_until(comp, Inflight0),
+    ReplyFun = fun
+        (_Seqno, #message{qos = ?QOS_0}) ->
+            {false, []};
+        (Seqno, #message{qos = ?QOS_1}) when Seqno < AckedUntil ->
+            [];
+        (Seqno, #message{qos = ?QOS_2}) when Seqno < CompUntil ->
+            [];
+        (Seqno, #message{qos = ?QOS_2}) when Seqno < RecUntil ->
+            PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
+            {pubrel, PacketId};
+        (Seqno, Message) ->
+            PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
+            {PacketId, emqx_message:set_flag(dup, true, Message)}
+    end,
+    {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(ReplyFun, Inflight0),
     {ok, Replies, Session#{inflight := Inflight}}.
 
 %%--------------------------------------------------------------------
@@ -521,11 +565,22 @@ create_tables() ->
             {attributes, record_info(fields, ds_pubrange)}
         ]
     ),
+    ok = mria:create_table(
+        ?SESSION_COMMITTED_OFFSET_TAB,
+        [
+            {rlog_shard, ?DS_MRIA_SHARD},
+            {type, set},
+            {storage, storage()},
+            {record_name, ds_committed_offset},
+            {attributes, record_info(fields, ds_committed_offset)}
+        ]
+    ),
     ok = mria:wait_for_tables([
         ?SESSION_TAB,
         ?SESSION_SUBSCRIPTIONS_TAB,
         ?SESSION_STREAM_TAB,
-        ?SESSION_PUBRANGE_TAB
+        ?SESSION_PUBRANGE_TAB,
+        ?SESSION_COMMITTED_OFFSET_TAB
     ]),
     ok.
 
@@ -629,6 +684,7 @@ session_drop(DSSessionId) ->
     transaction(fun() ->
         ok = session_drop_subscriptions(DSSessionId),
         ok = session_drop_pubranges(DSSessionId),
+        ok = session_drop_offsets(DSSessionId),
         ok = session_drop_streams(DSSessionId),
         ok = mnesia:delete(?SESSION_TAB, DSSessionId, write)
     end).
@@ -720,6 +776,17 @@ session_read_pubranges(DSSessionId, LockKind) ->
     ),
     mnesia:select(?SESSION_PUBRANGE_TAB, MS, LockKind).
 
+session_read_offsets(DSSessionID) ->
+    session_read_offsets(DSSessionID, read).
+
+session_read_offsets(DSSessionId, LockKind) ->
+    MS = ets:fun2ms(
+        fun(#ds_committed_offset{id = {Sess, Type}}) when Sess =:= DSSessionId ->
+            {DSSessionId, Type}
+        end
+    ),
+    mnesia:select(?SESSION_COMMITTED_OFFSET_TAB, MS, LockKind).
+
 -spec new_subscription_id(id(), topic_filter()) -> {subscription_id(), integer()}.
 new_subscription_id(DSSessionId, TopicFilter) ->
     %% Note: here we use _milliseconds_ to match with the timestamp
@@ -832,6 +899,17 @@ session_drop_pubranges(DSSessionId) ->
         RangeIds
     ).
 
+%% must be called inside a transaction
+-spec session_drop_offsets(id()) -> ok.
+session_drop_offsets(DSSessionId) ->
+    OffsetIds = session_read_offsets(DSSessionId, write),
+    lists:foreach(
+        fun(OffsetId) ->
+            mnesia:delete(?SESSION_COMMITTED_OFFSET_TAB, OffsetId, write)
+        end,
+        OffsetIds
+    ).
+
 %%--------------------------------------------------------------------------------
 
 transaction(Fun) ->

+ 20 - 1
apps/emqx/src/emqx_persistent_session_ds.hrl

@@ -22,8 +22,12 @@
 -define(SESSION_SUBSCRIPTIONS_TAB, emqx_ds_session_subscriptions).
 -define(SESSION_STREAM_TAB, emqx_ds_stream_tab).
 -define(SESSION_PUBRANGE_TAB, emqx_ds_pubrange_tab).
+-define(SESSION_COMMITTED_OFFSET_TAB, emqx_ds_committed_offset_tab).
 -define(DS_MRIA_SHARD, emqx_ds_session_shard).
 
+-define(T_INFLIGHT, 1).
+-define(T_CHECKPOINT, 2).
+
 -record(ds_sub, {
     id :: emqx_persistent_session_ds:subscription_id(),
     start_time :: emqx_ds:time(),
@@ -56,7 +60,11 @@
     %% * Inflight range is a range of yet unacked messages from this stream.
     %% * Checkpoint range was already acked, its purpose is to keep track of the
     %%   very last iterator for this stream.
-    type :: inflight | checkpoint,
+    type :: ?T_INFLIGHT | ?T_CHECKPOINT,
+    %% What commit tracks this range is part of.
+    %% This is rarely stored: we only need to persist it when the range contains
+    %% QoS 2 messages.
+    tracks = 0 :: non_neg_integer(),
     %% Meaning of this depends on the type of the range:
     %% * For inflight range, this is the iterator pointing to the first message in
     %%   the range.
@@ -68,6 +76,17 @@
 }).
 -type ds_pubrange() :: #ds_pubrange{}.
 
+-record(ds_committed_offset, {
+    id :: {
+        %% What session this marker belongs to.
+        _Session :: emqx_persistent_session_ds:id(),
+        %% Marker name.
+        _CommitType
+    },
+    %% Where this marker is pointing to: the first seqno that is not marked.
+    until :: emqx_persistent_message_ds_replayer:seqno()
+}).
+
 -record(session, {
     %% same as clientid
     id :: emqx_persistent_session_ds:id(),

+ 4 - 5
apps/emqx/test/emqx_persistent_messages_SUITE.erl

@@ -233,7 +233,7 @@ t_session_subscription_iterators(Config) ->
     ),
     ok.
 
-t_qos0(Config) ->
+t_qos0(_Config) ->
     Sub = connect(<<?MODULE_STRING "1">>, true, 30),
     Pub = connect(<<?MODULE_STRING "2">>, true, 0),
     try
@@ -258,7 +258,7 @@ t_qos0(Config) ->
         emqtt:stop(Pub)
     end.
 
-t_publish_as_persistent(Config) ->
+t_publish_as_persistent(_Config) ->
     Sub = connect(<<?MODULE_STRING "1">>, true, 30),
     Pub = connect(<<?MODULE_STRING "2">>, true, 30),
     try
@@ -272,9 +272,8 @@ t_publish_as_persistent(Config) ->
         ?assertMatch(
             [
                 #{qos := 0, topic := <<"t/1">>, payload := <<"1">>},
-                #{qos := 1, topic := <<"t/1">>, payload := <<"2">>}
-                %% TODO: QoS 2
-                %% #{qos := 2, topic := <<"t/1">>, payload := <<"3">>}
+                #{qos := 1, topic := <<"t/1">>, payload := <<"2">>},
+                #{qos := 2, topic := <<"t/1">>, payload := <<"3">>}
             ],
             receive_messages(3)
         )

+ 222 - 73
apps/emqx/test/emqx_persistent_session_SUITE.erl

@@ -17,6 +17,7 @@
 -module(emqx_persistent_session_SUITE).
 
 -include_lib("stdlib/include/assert.hrl").
+-include_lib("emqx/include/asserts.hrl").
 -include_lib("common_test/include/ct.hrl").
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
 -include_lib("emqx/include/emqx_mqtt.hrl").
@@ -53,10 +54,10 @@ all() ->
 groups() ->
     TCs = emqx_common_test_helpers:all(?MODULE),
     TCsNonGeneric = [t_choose_impl],
+    TCGroups = [{group, tcp}, {group, quic}, {group, ws}],
     [
-        {persistence_disabled, [{group, no_kill_connection_process}]},
-        {persistence_enabled, [{group, no_kill_connection_process}]},
-        {no_kill_connection_process, [], [{group, tcp}, {group, quic}, {group, ws}]},
+        {persistence_disabled, TCGroups},
+        {persistence_enabled, TCGroups},
         {tcp, [], TCs},
         {quic, [], TCs -- TCsNonGeneric},
         {ws, [], TCs -- TCsNonGeneric}
@@ -74,7 +75,7 @@ init_per_group(persistence_enabled, Config) ->
         {persistence, ds}
         | Config
     ];
-init_per_group(Group, Config) when Group == tcp ->
+init_per_group(tcp, Config) ->
     Apps = emqx_cth_suite:start(
         [{emqx, ?config(emqx_config, Config)}],
         #{work_dir => emqx_cth_suite:work_dir(Config)}
@@ -85,7 +86,7 @@ init_per_group(Group, Config) when Group == tcp ->
         {group_apps, Apps}
         | Config
     ];
-init_per_group(Group, Config) when Group == ws ->
+init_per_group(ws, Config) ->
     Apps = emqx_cth_suite:start(
         [{emqx, ?config(emqx_config, Config)}],
         #{work_dir => emqx_cth_suite:work_dir(Config)}
@@ -99,7 +100,7 @@ init_per_group(Group, Config) when Group == ws ->
         {group_apps, Apps}
         | Config
     ];
-init_per_group(Group, Config) when Group == quic ->
+init_per_group(quic, Config) ->
     Apps = emqx_cth_suite:start(
         [
             {emqx,
@@ -118,11 +119,7 @@ init_per_group(Group, Config) when Group == quic ->
         {ssl, true},
         {group_apps, Apps}
         | Config
-    ];
-init_per_group(no_kill_connection_process, Config) ->
-    [{kill_connection_process, false} | Config];
-init_per_group(kill_connection_process, Config) ->
-    [{kill_connection_process, true} | Config].
+    ].
 
 get_listener_port(Type, Name) ->
     case emqx_config:get([listeners, Type, Name, bind]) of
@@ -194,6 +191,8 @@ receive_message_loop(Count, Deadline) ->
     receive
         {publish, Msg} ->
             [Msg | receive_message_loop(Count - 1, Deadline)];
+        {pubrel, Msg} ->
+            [{pubrel, Msg} | receive_message_loop(Count - 1, Deadline)];
         _Other ->
             receive_message_loop(Count, Deadline)
     after Timeout ->
@@ -201,39 +200,44 @@ receive_message_loop(Count, Deadline) ->
     end.
 
 maybe_kill_connection_process(ClientId, Config) ->
-    case ?config(kill_connection_process, Config) of
-        true ->
-            case emqx_cm:lookup_channels(ClientId) of
-                [] ->
-                    ok;
-                [ConnectionPid] ->
-                    ?assert(is_pid(ConnectionPid)),
-                    Ref = monitor(process, ConnectionPid),
-                    ConnectionPid ! die_if_test,
-                    receive
-                        {'DOWN', Ref, process, ConnectionPid, normal} -> ok
-                    after 3000 -> error(process_did_not_die)
-                    end,
-                    wait_for_cm_unregister(ClientId)
-            end;
-        false ->
+    Persistence = ?config(persistence, Config),
+    case emqx_cm:lookup_channels(ClientId) of
+        [] ->
+            ok;
+        [ConnectionPid] when Persistence == ds ->
+            Ref = monitor(process, ConnectionPid),
+            ConnectionPid ! die_if_test,
+            ?assertReceive(
+                {'DOWN', Ref, process, ConnectionPid, Reason} when
+                    Reason == normal orelse Reason == noproc,
+                3000
+            ),
+            wait_connection_process_unregistered(ClientId);
+        _ ->
             ok
     end.
 
-wait_for_cm_unregister(ClientId) ->
-    wait_for_cm_unregister(ClientId, 100).
-
-wait_for_cm_unregister(_ClientId, 0) ->
-    error(cm_did_not_unregister);
-wait_for_cm_unregister(ClientId, N) ->
+wait_connection_process_dies(ClientId) ->
     case emqx_cm:lookup_channels(ClientId) of
         [] ->
             ok;
-        [_] ->
-            timer:sleep(100),
-            wait_for_cm_unregister(ClientId, N - 1)
+        [ConnectionPid] ->
+            Ref = monitor(process, ConnectionPid),
+            ?assertReceive(
+                {'DOWN', Ref, process, ConnectionPid, Reason} when
+                    Reason == normal orelse Reason == noproc,
+                3000
+            ),
+            wait_connection_process_unregistered(ClientId)
     end.
 
+wait_connection_process_unregistered(ClientId) ->
+    ?retry(
+        _Timeout = 100,
+        _Retries = 20,
+        ?assertEqual([], emqx_cm:lookup_channels(ClientId))
+    ).
+
 messages(Topic, Payloads) ->
     messages(Topic, Payloads, ?QOS_2).
 
@@ -272,23 +276,7 @@ do_publish(Messages = [_ | _], PublishFun, WaitForUnregister) ->
                 lists:foreach(fun(Message) -> PublishFun(Client, Message) end, Messages),
                 ok = emqtt:disconnect(Client),
                 %% Snabbkaffe sometimes fails unless all processes are gone.
-                case WaitForUnregister of
-                    false ->
-                        ok;
-                    true ->
-                        case emqx_cm:lookup_channels(ClientID) of
-                            [] ->
-                                ok;
-                            [ConnectionPid] ->
-                                ?assert(is_pid(ConnectionPid)),
-                                Ref1 = monitor(process, ConnectionPid),
-                                receive
-                                    {'DOWN', Ref1, process, ConnectionPid, _} -> ok
-                                after 3000 -> error(process_did_not_die)
-                                end,
-                                wait_for_cm_unregister(ClientID)
-                        end
-                end
+                WaitForUnregister andalso wait_connection_process_dies(ClientID)
             end
         ),
     receive
@@ -475,7 +463,7 @@ t_cancel_on_disconnect(Config) ->
     {ok, _} = emqtt:ConnFun(Client1),
     ok = emqtt:disconnect(Client1, 0, #{'Session-Expiry-Interval' => 0}),
 
-    wait_for_cm_unregister(ClientId),
+    wait_connection_process_unregistered(ClientId),
 
     {ok, Client2} = emqtt:start_link([
         {clientid, ClientId},
@@ -507,7 +495,7 @@ t_persist_on_disconnect(Config) ->
     %% Strangely enough, the disconnect is reported as successful by emqtt.
     ok = emqtt:disconnect(Client1, 0, #{'Session-Expiry-Interval' => 30}),
 
-    wait_for_cm_unregister(ClientId),
+    wait_connection_process_unregistered(ClientId),
 
     {ok, Client2} = emqtt:start_link([
         {clientid, ClientId},
@@ -619,7 +607,7 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
         {clientid, ClientId},
         {properties, #{'Session-Expiry-Interval' => 30}},
         {clean_start, true},
-        {auto_ack, false}
+        {auto_ack, never}
         | Config
     ]),
     {ok, _} = emqtt:ConnFun(Client1),
@@ -666,8 +654,7 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
 
     ?assertEqual(
         get_topicwise_order(Pubs1),
-        get_topicwise_order(Msgs1),
-        Msgs1
+        get_topicwise_order(Msgs1)
     ),
 
     NAcked = 4,
@@ -725,21 +712,6 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
 
     ok = emqtt:disconnect(Client2).
 
-get_topicwise_order(Msgs) ->
-    maps:groups_from_list(fun get_msgpub_topic/1, fun get_msgpub_payload/1, Msgs).
-
-get_msgpub_topic(#mqtt_msg{topic = Topic}) ->
-    Topic;
-get_msgpub_topic(#{topic := Topic}) ->
-    Topic.
-
-get_msgpub_payload(#mqtt_msg{payload = Payload}) ->
-    Payload;
-get_msgpub_payload(#{payload := Payload}) ->
-    Payload.
-
-t_publish_while_client_is_gone(init, Config) -> skip_ds_tc(Config);
-t_publish_while_client_is_gone('end', _Config) -> ok.
 t_publish_while_client_is_gone(Config) ->
     %% A persistent session should receive messages in its
     %% subscription even if the process owning the session dies.
@@ -782,6 +754,157 @@ t_publish_while_client_is_gone(Config) ->
 
     ok = emqtt:disconnect(Client2).
 
+t_publish_many_while_client_is_gone(Config) ->
+    %% A persistent session should receive all of the still unacked messages
+    %% for its subscriptions after the client dies or reconnects, in addition
+    %% to PUBRELs for the messages it has PUBRECed. While client must send
+    %% PUBACKs and PUBRECs in order, those orders are independent of each other.
+    ClientId = ?config(client_id, Config),
+    ConnFun = ?config(conn_fun, Config),
+    ClientOpts = [
+        {proto_ver, v5},
+        {clientid, ClientId},
+        {properties, #{'Session-Expiry-Interval' => 30}},
+        {auto_ack, never}
+        | Config
+    ],
+
+    {ok, Client1} = emqtt:start_link([{clean_start, true} | ClientOpts]),
+    {ok, _} = emqtt:ConnFun(Client1),
+    {ok, _, [?QOS_1]} = emqtt:subscribe(Client1, <<"t/+/foo">>, ?QOS_1),
+    {ok, _, [?QOS_2]} = emqtt:subscribe(Client1, <<"msg/feed/#">>, ?QOS_2),
+    {ok, _, [?QOS_2]} = emqtt:subscribe(Client1, <<"loc/+/+/+">>, ?QOS_2),
+
+    Pubs1 = [
+        #mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M1">>, qos = 1},
+        #mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M2">>, qos = 1},
+        #mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M3">>, qos = 2},
+        #mqtt_msg{topic = <<"loc/1/2/42">>, payload = <<"M4">>, qos = 2},
+        #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M5">>, qos = 2},
+        #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M6">>, qos = 1},
+        #mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M7">>, qos = 2},
+        #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M8">>, qos = 1},
+        #mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M9">>, qos = 2}
+    ],
+    ok = publish_many(Pubs1),
+    NPubs1 = length(Pubs1),
+
+    Msgs1 = receive_messages(NPubs1),
+    ct:pal("Msgs1 = ~p", [Msgs1]),
+    NMsgs1 = length(Msgs1),
+    ?assertEqual(NPubs1, NMsgs1),
+
+    ?assertEqual(
+        get_topicwise_order(Pubs1),
+        get_topicwise_order(Msgs1)
+    ),
+
+    %% PUBACK every QoS 1 message.
+    lists:foreach(
+        fun(PktId) -> ok = emqtt:puback(Client1, PktId) end,
+        [PktId || #{qos := 1, packet_id := PktId} <- Msgs1]
+    ),
+
+    %% PUBREC first `NRecs` QoS 2 messages.
+    NRecs = 3,
+    PubRecs1 = lists:sublist([PktId || #{qos := 2, packet_id := PktId} <- Msgs1], NRecs),
+    lists:foreach(
+        fun(PktId) -> ok = emqtt:pubrec(Client1, PktId) end,
+        PubRecs1
+    ),
+
+    %% Ensure that PUBACKs / PUBRECs are propagated to the channel.
+    pong = emqtt:ping(Client1),
+
+    %% Receive PUBRELs for the sent PUBRECs.
+    PubRels1 = receive_messages(NRecs),
+    ct:pal("PubRels1 = ~p", [PubRels1]),
+    ?assertEqual(
+        PubRecs1,
+        [PktId || {pubrel, #{packet_id := PktId}} <- PubRels1],
+        PubRels1
+    ),
+
+    ok = emqtt:disconnect(Client1),
+    maybe_kill_connection_process(ClientId, Config),
+
+    Pubs2 = [
+        #mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M10">>, qos = 2},
+        #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M11">>, qos = 1},
+        #mqtt_msg{topic = <<"msg/feed/friend">>, payload = <<"M12">>, qos = 2}
+    ],
+    ok = publish_many(Pubs2),
+    NPubs2 = length(Pubs2),
+
+    {ok, Client2} = emqtt:start_link([{clean_start, false} | ClientOpts]),
+    {ok, _} = emqtt:ConnFun(Client2),
+
+    %% Try to receive _at most_ `NPubs` messages.
+    %% There shouldn't be that much unacked messages in the replay anyway,
+    %% but it's an easy number to pick.
+    NPubs = NPubs1 + NPubs2,
+    Msgs2 = receive_messages(NPubs, _Timeout = 2000),
+    ct:pal("Msgs2 = ~p", [Msgs2]),
+
+    %% We should again receive PUBRELs for the PUBRECs we sent earlier.
+    ?assertEqual(
+        get_msgs_essentials(PubRels1),
+        [get_msg_essentials(PubRel) || PubRel = {pubrel, _} <- Msgs2]
+    ),
+
+    %% We should receive duplicates only for QoS 2 messages where PUBRELs were
+    %% not sent, in the same order as the original messages.
+    Msgs2Dups = [get_msg_essentials(M) || M = #{dup := true} <- Msgs2],
+    ?assertEqual(
+        Msgs2Dups,
+        [M || M = #{qos := 2} <- Msgs2Dups]
+    ),
+    ?assertEqual(
+        get_msgs_essentials(pick_respective_msgs(Msgs2Dups, Msgs1)),
+        Msgs2Dups
+    ),
+
+    %% Now complete all yet incomplete QoS 2 message flows instead.
+    PubRecs2 = [PktId || #{qos := 2, packet_id := PktId} <- Msgs2],
+    lists:foreach(
+        fun(PktId) -> ok = emqtt:pubrec(Client2, PktId) end,
+        PubRecs2
+    ),
+
+    PubRels2 = receive_messages(length(PubRecs2)),
+    ct:pal("PubRels2 = ~p", [PubRels2]),
+    ?assertEqual(
+        PubRecs2,
+        [PktId || {pubrel, #{packet_id := PktId}} <- PubRels2],
+        PubRels2
+    ),
+
+    %% PUBCOMP every PUBREL.
+    PubComps = [PktId || {pubrel, #{packet_id := PktId}} <- PubRels1 ++ PubRels2],
+    lists:foreach(
+        fun(PktId) -> ok = emqtt:pubcomp(Client2, PktId) end,
+        PubComps
+    ),
+
+    %% Ensure that PUBCOMPs are propagated to the channel.
+    pong = emqtt:ping(Client2),
+
+    ok = emqtt:disconnect(Client2),
+    maybe_kill_connection_process(ClientId, Config),
+
+    {ok, Client3} = emqtt:start_link([{clean_start, false} | ClientOpts]),
+    {ok, _} = emqtt:ConnFun(Client3),
+
+    %% Only the last unacked QoS 1 message should be retransmitted.
+    Msgs3 = receive_messages(NPubs, _Timeout = 2000),
+    ct:pal("Msgs3 = ~p", [Msgs3]),
+    ?assertMatch(
+        [#{topic := <<"t/100/foo">>, payload := <<"M11">>, qos := 1, dup := true}],
+        Msgs3
+    ),
+
+    ok = emqtt:disconnect(Client3).
+
 t_clean_start_drops_subscriptions(Config) ->
     %% 1. A persistent session is started and disconnected.
     %% 2. While disconnected, a message is published and persisted.
@@ -832,6 +955,7 @@ t_clean_start_drops_subscriptions(Config) ->
     [Msg1] = receive_messages(1),
     ?assertEqual({ok, iolist_to_binary(Payload2)}, maps:find(payload, Msg1)),
 
+    pong = emqtt:ping(Client2),
     ok = emqtt:disconnect(Client2),
     maybe_kill_connection_process(ClientId, Config),
 
@@ -849,6 +973,7 @@ t_clean_start_drops_subscriptions(Config) ->
     [Msg2] = receive_messages(1),
     ?assertEqual({ok, iolist_to_binary(Payload3)}, maps:find(payload, Msg2)),
 
+    pong = emqtt:ping(Client3),
     ok = emqtt:disconnect(Client3).
 
 t_unsubscribe(Config) ->
@@ -912,6 +1037,30 @@ t_multiple_subscription_matches(Config) ->
     ?assertEqual({ok, 2}, maps:find(qos, Msg2)),
     ok = emqtt:disconnect(Client2).
 
+get_topicwise_order(Msgs) ->
+    maps:groups_from_list(fun get_msgpub_topic/1, fun get_msgpub_payload/1, Msgs).
+
+get_msgpub_topic(#mqtt_msg{topic = Topic}) ->
+    Topic;
+get_msgpub_topic(#{topic := Topic}) ->
+    Topic.
+
+get_msgpub_payload(#mqtt_msg{payload = Payload}) ->
+    Payload;
+get_msgpub_payload(#{payload := Payload}) ->
+    Payload.
+
+get_msg_essentials(Msg = #{}) ->
+    maps:with([packet_id, topic, payload, qos], Msg);
+get_msg_essentials({pubrel, Msg}) ->
+    {pubrel, maps:with([packet_id, reason_code], Msg)}.
+
+get_msgs_essentials(Msgs) ->
+    [get_msg_essentials(M) || M <- Msgs].
+
+pick_respective_msgs(MsgRefs, Msgs) ->
+    [M || M <- Msgs, Ref <- MsgRefs, maps:get(packet_id, M) =:= maps:get(packet_id, Ref)].
+
 skip_ds_tc(Config) ->
     case ?config(persistence, Config) of
         ds ->

+ 1 - 1
apps/emqx_retainer/rebar.config

@@ -30,7 +30,7 @@
 {profiles, [
     {test, [
         {deps, [
-            {emqtt, {git, "https://github.com/emqx/emqtt", {tag, "1.9.1"}}}
+            {emqtt, {git, "https://github.com/emqx/emqtt", {tag, "1.9.6"}}}
         ]}
     ]}
 ]}.

+ 1 - 1
mix.exs

@@ -64,7 +64,7 @@ defmodule EMQXUmbrella.MixProject do
       {:pbkdf2, github: "emqx/erlang-pbkdf2", tag: "2.0.4", override: true},
       # maybe forbid to fetch quicer
       {:emqtt,
-       github: "emqx/emqtt", tag: "1.9.1", override: true, system_env: maybe_no_quic_env()},
+       github: "emqx/emqtt", tag: "1.9.6", override: true, system_env: maybe_no_quic_env()},
       {:rulesql, github: "emqx/rulesql", tag: "0.1.7"},
       {:observer_cli, "1.7.1"},
       {:system_monitor, github: "ieQu1/system_monitor", tag: "3.0.3"},

+ 1 - 1
rebar.config

@@ -69,7 +69,7 @@
     , {ecpool, {git, "https://github.com/emqx/ecpool", {tag, "0.5.4"}}}
     , {replayq, {git, "https://github.com/emqx/replayq.git", {tag, "0.3.7"}}}
     , {pbkdf2, {git, "https://github.com/emqx/erlang-pbkdf2.git", {tag, "2.0.4"}}}
-    , {emqtt, {git, "https://github.com/emqx/emqtt", {tag, "1.9.1"}}}
+    , {emqtt, {git, "https://github.com/emqx/emqtt", {tag, "1.9.6"}}}
     , {rulesql, {git, "https://github.com/emqx/rulesql", {tag, "0.1.7"}}}
     , {observer_cli, "1.7.1"} % NOTE: depends on recon 2.5.x
     , {system_monitor, {git, "https://github.com/ieQu1/system_monitor", {tag, "3.0.3"}}}