Просмотр исходного кода

refactor(sessds): move parts of message processing to replayer

To simplify the processing flow, reducing the number of back-and-forth
between the session and the replayer.
Andrew Mayorov 2 лет назад
Родитель
Сommit
fd26e690b8

+ 113 - 39
apps/emqx/src/emqx_persistent_message_ds_replayer.erl

@@ -33,6 +33,8 @@
 -export_type([inflight/0, seqno/0]).
 
 -include_lib("emqx/include/logger.hrl").
+-include_lib("emqx/include/emqx_mqtt.hrl").
+-include_lib("emqx_utils/include/emqx_message.hrl").
 -include("emqx_persistent_session_ds.hrl").
 
 -ifdef(TEST).
@@ -46,6 +48,8 @@
 -define(COMP, 1).
 
 -define(TRACK_FLAG(WHICH), (1 bsl WHICH)).
+-define(TRACK_FLAGS_ALL, ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)).
+-define(TRACK_FLAGS_NONE, 0).
 
 %%================================================================================
 %% Type declarations
@@ -66,9 +70,10 @@
 
 -opaque inflight() :: #inflight{}.
 
--type replies() :: reply() | [replies()].
--type reply() :: emqx_session:reply() | fun((emqx_types:packet_id()) -> emqx_session:replies()).
--type reply_fun() :: fun((seqno(), emqx_types:message()) -> replies()).
+-type message() :: emqx_types:message().
+-type replies() :: [emqx_session:reply()].
+
+-type preproc_fun() :: fun((message()) -> message() | [message()]).
 
 %%================================================================================
 %% API funcions
@@ -115,11 +120,11 @@ n_inflight(#inflight{offset_ranges = Ranges}) ->
         Ranges
     ).
 
