Просмотр исходного кода

fix(sessds): respect subscription options when publishing

Andrew Mayorov 2 лет назад
Родитель
Сommit
3265d2f2aa

+ 11 - 7
apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl

@@ -262,10 +262,12 @@ t_session_subscription_idempotency(Config) ->
         end,
         fun(Trace) ->
             ct:pal("trace:\n  ~p", [Trace]),
-            ConnInfo = #{},
+            Session = erpc:call(
+                Node1, emqx_persistent_session_ds, session_open, [ClientId, _ConnInfo = #{}]
+            ),
             ?assertMatch(
-                #{subscriptions := #{SubTopicFilter := #{}}},
-                erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId, ConnInfo])
+                #{SubTopicFilter := #{}},
+                emqx_session:info(subscriptions, Session)
             )
         end
     ),
@@ -336,10 +338,12 @@ t_session_unsubscription_idempotency(Config) ->
         end,
         fun(Trace) ->
             ct:pal("trace:\n  ~p", [Trace]),
-            ConnInfo = #{},
-            ?assertMatch(
-                #{subscriptions := Subs = #{}} when map_size(Subs) =:= 0,
-                erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId, ConnInfo])
+            Session = erpc:call(
+                Node1, emqx_persistent_session_ds, session_open, [ClientId, _ConnInfo = #{}]
+            ),
+            ?assertEqual(
+                #{},
+                emqx_session:info(subscriptions, Session)
             ),
             ok
         end

+ 20 - 15
apps/emqx/src/emqx_persistent_message_ds_replayer.erl

@@ -66,10 +66,9 @@
 
 -opaque inflight() :: #inflight{}.
 
--type reply_fun() :: fun(
-    (seqno(), emqx_types:message()) ->
-        emqx_session:replies() | {_AdvanceSeqno :: false, emqx_session:replies()}
-).
+-type replies() :: reply() | [replies()].
+-type reply() :: emqx_session:reply() | fun((emqx_types:packet_id()) -> emqx_session:replies()).
+-type reply_fun() :: fun((seqno(), emqx_types:message()) -> replies()).
 
 %%================================================================================
 %% API funcions
@@ -422,26 +421,32 @@ get_commit_next(comp, #inflight{commits = Commits}) ->
 
 publish(ReplyFun, FirstSeqno, Messages) ->
     lists:mapfoldl(
-        fun(Message, {Seqno, TAcc}) ->
-            case ReplyFun(Seqno, Message) of
-                {_Advance = false, Reply} ->
-                    {Reply, {Seqno, TAcc}};
-                Reply ->
-                    NextSeqno = next_seqno(Seqno),
-                    NextTAcc = add_msg_track(Message, TAcc),
-                    {Reply, {NextSeqno, NextTAcc}}
-            end
+        fun(Message, Acc = {Seqno, _Tracks}) ->
+            Reply = ReplyFun(Seqno, Message),
+            publish_reply(Reply, Acc)
         end,
         {FirstSeqno, 0},
         Messages
     ).
 
-add_msg_track(Message, Tracks) ->
+publish_reply(Replies = [_ | _], Acc) ->
+    lists:mapfoldl(fun publish_reply/2, Acc, Replies);
+publish_reply(Reply, {Seqno, Tracks}) when is_function(Reply) ->
+    Pub = Reply(seqno_to_packet_id(Seqno)),
+    NextSeqno = next_seqno(Seqno),
+    NextTracks = add_pub_track(Pub, Tracks),
+    {Pub, {NextSeqno, NextTracks}};
+publish_reply(Reply, Acc) ->
+    {Reply, Acc}.
+
+add_pub_track({PacketId, Message}, Tracks) when is_integer(PacketId) ->
     case emqx_message:qos(Message) of
         1 -> ?TRACK_FLAG(?ACK) bor Tracks;
         2 -> ?TRACK_FLAG(?COMP) bor Tracks;
         _ -> Tracks
-    end.
+    end;
+add_pub_track(_Pub, Tracks) ->
+    Tracks.
 
 keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) ->
     Range#ds_pubrange{

+ 121 - 59
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -102,6 +102,8 @@
 -define(TIMER_BUMP_LAST_ALIVE_AT, timer_bump_last_alive_at).
 -type timer() :: ?TIMER_PULL | ?TIMER_GET_STREAMS | ?TIMER_BUMP_LAST_ALIVE_AT.
 
+-type subscriptions() :: emqx_topic_gbt:t(nil(), subscription()).
+
 -type session() :: #{
     %% Client ID
     id := id(),
@@ -110,7 +112,7 @@
     %% When the client was last considered alive
     last_alive_at := timestamp(),
     %% Client’s Subscriptions.
-    subscriptions := #{topic_filter() => subscription()},
+    subscriptions := subscriptions(),
     %% Inflight messages
     inflight := emqx_persistent_message_ds_replayer:inflight(),
     %% Receive maximum
@@ -119,8 +121,6 @@
     conninfo := emqx_types:conninfo(),
     %% Timers
     timer() => reference(),
-    %% Upgrade QoS?
-    upgrade_qos := boolean(),
     %%
     props := map()
 }.
