Procházet zdrojové kódy

Merge pull request #12135 from keynslug/fix/ds-qos0-pubranges

fix(sessds): stop overwriting QoS0-only pubrange checkpoints
Andrew Mayorov před 2 roky
rodič
revize
abeb5e985f

+ 42 - 34
apps/emqx/src/emqx_persistent_message_ds_replayer.erl

@@ -113,8 +113,8 @@ n_inflight(#inflight{offset_ranges = Ranges}) ->
         fun
             (#ds_pubrange{type = ?T_CHECKPOINT}, N) ->
                 N;
-            (#ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until}, N) ->
-                N + range_size(First, Until)
+            (#ds_pubrange{type = ?T_INFLIGHT} = Range, N) ->
+                N + range_size(Range)
         end,
         0,
         Ranges
@@ -186,7 +186,11 @@ poll(PreprocFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSi
         true ->
             %% TODO: Wrap this in `mria:async_dirty/2`?
             Streams = shuffle(get_streams(SessionId)),
-            fetch(PreprocFun, SessionId, Inflight0, Streams, FreeSpace, [])
+            Checkpoints = find_checkpoints(Inflight0#inflight.offset_ranges),
+            {Publihes, Inflight} =
+                fetch(PreprocFun, SessionId, Inflight0, Checkpoints, Streams, FreeSpace, []),
+            %% Discard now irrelevant QoS0-only ranges, if any.
+            {Publihes, discard_committed(SessionId, Inflight)}
     end.
 
 %% Which seqno this track is committed until.
@@ -238,7 +242,7 @@ find_committed_until(Track, Ranges) ->
         Ranges
     ),
     case RangesUncommitted of
-        [#ds_pubrange{id = {_, CommittedUntil}} | _] ->
+        [#ds_pubrange{id = {_, CommittedUntil, _StreamRef}} | _] ->
             CommittedUntil;
         [] ->
             undefined
@@ -249,28 +253,27 @@ get_ranges(SessionId) ->
     Pat = erlang:make_tuple(
         record_info(size, ds_pubrange),
         '_',
-        [{1, ds_pubrange}, {#ds_pubrange.id, {SessionId, '_'}}]
+        [{1, ds_pubrange}, {#ds_pubrange.id, {SessionId, '_', '_'}}]
     ),
     mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
 
-fetch(PreprocFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
+fetch(PreprocFun, SessionId, Inflight0, CPs, [Stream | Streams], N, Acc) when N > 0 ->
     #inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
-    ItBegin = get_last_iterator(DSStream, Ranges),
+    ItBegin = get_last_iterator(Stream, CPs),
     {ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
     case Messages of
         [] ->
-            fetch(PreprocFun, SessionId, Inflight0, Streams, N, Acc);
+            fetch(PreprocFun, SessionId, Inflight0, CPs, Streams, N, Acc);
         _ ->
             %% We need to preserve the iterator pointing to the beginning of the
             %% range, so that we can replay it if needed.
             {Publishes, UntilSeqno} = publish_fetch(PreprocFun, FirstSeqno, Messages),
             Size = range_size(FirstSeqno, UntilSeqno),
             Range0 = #ds_pubrange{
-                id = {SessionId, FirstSeqno},
+                id = {SessionId, FirstSeqno, Stream#ds_stream.ref},
                 type = ?T_INFLIGHT,
                 tracks = compute_pub_tracks(Publishes),
                 until = UntilSeqno,
-                stream = DSStream#ds_stream.ref,
                 iterator = ItBegin
             },
             ok = preserve_range(Range0),
@@ -282,9 +285,9 @@ fetch(PreprocFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0
                 next_seqno = UntilSeqno,
                 offset_ranges = Ranges ++ [Range]
             },
-            fetch(PreprocFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
+            fetch(PreprocFun, SessionId, Inflight, CPs, Streams, N - Size, [Publishes | Acc])
     end;
-fetch(_ReplyFun, _SessionId, Inflight, _Streams, _N, Acc) ->
+fetch(_PreprocFun, _SessionId, Inflight, _CPs, _Streams, _N, Acc) ->
     Publishes = lists:append(lists:reverse(Acc)),
     {Publishes, Inflight}.
 
@@ -300,9 +303,9 @@ discard_committed(
 
 find_checkpoints(Ranges) ->
     lists:foldl(
-        fun(#ds_pubrange{stream = StreamRef, until = Until}, Acc) ->
+        fun(#ds_pubrange{id = {_SessionId, _, StreamRef}} = Range, Acc) ->
             %% For each stream, remember the last range over this stream.
-            Acc#{StreamRef => Until}
+            Acc#{StreamRef => Range}
         end,
         #{},
         Ranges
@@ -312,7 +315,7 @@ discard_committed_ranges(
     SessionId,
     Commits,
     Checkpoints,
-    Ranges = [Range = #ds_pubrange{until = Until, stream = StreamRef} | Rest]
+    Ranges = [Range = #ds_pubrange{id = {_SessionId, _, StreamRef}} | Rest]
 ) ->
     case discard_committed_range(Commits, Range) of
         discard ->
@@ -321,11 +324,11 @@ discard_committed_ranges(
             %% over this stream (i.e. a checkpoint).
             RangeKept =
                 case maps:get(StreamRef, Checkpoints) of
-                    CP when CP > Until ->
+                    Range ->
+                        [checkpoint_range(Range)];
+                    _Previous ->
                         discard_range(Range),
-                        [];
-                    Until ->
-                        [checkpoint_range(Range)]
+                        []
                 end,
             %% Since we're (intentionally) not using transactions here, it's important to
             %% issue database writes in the same order in which ranges are stored: from
@@ -381,7 +384,9 @@ discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) ->
 replay_range(
     PreprocFun,
     Commits,
-    Range0 = #ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until, iterator = It},
+    Range0 = #ds_pubrange{
+        type = ?T_INFLIGHT, id = {_, First, _StreamRef}, until = Until, iterator = It
+    },
     Acc
 ) ->
     Size = range_size(First, Until),
@@ -545,10 +550,10 @@ checkpoint_range(Range = #ds_pubrange{type = ?T_CHECKPOINT}) ->
     %% This range should have been checkpointed already.
     Range.
 
-get_last_iterator(DSStream = #ds_stream{ref = StreamRef}, Ranges) ->
-    case lists:keyfind(StreamRef, #ds_pubrange.stream, lists:reverse(Ranges)) of
-        false ->
-            DSStream#ds_stream.beginning;
+get_last_iterator(Stream = #ds_stream{ref = StreamRef}, Checkpoints) ->
+    case maps:get(StreamRef, Checkpoints, none) of
+        none ->
+            Stream#ds_stream.beginning;
         #ds_pubrange{iterator = ItNext} ->
             ItNext
     end.
@@ -593,6 +598,9 @@ packet_id_to_seqno_(NextSeqno, PacketId) ->
             N - ?EPOCH_SIZE
     end.
 
+range_size(#ds_pubrange{id = {_, First, _StreamRef}, until = Until}) ->
+    range_size(First, Until).
+
 range_size(FirstSeqno, UntilSeqno) ->
     %% This function assumes that gaps in the sequence ID occur _only_ when the
     %% packet ID wraps.
@@ -697,23 +705,23 @@ compute_inflight_range_test_() ->
         ?_assertEqual(
             {#{ack => 12, comp => 13}, 42},
             compute_inflight_range([
-                #ds_pubrange{id = {<<>>, 1}, until = 2, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 4}, until = 8, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 11}, until = 12, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 1, 0}, until = 2, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 4, 0}, until = 8, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 11, 0}, until = 12, type = ?T_CHECKPOINT},
                 #ds_pubrange{
-                    id = {<<>>, 12},
+                    id = {<<>>, 12, 0},
                     until = 13,
                     type = ?T_INFLIGHT,
                     tracks = ?TRACK_FLAG(?ACK)
                 },
                 #ds_pubrange{
-                    id = {<<>>, 13},
+                    id = {<<>>, 13, 0},
                     until = 20,
                     type = ?T_INFLIGHT,
                     tracks = ?TRACK_FLAG(?COMP)
                 },
                 #ds_pubrange{
-                    id = {<<>>, 20},
+                    id = {<<>>, 20, 0},
                     until = 42,
                     type = ?T_INFLIGHT,
                     tracks = ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)
@@ -723,10 +731,10 @@ compute_inflight_range_test_() ->
         ?_assertEqual(
             {#{ack => 13, comp => 13}, 13},
             compute_inflight_range([
-                #ds_pubrange{id = {<<>>, 1}, until = 2, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 4}, until = 8, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 11}, until = 12, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 12}, until = 13, type = ?T_CHECKPOINT}
+                #ds_pubrange{id = {<<>>, 1, 0}, until = 2, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 4, 0}, until = 8, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 11, 0}, until = 12, type = ?T_CHECKPOINT},
+                #ds_pubrange{id = {<<>>, 12, 0}, until = 13, type = ?T_CHECKPOINT}
             ])
         )
     ].

+ 13 - 8
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -210,8 +210,8 @@ info(subscriptions_max, #{props := Conf}) ->
     maps:get(max_subscriptions, Conf);
 info(upgrade_qos, #{props := Conf}) ->
     maps:get(upgrade_qos, Conf);
-% info(inflight, #sessmem{inflight = Inflight}) ->
-%     Inflight;
+info(inflight, #{inflight := Inflight}) ->
+    Inflight;
 info(inflight_cnt, #{inflight := Inflight}) ->
     emqx_persistent_message_ds_replayer:n_inflight(Inflight);
 info(inflight_max, #{receive_maximum := ReceiveMaximum}) ->
@@ -788,8 +788,8 @@ session_read_pubranges(DSSessionID) ->
 
 session_read_pubranges(DSSessionId, LockKind) ->
     MS = ets:fun2ms(
-        fun(#ds_pubrange{id = {Sess, First}}) when Sess =:= DSSessionId ->
-            {DSSessionId, First}
+        fun(#ds_pubrange{id = ID}) when element(1, ID) =:= DSSessionId ->
+            ID
         end
     ),
     mnesia:select(?SESSION_PUBRANGE_TAB, MS, LockKind).
@@ -1080,10 +1080,15 @@ list_all_streams() ->
 list_all_pubranges() ->
     DSPubranges = mnesia:dirty_match_object(?SESSION_PUBRANGE_TAB, #ds_pubrange{_ = '_'}),
     lists:foldl(
-        fun(Record = #ds_pubrange{id = {SessionId, First}}, Acc) ->
-            Range = export_record(
-                Record, #ds_pubrange.until, [until, stream, type, iterator], #{first => First}
-            ),
+        fun(Record = #ds_pubrange{id = {SessionId, First, StreamRef}}, Acc) ->
+            Range = #{
+                session => SessionId,
+                stream => StreamRef,
+                first => First,
+                until => Record#ds_pubrange.until,
+                type => Record#ds_pubrange.type,
+                iterator => Record#ds_pubrange.iterator
+            },
             maps:put(SessionId, maps:get(SessionId, Acc, []) ++ [Range], Acc)
         end,
         #{},

+ 3 - 5
apps/emqx/src/emqx_persistent_session_ds.hrl

@@ -50,20 +50,18 @@
         %% What session this range belongs to.
         _Session :: emqx_persistent_session_ds:id(),
         %% Where this range starts.
-        _First :: emqx_persistent_message_ds_replayer:seqno()
+        _First :: emqx_persistent_message_ds_replayer:seqno(),
+        %% Which stream this range is over.
+        _StreamRef
     },
     %% Where this range ends: the first seqno that is not included in the range.
     until :: emqx_persistent_message_ds_replayer:seqno(),
-    %% Which stream this range is over.
-    stream :: _StreamRef,
     %% Type of a range:
     %% * Inflight range is a range of yet unacked messages from this stream.
     %% * Checkpoint range was already acked, its purpose is to keep track of the
     %%   very last iterator for this stream.
     type :: ?T_INFLIGHT | ?T_CHECKPOINT,
     %% What commit tracks this range is part of.
-    %% This is rarely stored: we only need to persist it when the range contains
-    %% QoS 2 messages.
     tracks = 0 :: non_neg_integer(),
     %% Meaning of this depends on the type of the range:
     %% * For inflight range, this is the iterator pointing to the first message in

+ 69 - 0
apps/emqx/test/emqx_persistent_messages_SUITE.erl

@@ -258,6 +258,75 @@ t_qos0(_Config) ->
         emqtt:stop(Pub)
     end.
 
+t_qos0_only_many_streams(_Config) ->
+    ClientId = <<?MODULE_STRING "_sub">>,
+    Sub = connect(ClientId, true, 30),
+    Pub = connect(<<?MODULE_STRING "_pub">>, true, 0),
+    [ConnPid] = emqx_cm:lookup_channels(ClientId),
+    try
+        {ok, _, [1]} = emqtt:subscribe(Sub, <<"t/#">>, qos1),
+
+        [
+            emqtt:publish(Pub, Topic, Payload, ?QOS_0)
+         || {Topic, Payload} <- [
+                {<<"t/1">>, <<"foo">>},
+                {<<"t/2">>, <<"bar">>},
+                {<<"t/3">>, <<"baz">>}
+            ]
+        ],
+        ?assertMatch(
+            [_, _, _],
+            receive_messages(3)
+        ),
+
+        Inflight0 = get_session_inflight(ConnPid),
+
+        [
+            emqtt:publish(Pub, Topic, Payload, ?QOS_0)
+         || {Topic, Payload} <- [
+                {<<"t/2">>, <<"foo">>},
+                {<<"t/2">>, <<"bar">>},
+                {<<"t/1">>, <<"baz">>}
+            ]
+        ],
+        ?assertMatch(
+            [_, _, _],
+            receive_messages(3)
+        ),
+
+        [
+            emqtt:publish(Pub, Topic, Payload, ?QOS_0)
+         || {Topic, Payload} <- [
+                {<<"t/3">>, <<"foo">>},
+                {<<"t/3">>, <<"bar">>},
+                {<<"t/2">>, <<"baz">>}
+            ]
+        ],
+        ?assertMatch(
+            [_, _, _],
+            receive_messages(3)
+        ),
+
+        ?assertMatch(
+            #{pubranges := [_, _, _]},
+            emqx_persistent_session_ds:print_session(ClientId)
+        ),
+
+        Inflight1 = get_session_inflight(ConnPid),
+
+        %% TODO: Kinda stupid way to verify that the runtime state is not growing.
+        ?assert(
+            erlang:external_size(Inflight1) - erlang:external_size(Inflight0) < 16,
+            Inflight1
+        )
+    after
+        emqtt:stop(Sub),
+        emqtt:stop(Pub)
+    end.
+
+get_session_inflight(ConnPid) ->
+    emqx_connection:info({channel, {session, inflight}}, sys:get_state(ConnPid)).
+
 t_publish_as_persistent(_Config) ->
     Sub = connect(<<?MODULE_STRING "1">>, true, 30),
     Pub = connect(<<?MODULE_STRING "2">>, true, 30),