--spec replay(reply_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
-replay(ReplyFun, Inflight0 = #inflight{offset_ranges = Ranges0}) ->
+-spec replay(preproc_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
+replay(PreprocFunFun, Inflight0 = #inflight{offset_ranges = Ranges0, commits = Commits}) ->
     {Ranges, Replies} = lists:mapfoldr(
         fun(Range, Acc) ->
-            replay_range(ReplyFun, Range, Acc)
+            replay_range(PreprocFunFun, Commits, Range, Acc)
         end,
         [],
         Ranges0
@@ -165,9 +170,9 @@ commit_offset(
             {false, Inflight0}
     end.
 
--spec poll(reply_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
+-spec poll(preproc_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
     {emqx_session:replies(), inflight()}.
-poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
+poll(PreprocFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
     MinBatchSize = emqx_config:get([session_persistence, min_batch_size]),
     FetchThreshold = min(MinBatchSize, ceil(WindowSize / 2)),
     FreeSpace = WindowSize - n_inflight(Inflight0),
@@ -181,7 +186,7 @@ poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize
         true ->
             %% TODO: Wrap this in `mria:async_dirty/2`?
             Streams = shuffle(get_streams(SessionId)),
-            fetch(ReplyFun, SessionId, Inflight0, Streams, FreeSpace, [])
+            fetch(PreprocFun, SessionId, Inflight0, Streams, FreeSpace, [])
     end.
 
 %% Which seqno this track is committed until.
@@ -248,22 +253,22 @@ get_ranges(SessionId) ->
     ),
     mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
 
-fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
+fetch(PreprocFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
     #inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
     ItBegin = get_last_iterator(DSStream, Ranges),
     {ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
     case Messages of
         [] ->
-            fetch(ReplyFun, SessionId, Inflight0, Streams, N, Acc);
+            fetch(PreprocFun, SessionId, Inflight0, Streams, N, Acc);
         _ ->
             %% We need to preserve the iterator pointing to the beginning of the
             %% range, so that we can replay it if needed.
-            {Publishes, {UntilSeqno, Tracks}} = publish(ReplyFun, FirstSeqno, Messages),
+            {Publishes, UntilSeqno} = publish_fetch(PreprocFun, FirstSeqno, Messages),
             Size = range_size(FirstSeqno, UntilSeqno),
             Range0 = #ds_pubrange{
                 id = {SessionId, FirstSeqno},
                 type = ?T_INFLIGHT,
-                tracks = Tracks,
+                tracks = compute_pub_tracks(Publishes),
                 until = UntilSeqno,
                 stream = DSStream#ds_stream.ref,
                 iterator = ItBegin
@@ -277,7 +282,7 @@ fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 -
                 next_seqno = UntilSeqno,
                 offset_ranges = Ranges ++ [Range]
             },
-            fetch(ReplyFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
+            fetch(PreprocFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
     end;
 fetch(_ReplyFun, _SessionId, Inflight, _Streams, _N, Acc) ->
     Publishes = lists:append(lists:reverse(Acc)),
@@ -374,19 +379,20 @@ discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) ->
     TAck bor TComp.
 
 replay_range(
-    ReplyFun,
+    PreprocFun,
+    Commits,
     Range0 = #ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until, iterator = It},
     Acc
 ) ->
     Size = range_size(First, Until),
     {ok, ItNext, MessagesUnacked} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
     %% Asserting that range is consistent with the message storage state.
-    {Replies, {Until, _TracksInitial}} = publish(ReplyFun, First, MessagesUnacked),
+    {Replies, Until} = publish_replay(PreprocFun, Commits, First, MessagesUnacked),
     %% Again, we need to keep the iterator pointing past the end of the
     %% range, so that we can pick up where we left off.
     Range = keep_next_iterator(ItNext, Range0),
     {Range, Replies ++ Acc};
-replay_range(_ReplyFun, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
+replay_range(_PreprocFun, _Commits, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
     {Range0, Acc}.
 
 validate_commit(
@@ -419,33 +425,88 @@ get_commit_next(rec, #inflight{next_seqno = NextSeqno}) ->
 get_commit_next(comp, #inflight{commits = Commits}) ->
     maps:get(rec, Commits).
 
-publish(ReplyFun, FirstSeqno, Messages) ->
-    lists:mapfoldl(
-        fun(Message, Acc = {Seqno, _Tracks}) ->
-            Reply = ReplyFun(Seqno, Message),
-            publish_reply(Reply, Acc)
+publish_fetch(PreprocFun, FirstSeqno, Messages) ->
+    flatmapfoldl(
+        fun(MessageIn, Acc) ->
+            Message = PreprocFun(MessageIn),
+            publish_fetch(Message, Acc)
+        end,
+        FirstSeqno,
+        Messages
+    ).
+
+publish_fetch(#message{qos = ?QOS_0} = Message, Seqno) ->
+    {{undefined, Message}, Seqno};
+publish_fetch(#message{} = Message, Seqno) ->
+    PacketId = seqno_to_packet_id(Seqno),
+    {{PacketId, Message}, next_seqno(Seqno)};
+publish_fetch(Messages, Seqno) ->
+    flatmapfoldl(fun publish_fetch/2, Seqno, Messages).
+
+publish_replay(PreprocFun, Commits, FirstSeqno, Messages) ->
+    #{ack := AckedUntil, comp := CompUntil, rec := RecUntil} = Commits,
+    flatmapfoldl(
+        fun(MessageIn, Acc) ->
+            Message = PreprocFun(MessageIn),
+            publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
         end,
-        {FirstSeqno, 0},
+        FirstSeqno,
         Messages
     ).
 
-publish_reply(Replies = [_ | _], Acc) ->
-    lists:mapfoldl(fun publish_reply/2, Acc, Replies);
-publish_reply(Reply, {Seqno, Tracks}) when is_function(Reply) ->
-    Pub = Reply(seqno_to_packet_id(Seqno)),
-    NextSeqno = next_seqno(Seqno),
-    NextTracks = add_pub_track(Pub, Tracks),
-    {Pub, {NextSeqno, NextTracks}};
-publish_reply(Reply, Acc) ->
-    {Reply, Acc}.
-
-add_pub_track({PacketId, Message}, Tracks) when is_integer(PacketId) ->
-    case emqx_message:qos(Message) of
-        1 -> ?TRACK_FLAG(?ACK) bor Tracks;
-        2 -> ?TRACK_FLAG(?COMP) bor Tracks;
-        _ -> Tracks
+publish_replay(#message{qos = ?QOS_0}, _, _, _, Seqno) ->
+    %% QoS 0 (at most once) messages should not be replayed.
+    {[], Seqno};
+publish_replay(#message{qos = Qos} = Message, AckedUntil, CompUntil, RecUntil, Seqno) ->
+    case Qos of
+        ?QOS_1 when Seqno < AckedUntil ->
+            %% This message has already been acked, so we can skip it.
+            %% We still need to advance seqno, because previously we assigned this message
+            %% a unique Packet Id.
+            {[], next_seqno(Seqno)};
+        ?QOS_2 when Seqno < CompUntil ->
+            %% This message's flow has already been fully completed, so we can skip it.
+            %% We still need to advance seqno, because previously we assigned this message
+            %% a unique Packet Id.
+            {[], next_seqno(Seqno)};
+        ?QOS_2 when Seqno < RecUntil ->
+            %% This message's flow has been partially completed, we need to resend a PUBREL.
+            PacketId = seqno_to_packet_id(Seqno),
+            Pub = {pubrel, PacketId},
+            {Pub, next_seqno(Seqno)};
+        _ ->
+            %% This message flow hasn't been acked and/or received, we need to resend it.
+            PacketId = seqno_to_packet_id(Seqno),
+            Pub = {PacketId, emqx_message:set_flag(dup, true, Message)},
+            {Pub, next_seqno(Seqno)}
     end;
-add_pub_track(_Pub, Tracks) ->
+publish_replay([], _, _, _, Seqno) ->
+    {[], Seqno};
+publish_replay(Messages, AckedUntil, CompUntil, RecUntil, Seqno) ->
+    flatmapfoldl(
+        fun(Message, Acc) ->
+            publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
+        end,
+        Seqno,
+        Messages
+    ).
+
+-spec compute_pub_tracks(replies()) -> non_neg_integer().
+compute_pub_tracks(Pubs) ->
+    compute_pub_tracks(Pubs, ?TRACK_FLAGS_NONE).
+
+compute_pub_tracks(_Pubs, Tracks = ?TRACK_FLAGS_ALL) ->
+    Tracks;
+compute_pub_tracks([Pub | Rest], Tracks) ->
+    Track =
+        case Pub of
+            {_PacketId, #message{qos = ?QOS_1}} -> ?TRACK_FLAG(?ACK);
+            {_PacketId, #message{qos = ?QOS_2}} -> ?TRACK_FLAG(?COMP);
+            {pubrel, _PacketId} -> ?TRACK_FLAG(?COMP);
+            _ -> ?TRACK_FLAGS_NONE
+        end,
+    compute_pub_tracks(Rest, Track bor Tracks);
+compute_pub_tracks([], Tracks) ->
     Tracks.
 
 keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) ->
@@ -550,6 +611,19 @@ shuffle(L0) ->
     {_, L} = lists:unzip(L2),
     L.
 
+-spec flatmapfoldl(fun((X, Acc) -> {Y | [Y], Acc}), Acc, [X]) -> {[Y], Acc}.
+flatmapfoldl(_Fun, Acc, []) ->
+    {[], Acc};
+flatmapfoldl(Fun, Acc, [X | Xs]) ->
+    {Ys, NAcc} = Fun(X, Acc),
+    {Zs, FAcc} = flatmapfoldl(Fun, NAcc, Xs),
+    case is_list(Ys) of
+        true ->
+            {Ys ++ Zs, FAcc};
+        _ ->
+            {[Ys | Zs], FAcc}
+    end.
+
 ro_transaction(Fun) ->
     {atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
     Res.

+ 13 - 39
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -152,9 +152,8 @@
 -spec create(clientinfo(), conninfo(), emqx_session:conf()) ->
     session().
 create(#{clientid := ClientID}, ConnInfo, Conf) ->
-    % TODO: expiration
-    Session = ensure_timers(session_ensure_new(ClientID, ConnInfo)),
-    preserve_conf(ConnInfo, Conf, Session).
+    Session = session_ensure_new(ClientID, ConnInfo),
+    apply_conf(ConnInfo, Conf, ensure_timers(Session)).
 
 -spec open(clientinfo(), conninfo(), emqx_session:conf()) ->
     {_IsPresent :: true, session(), []} | false.
@@ -168,13 +167,13 @@ open(#{clientid := ClientID} = _ClientInfo, ConnInfo, Conf) ->
     ok = emqx_cm:discard_session(ClientID),
     case session_open(ClientID, ConnInfo) of
         Session0 = #{} ->
-            Session = preserve_conf(ConnInfo, Conf, Session0),
+            Session = apply_conf(ConnInfo, Conf, Session0),
             {true, ensure_timers(Session), []};
         false ->
             false
     end.
 
-preserve_conf(ConnInfo, Conf, Session) ->
+apply_conf(ConnInfo, Conf, Session) ->
     Session#{
         receive_maximum => receive_maximum(ConnInfo),
         props => Conf
@@ -399,7 +398,7 @@ deliver(_ClientInfo, _Delivers, Session) ->
     {ok, replies(), session()} | {ok, replies(), timeout(), session()}.
 handle_timeout(
     ClientInfo,
-    pull,
+    ?TIMER_PULL,
     Session0 = #{
         id := Id,
         inflight := Inflight0,
@@ -411,14 +410,9 @@ handle_timeout(
     MaxBatchSize = emqx_config:get([session_persistence, max_batch_size]),
     BatchSize = min(ReceiveMaximum, MaxBatchSize),
     UpgradeQoS = maps:get(upgrade_qos, Conf),
-    ReplyFun = make_reply_fun(ClientInfo, Subs, UpgradeQoS, fun
-        (_Seqno, Message = #message{qos = ?QOS_0}) ->
-            {undefined, Message};
-        (_Seqno, Message) ->
-            fun(PacketId) -> {PacketId, Message} end
-    end),
+    PreprocFun = make_preproc_fun(ClientInfo, Subs, UpgradeQoS),
     {Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll(
-        ReplyFun,
+        PreprocFun,
         Id,
         Inflight0,
         BatchSize
@@ -455,22 +449,8 @@ replay(
     Session = #{inflight := Inflight0, subscriptions := Subs, props := Conf}
 ) ->
     UpgradeQoS = maps:get(upgrade_qos, Conf),
-    AckedUntil = emqx_persistent_message_ds_replayer:committed_until(ack, Inflight0),
-    RecUntil = emqx_persistent_message_ds_replayer:committed_until(rec, Inflight0),
-    CompUntil = emqx_persistent_message_ds_replayer:committed_until(comp, Inflight0),
-    ReplyFun = make_reply_fun(ClientInfo, Subs, UpgradeQoS, fun
-        (_Seqno, #message{qos = ?QOS_0}) ->
-            [];
-        (Seqno, #message{qos = ?QOS_1}) when Seqno < AckedUntil ->
-            fun(_) -> [] end;
-        (Seqno, #message{qos = ?QOS_2}) when Seqno < CompUntil ->
-            fun(_) -> [] end;
-        (Seqno, #message{qos = ?QOS_2}) when Seqno < RecUntil ->
-            fun(PacketId) -> {pubrel, PacketId} end;
-        (_Seqno, Message) ->
-            fun(PacketId) -> {PacketId, emqx_message:set_flag(dup, true, Message)} end
-    end),
-    {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(ReplyFun, Inflight0),
+    PreprocFun = make_preproc_fun(ClientInfo, Subs, UpgradeQoS),
+    {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(PreprocFun, Inflight0),
     {ok, Replies, Session#{inflight := Inflight}}.
 
 %%--------------------------------------------------------------------
@@ -486,23 +466,17 @@ terminate(_Reason, _Session = #{}) ->
 
 %%--------------------------------------------------------------------
 
-make_reply_fun(ClientInfo, Subs, UpgradeQoS, InnerFun) ->
-    fun(Seqno, Message0 = #message{topic = Topic}) ->
+make_preproc_fun(ClientInfo, Subs, UpgradeQoS) ->
+    fun(Message = #message{topic = Topic}) ->
         emqx_utils:flattermap(
             fun(Match) ->
-                emqx_utils:flattermap(
-                    fun(Message) -> InnerFun(Seqno, Message) end,
-                    enrich_message(ClientInfo, Message0, Match, Subs, UpgradeQoS)
-                )
+                #{props := SubOpts} = subs_get_match(Match, Subs),
+                emqx_session:enrich_message(ClientInfo, Message, SubOpts, UpgradeQoS)
             end,
             subs_matches(Topic, Subs)
         )
     end.
 
-enrich_message(ClientInfo, Message, SubMatch, Subs, UpgradeQoS) ->
-    #{props := SubOpts} = subs_get_match(SubMatch, Subs),
-    emqx_session:enrich_message(ClientInfo, Message, SubOpts, UpgradeQoS).
-
 %%--------------------------------------------------------------------
 
 -spec add_subscription(topic_filter(), emqx_types:subopts(), id()) ->