@@ -177,7 +177,7 @@ open(#{clientid := ClientID} = _ClientInfo, ConnInfo, Conf) ->
 preserve_conf(ConnInfo, Conf, Session) ->
     Session#{
         receive_maximum => receive_maximum(ConnInfo),
-        upgrade_qos => maps:get(upgrade_qos, Conf)
+        props => Conf
     }.
 
 -spec destroy(session() | clientinfo()) -> ok.
@@ -203,10 +203,10 @@ info(created_at, #{created_at := CreatedAt}) ->
     CreatedAt;
 info(is_persistent, #{}) ->
     true;
-info(subscriptions, #{subscriptions := Iters}) ->
-    maps:map(fun(_, #{props := SubOpts}) -> SubOpts end, Iters);
-info(subscriptions_cnt, #{subscriptions := Iters}) ->
-    maps:size(Iters);
+info(subscriptions, #{subscriptions := Subs}) ->
+    subs_to_map(Subs);
+info(subscriptions_cnt, #{subscriptions := Subs}) ->
+    subs_size(Subs);
 info(subscriptions_max, #{props := Conf}) ->
     maps:get(max_subscriptions, Conf);
 info(upgrade_qos, #{props := Conf}) ->
@@ -273,41 +273,40 @@ subscribe(
     TopicFilter,
     SubOpts,
     Session = #{id := ID, subscriptions := Subs}
-) when is_map_key(TopicFilter, Subs) ->
-    Subscription = maps:get(TopicFilter, Subs),
-    NSubscription = update_subscription(TopicFilter, Subscription, SubOpts, ID),
-    {ok, Session#{subscriptions := Subs#{TopicFilter => NSubscription}}};
-subscribe(
-    TopicFilter,
-    SubOpts,
-    Session = #{id := ID, subscriptions := Subs}
 ) ->
-    % TODO: max_subscriptions
-    Subscription = add_subscription(TopicFilter, SubOpts, ID),
-    {ok, Session#{subscriptions := Subs#{TopicFilter => Subscription}}}.
+    case subs_lookup(TopicFilter, Subs) of
+        Subscription = #{} ->
+            NSubscription = update_subscription(TopicFilter, Subscription, SubOpts, ID),
+            NSubs = subs_insert(TopicFilter, NSubscription, Subs),
+            {ok, Session#{subscriptions := NSubs}};
+        undefined ->
+            % TODO: max_subscriptions
+            Subscription = add_subscription(TopicFilter, SubOpts, ID),
+            NSubs = subs_insert(TopicFilter, Subscription, Subs),
+            {ok, Session#{subscriptions := NSubs}}
+    end.
 
 -spec unsubscribe(topic_filter(), session()) ->
     {ok, session(), emqx_types:subopts()} | {error, emqx_types:reason_code()}.
 unsubscribe(
     TopicFilter,
     Session = #{id := ID, subscriptions := Subs}
-) when is_map_key(TopicFilter, Subs) ->
-    Subscription = maps:get(TopicFilter, Subs),
-    SubOpts = maps:get(props, Subscription),
-    ok = del_subscription(TopicFilter, ID),
-    {ok, Session#{subscriptions := maps:remove(TopicFilter, Subs)}, SubOpts};
-unsubscribe(
-    _TopicFilter,
-    _Session = #{}
 ) ->
-    {error, ?RC_NO_SUBSCRIPTION_EXISTED}.
+    case subs_lookup(TopicFilter, Subs) of
+        _Subscription = #{props := SubOpts} ->
+            ok = del_subscription(TopicFilter, ID),
+            NSubs = subs_delete(TopicFilter, Subs),
+            {ok, Session#{subscriptions := NSubs}, SubOpts};
+        undefined ->
+            {error, ?RC_NO_SUBSCRIPTION_EXISTED}
+    end.
 
 -spec get_subscription(topic_filter(), session()) ->
     emqx_types:subopts() | undefined.
 get_subscription(TopicFilter, #{subscriptions := Subs}) ->
-    case maps:get(TopicFilter, Subs, undefined) of
-        Subscription = #{} ->
-            maps:get(props, Subscription);
+    case subs_lookup(TopicFilter, Subs) of
+        _Subscription = #{props := SubOpts} ->
+            SubOpts;
         undefined ->
             undefined
     end.
@@ -328,9 +327,6 @@ publish(_PacketId, Msg, Session) ->
 %% Client -> Broker: PUBACK
 %%--------------------------------------------------------------------
 
-%% FIXME: parts of the commit offset function are mocked
--dialyzer({nowarn_function, puback/3}).
-
 -spec puback(clientinfo(), emqx_types:packet_id(), session()) ->
     {ok, emqx_types:message(), replies(), session()}
     | {error, emqx_types:reason_code()}.
@@ -402,20 +398,27 @@ deliver(_ClientInfo, _Delivers, Session) ->
 -spec handle_timeout(clientinfo(), _Timeout, session()) ->
     {ok, replies(), session()} | {ok, replies(), timeout(), session()}.
 handle_timeout(
-    _ClientInfo,
-    ?TIMER_PULL,
-    Session0 = #{id := Id, inflight := Inflight0, receive_maximum := ReceiveMaximum}
+    ClientInfo,
+    pull,
+    Session0 = #{
+        id := Id,
+        inflight := Inflight0,
+        subscriptions := Subs,
+        props := Conf,
+        receive_maximum := ReceiveMaximum
+    }
 ) ->
     MaxBatchSize = emqx_config:get([session_persistence, max_batch_size]),
     BatchSize = min(ReceiveMaximum, MaxBatchSize),
+    UpgradeQoS = maps:get(upgrade_qos, Conf),
+    ReplyFun = make_reply_fun(ClientInfo, Subs, UpgradeQoS, fun
+        (_Seqno, Message = #message{qos = ?QOS_0}) ->
+            {undefined, Message};
+        (_Seqno, Message) ->
+            fun(PacketId) -> {PacketId, Message} end
+    end),
     {Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll(
-        fun
-            (_Seqno, Message = #message{qos = ?QOS_0}) ->
-                {false, {undefined, Message}};
-            (Seqno, Message) ->
-                PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
-                {PacketId, Message}
-        end,
+        ReplyFun,
         Id,
         Inflight0,
         BatchSize
@@ -446,24 +449,27 @@ handle_timeout(_ClientInfo, ?TIMER_BUMP_LAST_ALIVE_AT, Session0) ->
 
 -spec replay(clientinfo(), [], session()) ->
     {ok, replies(), session()}.
-replay(_ClientInfo, [], Session = #{inflight := Inflight0}) ->
+replay(
+    ClientInfo,
+    [],
+    Session = #{inflight := Inflight0, subscriptions := Subs, props := Conf}
+) ->
+    UpgradeQoS = maps:get(upgrade_qos, Conf),
     AckedUntil = emqx_persistent_message_ds_replayer:committed_until(ack, Inflight0),
     RecUntil = emqx_persistent_message_ds_replayer:committed_until(rec, Inflight0),
     CompUntil = emqx_persistent_message_ds_replayer:committed_until(comp, Inflight0),
-    ReplyFun = fun
+    ReplyFun = make_reply_fun(ClientInfo, Subs, UpgradeQoS, fun
         (_Seqno, #message{qos = ?QOS_0}) ->
-            {false, []};
-        (Seqno, #message{qos = ?QOS_1}) when Seqno < AckedUntil ->
             [];
+        (Seqno, #message{qos = ?QOS_1}) when Seqno < AckedUntil ->
+            fun(_) -> [] end;
         (Seqno, #message{qos = ?QOS_2}) when Seqno < CompUntil ->
-            [];
+            fun(_) -> [] end;
         (Seqno, #message{qos = ?QOS_2}) when Seqno < RecUntil ->
-            PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
-            {pubrel, PacketId};
-        (Seqno, Message) ->
-            PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
-            {PacketId, emqx_message:set_flag(dup, true, Message)}
-    end,
+            fun(PacketId) -> {pubrel, PacketId} end;
+        (_Seqno, Message) ->
+            fun(PacketId) -> {PacketId, emqx_message:set_flag(dup, true, Message)} end
+    end),
     {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(ReplyFun, Inflight0),
     {ok, Replies, Session#{inflight := Inflight}}.
 
@@ -480,6 +486,25 @@ terminate(_Reason, _Session = #{}) ->
 
 %%--------------------------------------------------------------------
 
+make_reply_fun(ClientInfo, Subs, UpgradeQoS, InnerFun) ->
+    fun(Seqno, Message0 = #message{topic = Topic}) ->
+        emqx_utils:flattermap(
+            fun(Match) ->
+                emqx_utils:flattermap(
+                    fun(Message) -> InnerFun(Seqno, Message) end,
+                    enrich_message(ClientInfo, Message0, Match, Subs, UpgradeQoS)
+                )
+            end,
+            subs_matches(Topic, Subs)
+        )
+    end.
+
+enrich_message(ClientInfo, Message, SubMatch, Subs, UpgradeQoS) ->
+    #{props := SubOpts} = subs_get_match(SubMatch, Subs),
+    emqx_session:enrich_message(ClientInfo, Message, SubOpts, UpgradeQoS).
+
+%%--------------------------------------------------------------------
+
 -spec add_subscription(topic_filter(), emqx_types:subopts(), id()) ->
     subscription().
 add_subscription(TopicFilter, SubOpts, DSSessionID) ->
@@ -650,7 +675,7 @@ session_ensure_new(SessionId, ConnInfo) ->
         ok = session_drop_subscriptions(SessionId),
         Session = export_session(session_create(SessionId, ConnInfo)),
         Session#{
-            subscriptions => #{},
+            subscriptions => subs_new(),
             inflight => emqx_persistent_message_ds_replayer:new()
         }
     end).
@@ -842,7 +867,7 @@ do_ensure_all_iterators_closed(_DSSessionID) ->
 renew_streams(#{id := SessionId, subscriptions := Subscriptions}) ->
     transaction(fun() ->
         ExistingStreams = mnesia:read(?SESSION_STREAM_TAB, SessionId, write),
-        maps:fold(
+        subs_fold(
             fun(TopicFilter, #{start_time := StartTime}, Streams) ->
                 TopicFilterWords = emqx_topic:words(TopicFilter),
                 renew_topic_streams(SessionId, TopicFilterWords, StartTime, Streams)
@@ -924,6 +949,43 @@ session_drop_offsets(DSSessionId) ->
 
 %%--------------------------------------------------------------------------------
 
+subs_new() ->
+    emqx_topic_gbt:new().
+
+subs_lookup(TopicFilter, Subs) ->
+    emqx_topic_gbt:lookup(TopicFilter, [], Subs, undefined).
+
+subs_insert(TopicFilter, Subscription, Subs) ->
+    emqx_topic_gbt:insert(TopicFilter, [], Subscription, Subs).
+
+subs_delete(TopicFilter, Subs) ->
+    emqx_topic_gbt:delete(TopicFilter, [], Subs).
+
+subs_matches(Topic, Subs) ->
+    emqx_topic_gbt:matches(Topic, Subs, []).
+
+subs_get_match(M, Subs) ->
+    emqx_topic_gbt:get_record(M, Subs).
+
+subs_size(Subs) ->
+    emqx_topic_gbt:size(Subs).
+
+subs_to_map(Subs) ->
+    subs_fold(
+        fun(TopicFilter, #{props := Props}, Acc) -> Acc#{TopicFilter => Props} end,
+        #{},
+        Subs
+    ).
+
+subs_fold(Fun, AccIn, Subs) ->
+    emqx_topic_gbt:fold(
+        fun(Key, Sub, Acc) -> Fun(emqx_topic_gbt:get_topic(Key), Sub, Acc) end,
+        AccIn,
+        Subs
+    ).
+
+%%--------------------------------------------------------------------------------
+
 transaction(Fun) ->
     case mnesia:is_transaction() of
         true ->
@@ -942,9 +1004,9 @@ ro_transaction(Fun) ->
 export_subscriptions(DSSubs) ->
     lists:foldl(
         fun(DSSub = #ds_sub{id = {_DSSessionId, TopicFilter}}, Acc) ->
-            Acc#{TopicFilter => export_subscription(DSSub)}
+            subs_insert(TopicFilter, export_subscription(DSSub), Acc)
         end,
-        #{},
+        subs_new(),
         DSSubs
     ).
 

+ 4 - 1
apps/emqx/src/emqx_session.erl

@@ -96,7 +96,10 @@
 ]).
 
 % Foreign session implementations
--export([enrich_delivers/3]).
+-export([
+    enrich_delivers/3,
+    enrich_message/4
+]).
 
 % Utilities
 -export([should_keep/1]).

+ 2 - 3
apps/emqx/test/emqx_persistent_session_SUITE.erl

@@ -323,7 +323,8 @@ t_choose_impl(Config) ->
             ds -> emqx_persistent_session_ds
         end,
         emqx_connection:info({channel, {session, impl}}, sys:get_state(ChanPid))
-    ).
+    ),
+    ok = emqtt:disconnect(Client).
 
 t_connect_discards_existing_client(Config) ->
     ClientId = ?config(client_id, Config),
@@ -1009,8 +1010,6 @@ t_unsubscribe(Config) ->
     ?assertMatch([], [Sub || {ST, _} = Sub <- emqtt:subscriptions(Client), ST =:= STopic]),
     ok = emqtt:disconnect(Client).
 
-t_multiple_subscription_matches(init, Config) -> skip_ds_tc(Config);
-t_multiple_subscription_matches('end', _Config) -> ok.
 t_multiple_subscription_matches(Config) ->
     ConnFun = ?config(conn_fun, Config),
     Topic = ?config(topic, Config),