Просмотр исходного кода

Merge pull request #12055 from keynslug/ft/EMQX-11474/subopts

fix(sessds): respect subscription options when publishing
Andrew Mayorov 2 лет назад
Родитель
Сommit
286d483a3a

+ 1 - 1
apps/emqx/include/emqx_session_mem.hrl

@@ -28,7 +28,7 @@
     %% Max subscriptions allowed
     max_subscriptions :: non_neg_integer() | infinity,
     %% Upgrade QoS?
-    upgrade_qos :: boolean(),
+    upgrade_qos = false :: boolean(),
     %% Client <- Broker: QoS1/2 messages sent to the client but
     %% have not been unacked.
     inflight :: emqx_inflight:inflight(),

+ 11 - 7
apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl

@@ -262,10 +262,12 @@ t_session_subscription_idempotency(Config) ->
         end,
         fun(Trace) ->
             ct:pal("trace:\n  ~p", [Trace]),
-            ConnInfo = #{},
+            Session = erpc:call(
+                Node1, emqx_persistent_session_ds, session_open, [ClientId, _ConnInfo = #{}]
+            ),
             ?assertMatch(
-                #{subscriptions := #{SubTopicFilter := #{}}},
-                erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId, ConnInfo])
+                #{SubTopicFilter := #{}},
+                emqx_session:info(subscriptions, Session)
             )
         end
     ),
@@ -336,10 +338,12 @@ t_session_unsubscription_idempotency(Config) ->
         end,
         fun(Trace) ->
             ct:pal("trace:\n  ~p", [Trace]),
-            ConnInfo = #{},
-            ?assertMatch(
-                #{subscriptions := Subs = #{}} when map_size(Subs) =:= 0,
-                erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId, ConnInfo])
+            Session = erpc:call(
+                Node1, emqx_persistent_session_ds, session_open, [ClientId, _ConnInfo = #{}]
+            ),
+            ?assertEqual(
+                #{},
+                emqx_session:info(subscriptions, Session)
             ),
             ok
         end

+ 115 - 36
apps/emqx/src/emqx_persistent_message_ds_replayer.erl

@@ -33,6 +33,8 @@
 -export_type([inflight/0, seqno/0]).
 
 -include_lib("emqx/include/logger.hrl").
+-include_lib("emqx/include/emqx_mqtt.hrl").
+-include_lib("emqx_utils/include/emqx_message.hrl").
 -include("emqx_persistent_session_ds.hrl").
 
 -ifdef(TEST).
@@ -46,6 +48,8 @@
 -define(COMP, 1).
 
 -define(TRACK_FLAG(WHICH), (1 bsl WHICH)).
+-define(TRACK_FLAGS_ALL, ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)).
+-define(TRACK_FLAGS_NONE, 0).
 
 %%================================================================================
 %% Type declarations
@@ -66,10 +70,10 @@
 
 -opaque inflight() :: #inflight{}.
 
--type reply_fun() :: fun(
-    (seqno(), emqx_types:message()) ->
-        emqx_session:replies() | {_AdvanceSeqno :: false, emqx_session:replies()}
-).
+-type message() :: emqx_types:message().
+-type replies() :: [emqx_session:reply()].
+
+-type preproc_fun() :: fun((message()) -> message() | [message()]).
 
 %%================================================================================
 %% API funcions
@@ -116,11 +120,11 @@ n_inflight(#inflight{offset_ranges = Ranges}) ->
         Ranges
     ).
 
