Преглед изворни кода

feat(sessds): preserve acks / ranges in mnesia for replays

Andrew Mayorov пре 2 година
родитељ
комит
1246d714c5

+ 2 - 8
apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl

@@ -357,18 +357,12 @@ do_t_session_discard(Params) ->
                 _Attempts0 = 50,
                 true = map_size(emqx_persistent_session_ds:list_all_streams()) > 0
             ),
-            ?retry(
-                _Sleep0 = 100,
-                _Attempts0 = 50,
-                true = map_size(emqx_persistent_session_ds:list_all_iterators()) > 0
-            ),
             ok = emqtt:stop(Client0),
             ?tp(notice, "disconnected", #{}),
 
             ?tp(notice, "reconnecting", #{}),
-            %% we still have iterators and streams
+            %% we still have streams
             ?assert(map_size(emqx_persistent_session_ds:list_all_streams()) > 0),
-            ?assert(map_size(emqx_persistent_session_ds:list_all_iterators()) > 0),
             Client1 = start_client(ReconnectOpts),
             {ok, _} = emqtt:connect(Client1),
             ?assertEqual([], emqtt:subscriptions(Client1)),
@@ -381,7 +375,7 @@ do_t_session_discard(Params) ->
             ?assertEqual(#{}, emqx_persistent_session_ds:list_all_subscriptions()),
             ?assertEqual([], emqx_persistent_session_ds_router:topics()),
             ?assertEqual(#{}, emqx_persistent_session_ds:list_all_streams()),
-            ?assertEqual(#{}, emqx_persistent_session_ds:list_all_iterators()),
+            ?assertEqual(#{}, emqx_persistent_session_ds:list_all_pubranges()),
             ok = emqtt:stop(Client1),
             ?tp(notice, "disconnected", #{}),
 

+ 308 - 125
apps/emqx/src/emqx_persistent_message_ds_replayer.erl

@@ -19,12 +19,12 @@
 -module(emqx_persistent_message_ds_replayer).
 
 %% API:
--export([new/0, next_packet_id/1, replay/2, commit_offset/3, poll/3, n_inflight/1]).
+-export([new/0, open/1, next_packet_id/1, replay/1, commit_offset/3, poll/3, n_inflight/1]).
 
 %% internal exports:
 -export([]).
 
--export_type([inflight/0]).
+-export_type([inflight/0, seqno/0]).
 
 -include_lib("emqx/include/logger.hrl").
 -include("emqx_persistent_session_ds.hrl").
@@ -42,17 +42,28 @@
 -type seqno() :: non_neg_integer().
 
 -record(range, {
-    stream :: emqx_ds:stream(),
+    stream :: _StreamRef,
     first :: seqno(),
-    last :: seqno(),
-    iterator_next :: emqx_ds:iterator() | undefined
+    until :: seqno(),
+    %% Type of a range:
+    %% * Inflight range is a range of yet unacked messages from this stream.
+    %% * Checkpoint range was already acked, its purpose is to keep track of the
+    %%   very last iterator for this stream.
+    type :: inflight | checkpoint,
+    %% Meaning of this depends on the type of the range:
+    %% * For inflight range, this is the iterator pointing to the first message in
+    %%   the range.
+    %% * For checkpoint range, this is the iterator pointing right past the last
+    %%   message in the range.
+    iterator :: emqx_ds:iterator()
 }).
 
 -type range() :: #range{}.
 
 -record(inflight, {
-    next_seqno = 0 :: seqno(),
-    acked_seqno = 0 :: seqno(),
+    next_seqno = 1 :: seqno(),
+    acked_until = 1 :: seqno(),
+    %% Ranges are sorted in ascending order of their sequence numbers.
     offset_ranges = [] :: [range()]
 }).
 
@@ -66,34 +77,37 @@
 new() ->
     #inflight{}.
 
+-spec open(emqx_persistent_session_ds:id()) -> inflight().
+open(SessionId) ->
+    Ranges = ro_transaction(fun() -> get_ranges(SessionId) end),
+    {AckedUntil, NextSeqno} = compute_inflight_range(Ranges),
+    #inflight{
+        acked_until = AckedUntil,
+        next_seqno = NextSeqno,
+        offset_ranges = Ranges
+    }.
+
 -spec next_packet_id(inflight()) -> {emqx_types:packet_id(), inflight()}.
-next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqNo}) ->
-    Inflight = Inflight0#inflight{next_seqno = LastSeqNo + 1},
-    case LastSeqNo rem 16#10000 of
-        0 ->
-            %% We skip sequence numbers that lead to PacketId = 0 to
-            %% simplify math. Note: it leads to occasional gaps in the
-            %% sequence numbers.
-            next_packet_id(Inflight);
-        PacketId ->
-            {PacketId, Inflight}
-    end.
+next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqno}) ->
+    Inflight = Inflight0#inflight{next_seqno = next_seqno(LastSeqno)},
+    {seqno_to_packet_id(LastSeqno), Inflight}.
 
 -spec n_inflight(inflight()) -> non_neg_integer().
-n_inflight(#inflight{next_seqno = NextSeqNo, acked_seqno = AckedSeqno}) ->
-    %% NOTE: this function assumes that gaps in the sequence ID occur
-    %% _only_ when the packet ID wraps:
-    case AckedSeqno >= ((NextSeqNo bsr 16) bsl 16) of
-        true ->
-            NextSeqNo - AckedSeqno;
-        false ->
-            NextSeqNo - AckedSeqno - 1
-    end.
+n_inflight(#inflight{next_seqno = NextSeqno, acked_until = AckedUntil}) ->
+    range_size(AckedUntil, NextSeqno).
 
--spec replay(emqx_persistent_session_ds:id(), inflight()) ->
-    emqx_session:replies().
-replay(_SessionId, _Inflight = #inflight{offset_ranges = _Ranges}) ->
-    [].
+-spec replay(inflight()) ->
+    {emqx_session:replies(), inflight()}.
+replay(Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}) ->
+    {Ranges, Replies} = lists:mapfoldr(
+        fun(Range, Acc) ->
+            replay_range(Range, AckedUntil, Acc)
+        end,
+        [],
+        Ranges0
+    ),
+    Inflight = Inflight0#inflight{offset_ranges = Ranges},
+    {Replies, Inflight}.
 
 -spec commit_offset(emqx_persistent_session_ds:id(), emqx_types:packet_id(), inflight()) ->
     {_IsValidOffset :: boolean(), inflight()}.
@@ -101,47 +115,34 @@ commit_offset(
     SessionId,
     PacketId,
     Inflight0 = #inflight{
-        acked_seqno = AckedSeqno0, next_seqno = NextSeqNo, offset_ranges = Ranges0
+        acked_until = AckedUntil, next_seqno = NextSeqno
     }
 ) ->
-    AckedSeqno =
-        case packet_id_to_seqno(NextSeqNo, PacketId) of
-            N when N > AckedSeqno0; AckedSeqno0 =:= 0 ->
-                N;
-            OutOfRange ->
-                ?SLOG(warning, #{
-                    msg => "out-of-order_ack",
-                    prev_seqno => AckedSeqno0,
-                    acked_seqno => OutOfRange,
-                    next_seqno => NextSeqNo,
-                    packet_id => PacketId
-                }),
-                AckedSeqno0
-        end,
-    Ranges = lists:filter(
-        fun(#range{stream = Stream, last = LastSeqno, iterator_next = ItNext}) ->
-            case LastSeqno =< AckedSeqno of
-                true ->
-                    %% This range has been fully
-                    %% acked. Remove it and replace saved
-                    %% iterator with the trailing iterator.
-                    update_iterator(SessionId, Stream, ItNext),
-                    false;
-                false ->
-                    %% This range still has unacked
-                    %% messages:
-                    true
-            end
-        end,
-        Ranges0
-    ),
-    Inflight = Inflight0#inflight{acked_seqno = AckedSeqno, offset_ranges = Ranges},
-    {true, Inflight}.
+    case packet_id_to_seqno(NextSeqno, PacketId) of
+        Seqno when Seqno >= AckedUntil andalso Seqno < NextSeqno ->
+            %% TODO
+            %% We do not preserve `acked_until` in the database. Instead, we discard
+            %% fully acked ranges from the database. In effect, this means that the
+            %% most recent `acked_until` the client has sent may be lost in case of a
+            %% crash or client loss.
+            Inflight1 = Inflight0#inflight{acked_until = next_seqno(Seqno)},
+            Inflight = discard_acked(SessionId, Inflight1),
+            {true, Inflight};
+        OutOfRange ->
+            ?SLOG(warning, #{
+                msg => "out-of-order_ack",
+                acked_until => AckedUntil,
+                acked_seqno => OutOfRange,
+                next_seqno => NextSeqno,
+                packet_id => PacketId
+            }),
+            {false, Inflight0}
+    end.
 
 -spec poll(emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
     {emqx_session:replies(), inflight()}.
 poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff ->
-    #inflight{next_seqno = NextSeqNo0, acked_seqno = AckedSeqno} =
+    #inflight{next_seqno = NextSeqNo0, acked_until = AckedSeqno} =
         Inflight0,
     FetchThreshold = max(1, WindowSize div 2),
     FreeSpace = AckedSeqno + WindowSize - NextSeqNo0,
@@ -153,6 +154,7 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
             %% client get stuck even?
             {[], Inflight0};
         true ->
+            %% TODO: Wrap this in `mria:async_dirty/2`?
             Streams = shuffle(get_streams(SessionId)),
             fetch(SessionId, Inflight0, Streams, FreeSpace, [])
     end.
@@ -165,75 +167,206 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
 %% Internal functions
 %%================================================================================
 
-fetch(_SessionId, Inflight, _Streams = [], _N, Acc) ->
-    {lists:reverse(Acc), Inflight};
-fetch(_SessionId, Inflight, _Streams, 0, Acc) ->
-    {lists:reverse(Acc), Inflight};
-fetch(SessionId, Inflight0, [Stream | Streams], N, Publishes0) ->
-    #inflight{next_seqno = FirstSeqNo, offset_ranges = Ranges0} = Inflight0,
-    ItBegin = get_last_iterator(SessionId, Stream, Ranges0),
+compute_inflight_range([]) ->
+    {1, 1};
+compute_inflight_range(Ranges) ->
+    _RangeLast = #range{until = LastSeqno} = lists:last(Ranges),
+    RangesUnacked = lists:dropwhile(fun(#range{type = T}) -> T == checkpoint end, Ranges),
+    case RangesUnacked of
+        [#range{first = AckedUntil} | _] ->
+            {AckedUntil, LastSeqno};
+        [] ->
+            {LastSeqno, LastSeqno}
+    end.
+
+get_ranges(SessionId) ->
+    DSRanges = mnesia:match_object(
+        ?SESSION_PUBRANGE_TAB,
+        #ds_pubrange{id = {SessionId, '_'}, _ = '_'},
+        read
+    ),
+    lists:map(fun export_range/1, DSRanges).
+
+export_range(#ds_pubrange{
+    type = Type, id = {_, First}, until = Until, stream = StreamRef, iterator = It
+}) ->
+    #range{type = Type, stream = StreamRef, first = First, until = Until, iterator = It}.
+
+fetch(SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
+    #inflight{next_seqno = FirstSeqno, offset_ranges = Ranges0} = Inflight0,
+    ItBegin = get_last_iterator(DSStream, Ranges0),
     {ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
-    {NMessages, Publishes, Inflight1} =
-        lists:foldl(
-            fun(Msg, {N0, PubAcc0, InflightAcc0}) ->
-                {PacketId, InflightAcc} = next_packet_id(InflightAcc0),
-                PubAcc = [{PacketId, Msg} | PubAcc0],
-                {N0 + 1, PubAcc, InflightAcc}
-            end,
-            {0, Publishes0, Inflight0},
-            Messages
-        ),
-    #inflight{next_seqno = LastSeqNo} = Inflight1,
-    case NMessages > 0 of
-        true ->
-            Range = #range{
-                first = FirstSeqNo,
-                last = LastSeqNo - 1,
-                stream = Stream,
-                iterator_next = ItEnd
+    {Publishes, UntilSeqno} = publish(FirstSeqno, Messages),
+    case range_size(FirstSeqno, UntilSeqno) of
+        Size when Size > 0 ->
+            Range0 = #range{
+                type = inflight,
+                first = FirstSeqno,
+                until = UntilSeqno,
+                stream = DSStream#ds_stream.ref,
+                iterator = ItBegin
             },
-            Inflight = Inflight1#inflight{offset_ranges = Ranges0 ++ [Range]},
-            fetch(SessionId, Inflight, Streams, N - NMessages, Publishes);
-        false ->
-            fetch(SessionId, Inflight1, Streams, N, Publishes)
-    end.
+            %% We need to preserve the iterator pointing to the beginning of the
+            %% range, so that we can replay it if needed.
+            ok = preserve_range(SessionId, Range0),
+            %% ...Yet we need to keep the iterator pointing past the end of the
+            %% range, so that we can pick up where we left off: it will become
+            %% `ItBegin` of the next range for this stream.
+            Range = Range0#range{iterator = ItEnd},
+            Ranges = Ranges0 ++ [Range#range{iterator = ItEnd}],
+            Inflight = Inflight0#inflight{
+                next_seqno = UntilSeqno,
+                offset_ranges = Ranges
+            },
+            fetch(SessionId, Inflight, Streams, N - Size, [Publishes | Acc]);
+        0 ->
+            fetch(SessionId, Inflight0, Streams, N, Acc)
+    end;
+fetch(_SessionId, Inflight, _Streams, _N, Acc) ->
+    Publishes = lists:append(lists:reverse(Acc)),
+    {Publishes, Inflight}.
+
+discard_acked(
+    SessionId,
+    Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}
+) ->
+    %% TODO: This could be kept and incrementally updated in the inflight state.
+    Checkpoints = find_checkpoints(Ranges0),
+    %% TODO: Wrap this in `mria:async_dirty/2`?
+    Ranges = discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Ranges0),
+    Inflight0#inflight{offset_ranges = Ranges}.
+
+find_checkpoints(Ranges) ->
+    lists:foldl(
+        fun(#range{stream = StreamRef, until = Until}, Acc) ->
+            %% For each stream, remember the last range over this stream.
+            Acc#{StreamRef => Until}
+        end,
+        #{},
+        Ranges
+    ).
+
+discard_acked_ranges(
+    SessionId,
+    AckedUntil,
+    Checkpoints,
+    [Range = #range{until = Until, stream = StreamRef} | Rest]
+) when Until =< AckedUntil ->
+    %% This range has been fully acked.
+    %% Either discard it completely, or preserve the iterator for the next range
+    %% over this stream (i.e. a checkpoint).
+    RangeKept =
+        case maps:get(StreamRef, Checkpoints) of
+            CP when CP > Until ->
+                discard_range(SessionId, Range),
+                [];
+            Until ->
+                checkpoint_range(SessionId, Range),
+                [Range#range{type = checkpoint}]
+        end,
+    %% Since we're (intentionally) not using transactions here, it's important to
+    %% issue database writes in the same order in which ranges are stored: from
+    %% the oldest to the newest. This is also why we need to compute which ranges
+    %% should become checkpoints before we start writing anything.
+    RangeKept ++ discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Rest);
+discard_acked_ranges(_SessionId, _AckedUntil, _Checkpoints, Ranges) ->
+    %% The rest of ranges (if any) still have unacked messages.
+    Ranges.
+
+replay_range(
+    Range0 = #range{type = inflight, first = First, until = Until, iterator = It},
+    AckedUntil,
+    Acc
+) ->
+    Size = range_size(First, Until),
+    FirstUnacked = max(First, AckedUntil),
+    {ok, ItNext, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
+    MessagesUnacked =
+        case FirstUnacked of
+            First ->
+                Messages;
+            _ ->
+                lists:nthtail(range_size(First, FirstUnacked), Messages)
+        end,
+    %% Asserting that range is consistent with the message storage state.
+    {Replies, Until} = publish(FirstUnacked, MessagesUnacked),
+    Range = Range0#range{iterator = ItNext},
+    {Range, Replies ++ Acc};
+replay_range(Range0 = #range{type = checkpoint}, _AckedUntil, Acc) ->
+    {Range0, Acc}.
+
+publish(FirstSeqno, Messages) ->
+    lists:mapfoldl(
+        fun(Message, Seqno) ->
+            PacketId = seqno_to_packet_id(Seqno),
+            {{PacketId, Message}, next_seqno(Seqno)}
+        end,
+        FirstSeqno,
+        Messages
+    ).
 
--spec update_iterator(emqx_persistent_session_ds:id(), emqx_ds:stream(), emqx_ds:iterator()) -> ok.
-update_iterator(DSSessionId, Stream, Iterator) ->
-    %% Workaround: we convert `Stream' to a binary before attempting to store it in
-    %% mnesia(rocksdb) because of a bug in `mnesia_rocksdb' when trying to do
-    %% `mnesia:dirty_all_keys' later.
-    StreamBin = term_to_binary(Stream),
-    mria:dirty_write(?SESSION_ITER_TAB, #ds_iter{id = {DSSessionId, StreamBin}, iter = Iterator}).
+-spec preserve_range(emqx_persistent_session_ds:id(), range()) -> ok.
+preserve_range(
+    SessionId,
+    #range{first = First, until = Until, stream = StreamRef, iterator = It}
+) ->
+    DSRange = #ds_pubrange{
+        id = {SessionId, First},
+        until = Until,
+        stream = StreamRef,
+        type = inflight,
+        iterator = It
+    },
+    mria:dirty_write(?SESSION_PUBRANGE_TAB, DSRange).
+
+-spec discard_range(emqx_persistent_session_ds:id(), range()) -> ok.
+discard_range(SessionId, #range{first = First}) ->
+    mria:dirty_delete(?SESSION_PUBRANGE_TAB, {SessionId, First}).
+
+-spec checkpoint_range(emqx_persistent_session_ds:id(), range()) -> ok.
+checkpoint_range(
+    SessionId,
+    #range{type = inflight, first = First, until = Until, stream = StreamRef, iterator = ItNext}
+) ->
+    DSRange = #ds_pubrange{
+        id = {SessionId, First},
+        until = Until,
+        stream = StreamRef,
+        type = checkpoint,
+        iterator = ItNext
+    },
+    mria:dirty_write(?SESSION_PUBRANGE_TAB, DSRange);
+checkpoint_range(_SessionId, #range{type = checkpoint}) ->
+    %% This range should have been checkpointed already.
+    ok.
 
-get_last_iterator(SessionId, Stream, Ranges) ->
-    case lists:keyfind(Stream, #range.stream, lists:reverse(Ranges)) of
+get_last_iterator(DSStream = #ds_stream{ref = StreamRef}, Ranges) ->
+    case lists:keyfind(StreamRef, #range.stream, lists:reverse(Ranges)) of
         false ->
-            get_iterator(SessionId, Stream);
-        #range{iterator_next = Next} ->
-            Next
+            DSStream#ds_stream.beginning;
+        #range{iterator = ItNext} ->
+            ItNext
     end.
 
--spec get_iterator(emqx_persistent_session_ds:id(), emqx_ds:stream()) -> emqx_ds:iterator().
-get_iterator(DSSessionId, Stream) ->
-    %% See comment in `update_iterator'.
-    StreamBin = term_to_binary(Stream),
-    Id = {DSSessionId, StreamBin},
-    [#ds_iter{iter = It}] = mnesia:dirty_read(?SESSION_ITER_TAB, Id),
-    It.
-
--spec get_streams(emqx_persistent_session_ds:id()) -> [emqx_ds:stream()].
+-spec get_streams(emqx_persistent_session_ds:id()) -> [ds_stream()].
 get_streams(SessionId) ->
-    lists:map(
-        fun(#ds_stream{stream = Stream}) ->
-            Stream
-        end,
-        mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId)
-    ).
+    mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId).
+
+next_seqno(Seqno) ->
+    NextSeqno = Seqno + 1,
+    case seqno_to_packet_id(NextSeqno) of
+        0 ->
+            %% We skip sequence numbers that lead to PacketId = 0 to
+            %% simplify math. Note: it leads to occasional gaps in the
+            %% sequence numbers.
+            NextSeqno + 1;
+        _ ->
+            NextSeqno
+    end.
 
 %% Reconstruct session counter by adding most significant bits from
 %% the current counter to the packet id.
--spec packet_id_to_seqno(non_neg_integer(), emqx_types:packet_id()) -> non_neg_integer().
+-spec packet_id_to_seqno(_Next :: seqno(), emqx_types:packet_id()) -> seqno().
 packet_id_to_seqno(NextSeqNo, PacketId) ->
     Epoch = NextSeqNo bsr 16,
     case packet_id_to_seqno_(Epoch, PacketId) of
@@ -243,10 +376,20 @@ packet_id_to_seqno(NextSeqNo, PacketId) ->
             packet_id_to_seqno_(Epoch - 1, PacketId)
     end.
 
--spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> non_neg_integer().
+-spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> seqno().
 packet_id_to_seqno_(Epoch, PacketId) ->
     (Epoch bsl 16) + PacketId.
 
+-spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id().
+seqno_to_packet_id(Seqno) ->
+    Seqno rem 16#10000.
+
+range_size(FirstSeqno, UntilSeqno) ->
+    %% This function assumes that gaps in the sequence ID occur _only_ when the
+    %% packet ID wraps.
+    Size = UntilSeqno - FirstSeqno,
+    Size + (FirstSeqno bsr 16) - (UntilSeqno bsr 16).
+
 -spec shuffle([A]) -> [A].
 shuffle(L0) ->
     L1 = lists:map(
@@ -259,6 +402,10 @@ shuffle(L0) ->
     {_, L} = lists:unzip(L2),
     L.
 
+ro_transaction(Fun) ->
+    {atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
+    Res.
+
 -ifdef(TEST).
 
 %% This test only tests boundary conditions (to make sure property-based test didn't skip them):
@@ -311,4 +458,40 @@ seqno_gen(NextSeqNo) ->
     Max = max(0, NextSeqNo - 1),
     range(Min, Max).
 
+range_size_test_() ->
+    [
+        ?_assertEqual(0, range_size(42, 42)),
+        ?_assertEqual(1, range_size(42, 43)),
+        ?_assertEqual(1, range_size(16#ffff, 16#10001)),
+        ?_assertEqual(16#ffff - 456 + 123, range_size(16#1f0000 + 456, 16#200000 + 123))
+    ].
+
+compute_inflight_range_test_() ->
+    [
+        ?_assertEqual(
+            {1, 1},
+            compute_inflight_range([])
+        ),
+        ?_assertEqual(
+            {12, 42},
+            compute_inflight_range([
+                #range{first = 1, until = 2, type = checkpoint},
+                #range{first = 4, until = 8, type = checkpoint},
+                #range{first = 11, until = 12, type = checkpoint},
+                #range{first = 12, until = 13, type = inflight},
+                #range{first = 13, until = 20, type = inflight},
+                #range{first = 20, until = 42, type = inflight}
+            ])
+        ),
+        ?_assertEqual(
+            {13, 13},
+            compute_inflight_range([
+                #range{first = 1, until = 2, type = checkpoint},
+                #range{first = 4, until = 8, type = checkpoint},
+                #range{first = 11, until = 12, type = checkpoint},
+                #range{first = 12, until = 13, type = checkpoint}
+            ])
+        )
+    ].
+
 -endif.

+ 91 - 79
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -76,7 +76,7 @@
     list_all_sessions/0,
     list_all_subscriptions/0,
     list_all_streams/0,
-    list_all_iterators/0
+    list_all_pubranges/0
 ]).
 -endif.
 
@@ -359,15 +359,16 @@ handle_timeout(
         end,
     ensure_timer(pull, Timeout),
     {ok, Publishes, Session#{inflight => Inflight}};
-handle_timeout(_ClientInfo, get_streams, Session = #{id := Id}) ->
-    renew_streams(Id),
+handle_timeout(_ClientInfo, get_streams, Session) ->
+    renew_streams(Session),
     ensure_timer(get_streams),
     {ok, [], Session}.
 
 -spec replay(clientinfo(), [], session()) ->
     {ok, replies(), session()}.
-replay(_ClientInfo, [], Session = #{}) ->
-    {ok, [], Session}.
+replay(_ClientInfo, [], Session = #{inflight := Inflight0}) ->
+    {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(Inflight0),
+    {ok, Replies, Session#{inflight := Inflight}}.
 
 %%--------------------------------------------------------------------
 
@@ -474,17 +475,20 @@ create_tables() ->
         ]
     ),
     ok = mria:create_table(
-        ?SESSION_ITER_TAB,
+        ?SESSION_PUBRANGE_TAB,
         [
             {rlog_shard, ?DS_MRIA_SHARD},
-            {type, set},
+            {type, ordered_set},
             {storage, storage()},
-            {record_name, ds_iter},
-            {attributes, record_info(fields, ds_iter)}
+            {record_name, ds_pubrange},
+            {attributes, record_info(fields, ds_pubrange)}
         ]
     ),
     ok = mria:wait_for_tables([
-        ?SESSION_TAB, ?SESSION_SUBSCRIPTIONS_TAB, ?SESSION_STREAM_TAB, ?SESSION_ITER_TAB
+        ?SESSION_TAB,
+        ?SESSION_SUBSCRIPTIONS_TAB,
+        ?SESSION_STREAM_TAB,
+        ?SESSION_PUBRANGE_TAB
     ]),
     ok.
 
@@ -512,9 +516,10 @@ session_open(SessionId) ->
                 Session = export_session(Record),
                 DSSubs = session_read_subscriptions(SessionId),
                 Subscriptions = export_subscriptions(DSSubs),
+                Inflight = emqx_persistent_message_ds_replayer:open(SessionId),
                 Session#{
                     subscriptions => Subscriptions,
-                    inflight => emqx_persistent_message_ds_replayer:new()
+                    inflight => Inflight
                 };
             [] ->
                 false
@@ -549,7 +554,7 @@ session_create(SessionId, Props) ->
 session_drop(DSSessionId) ->
     transaction(fun() ->
         ok = session_drop_subscriptions(DSSessionId),
-        ok = session_drop_iterators(DSSessionId),
+        ok = session_drop_pubranges(DSSessionId),
         ok = session_drop_streams(DSSessionId),
         ok = mnesia:delete(?SESSION_TAB, DSSessionId, write)
     end).
@@ -663,77 +668,82 @@ do_ensure_all_iterators_closed(_DSSessionID) ->
 %% Reading batches
 %%--------------------------------------------------------------------
 
--spec renew_streams(id()) -> ok.
-renew_streams(DSSessionId) ->
-    Subscriptions = ro_transaction(fun() -> session_read_subscriptions(DSSessionId) end),
-    ExistingStreams = ro_transaction(fun() -> mnesia:read(?SESSION_STREAM_TAB, DSSessionId) end),
-    lists:foreach(
-        fun(#ds_sub{id = {_, TopicFilter}, start_time = StartTime}) ->
-            renew_streams(DSSessionId, ExistingStreams, TopicFilter, StartTime)
+-spec renew_streams(session()) -> ok.
+renew_streams(#{id := SessionId, subscriptions := Subscriptions}) ->
+    transaction(fun() ->
+        ExistingStreams = mnesia:read(?SESSION_STREAM_TAB, SessionId, write),
+        maps:fold(
+            fun(TopicFilter, #{start_time := StartTime}, Streams) ->
+                TopicFilterWords = emqx_topic:words(TopicFilter),
+                renew_topic_streams(SessionId, TopicFilterWords, StartTime, Streams)
+            end,
+            ExistingStreams,
+            Subscriptions
+        )
+    end),
+    ok.
+
+-spec renew_topic_streams(id(), topic_filter_words(), emqx_ds:time(), _Acc :: [ds_stream()]) -> ok.
+renew_topic_streams(DSSessionId, TopicFilter, StartTime, ExistingStreams) ->
+    TopicStreams = emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
+    lists:foldl(
+        fun({Rank, Stream}, Streams) ->
+            case lists:keymember(Stream, #ds_stream.stream, Streams) of
+                true ->
+                    Streams;
+                false ->
+                    StreamRef = length(Streams) + 1,
+                    DSStream = session_store_stream(
+                        DSSessionId,
+                        StreamRef,
+                        Stream,
+                        Rank,
+                        TopicFilter,
+                        StartTime
+                    ),
+                    [DSStream | Streams]
+            end
         end,
-        Subscriptions
+        ExistingStreams,
+        TopicStreams
     ).
 
--spec renew_streams(id(), [ds_stream()], topic_filter_words(), emqx_ds:time()) -> ok.
-renew_streams(DSSessionId, ExistingStreams, TopicFilter, StartTime) ->
-    AllStreams = emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
-    transaction(
-        fun() ->
-            lists:foreach(
-                fun({Rank, Stream}) ->
-                    Rec = #ds_stream{
-                        session = DSSessionId,
-                        topic_filter = TopicFilter,
-                        stream = Stream,
-                        rank = Rank
-                    },
-                    case lists:member(Rec, ExistingStreams) of
-                        true ->
-                            ok;
-                        false ->
-                            mnesia:write(?SESSION_STREAM_TAB, Rec, write),
-                            {ok, Iterator} = emqx_ds:make_iterator(
-                                ?PERSISTENT_MESSAGE_DB, Stream, TopicFilter, StartTime
-                            ),
-                            %% Workaround: we convert `Stream' to a binary before
-                            %% attempting to store it in mnesia(rocksdb) because of a bug
-                            %% in `mnesia_rocksdb' when trying to do
-                            %% `mnesia:dirty_all_keys' later.
-                            StreamBin = term_to_binary(Stream),
-                            IterRec = #ds_iter{id = {DSSessionId, StreamBin}, iter = Iterator},
-                            mnesia:write(?SESSION_ITER_TAB, IterRec, write)
-                    end
-                end,
-                AllStreams
-            )
-        end
-    ).
+session_store_stream(DSSessionId, StreamRef, Stream, Rank, TopicFilter, StartTime) ->
+    {ok, ItBegin} = emqx_ds:make_iterator(
+        ?PERSISTENT_MESSAGE_DB,
+        Stream,
+        TopicFilter,
+        StartTime
+    ),
+    DSStream = #ds_stream{
+        session = DSSessionId,
+        ref = StreamRef,
+        stream = Stream,
+        rank = Rank,
+        beginning = ItBegin
+    },
+    mnesia:write(?SESSION_STREAM_TAB, DSStream, write),
+    DSStream.
 
 %% must be called inside a transaction
 -spec session_drop_streams(id()) -> ok.
 session_drop_streams(DSSessionId) ->
-    MS = ets:fun2ms(
-        fun(#ds_stream{session = DSSessionId0}) when DSSessionId0 =:= DSSessionId ->
-            DSSessionId0
-        end
-    ),
-    StreamIDs = mnesia:select(?SESSION_STREAM_TAB, MS, write),
-    lists:foreach(fun(Key) -> mnesia:delete(?SESSION_STREAM_TAB, Key, write) end, StreamIDs).
+    mnesia:delete(?SESSION_STREAM_TAB, DSSessionId, write).
 
 %% must be called inside a transaction
--spec session_drop_iterators(id()) -> ok.
-session_drop_iterators(DSSessionId) ->
+-spec session_drop_pubranges(id()) -> ok.
+session_drop_pubranges(DSSessionId) ->
     MS = ets:fun2ms(
-        fun(#ds_iter{id = {DSSessionId0, StreamBin}}) when DSSessionId0 =:= DSSessionId ->
-            StreamBin
+        fun(#ds_pubrange{id = {DSSessionId0, First}}) when DSSessionId0 =:= DSSessionId ->
+            {DSSessionId, First}
         end
     ),
-    StreamBins = mnesia:select(?SESSION_ITER_TAB, MS, write),
+    RangeIds = mnesia:select(?SESSION_PUBRANGE_TAB, MS, write),
     lists:foreach(
-        fun(StreamBin) ->
-            mnesia:delete(?SESSION_ITER_TAB, {DSSessionId, StreamBin}, write)
+        fun(RangeId) ->
+            mnesia:delete(?SESSION_PUBRANGE_TAB, RangeId, write)
         end,
-        StreamBins
+        RangeIds
     ).
 
 %%--------------------------------------------------------------------------------
@@ -758,7 +768,7 @@ export_subscriptions(DSSubs) ->
     ).
 
 export_session(#session{} = Record) ->
-    export_record(Record, #session.id, [id, created_at, expires_at, inflight, props], #{}).
+    export_record(Record, #session.id, [id, created_at, expires_at, props], #{}).
 
 export_subscription(#ds_sub{} = Record) ->
     export_record(Record, #ds_sub.start_time, [start_time, props, extra], #{}).
@@ -833,16 +843,18 @@ list_all_streams() ->
     ),
     maps:from_list(DSStreams).
 
-list_all_iterators() ->
-    DSIterIds = mnesia:dirty_all_keys(?SESSION_ITER_TAB),
-    DSIters = lists:map(
-        fun(DSIterId) ->
-            [Record] = mnesia:dirty_read(?SESSION_ITER_TAB, DSIterId),
-            {DSIterId, export_record(Record, #ds_iter.id, [id, iter], #{})}
+list_all_pubranges() ->
+    DSPubranges = mnesia:dirty_match_object(?SESSION_PUBRANGE_TAB, #ds_pubrange{_ = '_'}),
+    lists:foldl(
+        fun(Record = #ds_pubrange{id = {SessionId, First}}, Acc) ->
+            Range = export_record(
+                Record, #ds_pubrange.until, [until, stream, type, iterator], #{first => First}
+            ),
+            maps:put(SessionId, maps:get(SessionId, Acc, []) ++ [Range], Acc)
         end,
-        DSIterIds
-    ),
-    maps:from_list(DSIters).
+        #{},
+        DSPubranges
+    ).
 
 %% ifdef(TEST)
 -endif.

+ 15 - 8
apps/emqx/src/emqx_persistent_session_ds.hrl

@@ -21,7 +21,7 @@
 -define(SESSION_TAB, emqx_ds_session).
 -define(SESSION_SUBSCRIPTIONS_TAB, emqx_ds_session_subscriptions).
 -define(SESSION_STREAM_TAB, emqx_ds_stream_tab).
--define(SESSION_ITER_TAB, emqx_ds_iter_tab).
+-define(SESSION_PUBRANGE_TAB, emqx_ds_pubrange_tab).
 -define(DS_MRIA_SHARD, emqx_ds_session_shard).
 
 -record(ds_sub, {
@@ -34,17 +34,24 @@
 
 -record(ds_stream, {
     session :: emqx_persistent_session_ds:id(),
-    topic_filter :: emqx_ds:topic_filter(),
+    ref :: _StreamRef,
     stream :: emqx_ds:stream(),
-    rank :: emqx_ds:stream_rank()
+    rank :: emqx_ds:stream_rank(),
+    beginning :: emqx_ds:iterator()
 }).
 -type ds_stream() :: #ds_stream{}.
--type ds_stream_bin() :: binary().
 
--record(ds_iter, {
-    id :: {emqx_persistent_session_ds:id(), ds_stream_bin()},
-    iter :: emqx_ds:iterator()
+-record(ds_pubrange, {
+    id :: {
+        _Session :: emqx_persistent_session_ds:id(),
+        _First :: emqx_persistent_message_ds_replayer:seqno()
+    },
+    until :: emqx_persistent_message_ds_replayer:seqno(),
+    stream :: _StreamRef,
+    type :: inflight | checkpoint,
+    iterator :: emqx_ds:iterator()
 }).
+-type ds_pubrange() :: #ds_pubrange{}.
 
 -record(session, {
     %% same as clientid
@@ -52,7 +59,7 @@
     %% creation time
     created_at :: _Millisecond :: non_neg_integer(),
     expires_at = never :: _Millisecond :: non_neg_integer() | never,
-    inflight :: emqx_persistent_message_ds_replayer:inflight(),
+    % last_ack = 0 :: emqx_persistent_message_ds_replayer:seqno(),
     %% for future usage
     props = #{} :: map()
 }).