--spec replay(reply_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
-replay(ReplyFun, Inflight0 = #inflight{offset_ranges = Ranges0}) ->
+-spec replay(preproc_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
+replay(PreprocFunFun, Inflight0 = #inflight{offset_ranges = Ranges0, commits = Commits}) ->
     {Ranges, Replies} = lists:mapfoldr(
         fun(Range, Acc) ->
-            replay_range(ReplyFun, Range, Acc)
+            replay_range(PreprocFunFun, Commits, Range, Acc)
         end,
         [],
         Ranges0
@@ -166,9 +170,9 @@ commit_offset(
             {false, Inflight0}
     end.
 
--spec poll(reply_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
+-spec poll(preproc_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
     {emqx_session:replies(), inflight()}.
-poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
+poll(PreprocFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
     MinBatchSize = emqx_config:get([session_persistence, min_batch_size]),
     FetchThreshold = min(MinBatchSize, ceil(WindowSize / 2)),
     FreeSpace = WindowSize - n_inflight(Inflight0),
@@ -182,7 +186,7 @@ poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize
         true ->
             %% TODO: Wrap this in `mria:async_dirty/2`?
             Streams = shuffle(get_streams(SessionId)),
-            fetch(ReplyFun, SessionId, Inflight0, Streams, FreeSpace, [])
+            fetch(PreprocFun, SessionId, Inflight0, Streams, FreeSpace, [])
     end.
 
 %% Which seqno this track is committed until.
@@ -249,22 +253,22 @@ get_ranges(SessionId) ->
     ),
     mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
 
-fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
+fetch(PreprocFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
     #inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
     ItBegin = get_last_iterator(DSStream, Ranges),
     {ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
     case Messages of
         [] ->
-            fetch(ReplyFun, SessionId, Inflight0, Streams, N, Acc);
+            fetch(PreprocFun, SessionId, Inflight0, Streams, N, Acc);
         _ ->
             %% We need to preserve the iterator pointing to the beginning of the
             %% range, so that we can replay it if needed.
-            {Publishes, {UntilSeqno, Tracks}} = publish(ReplyFun, FirstSeqno, Messages),
+            {Publishes, UntilSeqno} = publish_fetch(PreprocFun, FirstSeqno, Messages),
             Size = range_size(FirstSeqno, UntilSeqno),
             Range0 = #ds_pubrange{
                 id = {SessionId, FirstSeqno},
                 type = ?T_INFLIGHT,
-                tracks = Tracks,
+                tracks = compute_pub_tracks(Publishes),
                 until = UntilSeqno,
                 stream = DSStream#ds_stream.ref,
                 iterator = ItBegin
@@ -278,7 +282,7 @@ fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 -
                 next_seqno = UntilSeqno,
                 offset_ranges = Ranges ++ [Range]
             },
-            fetch(ReplyFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
+            fetch(PreprocFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
     end;
 fetch(_ReplyFun, _SessionId, Inflight, _Streams, _N, Acc) ->
     Publishes = lists:append(lists:reverse(Acc)),
@@ -375,19 +379,20 @@ discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) ->
     TAck bor TComp.
 
 replay_range(
-    ReplyFun,
+    PreprocFun,
+    Commits,
     Range0 = #ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until, iterator = It},
     Acc
 ) ->
     Size = range_size(First, Until),
     {ok, ItNext, MessagesUnacked} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
     %% Asserting that range is consistent with the message storage state.
-    {Replies, {Until, _TracksInitial}} = publish(ReplyFun, First, MessagesUnacked),
+    {Replies, Until} = publish_replay(PreprocFun, Commits, First, MessagesUnacked),
     %% Again, we need to keep the iterator pointing past the end of the
     %% range, so that we can pick up where we left off.
     Range = keep_next_iterator(ItNext, Range0),
     {Range, Replies ++ Acc};
-replay_range(_ReplyFun, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
+replay_range(_PreprocFun, _Commits, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
     {Range0, Acc}.
 
 validate_commit(
@@ -420,28 +425,89 @@ get_commit_next(rec, #inflight{next_seqno = NextSeqno}) ->
 get_commit_next(comp, #inflight{commits = Commits}) ->
     maps:get(rec, Commits).
 
-publish(ReplyFun, FirstSeqno, Messages) ->
-    lists:mapfoldl(
-        fun(Message, {Seqno, TAcc}) ->
-            case ReplyFun(Seqno, Message) of
-                {_Advance = false, Reply} ->
-                    {Reply, {Seqno, TAcc}};
-                Reply ->
-                    NextSeqno = next_seqno(Seqno),
-                    NextTAcc = add_msg_track(Message, TAcc),
-                    {Reply, {NextSeqno, NextTAcc}}
-            end
+publish_fetch(PreprocFun, FirstSeqno, Messages) ->
+    flatmapfoldl(
+        fun(MessageIn, Acc) ->
+            Message = PreprocFun(MessageIn),
+            publish_fetch(Message, Acc)
         end,
-        {FirstSeqno, 0},
+        FirstSeqno,
         Messages
     ).
 
-add_msg_track(Message, Tracks) ->
-    case emqx_message:qos(Message) of
-        1 -> ?TRACK_FLAG(?ACK) bor Tracks;
-        2 -> ?TRACK_FLAG(?COMP) bor Tracks;
-        _ -> Tracks
-    end.
+publish_fetch(#message{qos = ?QOS_0} = Message, Seqno) ->
+    {{undefined, Message}, Seqno};
+publish_fetch(#message{} = Message, Seqno) ->
+    PacketId = seqno_to_packet_id(Seqno),
+    {{PacketId, Message}, next_seqno(Seqno)};
+publish_fetch(Messages, Seqno) ->
+    flatmapfoldl(fun publish_fetch/2, Seqno, Messages).
+
+publish_replay(PreprocFun, Commits, FirstSeqno, Messages) ->
+    #{ack := AckedUntil, comp := CompUntil, rec := RecUntil} = Commits,
+    flatmapfoldl(
+        fun(MessageIn, Acc) ->
+            Message = PreprocFun(MessageIn),
+            publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
+        end,
+        FirstSeqno,
+        Messages
+    ).
+
+publish_replay(#message{qos = ?QOS_0}, _, _, _, Seqno) ->
+    %% QoS 0 (at most once) messages should not be replayed.
+    {[], Seqno};
+publish_replay(#message{qos = Qos} = Message, AckedUntil, CompUntil, RecUntil, Seqno) ->
+    case Qos of
+        ?QOS_1 when Seqno < AckedUntil ->
+            %% This message has already been acked, so we can skip it.
+            %% We still need to advance seqno, because previously we assigned this message
+            %% a unique Packet Id.
+            {[], next_seqno(Seqno)};
+        ?QOS_2 when Seqno < CompUntil ->
+            %% This message's flow has already been fully completed, so we can skip it.
+            %% We still need to advance seqno, because previously we assigned this message
+            %% a unique Packet Id.
+            {[], next_seqno(Seqno)};
+        ?QOS_2 when Seqno < RecUntil ->
+            %% This message's flow has been partially completed, we need to resend a PUBREL.
+            PacketId = seqno_to_packet_id(Seqno),
+            Pub = {pubrel, PacketId},
+            {Pub, next_seqno(Seqno)};
+        _ ->
+            %% This message flow hasn't been acked and/or received, we need to resend it.
+            PacketId = seqno_to_packet_id(Seqno),
+            Pub = {PacketId, emqx_message:set_flag(dup, true, Message)},
+            {Pub, next_seqno(Seqno)}
+    end;
+publish_replay([], _, _, _, Seqno) ->
+    {[], Seqno};
+publish_replay(Messages, AckedUntil, CompUntil, RecUntil, Seqno) ->
+    flatmapfoldl(
+        fun(Message, Acc) ->
+            publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
+        end,
+        Seqno,
+        Messages
+    ).
+
+-spec compute_pub_tracks(replies()) -> non_neg_integer().
+compute_pub_tracks(Pubs) ->
+    compute_pub_tracks(Pubs, ?TRACK_FLAGS_NONE).
+
+compute_pub_tracks(_Pubs, Tracks = ?TRACK_FLAGS_ALL) ->
+    Tracks;
+compute_pub_tracks([Pub | Rest], Tracks) ->
+    Track =
+        case Pub of
+            {_PacketId, #message{qos = ?QOS_1}} -> ?TRACK_FLAG(?ACK);
+            {_PacketId, #message{qos = ?QOS_2}} -> ?TRACK_FLAG(?COMP);
+            {pubrel, _PacketId} -> ?TRACK_FLAG(?COMP);
+            _ -> ?TRACK_FLAGS_NONE
+        end,
+    compute_pub_tracks(Rest, Track bor Tracks);
+compute_pub_tracks([], Tracks) ->
+    Tracks.
 
 keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) ->
     Range#ds_pubrange{
@@ -545,6 +611,19 @@ shuffle(L0) ->
     {_, L} = lists:unzip(L2),
     L.
 
+-spec flatmapfoldl(fun((X, Acc) -> {Y | [Y], Acc}), Acc, [X]) -> {[Y], Acc}.
+flatmapfoldl(_Fun, Acc, []) ->
+    {[], Acc};
+flatmapfoldl(Fun, Acc, [X | Xs]) ->
+    {Ys, NAcc} = Fun(X, Acc),
+    {Zs, FAcc} = flatmapfoldl(Fun, NAcc, Xs),
+    case is_list(Ys) of
+        true ->
+            {Ys ++ Zs, FAcc};
+        _ ->
+            {[Ys | Zs], FAcc}
+    end.
+
 ro_transaction(Fun) ->
     {atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
     Res.

+ 129 - 89
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -29,7 +29,7 @@
 %% Session API
 -export([
     create/3,
-    open/2,
+    open/3,
     destroy/1
 ]).
 
@@ -102,6 +102,8 @@
 -define(TIMER_BUMP_LAST_ALIVE_AT, timer_bump_last_alive_at).
 -type timer() :: ?TIMER_PULL | ?TIMER_GET_STREAMS | ?TIMER_BUMP_LAST_ALIVE_AT.
 
+-type subscriptions() :: emqx_topic_gbt:t(nil(), subscription()).
+
 -type session() :: #{
     %% Client ID
     id := id(),
@@ -110,7 +112,7 @@
     %% When the client was last considered alive
     last_alive_at := timestamp(),
     %% Client’s Subscriptions.
-    subscriptions := #{topic_filter() => subscription()},
+    subscriptions := subscriptions(),
     %% Inflight messages
     inflight := emqx_persistent_message_ds_replayer:inflight(),
     %% Receive maximum
@@ -150,12 +152,12 @@
 -spec create(clientinfo(), conninfo(), emqx_session:conf()) ->
     session().
 create(#{clientid := ClientID}, ConnInfo, Conf) ->
-    % TODO: expiration
-    ensure_timers(ensure_session(ClientID, ConnInfo, Conf)).
+    Session = session_ensure_new(ClientID, ConnInfo),
+    apply_conf(ConnInfo, Conf, ensure_timers(Session)).
 
--spec open(clientinfo(), conninfo()) ->
+-spec open(clientinfo(), conninfo(), emqx_session:conf()) ->
     {_IsPresent :: true, session(), []} | false.
-open(#{clientid := ClientID} = _ClientInfo, ConnInfo) ->
+open(#{clientid := ClientID} = _ClientInfo, ConnInfo, Conf) ->
     %% NOTE
     %% The fact that we need to concern about discarding all live channels here
     %% is essentially a consequence of the in-memory session design, where we
@@ -165,20 +167,16 @@ open(#{clientid := ClientID} = _ClientInfo, ConnInfo) ->
     ok = emqx_cm:discard_session(ClientID),
     case session_open(ClientID, ConnInfo) of
         Session0 = #{} ->
-            ReceiveMaximum = receive_maximum(ConnInfo),
-            Session = Session0#{receive_maximum => ReceiveMaximum},
+            Session = apply_conf(ConnInfo, Conf, Session0),
             {true, ensure_timers(Session), []};
         false ->
             false
     end.
 
-ensure_session(ClientID, ConnInfo, Conf) ->
-    Session = session_ensure_new(ClientID, ConnInfo, Conf),
-    ReceiveMaximum = receive_maximum(ConnInfo),
+apply_conf(ConnInfo, Conf, Session) ->
     Session#{
-        conninfo => ConnInfo,
-        receive_maximum => ReceiveMaximum,
-        subscriptions => #{}
+        receive_maximum => receive_maximum(ConnInfo),
+        props => Conf
     }.
 
 -spec destroy(session() | clientinfo()) -> ok.
@@ -204,10 +202,10 @@ info(created_at, #{created_at := CreatedAt}) ->
     CreatedAt;
 info(is_persistent, #{}) ->
     true;
-info(subscriptions, #{subscriptions := Iters}) ->
-    maps:map(fun(_, #{props := SubOpts}) -> SubOpts end, Iters);
-info(subscriptions_cnt, #{subscriptions := Iters}) ->
-    maps:size(Iters);
+info(subscriptions, #{subscriptions := Subs}) ->
+    subs_to_map(Subs);
+info(subscriptions_cnt, #{subscriptions := Subs}) ->
+    subs_size(Subs);
 info(subscriptions_max, #{props := Conf}) ->
     maps:get(max_subscriptions, Conf);
 info(upgrade_qos, #{props := Conf}) ->
@@ -274,41 +272,40 @@ subscribe(
     TopicFilter,
     SubOpts,
     Session = #{id := ID, subscriptions := Subs}
-) when is_map_key(TopicFilter, Subs) ->
-    Subscription = maps:get(TopicFilter, Subs),
-    NSubscription = update_subscription(TopicFilter, Subscription, SubOpts, ID),
-    {ok, Session#{subscriptions := Subs#{TopicFilter => NSubscription}}};
-subscribe(
-    TopicFilter,
-    SubOpts,
-    Session = #{id := ID, subscriptions := Subs}
 ) ->
-    % TODO: max_subscriptions
-    Subscription = add_subscription(TopicFilter, SubOpts, ID),
-    {ok, Session#{subscriptions := Subs#{TopicFilter => Subscription}}}.
+    case subs_lookup(TopicFilter, Subs) of
+        Subscription = #{} ->
+            NSubscription = update_subscription(TopicFilter, Subscription, SubOpts, ID),
+            NSubs = subs_insert(TopicFilter, NSubscription, Subs),
+            {ok, Session#{subscriptions := NSubs}};
+        undefined ->
+            % TODO: max_subscriptions
+            Subscription = add_subscription(TopicFilter, SubOpts, ID),
+            NSubs = subs_insert(TopicFilter, Subscription, Subs),
+            {ok, Session#{subscriptions := NSubs}}
+    end.
 
 -spec unsubscribe(topic_filter(), session()) ->
     {ok, session(), emqx_types:subopts()} | {error, emqx_types:reason_code()}.
 unsubscribe(
     TopicFilter,
     Session = #{id := ID, subscriptions := Subs}
-) when is_map_key(TopicFilter, Subs) ->
-    Subscription = maps:get(TopicFilter, Subs),
-    SubOpts = maps:get(props, Subscription),
-    ok = del_subscription(TopicFilter, ID),
-    {ok, Session#{subscriptions := maps:remove(TopicFilter, Subs)}, SubOpts};
-unsubscribe(
-    _TopicFilter,
-    _Session = #{}
 ) ->
-    {error, ?RC_NO_SUBSCRIPTION_EXISTED}.
+    case subs_lookup(TopicFilter, Subs) of
+        _Subscription = #{props := SubOpts} ->
+            ok = del_subscription(TopicFilter, ID),
+            NSubs = subs_delete(TopicFilter, Subs),
+            {ok, Session#{subscriptions := NSubs}, SubOpts};
+        undefined ->
+            {error, ?RC_NO_SUBSCRIPTION_EXISTED}
+    end.
 
 -spec get_subscription(topic_filter(), session()) ->
     emqx_types:subopts() | undefined.
 get_subscription(TopicFilter, #{subscriptions := Subs}) ->
-    case maps:get(TopicFilter, Subs, undefined) of
-        Subscription = #{} ->
-            maps:get(props, Subscription);
+    case subs_lookup(TopicFilter, Subs) of
+        _Subscription = #{props := SubOpts} ->
+            SubOpts;
         undefined ->
             undefined
     end.
@@ -329,9 +326,6 @@ publish(_PacketId, Msg, Session) ->
 %% Client -> Broker: PUBACK
 %%--------------------------------------------------------------------
 
-%% FIXME: parts of the commit offset function are mocked
--dialyzer({nowarn_function, puback/3}).
-
 -spec puback(clientinfo(), emqx_types:packet_id(), session()) ->
     {ok, emqx_types:message(), replies(), session()}
     | {error, emqx_types:reason_code()}.
@@ -403,20 +397,22 @@ deliver(_ClientInfo, _Delivers, Session) ->
 -spec handle_timeout(clientinfo(), _Timeout, session()) ->
     {ok, replies(), session()} | {ok, replies(), timeout(), session()}.
 handle_timeout(
-    _ClientInfo,
+    ClientInfo,
     ?TIMER_PULL,
-    Session0 = #{id := Id, inflight := Inflight0, receive_maximum := ReceiveMaximum}
+    Session0 = #{
+        id := Id,
+        inflight := Inflight0,
+        subscriptions := Subs,
+        props := Conf,
+        receive_maximum := ReceiveMaximum
+    }
 ) ->
     MaxBatchSize = emqx_config:get([session_persistence, max_batch_size]),
     BatchSize = min(ReceiveMaximum, MaxBatchSize),
+    UpgradeQoS = maps:get(upgrade_qos, Conf),
+    PreprocFun = make_preproc_fun(ClientInfo, Subs, UpgradeQoS),
     {Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll(
-        fun
-            (_Seqno, Message = #message{qos = ?QOS_0}) ->
-                {false, {undefined, Message}};
-            (Seqno, Message) ->
-                PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
-                {PacketId, Message}
-        end,
+        PreprocFun,
         Id,
         Inflight0,
         BatchSize
@@ -442,30 +438,21 @@ handle_timeout(_ClientInfo, ?TIMER_BUMP_LAST_ALIVE_AT, Session0) ->
     BumpInterval = emqx_config:get([session_persistence, last_alive_update_interval]),
     EstimatedLastAliveAt = now_ms() + BumpInterval,
     Session = session_set_last_alive_at_trans(Session0, EstimatedLastAliveAt),
-    BumpInterval = emqx_config:get([session_persistence, last_alive_update_interval]),
-    {ok, [], emqx_session:ensure_timer(?TIMER_BUMP_LAST_ALIVE_AT, BumpInterval, Session)}.
+    {ok, [], emqx_session:ensure_timer(?TIMER_BUMP_LAST_ALIVE_AT, BumpInterval, Session)};
+handle_timeout(_ClientInfo, expire_awaiting_rel, Session) ->
+    %% TODO: stub
+    {ok, [], Session}.
 
 -spec replay(clientinfo(), [], session()) ->
     {ok, replies(), session()}.
-replay(_ClientInfo, [], Session = #{inflight := Inflight0}) ->
-    AckedUntil = emqx_persistent_message_ds_replayer:committed_until(ack, Inflight0),
-    RecUntil = emqx_persistent_message_ds_replayer:committed_until(rec, Inflight0),
-    CompUntil = emqx_persistent_message_ds_replayer:committed_until(comp, Inflight0),
-    ReplyFun = fun
-        (_Seqno, #message{qos = ?QOS_0}) ->
-            {false, []};
-        (Seqno, #message{qos = ?QOS_1}) when Seqno < AckedUntil ->
-            [];
-        (Seqno, #message{qos = ?QOS_2}) when Seqno < CompUntil ->
-            [];
-        (Seqno, #message{qos = ?QOS_2}) when Seqno < RecUntil ->
-            PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
-            {pubrel, PacketId};
-        (Seqno, Message) ->
-            PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
-            {PacketId, emqx_message:set_flag(dup, true, Message)}
-    end,
-    {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(ReplyFun, Inflight0),
+replay(
+    ClientInfo,
+    [],
+    Session = #{inflight := Inflight0, subscriptions := Subs, props := Conf}
+) ->
+    UpgradeQoS = maps:get(upgrade_qos, Conf),
+    PreprocFun = make_preproc_fun(ClientInfo, Subs, UpgradeQoS),
+    {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(PreprocFun, Inflight0),
     {ok, Replies, Session#{inflight := Inflight}}.
 
 %%--------------------------------------------------------------------
@@ -481,6 +468,19 @@ terminate(_Reason, _Session = #{}) ->
 
 %%--------------------------------------------------------------------
 
+make_preproc_fun(ClientInfo, Subs, UpgradeQoS) ->
+    fun(Message = #message{topic = Topic}) ->
+        emqx_utils:flattermap(
+            fun(Match) ->
+                #{props := SubOpts} = subs_get_match(Match, Subs),
+                emqx_session:enrich_message(ClientInfo, Message, SubOpts, UpgradeQoS)
+            end,
+            subs_matches(Topic, Subs)
+        )
+    end.
+
+%%--------------------------------------------------------------------
+
 -spec add_subscription(topic_filter(), emqx_types:subopts(), id()) ->
     subscription().
 add_subscription(TopicFilter, SubOpts, DSSessionID) ->
@@ -644,25 +644,24 @@ session_open(SessionId, NewConnInfo) ->
         end
     end).
 
--spec session_ensure_new(id(), emqx_types:conninfo(), _Props :: map()) ->
+-spec session_ensure_new(id(), emqx_types:conninfo()) ->
     session().
-session_ensure_new(SessionId, ConnInfo, Props) ->
+session_ensure_new(SessionId, ConnInfo) ->
     transaction(fun() ->
-        ok = session_drop_subscriptions(SessionId),
-        Session = export_session(session_create(SessionId, ConnInfo, Props)),
+        ok = session_drop_records(SessionId),
+        Session = export_session(session_create(SessionId, ConnInfo)),
         Session#{
-            subscriptions => #{},
+            subscriptions => subs_new(),
             inflight => emqx_persistent_message_ds_replayer:new()
         }
     end).
 
-session_create(SessionId, ConnInfo, Props) ->
+session_create(SessionId, ConnInfo) ->
     Session = #session{
         id = SessionId,
         created_at = now_ms(),
         last_alive_at = now_ms(),
-        conninfo = ConnInfo,
-        props = Props
+        conninfo = ConnInfo
     },
     ok = mnesia:write(?SESSION_TAB, Session, write),
     Session.
@@ -696,13 +695,17 @@ session_set_last_alive_at(SessionRecord0, LastAliveAt) ->
 -spec session_drop(id()) -> ok.
 session_drop(DSSessionId) ->
     transaction(fun() ->
-        ok = session_drop_subscriptions(DSSessionId),
-        ok = session_drop_pubranges(DSSessionId),
-        ok = session_drop_offsets(DSSessionId),
-        ok = session_drop_streams(DSSessionId),
+        ok = session_drop_records(DSSessionId),
         ok = mnesia:delete(?SESSION_TAB, DSSessionId, write)
     end).
 
+-spec session_drop_records(id()) -> ok.
+session_drop_records(DSSessionId) ->
+    ok = session_drop_subscriptions(DSSessionId),
+    ok = session_drop_pubranges(DSSessionId),
+    ok = session_drop_offsets(DSSessionId),
+    ok = session_drop_streams(DSSessionId).
+
 -spec session_drop_subscriptions(id()) -> ok.
 session_drop_subscriptions(DSSessionId) ->
     Subscriptions = session_read_subscriptions(DSSessionId, write),
@@ -844,7 +847,7 @@ do_ensure_all_iterators_closed(_DSSessionID) ->
 renew_streams(#{id := SessionId, subscriptions := Subscriptions}) ->
     transaction(fun() ->
         ExistingStreams = mnesia:read(?SESSION_STREAM_TAB, SessionId, write),
-        maps:fold(
+        subs_fold(
             fun(TopicFilter, #{start_time := StartTime}, Streams) ->
                 TopicFilterWords = emqx_topic:words(TopicFilter),
                 renew_topic_streams(SessionId, TopicFilterWords, StartTime, Streams)
@@ -926,6 +929,43 @@ session_drop_offsets(DSSessionId) ->
 
 %%--------------------------------------------------------------------------------
 
+subs_new() ->
+    emqx_topic_gbt:new().
+
+subs_lookup(TopicFilter, Subs) ->
+    emqx_topic_gbt:lookup(TopicFilter, [], Subs, undefined).
+
+subs_insert(TopicFilter, Subscription, Subs) ->
+    emqx_topic_gbt:insert(TopicFilter, [], Subscription, Subs).
+
+subs_delete(TopicFilter, Subs) ->
+    emqx_topic_gbt:delete(TopicFilter, [], Subs).
+
+subs_matches(Topic, Subs) ->
+    emqx_topic_gbt:matches(Topic, Subs, []).
+
+subs_get_match(M, Subs) ->
+    emqx_topic_gbt:get_record(M, Subs).
+
+subs_size(Subs) ->
+    emqx_topic_gbt:size(Subs).
+
+subs_to_map(Subs) ->
+    subs_fold(
+        fun(TopicFilter, #{props := Props}, Acc) -> Acc#{TopicFilter => Props} end,
+        #{},
+        Subs
+    ).
+
+subs_fold(Fun, AccIn, Subs) ->
+    emqx_topic_gbt:fold(
+        fun(Key, Sub, Acc) -> Fun(emqx_topic_gbt:get_topic(Key), Sub, Acc) end,
+        AccIn,
+        Subs
+    ).
+
+%%--------------------------------------------------------------------------------
+
 transaction(Fun) ->
     case mnesia:is_transaction() of
         true ->
@@ -944,9 +984,9 @@ ro_transaction(Fun) ->
 export_subscriptions(DSSubs) ->
     lists:foldl(
         fun(DSSub = #ds_sub{id = {_DSSessionId, TopicFilter}}, Acc) ->
-            Acc#{TopicFilter => export_subscription(DSSub)}
+            subs_insert(TopicFilter, export_subscription(DSSub), Acc)
         end,
-        #{},
+        subs_new(),
         DSSubs
     ).
 

+ 15 - 18
apps/emqx/src/emqx_session.erl

@@ -96,13 +96,16 @@
 ]).
 
 % Foreign session implementations
--export([enrich_delivers/3]).
+-export([
+    enrich_delivers/3,
+    enrich_message/4
+]).
 
 % Utilities
 -export([should_keep/1]).
 
 % Tests only
--export([get_session_conf/2]).
+-export([get_session_conf/1]).
 
 -export_type([
     t/0,
@@ -137,8 +140,6 @@
 -type conf() :: #{
     %% Max subscriptions allowed
     max_subscriptions := non_neg_integer() | infinity,
-    %% Max inflight messages allowed
-    max_inflight := non_neg_integer(),
     %% Maximum number of awaiting QoS2 messages allowed
     max_awaiting_rel := non_neg_integer() | infinity,
     %% Upgrade QoS?
@@ -171,7 +172,7 @@
 
 -callback create(clientinfo(), conninfo(), conf()) ->
     t().
--callback open(clientinfo(), conninfo()) ->
+-callback open(clientinfo(), conninfo(), conf()) ->
     {_IsPresent :: true, t(), _ReplayContext} | false.
 -callback destroy(t() | clientinfo()) -> ok.
 
@@ -181,7 +182,7 @@
 
 -spec create(clientinfo(), conninfo()) -> t().
 create(ClientInfo, ConnInfo) ->
-    Conf = get_session_conf(ClientInfo, ConnInfo),
+    Conf = get_session_conf(ClientInfo),
     create(ClientInfo, ConnInfo, Conf).
 
 create(ClientInfo, ConnInfo, Conf) ->
@@ -198,12 +199,12 @@ create(Mod, ClientInfo, ConnInfo, Conf) ->
 -spec open(clientinfo(), conninfo()) ->
     {_IsPresent :: true, t(), _ReplayContext} | {_IsPresent :: false, t()}.
 open(ClientInfo, ConnInfo) ->
-    Conf = get_session_conf(ClientInfo, ConnInfo),
+    Conf = get_session_conf(ClientInfo),
     Mods = [Default | _] = choose_impl_candidates(ConnInfo),
     %% NOTE
     %% Try to look the existing session up in session stores corresponding to the given
     %% `Mods` in order, starting from the last one.
-    case try_open(Mods, ClientInfo, ConnInfo) of
+    case try_open(Mods, ClientInfo, ConnInfo, Conf) of
         {_IsPresent = true, _, _} = Present ->
             Present;
         false ->
@@ -212,24 +213,20 @@ open(ClientInfo, ConnInfo) ->
             {false, create(Default, ClientInfo, ConnInfo, Conf)}
     end.
 
-try_open([Mod | Rest], ClientInfo, ConnInfo) ->
-    case try_open(Rest, ClientInfo, ConnInfo) of
+try_open([Mod | Rest], ClientInfo, ConnInfo, Conf) ->
+    case try_open(Rest, ClientInfo, ConnInfo, Conf) of
         {_IsPresent = true, _, _} = Present ->
             Present;
         false ->
-            Mod:open(ClientInfo, ConnInfo)
+            Mod:open(ClientInfo, ConnInfo, Conf)
     end;
-try_open([], _ClientInfo, _ConnInfo) ->
+try_open([], _ClientInfo, _ConnInfo, _Conf) ->
     false.
 
--spec get_session_conf(clientinfo(), conninfo()) -> conf().
-get_session_conf(
-    #{zone := Zone},
-    #{receive_maximum := MaxInflight}
-) ->
+-spec get_session_conf(clientinfo()) -> conf().
+get_session_conf(_ClientInfo = #{zone := Zone}) ->
     #{
         max_subscriptions => get_mqtt_conf(Zone, max_subscriptions),
-        max_inflight => MaxInflight,
         max_awaiting_rel => get_mqtt_conf(Zone, max_awaiting_rel),
         upgrade_qos => get_mqtt_conf(Zone, upgrade_qos),
         retry_interval => get_mqtt_conf(Zone, retry_interval),

+ 26 - 6
apps/emqx/src/emqx_session_mem.erl

@@ -59,7 +59,7 @@
 
 -export([
     create/3,
-    open/2,
+    open/3,
     destroy/1
 ]).
 
@@ -152,7 +152,11 @@
 
 -spec create(clientinfo(), conninfo(), emqx_session:conf()) ->
     session().
-create(#{zone := Zone, clientid := ClientId}, #{expiry_interval := EI}, Conf) ->
+create(
+    #{zone := Zone, clientid := ClientId},
+    #{expiry_interval := EI, receive_maximum := ReceiveMax},
+    Conf
+) ->
     QueueOpts = get_mqueue_conf(Zone),
     #session{
         id = emqx_guid:gen(),
@@ -160,7 +164,7 @@ create(#{zone := Zone, clientid := ClientId}, #{expiry_interval := EI}, Conf) ->
         created_at = erlang:system_time(millisecond),
         is_persistent = EI > 0,
         subscriptions = #{},
-        inflight = emqx_inflight:new(maps:get(max_inflight, Conf)),
+        inflight = emqx_inflight:new(ReceiveMax),
         mqueue = emqx_mqueue:init(QueueOpts),
         next_pkt_id = 1,
         awaiting_rel = #{},
@@ -195,14 +199,16 @@ destroy(_Session) ->
 %% Open a (possibly existing) Session
 %%--------------------------------------------------------------------
 
--spec open(clientinfo(), conninfo()) ->
+-spec open(clientinfo(), conninfo(), emqx_session:conf()) ->
     {_IsPresent :: true, session(), replayctx()} | _IsPresent :: false.
-open(ClientInfo = #{clientid := ClientId}, _ConnInfo) ->
+open(ClientInfo = #{clientid := ClientId}, ConnInfo, Conf) ->
     case emqx_cm:takeover_session_begin(ClientId) of
         {ok, SessionRemote, TakeoverState} ->
-            Session = resume(ClientInfo, SessionRemote),
+            Session0 = resume(ClientInfo, SessionRemote),
             case emqx_cm:takeover_session_end(TakeoverState) of
                 {ok, Pendings} ->
+                    Session1 = resize_inflight(ConnInfo, Session0),
+                    Session = apply_conf(Conf, Session1),
                     clean_session(ClientInfo, Session, Pendings);
                 {error, _} ->
                     % TODO log error?
@@ -212,6 +218,20 @@ open(ClientInfo = #{clientid := ClientId}, _ConnInfo) ->
             false
     end.
 
+resize_inflight(#{receive_maximum := ReceiveMax}, Session = #session{inflight = Inflight}) ->
+    Session#session{
+        inflight = emqx_inflight:resize(ReceiveMax, Inflight)
+    }.
+
+apply_conf(Conf, Session = #session{}) ->
+    Session#session{
+        max_subscriptions = maps:get(max_subscriptions, Conf),
+        max_awaiting_rel = maps:get(max_awaiting_rel, Conf),
+        upgrade_qos = maps:get(upgrade_qos, Conf),
+        retry_interval = maps:get(retry_interval, Conf),
+        await_rel_timeout = maps:get(await_rel_timeout, Conf)
+    }.
+
 clean_session(ClientInfo, Session = #session{mqueue = Q}, Pendings) ->
     Q1 = emqx_mqueue:filter(fun emqx_session:should_keep/1, Q),
     Session1 = Session#session{mqueue = Q1},

+ 56 - 41
apps/emqx/src/emqx_topic_gbt.erl

@@ -14,14 +14,17 @@
 %% limitations under the License.
 %%--------------------------------------------------------------------
 
-%% @doc Topic index implemetation with gb_trees stored in persistent_term.
-%% This is only suitable for a static set of topic or topic-filters.
+%% @doc Topic index implemetation with gb_trees as the underlying data
+%% structure.
 
 -module(emqx_topic_gbt).
 
--export([new/0, new/1]).
+-export([new/0]).
+-export([size/1]).
 -export([insert/4]).
 -export([delete/3]).
+-export([lookup/4]).
+-export([fold/3]).
 -export([match/2]).
 -export([matches/3]).
 
@@ -29,53 +32,74 @@
 -export([get_topic/1]).
 -export([get_record/2]).
 
+-export_type([t/0, t/2, match/1]).
+
 -type key(ID) :: emqx_trie_search:key(ID).
 -type words() :: emqx_trie_search:words().
 -type match(ID) :: key(ID).
--type name() :: any().
 
-%% @private Only for testing.
--spec new() -> name().
-new() ->
-    new(test).
+-opaque t(ID, Value) :: gb_trees:tree(key(ID), Value).
+-opaque t() :: t(_ID, _Value).
 
 %% @doc Create a new gb_tree and store it in the persitent_term with the
 %% given name.
--spec new(name()) -> name().
-new(Name) ->
-    T = gb_trees:from_orddict([]),
-    true = gbt_update(Name, T),
-    Name.
+-spec new() -> t().
+new() ->
+    gb_trees:empty().
+
+-spec size(t()) -> non_neg_integer().
+size(Gbt) ->
+    gb_trees:size(Gbt).
 
 %% @doc Insert a new entry into the index that associates given topic filter to given
 %% record ID, and attaches arbitrary record to the entry. This allows users to choose
 %% between regular and "materialized" indexes, for example.
--spec insert(emqx_types:topic() | words(), _ID, _Record, name()) -> true.
-insert(Filter, ID, Record, Name) ->
-    Tree = gbt(Name),
+-spec insert(emqx_types:topic() | words(), _ID, _Record, t()) -> t().
+insert(Filter, ID, Record, Gbt) ->
     Key = key(Filter, ID),
-    NewTree = gb_trees:enter(Key, Record, Tree),
-    true = gbt_update(Name, NewTree).
+    gb_trees:enter(Key, Record, Gbt).
 
 %% @doc Delete an entry from the index that associates given topic filter to given
 %% record ID. Deleting non-existing entry is not an error.
--spec delete(emqx_types:topic() | words(), _ID, name()) -> true.
-delete(Filter, ID, Name) ->
-    Tree = gbt(Name),
+-spec delete(emqx_types:topic() | words(), _ID, t()) -> t().
+delete(Filter, ID, Gbt) ->
     Key = key(Filter, ID),
-    NewTree = gb_trees:delete_any(Key, Tree),
-    true = gbt_update(Name, NewTree).
+    gb_trees:delete_any(Key, Gbt).
+
+-spec lookup(emqx_types:topic() | words(), _ID, t(), Default) -> _Record | Default.
+lookup(Filter, ID, Gbt, Default) ->
+    Key = key(Filter, ID),
+    case gb_trees:lookup(Key, Gbt) of
+        {value, Record} ->
+            Record;
+        none ->
+            Default
+    end.
+
+-spec fold(fun((key(_ID), _Record, Acc) -> Acc), Acc, t()) -> Acc.
+fold(Fun, Acc, Gbt) ->
+    Iter = gb_trees:iterator(Gbt),
+    fold_iter(Fun, Acc, Iter).
+
+fold_iter(Fun, Acc, Iter) ->
+    case gb_trees:next(Iter) of
+        {Key, Record, NIter} ->
+            fold_iter(Fun, Fun(Key, Record, Acc), NIter);
+        none ->
+            Acc
+    end.
 
 %% @doc Match given topic against the index and return the first match, or `false` if
 %% no match is found.
--spec match(emqx_types:topic(), name()) -> match(_ID) | false.
-match(Topic, Name) ->
-    emqx_trie_search:match(Topic, make_nextf(Name)).
+-spec match(emqx_types:topic(), t()) -> match(_ID) | false.
+match(Topic, Gbt) ->
+    emqx_trie_search:match(Topic, make_nextf(Gbt)).
 
 %% @doc Match given topic against the index and return _all_ matches.
 %% If `unique` option is given, return only unique matches by record ID.
-matches(Topic, Name, Opts) ->
-    emqx_trie_search:matches(Topic, make_nextf(Name), Opts).
+-spec matches(emqx_types:topic(), t(), emqx_trie_search:opts()) -> [match(_ID)].
+matches(Topic, Gbt, Opts) ->
+    emqx_trie_search:matches(Topic, make_nextf(Gbt), Opts).
 
 %% @doc Extract record ID from the match.
 -spec get_id(match(ID)) -> ID.
@@ -88,21 +112,13 @@ get_topic(Key) ->
     emqx_trie_search:get_topic(Key).
 
 %% @doc Fetch the record associated with the match.
--spec get_record(match(_ID), name()) -> _Record.
-get_record(Key, Name) ->
-    Gbt = gbt(Name),
+-spec get_record(match(_ID), t()) -> _Record.
+get_record(Key, Gbt) ->
     gb_trees:get(Key, Gbt).
 
 key(TopicOrFilter, ID) ->
     emqx_trie_search:make_key(TopicOrFilter, ID).
 
-gbt(Name) ->
-    persistent_term:get({?MODULE, Name}).
-
-gbt_update(Name, Tree) ->
-    persistent_term:put({?MODULE, Name}, Tree),
-    true.
-
 gbt_next(nil, _Input) ->
     '$end_of_table';
 gbt_next({P, _V, _Smaller, Bigger}, K) when K >= P ->
@@ -115,6 +131,5 @@ gbt_next({P, _V, Smaller, _Bigger}, K) ->
             NextKey
     end.
 
-make_nextf(Name) ->
-    {_SizeWeDontCare, TheTree} = gbt(Name),
-    fun(Key) -> gbt_next(TheTree, Key) end.
+make_nextf({_Size, Tree}) ->
+    fun(Key) -> gbt_next(Tree, Key) end.

+ 71 - 0
apps/emqx/src/emqx_topic_gbt_pterm.erl

@@ -0,0 +1,71 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+%% @doc Topic index implemetation with gb_tree as a persistent term.
+%% This is only suitable for a static set of topic or topic-filters.
+
+-module(emqx_topic_gbt_pterm).
+
+-export([new/0, new/1]).
+-export([insert/4]).
+-export([delete/3]).
+-export([match/2]).
+-export([matches/3]).
+
+-export([get_record/2]).
+
+-type name() :: any().
+-type match(ID) :: emqx_topic_gbt:match(ID).
+
+%% @private Only for testing.
+-spec new() -> name().
+new() ->
+    new(test).
+
+-spec new(name()) -> name().
+new(Name) ->
+    true = pterm_update(Name, emqx_topic_gbt:new()),
+    Name.
+
+-spec insert(emqx_types:topic() | emqx_trie_search:words(), _ID, _Record, name()) -> true.
+insert(Filter, ID, Record, Name) ->
+    pterm_update(Name, emqx_topic_gbt:insert(Filter, ID, Record, pterm(Name))).
+
+-spec delete(emqx_types:topic() | emqx_trie_search:words(), _ID, name()) -> name().
+delete(Filter, ID, Name) ->
+    pterm_update(Name, emqx_topic_gbt:delete(Filter, ID, pterm(Name))).
+
+-spec match(emqx_types:topic(), name()) -> match(_ID) | false.
+match(Topic, Name) ->
+    emqx_topic_gbt:match(Topic, pterm(Name)).
+
+-spec matches(emqx_types:topic(), name(), emqx_trie_search:opts()) -> [match(_ID)].
+matches(Topic, Name, Opts) ->
+    emqx_topic_gbt:matches(Topic, pterm(Name), Opts).
+
+%% @doc Fetch the record associated with the match.
+-spec get_record(match(_ID), name()) -> _Record.
+get_record(Key, Name) ->
+    emqx_topic_gbt:get_record(Key, pterm(Name)).
+
+%%
+
+pterm(Name) ->
+    persistent_term:get({?MODULE, Name}).
+
+pterm_update(Name, Tree) ->
+    persistent_term:put({?MODULE, Name}, Tree),
+    true.

+ 1 - 1
apps/emqx/test/emqx_persistent_messages_SUITE.erl

@@ -262,7 +262,7 @@ t_publish_as_persistent(_Config) ->
     Sub = connect(<<?MODULE_STRING "1">>, true, 30),
     Pub = connect(<<?MODULE_STRING "2">>, true, 30),
     try
-        {ok, _, [1]} = emqtt:subscribe(Sub, <<"t/#">>, qos1),
+        {ok, _, [?RC_GRANTED_QOS_2]} = emqtt:subscribe(Sub, <<"t/#">>, qos2),
         Messages = [
             {<<"t/1">>, <<"1">>, 0},
             {<<"t/1">>, <<"2">>, 1},

+ 2 - 6
apps/emqx/test/emqx_persistent_session_SUITE.erl

@@ -323,7 +323,8 @@ t_choose_impl(Config) ->
             ds -> emqx_persistent_session_ds
         end,
         emqx_connection:info({channel, {session, impl}}, sys:get_state(ChanPid))
-    ).
+    ),
+    ok = emqtt:disconnect(Client).
 
 t_connect_discards_existing_client(Config) ->
     ClientId = ?config(client_id, Config),
@@ -389,9 +390,6 @@ t_connect_session_expiry_interval(Config) ->
     ok = emqtt:disconnect(Client2).
 
 %% [MQTT-3.1.2-23]
-%% TODO: un-skip after QoS 2 support is implemented in DS.
-t_connect_session_expiry_interval_qos2(init, Config) -> skip_ds_tc(Config);
-t_connect_session_expiry_interval_qos2('end', _Config) -> ok.
 t_connect_session_expiry_interval_qos2(Config) ->
     ConnFun = ?config(conn_fun, Config),
     Topic = ?config(topic, Config),
@@ -1009,8 +1007,6 @@ t_unsubscribe(Config) ->
     ?assertMatch([], [Sub || {ST, _} = Sub <- emqtt:subscriptions(Client), ST =:= STopic]),
     ok = emqtt:disconnect(Client).
 
-t_multiple_subscription_matches(init, Config) -> skip_ds_tc(Config);
-t_multiple_subscription_matches('end', _Config) -> ok.
 t_multiple_subscription_matches(Config) ->
     ConnFun = ?config(conn_fun, Config),
     Topic = ?config(topic, Config),

+ 2 - 2
apps/emqx/test/emqx_session_mem_SUITE.erl

@@ -67,7 +67,7 @@ t_session_init(_) ->
     Session = emqx_session_mem:create(
         ClientInfo,
         ConnInfo,
-        emqx_session:get_session_conf(ClientInfo, ConnInfo)
+        emqx_session:get_session_conf(ClientInfo)
     ),
     ?assertEqual(#{}, emqx_session_mem:info(subscriptions, Session)),
     ?assertEqual(0, emqx_session_mem:info(subscriptions_cnt, Session)),
@@ -531,7 +531,7 @@ session(InitFields) when is_map(InitFields) ->
     Session = emqx_session_mem:create(
         ClientInfo,
         ConnInfo,
-        emqx_session:get_session_conf(ClientInfo, ConnInfo)
+        emqx_session:get_session_conf(ClientInfo)
     ),
     maps:fold(
         fun(Field, Value, SessionAcc) ->

+ 1 - 1
apps/emqx/test/emqx_topic_index_SUITE.erl

@@ -40,7 +40,7 @@ groups() ->
 init_per_group(ets, Config) ->
     [{index_module, emqx_topic_index} | Config];
 init_per_group(gb_tree, Config) ->
-    [{index_module, emqx_topic_gbt} | Config].
+    [{index_module, emqx_topic_gbt_pterm} | Config].
 
 end_per_group(_Group, _Config) ->
     ok.

+ 1 - 1
apps/emqx_gateway_mqttsn/src/emqx_gateway_mqttsn.app.src

@@ -1,7 +1,7 @@
 %% -*- mode: erlang -*-
 {application, emqx_gateway_mqttsn, [
     {description, "MQTT-SN Gateway"},
-    {vsn, "0.1.6"},
+    {vsn, "0.1.7"},
     {registered, []},
     {applications, [kernel, stdlib, emqx, emqx_gateway]},
     {env, []},

+ 1 - 1
apps/emqx_gateway_mqttsn/src/emqx_mqttsn_session.erl

@@ -54,7 +54,7 @@
 
 init(ClientInfo) ->
     ConnInfo = #{receive_maximum => 1, expiry_interval => 0},
-    SessionConf = emqx_session:get_session_conf(ClientInfo, ConnInfo),
+    SessionConf = emqx_session:get_session_conf(ClientInfo),
     #{
         registry => emqx_mqttsn_registry:init(),
         session => emqx_session_mem:create(ClientInfo, ConnInfo, SessionConf)

+ 23 - 0
apps/emqx_utils/src/emqx_utils.erl

@@ -60,6 +60,7 @@
     safe_filename/1,
     diff_lists/3,
     merge_lists/3,
+    flattermap/2,
     tcp_keepalive_opts/4,
     format/1,
     format_mfal/1,
@@ -999,6 +1000,28 @@ search(ExpectValue, KeyFunc, [Item | List]) ->
         false -> search(ExpectValue, KeyFunc, List)
     end.
 
+%% @doc Maps over a list of terms and flattens the result, giving back a flat
+%% list of terms. It's similar to `lists:flatmap/2`, but it also works on a
+%% single term as `Fun` output (thus, the wordplay on "flatter").
+%% The purpose of this function is to adapt to `Fun`s that return either a `[]`
+%% or a term, and to avoid costs of list construction and flattening when
+%% dealing with large lists.
+-spec flattermap(Fun, [X]) -> [X] when
+    Fun :: fun((X) -> [X] | X).
+flattermap(_Fun, []) ->
+    [];
+flattermap(Fun, [X | Xs]) ->
+    flatcomb(Fun(X), flattermap(Fun, Xs)).
+
+flatcomb([], Zs) ->
+    Zs;
+flatcomb(Ys = [_ | _], []) ->
+    Ys;
+flatcomb(Ys = [_ | _], Zs = [_ | _]) ->
+    Ys ++ Zs;
+flatcomb(Y, Zs) ->
+    [Y | Zs].
+
 -ifdef(TEST).
 -include_lib("eunit/include/eunit.hrl").
 

+ 47 - 12
apps/emqx_utils/test/emqx_utils_SUITE.erl

@@ -20,6 +20,7 @@
 -compile(nowarn_export_all).
 
 -include_lib("eunit/include/eunit.hrl").
+-include_lib("emqx/include/asserts.hrl").
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
 
 -define(SOCKOPTS, [
@@ -87,13 +88,13 @@ t_pipeline(_) ->
 t_start_timer(_) ->
     TRef = emqx_utils:start_timer(1, tmsg),
     timer:sleep(2),
-    ?assertEqual([{timeout, TRef, tmsg}], drain()),
+    ?assertEqual([{timeout, TRef, tmsg}], ?drainMailbox()),
     ok = emqx_utils:cancel_timer(TRef).
 
 t_cancel_timer(_) ->
     Timer = emqx_utils:start_timer(0, foo),
     ok = emqx_utils:cancel_timer(Timer),
-    ?assertEqual([], drain()),
+    ?assertEqual([], ?drainMailbox()),
     ok = emqx_utils:cancel_timer(undefined).
 
 t_proc_name(_) ->
@@ -153,16 +154,6 @@ t_check(_) ->
         emqx_utils:check_oom(Policy)
     ).
 
-drain() ->
-    drain([]).
-
-drain(Acc) ->
-    receive
-        Msg -> drain([Msg | Acc])
-    after 0 ->
-        lists:reverse(Acc)
-    end.
-
 t_rand_seed(_) ->
     ?assert(is_tuple(emqx_utils:rand_seed())).
 
@@ -240,3 +231,47 @@ t_pmap_late_reply(_) ->
         []
     ),
     ok.
+
+t_flattermap(_) ->
+    ?assertEqual(
+        [42],
+        emqx_utils:flattermap(fun identity/1, [42])
+    ),
+    ?assertEqual(
+        [42, 42],
+        emqx_utils:flattermap(fun duplicate/1, [42])
+    ),
+    ?assertEqual(
+        [],
+        emqx_utils:flattermap(fun nil/1, [42])
+    ),
+    ?assertEqual(
+        [1, 1, 2, 2, 3, 3],
+        emqx_utils:flattermap(fun duplicate/1, [1, 2, 3])
+    ),
+    ?assertEqual(
+        [],
+        emqx_utils:flattermap(fun nil/1, [1, 2, 3])
+    ),
+    ?assertEqual(
+        [1, 2, 2, 4, 5, 5],
+        emqx_utils:flattermap(
+            fun(X) ->
+                case X rem 3 of
+                    0 -> [];
+                    1 -> X;
+                    2 -> [X, X]
+                end
+            end,
+            [1, 2, 3, 4, 5]
+        )
+    ).
+
+duplicate(X) ->
+    [X, X].
+
+nil(_) ->
+    [].
+
+identity(X) ->
+    X.