فهرست منبع

Merge pull request #12251 from ieQu1/dev/refactor-sessds

Refactor and optimize persistent session
ieQu1 2 سال پیش
والد
کامیت
9e0bea098e

+ 71 - 110
apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl

@@ -54,12 +54,12 @@ init_per_testcase(TestCase, Config) when
 init_per_testcase(t_session_gc = TestCase, Config) ->
     Opts = #{
         n => 3,
-        roles => [core, core, replicant],
+        roles => [core, core, core],
         extra_emqx_conf =>
             "\n session_persistence {"
             "\n   last_alive_update_interval = 500ms "
-            "\n   session_gc_interval = 2s "
-            "\n   session_gc_batch_size = 1 "
+            "\n   session_gc_interval = 1s "
+            "\n   session_gc_batch_size = 2 "
             "\n }"
     },
     Cluster = cluster(Opts),
@@ -91,7 +91,7 @@ end_per_testcase(_TestCase, _Config) ->
     ok.
 
 %%------------------------------------------------------------------------------
-%% Helper fns
+%% Helper functions
 %%------------------------------------------------------------------------------
 
 cluster(#{n := N} = Opts) ->
@@ -147,9 +147,10 @@ start_client(Opts0 = #{}) ->
         proto_ver => v5,
         properties => #{'Session-Expiry-Interval' => 300}
     },
-    Opts = maps:to_list(emqx_utils_maps:deep_merge(Defaults, Opts0)),
-    ct:pal("starting client with opts:\n  ~p", [Opts]),
-    {ok, Client} = emqtt:start_link(Opts),
+    Opts = emqx_utils_maps:deep_merge(Defaults, Opts0),
+    ?tp(notice, "starting client", Opts),
+    {ok, Client} = emqtt:start_link(maps:to_list(Opts)),
+    unlink(Client),
     on_exit(fun() -> catch emqtt:stop(Client) end),
     Client.
 
@@ -164,59 +165,27 @@ is_persistent_connect_opts(#{properties := #{'Session-Expiry-Interval' := EI}})
     EI > 0.
 
 list_all_sessions(Node) ->
-    erpc:call(Node, emqx_persistent_session_ds, list_all_sessions, []).
+    erpc:call(Node, emqx_persistent_session_ds_state, list_sessions, []).
 
 list_all_subscriptions(Node) ->
-    erpc:call(Node, emqx_persistent_session_ds, list_all_subscriptions, []).
+    Sessions = list_all_sessions(Node),
+    lists:flatmap(
+        fun(ClientId) ->
+            #{s := #{subscriptions := Subs}} = erpc:call(
+                Node, emqx_persistent_session_ds, print_session, [ClientId]
+            ),
+            maps:to_list(Subs)
+        end,
+        Sessions
+    ).
 
 list_all_pubranges(Node) ->
     erpc:call(Node, emqx_persistent_session_ds, list_all_pubranges, []).
 
-prop_only_cores_run_gc(CoreNodes) ->
-    {"only core nodes run gc", fun(Trace) -> ?MODULE:prop_only_cores_run_gc(Trace, CoreNodes) end}.
-prop_only_cores_run_gc(Trace, CoreNodes) ->
-    GCNodes = lists:usort([
-        N
-     || #{
-            ?snk_kind := K,
-            ?snk_meta := #{node := N}
-        } <- Trace,
-        lists:member(K, [ds_session_gc, ds_session_gc_lock_taken]),
-        N =/= node()
-    ]),
-    ?assertEqual(lists:usort(CoreNodes), GCNodes).
-
 %%------------------------------------------------------------------------------
 %% Testcases
 %%------------------------------------------------------------------------------
 
-t_non_persistent_session_subscription(_Config) ->
-    ClientId = atom_to_binary(?FUNCTION_NAME),
-    SubTopicFilter = <<"t/#">>,
-    ?check_trace(
-        #{timetrap => 30_000},
-        begin
-            ?tp(notice, "starting", #{}),
-            Client = start_client(#{
-                clientid => ClientId,
-                properties => #{'Session-Expiry-Interval' => 0}
-            }),
-            {ok, _} = emqtt:connect(Client),
-            ?tp(notice, "subscribing", #{}),
-            {ok, _, [?RC_GRANTED_QOS_2]} = emqtt:subscribe(Client, SubTopicFilter, qos2),
-
-            ok = emqtt:stop(Client),
-
-            ok
-        end,
-        fun(Trace) ->
-            ct:pal("trace:\n  ~p", [Trace]),
-            ?assertEqual([], ?of_kind(ds_session_subscription_added, Trace)),
-            ok
-        end
-    ),
-    ok.
-
 t_session_subscription_idempotency(Config) ->
     [Node1Spec | _] = ?config(node_specs, Config),
     [Node1] = ?config(nodes, Config),
@@ -288,10 +257,10 @@ t_session_unsubscription_idempotency(Config) ->
     ?check_trace(
         #{timetrap => 30_000},
         begin
+            #{timetrap => 20_000},
             ?force_ordering(
                 #{
-                    ?snk_kind := persistent_session_ds_subscription_delete,
-                    ?snk_span := {complete, _}
+                    ?snk_kind := persistent_session_ds_subscription_delete
                 },
                 _NEvents0 = 1,
                 #{?snk_kind := will_restart_node},
@@ -409,27 +378,26 @@ do_t_session_discard(Params) ->
             ?retry(
                 _Sleep0 = 100,
                 _Attempts0 = 50,
-                true = map_size(emqx_persistent_session_ds:list_all_streams()) > 0
+                #{} = emqx_persistent_session_ds_state:print_session(ClientId)
             ),
             ok = emqtt:stop(Client0),
             ?tp(notice, "disconnected", #{}),
 
             ?tp(notice, "reconnecting", #{}),
-            %% we still have streams
-            ?assert(map_size(emqx_persistent_session_ds:list_all_streams()) > 0),
+            %% we still have the session:
+            ?assertMatch(#{}, emqx_persistent_session_ds_state:print_session(ClientId)),
             Client1 = start_client(ReconnectOpts),
             {ok, _} = emqtt:connect(Client1),
             ?assertEqual([], emqtt:subscriptions(Client1)),
             case is_persistent_connect_opts(ReconnectOpts) of
                 true ->
-                    ?assertMatch(#{ClientId := _}, emqx_persistent_session_ds:list_all_sessions());
+                    ?assertMatch(#{}, emqx_persistent_session_ds_state:print_session(ClientId));
                 false ->
-                    ?assertEqual(#{}, emqx_persistent_session_ds:list_all_sessions())
+                    ?assertEqual(
+                        undefined, emqx_persistent_session_ds_state:print_session(ClientId)
+                    )
             end,
-            ?assertEqual(#{}, emqx_persistent_session_ds:list_all_subscriptions()),
             ?assertEqual([], emqx_persistent_session_ds_router:topics()),
-            ?assertEqual(#{}, emqx_persistent_session_ds:list_all_streams()),
-            ?assertEqual(#{}, emqx_persistent_session_ds:list_all_pubranges()),
             ok = emqtt:stop(Client1),
             ?tp(notice, "disconnected", #{}),
 
@@ -443,6 +411,8 @@ do_t_session_discard(Params) ->
     ok.
 
 t_session_expiration1(Config) ->
+    %% This testcase verifies that the properties passed in the
+    %% CONNECT packet are respected by the GC process:
     ClientId = atom_to_binary(?FUNCTION_NAME),
     Opts = #{
         clientid => ClientId,
@@ -455,6 +425,9 @@ t_session_expiration1(Config) ->
     do_t_session_expiration(Config, Opts).
 
 t_session_expiration2(Config) ->
+    %% This testcase updates the expiry interval for the session in
+    %% the _DISCONNECT_ packet. This setting should be respected by GC
+    %% process:
     ClientId = atom_to_binary(?FUNCTION_NAME),
     Opts = #{
         clientid => ClientId,
@@ -469,6 +442,8 @@ t_session_expiration2(Config) ->
     do_t_session_expiration(Config, Opts).
 
 do_t_session_expiration(_Config, Opts) ->
+    %% Sequence is a list of pairs of properties passed through the
+    %% CONNECT and for the DISCONNECT for each session:
     #{
         clientid := ClientId,
         sequence := [
@@ -486,7 +461,7 @@ do_t_session_expiration(_Config, Opts) ->
             Client0 = start_client(Params0),
             {ok, _} = emqtt:connect(Client0),
             {ok, _, [?RC_GRANTED_QOS_2]} = emqtt:subscribe(Client0, Topic, ?QOS_2),
-            Subs0 = emqx_persistent_session_ds:list_all_subscriptions(),
+            #{s := #{subscriptions := Subs0}} = emqx_persistent_session_ds:print_session(ClientId),
             ?assertEqual(1, map_size(Subs0), #{subs => Subs0}),
             Info0 = maps:from_list(emqtt:info(Client0)),
             ?assertEqual(0, maps:get(session_present, Info0), #{info => Info0}),
@@ -501,7 +476,7 @@ do_t_session_expiration(_Config, Opts) ->
             ?assertEqual([], Subs1),
             emqtt:disconnect(Client1, ?RC_NORMAL_DISCONNECTION, SecondDisconn),
 
-            ct:sleep(1_500),
+            ct:sleep(2_500),
 
             Params2 = maps:merge(CommonParams, ThirdConn),
             Client2 = start_client(Params2),
@@ -513,9 +488,9 @@ do_t_session_expiration(_Config, Opts) ->
             emqtt:publish(Client2, Topic, <<"payload">>),
             ?assertNotReceive({publish, #{topic := Topic}}),
             %% ensure subscriptions are absent from table.
-            ?assertEqual(#{}, emqx_persistent_session_ds:list_all_subscriptions()),
+            #{s := #{subscriptions := Subs3}} = emqx_persistent_session_ds:print_session(ClientId),
+            ?assertEqual([], maps:to_list(Subs3)),
             emqtt:disconnect(Client2, ?RC_NORMAL_DISCONNECTION, ThirdDisconn),
-
             ok
         end,
         []
@@ -531,6 +506,7 @@ t_session_gc(Config) ->
         Port2,
         Port3
     ] = lists:map(fun(N) -> get_mqtt_port(N, tcp) end, Nodes),
+    ct:pal("Ports: ~p", [[Port1, Port2, Port3]]),
     CommonParams = #{
         clean_start => false,
         proto_ver => v5
@@ -549,14 +525,14 @@ t_session_gc(Config) ->
     ?check_trace(
         #{timetrap => 30_000},
         begin
-            ClientId0 = <<"session_gc0">>,
-            Client0 = StartClient(ClientId0, Port1, 30),
-
             ClientId1 = <<"session_gc1">>,
-            Client1 = StartClient(ClientId1, Port2, 1),
+            Client1 = StartClient(ClientId1, Port1, 30),
 
             ClientId2 = <<"session_gc2">>,
-            Client2 = StartClient(ClientId2, Port3, 1),
+            Client2 = StartClient(ClientId2, Port2, 1),
+
+            ClientId3 = <<"session_gc3">>,
+            Client3 = StartClient(ClientId3, Port3, 1),
 
             lists:foreach(
                 fun(Client) ->
@@ -566,55 +542,48 @@ t_session_gc(Config) ->
                     {ok, _} = emqtt:publish(Client, Topic, Payload, ?QOS_1),
                     ok
                 end,
-                [Client0, Client1, Client2]
+                [Client1, Client2, Client3]
             ),
 
             %% Clients are still alive; no session is garbage collected.
-            Res0 = ?block_until(
-                #{
-                    ?snk_kind := ds_session_gc,
-                    ?snk_span := {complete, _},
-                    ?snk_meta := #{node := N}
-                } when
-                    N =/= node(),
-                3 * GCInterval + 1_000
+            ?assertMatch(
+                {ok, _},
+                ?block_until(
+                    #{
+                        ?snk_kind := ds_session_gc,
+                        ?snk_span := {complete, _},
+                        ?snk_meta := #{node := N}
+                    } when N =/= node()
+                )
             ),
-            ?assertMatch({ok, _}, Res0),
-            {ok, #{?snk_meta := #{time := T0}}} = Res0,
-            Sessions0 = list_all_sessions(Node1),
-            Subs0 = list_all_subscriptions(Node1),
-            ?assertEqual(3, map_size(Sessions0), #{sessions => Sessions0}),
-            ?assertEqual(3, map_size(Subs0), #{subs => Subs0}),
+            ?assertMatch([_, _, _], list_all_sessions(Node1), sessions),
+            ?assertMatch([_, _, _], list_all_subscriptions(Node1), subscriptions),
 
             %% Now we disconnect 2 of them; only those should be GC'ed.
+
             ?assertMatch(
                 {ok, {ok, _}},
                 ?wait_async_action(
-                    emqtt:stop(Client1),
-                    #{?snk_kind := terminate},
-                    1_000
+                    emqtt:stop(Client2),
+                    #{?snk_kind := terminate}
                 )
             ),
-            ct:pal("disconnected client1"),
+            ?tp(notice, "disconnected client1", #{}),
             ?assertMatch(
                 {ok, {ok, _}},
                 ?wait_async_action(
-                    emqtt:stop(Client2),
-                    #{?snk_kind := terminate},
-                    1_000
+                    emqtt:stop(Client3),
+                    #{?snk_kind := terminate}
                 )
             ),
-            ct:pal("disconnected client2"),
+            ?tp(notice, "disconnected client2", #{}),
             ?assertMatch(
                 {ok, _},
                 ?block_until(
                     #{
                         ?snk_kind := ds_session_gc_cleaned,
-                        ?snk_meta := #{node := N, time := T},
-                        session_ids := [ClientId1]
-                    } when
-                        N =/= node() andalso T > T0,
-                    4 * GCInterval + 1_000
+                        session_id := ClientId2
+                    }
                 )
             ),
             ?assertMatch(
@@ -622,22 +591,14 @@ t_session_gc(Config) ->
                 ?block_until(
                     #{
                         ?snk_kind := ds_session_gc_cleaned,
-                        ?snk_meta := #{node := N, time := T},
-                        session_ids := [ClientId2]
-                    } when
-                        N =/= node() andalso T > T0,
-                    4 * GCInterval + 1_000
+                        session_id := ClientId3
+                    }
                 )
             ),
-            Sessions1 = list_all_sessions(Node1),
-            Subs1 = list_all_subscriptions(Node1),
-            ?assertEqual(1, map_size(Sessions1), #{sessions => Sessions1}),
-            ?assertEqual(1, map_size(Subs1), #{subs => Subs1}),
-
+            ?assertMatch([ClientId1], list_all_sessions(Node1), sessions),
+            ?assertMatch([_], list_all_subscriptions(Node1), subscriptions),
             ok
         end,
-        [
-            prop_only_cores_run_gc(CoreNodes)
-        ]
+        []
     ),
     ok.

+ 4 - 2
apps/emqx/src/emqx_channel.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2019-2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2019-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.
@@ -191,7 +191,9 @@ info(topic_aliases, #channel{topic_aliases = Aliases}) ->
 info(alias_maximum, #channel{alias_maximum = Limits}) ->
     Limits;
 info(timers, #channel{timers = Timers}) ->
-    Timers.
+    Timers;
+info(session_state, #channel{session = Session}) ->
+    Session.
 
 set_conn_state(ConnState, Channel) ->
     Channel#channel{conn_state = ConnState}.

+ 1 - 1
apps/emqx/src/emqx_cm.erl

@@ -1,5 +1,5 @@
 %%-------------------------------------------------------------------
-%% Copyright (c) 2017-2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2017-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.

+ 0 - 795
apps/emqx/src/emqx_persistent_message_ds_replayer.erl

@@ -1,795 +0,0 @@
-%%--------------------------------------------------------------------
-%% Copyright (c) 2023 EMQ Technologies Co., Ltd. All Rights Reserved.
-%%
-%% Licensed under the Apache License, Version 2.0 (the "License");
-%% you may not use this file except in compliance with the License.
-%% You may obtain a copy of the License at
-%%
-%%     http://www.apache.org/licenses/LICENSE-2.0
-%%
-%% Unless required by applicable law or agreed to in writing, software
-%% distributed under the License is distributed on an "AS IS" BASIS,
-%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-%% See the License for the specific language governing permissions and
-%% limitations under the License.
-%%--------------------------------------------------------------------
-
-%% @doc This module implements the routines for replaying streams of
-%% messages.
--module(emqx_persistent_message_ds_replayer).
-
-%% API:
--export([new/0, open/1, next_packet_id/1, n_inflight/1]).
-
--export([poll/4, replay/2, commit_offset/4]).
-
--export([seqno_to_packet_id/1, packet_id_to_seqno/2]).
-
--export([committed_until/2]).
-
-%% internal exports:
--export([]).
-
--export_type([inflight/0, seqno/0]).
-
--include_lib("emqx/include/logger.hrl").
--include_lib("emqx/include/emqx_mqtt.hrl").
--include_lib("emqx_utils/include/emqx_message.hrl").
--include("emqx_persistent_session_ds.hrl").
-
--ifdef(TEST).
--include_lib("proper/include/proper.hrl").
--include_lib("eunit/include/eunit.hrl").
--endif.
-
--define(EPOCH_SIZE, 16#10000).
-
--define(ACK, 0).
--define(COMP, 1).
-
--define(TRACK_FLAG(WHICH), (1 bsl WHICH)).
--define(TRACK_FLAGS_ALL, ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)).
--define(TRACK_FLAGS_NONE, 0).
-
-%%================================================================================
-%% Type declarations
-%%================================================================================
-
-%% Note: sequence numbers are monotonic; they don't wrap around:
--type seqno() :: non_neg_integer().
-
--type track() :: ack | comp.
--type commit_type() :: rec.
-
--record(inflight, {
-    next_seqno = 1 :: seqno(),
-    commits = #{ack => 1, comp => 1, rec => 1} :: #{track() | commit_type() => seqno()},
-    %% Ranges are sorted in ascending order of their sequence numbers.
-    offset_ranges = [] :: [ds_pubrange()]
-}).
-
--opaque inflight() :: #inflight{}.
-
--type message() :: emqx_types:message().
--type replies() :: [emqx_session:reply()].
-
--type preproc_fun() :: fun((message()) -> message() | [message()]).
-
-%%================================================================================
-%% API funcions
-%%================================================================================
-
--spec new() -> inflight().
-new() ->
-    #inflight{}.
-
--spec open(emqx_persistent_session_ds:id()) -> inflight().
-open(SessionId) ->
-    {Ranges, RecUntil} = ro_transaction(
-        fun() -> {get_ranges(SessionId), get_committed_offset(SessionId, rec)} end
-    ),
-    {Commits, NextSeqno} = compute_inflight_range(Ranges),
-    #inflight{
-        commits = Commits#{rec => RecUntil},
-        next_seqno = NextSeqno,
-        offset_ranges = Ranges
-    }.
-
--spec next_packet_id(inflight()) -> {emqx_types:packet_id(), inflight()}.
-next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqno}) ->
-    Inflight = Inflight0#inflight{next_seqno = next_seqno(LastSeqno)},
-    {seqno_to_packet_id(LastSeqno), Inflight}.
-
--spec n_inflight(inflight()) -> non_neg_integer().
-n_inflight(#inflight{offset_ranges = Ranges}) ->
-    %% TODO
-    %% This is not very efficient. Instead, we can take the maximum of
-    %% `range_size(AckedUntil, NextSeqno)` and `range_size(CompUntil, NextSeqno)`.
-    %% This won't be exact number but a pessimistic estimate, but this way we
-    %% will penalize clients that PUBACK QoS 1 messages but don't PUBCOMP QoS 2
-    %% messages for some reason. For that to work, we need to additionally track
-    %% actual `AckedUntil` / `CompUntil` during `commit_offset/4`.
-    lists:foldl(
-        fun
-            (#ds_pubrange{type = ?T_CHECKPOINT}, N) ->
-                N;
-            (#ds_pubrange{type = ?T_INFLIGHT} = Range, N) ->
-                N + range_size(Range)
-        end,
-        0,
-        Ranges
-    ).
-
--spec replay(preproc_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
-replay(PreprocFunFun, Inflight0 = #inflight{offset_ranges = Ranges0, commits = Commits}) ->
-    {Ranges, Replies} = lists:mapfoldr(
-        fun(Range, Acc) ->
-            replay_range(PreprocFunFun, Commits, Range, Acc)
-        end,
-        [],
-        Ranges0
-    ),
-    Inflight = Inflight0#inflight{offset_ranges = Ranges},
-    {Replies, Inflight}.
-
--spec commit_offset(emqx_persistent_session_ds:id(), Offset, emqx_types:packet_id(), inflight()) ->
-    {_IsValidOffset :: boolean(), inflight()}
-when
-    Offset :: track() | commit_type().
-commit_offset(
-    SessionId,
-    Track,
-    PacketId,
-    Inflight0 = #inflight{commits = Commits}
-) when Track == ack orelse Track == comp ->
-    case validate_commit(Track, PacketId, Inflight0) of
-        CommitUntil when is_integer(CommitUntil) ->
-            %% TODO
-            %% We do not preserve `CommitUntil` in the database. Instead, we discard
-            %% fully acked ranges from the database. In effect, this means that the
-            %% most recent `CommitUntil` the client has sent may be lost in case of a
-            %% crash or client loss.
-            Inflight1 = Inflight0#inflight{commits = Commits#{Track := CommitUntil}},
-            Inflight = discard_committed(SessionId, Inflight1),
-            {true, Inflight};
-        false ->
-            {false, Inflight0}
-    end;
-commit_offset(
-    SessionId,
-    CommitType = rec,
-    PacketId,
-    Inflight0 = #inflight{commits = Commits}
-) ->
-    case validate_commit(CommitType, PacketId, Inflight0) of
-        CommitUntil when is_integer(CommitUntil) ->
-            update_committed_offset(SessionId, CommitType, CommitUntil),
-            Inflight = Inflight0#inflight{commits = Commits#{CommitType := CommitUntil}},
-            {true, Inflight};
-        false ->
-            {false, Inflight0}
-    end.
-
--spec poll(preproc_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
-    {emqx_session:replies(), inflight()}.
-poll(PreprocFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
-    MinBatchSize = emqx_config:get([session_persistence, min_batch_size]),
-    FetchThreshold = min(MinBatchSize, ceil(WindowSize / 2)),
-    FreeSpace = WindowSize - n_inflight(Inflight0),
-    case FreeSpace >= FetchThreshold of
-        false ->
-            %% TODO: this branch is meant to avoid fetching data from
-            %% the DB in chunks that are too small. However, this
-            %% logic is not exactly good for the latency. Can the
-            %% client get stuck even?
-            {[], Inflight0};
-        true ->
-            %% TODO: Wrap this in `mria:async_dirty/2`?
-            Checkpoints = find_checkpoints(Inflight0#inflight.offset_ranges),
-            StreamGroups = group_streams(get_streams(SessionId)),
-            {Publihes, Inflight} =
-                fetch(PreprocFun, SessionId, Inflight0, Checkpoints, StreamGroups, FreeSpace, []),
-            %% Discard now irrelevant QoS0-only ranges, if any.
-            {Publihes, discard_committed(SessionId, Inflight)}
-    end.
-
-%% Which seqno this track is committed until.
-%% "Until" means this is first seqno that is _not yet committed_ for this track.
--spec committed_until(track() | commit_type(), inflight()) -> seqno().
-committed_until(Track, #inflight{commits = Commits}) ->
-    maps:get(Track, Commits).
-
--spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id() | 0.
-seqno_to_packet_id(Seqno) ->
-    Seqno rem ?EPOCH_SIZE.
-
-%% Reconstruct session counter by adding most significant bits from
-%% the current counter to the packet id.
--spec packet_id_to_seqno(emqx_types:packet_id(), inflight()) -> seqno().
-packet_id_to_seqno(PacketId, #inflight{next_seqno = NextSeqno}) ->
-    packet_id_to_seqno_(NextSeqno, PacketId).
-
-%%================================================================================
-%% Internal exports
-%%================================================================================
-
-%%================================================================================
-%% Internal functions
-%%================================================================================
-
-compute_inflight_range([]) ->
-    {#{ack => 1, comp => 1}, 1};
-compute_inflight_range(Ranges) ->
-    _RangeLast = #ds_pubrange{until = LastSeqno} = lists:last(Ranges),
-    AckedUntil = find_committed_until(ack, Ranges),
-    CompUntil = find_committed_until(comp, Ranges),
-    Commits = #{
-        ack => emqx_maybe:define(AckedUntil, LastSeqno),
-        comp => emqx_maybe:define(CompUntil, LastSeqno)
-    },
-    {Commits, LastSeqno}.
-
-find_committed_until(Track, Ranges) ->
-    RangesUncommitted = lists:dropwhile(
-        fun(Range) ->
-            case Range of
-                #ds_pubrange{type = ?T_CHECKPOINT} ->
-                    true;
-                #ds_pubrange{type = ?T_INFLIGHT, tracks = Tracks} ->
-                    not has_track(Track, Tracks)
-            end
-        end,
-        Ranges
-    ),
-    case RangesUncommitted of
-        [#ds_pubrange{id = {_, CommittedUntil, _StreamRef}} | _] ->
-            CommittedUntil;
-        [] ->
-            undefined
-    end.
-
--spec get_ranges(emqx_persistent_session_ds:id()) -> [ds_pubrange()].
-get_ranges(SessionId) ->
-    Pat = erlang:make_tuple(
-        record_info(size, ds_pubrange),
-        '_',
-        [{1, ds_pubrange}, {#ds_pubrange.id, {SessionId, '_', '_'}}]
-    ),
-    mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
-
-fetch(PreprocFun, SessionId, Inflight0, CPs, Groups, N, Acc) when N > 0, Groups =/= [] ->
-    #inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
-    {Stream, Groups2} = get_the_first_stream(Groups),
-    case get_next_n_messages_from_stream(Stream, CPs, N) of
-        [] ->
-            fetch(PreprocFun, SessionId, Inflight0, CPs, Groups2, N, Acc);
-        {ItBegin, ItEnd, Messages} ->
-            %% We need to preserve the iterator pointing to the beginning of the
-            %% range, so that we can replay it if needed.
-            {Publishes, UntilSeqno} = publish_fetch(PreprocFun, FirstSeqno, Messages),
-            Size = range_size(FirstSeqno, UntilSeqno),
-            Range0 = #ds_pubrange{
-                id = {SessionId, FirstSeqno, Stream#ds_stream.ref},
-                type = ?T_INFLIGHT,
-                tracks = compute_pub_tracks(Publishes),
-                until = UntilSeqno,
-                iterator = ItBegin
-            },
-            ok = preserve_range(Range0),
-            %% ...Yet we need to keep the iterator pointing past the end of the
-            %% range, so that we can pick up where we left off: it will become
-            %% `ItBegin` of the next range for this stream.
-            Range = keep_next_iterator(ItEnd, Range0),
-            Inflight = Inflight0#inflight{
-                next_seqno = UntilSeqno,
-                offset_ranges = Ranges ++ [Range]
-            },
-            fetch(PreprocFun, SessionId, Inflight, CPs, Groups2, N - Size, [Publishes | Acc])
-    end;
-fetch(_ReplyFun, _SessionId, Inflight, _CPs, _Groups, _N, Acc) ->
-    Publishes = lists:append(lists:reverse(Acc)),
-    {Publishes, Inflight}.
-
-discard_committed(
-    SessionId,
-    Inflight0 = #inflight{commits = Commits, offset_ranges = Ranges0}
-) ->
-    %% TODO: This could be kept and incrementally updated in the inflight state.
-    Checkpoints = find_checkpoints(Ranges0),
-    %% TODO: Wrap this in `mria:async_dirty/2`?
-    Ranges = discard_committed_ranges(SessionId, Commits, Checkpoints, Ranges0),
-    Inflight0#inflight{offset_ranges = Ranges}.
-
-find_checkpoints(Ranges) ->
-    lists:foldl(
-        fun(#ds_pubrange{id = {_SessionId, _, StreamRef}} = Range, Acc) ->
-            %% For each stream, remember the last range over this stream.
-            Acc#{StreamRef => Range}
-        end,
-        #{},
-        Ranges
-    ).
-
-discard_committed_ranges(
-    SessionId,
-    Commits,
-    Checkpoints,
-    Ranges = [Range = #ds_pubrange{id = {_SessionId, _, StreamRef}} | Rest]
-) ->
-    case discard_committed_range(Commits, Range) of
-        discard ->
-            %% This range has been fully committed.
-            %% Either discard it completely, or preserve the iterator for the next range
-            %% over this stream (i.e. a checkpoint).
-            RangeKept =
-                case maps:get(StreamRef, Checkpoints) of
-                    Range ->
-                        [checkpoint_range(Range)];
-                    _Previous ->
-                        discard_range(Range),
-                        []
-                end,
-            %% Since we're (intentionally) not using transactions here, it's important to
-            %% issue database writes in the same order in which ranges are stored: from
-            %% the oldest to the newest. This is also why we need to compute which ranges
-            %% should become checkpoints before we start writing anything.
-            RangeKept ++ discard_committed_ranges(SessionId, Commits, Checkpoints, Rest);
-        keep ->
-            %% This range has not been fully committed.
-            [Range | discard_committed_ranges(SessionId, Commits, Checkpoints, Rest)];
-        keep_all ->
-            %% The rest of ranges (if any) still have uncommitted messages.
-            Ranges;
-        TracksLeft ->
-            %% Only some track has been committed.
-            %% Preserve the uncommitted tracks in the database.
-            RangeKept = Range#ds_pubrange{tracks = TracksLeft},
-            preserve_range(restore_first_iterator(RangeKept)),
-            [RangeKept | discard_committed_ranges(SessionId, Commits, Checkpoints, Rest)]
-    end;
-discard_committed_ranges(_SessionId, _Commits, _Checkpoints, []) ->
-    [].
-
-discard_committed_range(_Commits, #ds_pubrange{type = ?T_CHECKPOINT}) ->
-    discard;
-discard_committed_range(
-    #{ack := AckedUntil, comp := CompUntil},
-    #ds_pubrange{until = Until}
-) when Until > AckedUntil andalso Until > CompUntil ->
-    keep_all;
-discard_committed_range(Commits, #ds_pubrange{until = Until, tracks = Tracks}) ->
-    case discard_tracks(Commits, Until, Tracks) of
-        0 ->
-            discard;
-        Tracks ->
-            keep;
-        TracksLeft ->
-            TracksLeft
-    end.
-
-discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) ->
-    TAck =
-        case Until > AckedUntil of
-            true -> ?TRACK_FLAG(?ACK) band Tracks;
-            false -> 0
-        end,
-    TComp =
-        case Until > CompUntil of
-            true -> ?TRACK_FLAG(?COMP) band Tracks;
-            false -> 0
-        end,
-    TAck bor TComp.
-
-replay_range(
-    PreprocFun,
-    Commits,
-    Range0 = #ds_pubrange{
-        type = ?T_INFLIGHT, id = {_, First, _StreamRef}, until = Until, iterator = It
-    },
-    Acc
-) ->
-    Size = range_size(First, Until),
-    {ok, ItNext, MessagesUnacked} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
-    %% Asserting that range is consistent with the message storage state.
-    {Replies, Until} = publish_replay(PreprocFun, Commits, First, MessagesUnacked),
-    %% Again, we need to keep the iterator pointing past the end of the
-    %% range, so that we can pick up where we left off.
-    Range = keep_next_iterator(ItNext, Range0),
-    {Range, Replies ++ Acc};
-replay_range(_PreprocFun, _Commits, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
-    {Range0, Acc}.
-
-validate_commit(
-    Track,
-    PacketId,
-    Inflight = #inflight{commits = Commits, next_seqno = NextSeqno}
-) ->
-    Seqno = packet_id_to_seqno_(NextSeqno, PacketId),
-    CommittedUntil = maps:get(Track, Commits),
-    CommitNext = get_commit_next(Track, Inflight),
-    case Seqno >= CommittedUntil andalso Seqno < CommitNext of
-        true ->
-            next_seqno(Seqno);
-        false ->
-            ?SLOG(warning, #{
-                msg => "out-of-order_commit",
-                track => Track,
-                packet_id => PacketId,
-                commit_seqno => Seqno,
-                committed_until => CommittedUntil,
-                commit_next => CommitNext
-            }),
-            false
-    end.
-
-get_commit_next(ack, #inflight{next_seqno = NextSeqno}) ->
-    NextSeqno;
-get_commit_next(rec, #inflight{next_seqno = NextSeqno}) ->
-    NextSeqno;
-get_commit_next(comp, #inflight{commits = Commits}) ->
-    maps:get(rec, Commits).
-
-publish_fetch(PreprocFun, FirstSeqno, Messages) ->
-    flatmapfoldl(
-        fun({_DSKey, MessageIn}, Acc) ->
-            Message = PreprocFun(MessageIn),
-            publish_fetch(Message, Acc)
-        end,
-        FirstSeqno,
-        Messages
-    ).
-
-publish_fetch(#message{qos = ?QOS_0} = Message, Seqno) ->
-    {{undefined, Message}, Seqno};
-publish_fetch(#message{} = Message, Seqno) ->
-    PacketId = seqno_to_packet_id(Seqno),
-    {{PacketId, Message}, next_seqno(Seqno)};
-publish_fetch(Messages, Seqno) ->
-    flatmapfoldl(fun publish_fetch/2, Seqno, Messages).
-
-publish_replay(PreprocFun, Commits, FirstSeqno, Messages) ->
-    #{ack := AckedUntil, comp := CompUntil, rec := RecUntil} = Commits,
-    flatmapfoldl(
-        fun({_DSKey, MessageIn}, Acc) ->
-            Message = PreprocFun(MessageIn),
-            publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
-        end,
-        FirstSeqno,
-        Messages
-    ).
-
-publish_replay(#message{qos = ?QOS_0}, _, _, _, Seqno) ->
-    %% QoS 0 (at most once) messages should not be replayed.
-    {[], Seqno};
-publish_replay(#message{qos = Qos} = Message, AckedUntil, CompUntil, RecUntil, Seqno) ->
-    case Qos of
-        ?QOS_1 when Seqno < AckedUntil ->
-            %% This message has already been acked, so we can skip it.
-            %% We still need to advance seqno, because previously we assigned this message
-            %% a unique Packet Id.
-            {[], next_seqno(Seqno)};
-        ?QOS_2 when Seqno < CompUntil ->
-            %% This message's flow has already been fully completed, so we can skip it.
-            %% We still need to advance seqno, because previously we assigned this message
-            %% a unique Packet Id.
-            {[], next_seqno(Seqno)};
-        ?QOS_2 when Seqno < RecUntil ->
-            %% This message's flow has been partially completed, we need to resend a PUBREL.
-            PacketId = seqno_to_packet_id(Seqno),
-            Pub = {pubrel, PacketId},
-            {Pub, next_seqno(Seqno)};
-        _ ->
-            %% This message flow hasn't been acked and/or received, we need to resend it.
-            PacketId = seqno_to_packet_id(Seqno),
-            Pub = {PacketId, emqx_message:set_flag(dup, true, Message)},
-            {Pub, next_seqno(Seqno)}
-    end;
-publish_replay([], _, _, _, Seqno) ->
-    {[], Seqno};
-publish_replay(Messages, AckedUntil, CompUntil, RecUntil, Seqno) ->
-    flatmapfoldl(
-        fun(Message, Acc) ->
-            publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
-        end,
-        Seqno,
-        Messages
-    ).
-
--spec compute_pub_tracks(replies()) -> non_neg_integer().
-compute_pub_tracks(Pubs) ->
-    compute_pub_tracks(Pubs, ?TRACK_FLAGS_NONE).
-
-compute_pub_tracks(_Pubs, Tracks = ?TRACK_FLAGS_ALL) ->
-    Tracks;
-compute_pub_tracks([Pub | Rest], Tracks) ->
-    Track =
-        case Pub of
-            {_PacketId, #message{qos = ?QOS_1}} -> ?TRACK_FLAG(?ACK);
-            {_PacketId, #message{qos = ?QOS_2}} -> ?TRACK_FLAG(?COMP);
-            {pubrel, _PacketId} -> ?TRACK_FLAG(?COMP);
-            _ -> ?TRACK_FLAGS_NONE
-        end,
-    compute_pub_tracks(Rest, Track bor Tracks);
-compute_pub_tracks([], Tracks) ->
-    Tracks.
-
-keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) ->
-    Range#ds_pubrange{
-        iterator = ItNext,
-        %% We need to keep the first iterator around, in case we need to preserve
-        %% this range again, updating still uncommitted tracks it's part of.
-        misc = Misc#{iterator_first => ItFirst}
-    }.
-
-restore_first_iterator(Range = #ds_pubrange{misc = Misc = #{iterator_first := ItFirst}}) ->
-    Range#ds_pubrange{
-        iterator = ItFirst,
-        misc = maps:remove(iterator_first, Misc)
-    }.
-
--spec preserve_range(ds_pubrange()) -> ok.
-preserve_range(Range = #ds_pubrange{type = ?T_INFLIGHT}) ->
-    mria:dirty_write(?SESSION_PUBRANGE_TAB, Range).
-
-has_track(ack, Tracks) ->
-    (?TRACK_FLAG(?ACK) band Tracks) > 0;
-has_track(comp, Tracks) ->
-    (?TRACK_FLAG(?COMP) band Tracks) > 0.
-
--spec discard_range(ds_pubrange()) -> ok.
-discard_range(#ds_pubrange{id = RangeId}) ->
-    mria:dirty_delete(?SESSION_PUBRANGE_TAB, RangeId).
-
--spec checkpoint_range(ds_pubrange()) -> ds_pubrange().
-checkpoint_range(Range0 = #ds_pubrange{type = ?T_INFLIGHT}) ->
-    Range = Range0#ds_pubrange{type = ?T_CHECKPOINT, misc = #{}},
-    ok = mria:dirty_write(?SESSION_PUBRANGE_TAB, Range),
-    Range;
-checkpoint_range(Range = #ds_pubrange{type = ?T_CHECKPOINT}) ->
-    %% This range should have been checkpointed already.
-    Range.
-
-get_last_iterator(Stream = #ds_stream{ref = StreamRef}, Checkpoints) ->
-    case maps:get(StreamRef, Checkpoints, none) of
-        none ->
-            Stream#ds_stream.beginning;
-        #ds_pubrange{iterator = ItNext} ->
-            ItNext
-    end.
-
--spec get_streams(emqx_persistent_session_ds:id()) -> [ds_stream()].
-get_streams(SessionId) ->
-    mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId).
-
--spec get_committed_offset(emqx_persistent_session_ds:id(), _Name) -> seqno().
-get_committed_offset(SessionId, Name) ->
-    case mnesia:read(?SESSION_COMMITTED_OFFSET_TAB, {SessionId, Name}) of
-        [] ->
-            1;
-        [#ds_committed_offset{until = Seqno}] ->
-            Seqno
-    end.
-
--spec update_committed_offset(emqx_persistent_session_ds:id(), _Name, seqno()) -> ok.
-update_committed_offset(SessionId, Name, Until) ->
-    mria:dirty_write(?SESSION_COMMITTED_OFFSET_TAB, #ds_committed_offset{
-        id = {SessionId, Name}, until = Until
-    }).
-
-next_seqno(Seqno) ->
-    NextSeqno = Seqno + 1,
-    case seqno_to_packet_id(NextSeqno) of
-        0 ->
-            %% We skip sequence numbers that lead to PacketId = 0 to
-            %% simplify math. Note: it leads to occasional gaps in the
-            %% sequence numbers.
-            NextSeqno + 1;
-        _ ->
-            NextSeqno
-    end.
-
-packet_id_to_seqno_(NextSeqno, PacketId) ->
-    Epoch = NextSeqno bsr 16,
-    case (Epoch bsl 16) + PacketId of
-        N when N =< NextSeqno ->
-            N;
-        N ->
-            N - ?EPOCH_SIZE
-    end.
-
-range_size(#ds_pubrange{id = {_, First, _StreamRef}, until = Until}) ->
-    range_size(First, Until).
-
-range_size(FirstSeqno, UntilSeqno) ->
-    %% This function assumes that gaps in the sequence ID occur _only_ when the
-    %% packet ID wraps.
-    Size = UntilSeqno - FirstSeqno,
-    Size + (FirstSeqno bsr 16) - (UntilSeqno bsr 16).
-
-%%================================================================================
-%% stream scheduler
-
-%% group streams by the first position in the rank
--spec group_streams(list(ds_stream())) -> list(list(ds_stream())).
-group_streams(Streams) ->
-    Groups = maps:groups_from_list(
-        fun(#ds_stream{rank = {RankX, _}}) -> RankX end,
-        Streams
-    ),
-    shuffle(maps:values(Groups)).
-
--spec shuffle([A]) -> [A].
-shuffle(L0) ->
-    L1 = lists:map(
-        fun(A) ->
-            %% maybe topic/stream prioritization could be introduced here?
-            {rand:uniform(), A}
-        end,
-        L0
-    ),
-    L2 = lists:sort(L1),
-    {_, L} = lists:unzip(L2),
-    L.
-
-get_the_first_stream([Group | Groups]) ->
-    case get_next_stream_from_group(Group) of
-        {Stream, {sorted, []}} ->
-            {Stream, Groups};
-        {Stream, Group2} ->
-            {Stream, [Group2 | Groups]};
-        undefined ->
-            get_the_first_stream(Groups)
-    end;
-get_the_first_stream([]) ->
-    %% how this possible ?
-    throw(#{reason => no_valid_stream}).
-
-%% the scheduler is simple, try to get messages from the same shard, but it's okay to take turns
-get_next_stream_from_group({sorted, [H | T]}) ->
-    {H, {sorted, T}};
-get_next_stream_from_group({sorted, []}) ->
-    undefined;
-get_next_stream_from_group(Streams) ->
-    [Stream | T] = lists:sort(
-        fun(#ds_stream{rank = {_, RankA}}, #ds_stream{rank = {_, RankB}}) ->
-            RankA < RankB
-        end,
-        Streams
-    ),
-    {Stream, {sorted, T}}.
-
-get_next_n_messages_from_stream(Stream, CPs, N) ->
-    ItBegin = get_last_iterator(Stream, CPs),
-    case emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N) of
-        {ok, _ItEnd, []} ->
-            [];
-        {ok, ItEnd, Messages} ->
-            {ItBegin, ItEnd, Messages};
-        {ok, end_of_stream} ->
-            %% TODO: how to skip this closed stream or it should be taken over by lower level layer
-            []
-    end.
-
-%%================================================================================
-
--spec flatmapfoldl(fun((X, Acc) -> {Y | [Y], Acc}), Acc, [X]) -> {[Y], Acc}.
-flatmapfoldl(_Fun, Acc, []) ->
-    {[], Acc};
-flatmapfoldl(Fun, Acc, [X | Xs]) ->
-    {Ys, NAcc} = Fun(X, Acc),
-    {Zs, FAcc} = flatmapfoldl(Fun, NAcc, Xs),
-    case is_list(Ys) of
-        true ->
-            {Ys ++ Zs, FAcc};
-        _ ->
-            {[Ys | Zs], FAcc}
-    end.
-
-ro_transaction(Fun) ->
-    {atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
-    Res.
-
--ifdef(TEST).
-
-%% This test only tests boundary conditions (to make sure property-based test didn't skip them):
-packet_id_to_seqno_test() ->
-    %% Packet ID = 1; first epoch:
-    ?assertEqual(1, packet_id_to_seqno_(1, 1)),
-    ?assertEqual(1, packet_id_to_seqno_(10, 1)),
-    ?assertEqual(1, packet_id_to_seqno_(1 bsl 16 - 1, 1)),
-    ?assertEqual(1, packet_id_to_seqno_(1 bsl 16, 1)),
-    %% Packet ID = 1; second and 3rd epochs:
-    ?assertEqual(1 bsl 16 + 1, packet_id_to_seqno_(1 bsl 16 + 1, 1)),
-    ?assertEqual(1 bsl 16 + 1, packet_id_to_seqno_(2 bsl 16, 1)),
-    ?assertEqual(2 bsl 16 + 1, packet_id_to_seqno_(2 bsl 16 + 1, 1)),
-    %% Packet ID = 16#ffff:
-    PID = 1 bsl 16 - 1,
-    ?assertEqual(PID, packet_id_to_seqno_(PID, PID)),
-    ?assertEqual(PID, packet_id_to_seqno_(1 bsl 16, PID)),
-    ?assertEqual(1 bsl 16 + PID, packet_id_to_seqno_(2 bsl 16, PID)),
-    ok.
-
-packet_id_to_seqno_test_() ->
-    Opts = [{numtests, 1000}, {to_file, user}],
-    {timeout, 30, fun() -> ?assert(proper:quickcheck(packet_id_to_seqno_prop(), Opts)) end}.
-
-packet_id_to_seqno_prop() ->
-    ?FORALL(
-        NextSeqNo,
-        next_seqno_gen(),
-        ?FORALL(
-            SeqNo,
-            seqno_gen(NextSeqNo),
-            begin
-                PacketId = seqno_to_packet_id(SeqNo),
-                ?assertEqual(SeqNo, packet_id_to_seqno_(NextSeqNo, PacketId)),
-                true
-            end
-        )
-    ).
-
-next_seqno_gen() ->
-    ?LET(
-        {Epoch, Offset},
-        {non_neg_integer(), non_neg_integer()},
-        Epoch bsl 16 + Offset
-    ).
-
-seqno_gen(NextSeqNo) ->
-    WindowSize = 1 bsl 16 - 1,
-    Min = max(0, NextSeqNo - WindowSize),
-    Max = max(0, NextSeqNo - 1),
-    range(Min, Max).
-
-range_size_test_() ->
-    [
-        ?_assertEqual(0, range_size(42, 42)),
-        ?_assertEqual(1, range_size(42, 43)),
-        ?_assertEqual(1, range_size(16#ffff, 16#10001)),
-        ?_assertEqual(16#ffff - 456 + 123, range_size(16#1f0000 + 456, 16#200000 + 123))
-    ].
-
-compute_inflight_range_test_() ->
-    [
-        ?_assertEqual(
-            {#{ack => 1, comp => 1}, 1},
-            compute_inflight_range([])
-        ),
-        ?_assertEqual(
-            {#{ack => 12, comp => 13}, 42},
-            compute_inflight_range([
-                #ds_pubrange{id = {<<>>, 1, 0}, until = 2, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 4, 0}, until = 8, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 11, 0}, until = 12, type = ?T_CHECKPOINT},
-                #ds_pubrange{
-                    id = {<<>>, 12, 0},
-                    until = 13,
-                    type = ?T_INFLIGHT,
-                    tracks = ?TRACK_FLAG(?ACK)
-                },
-                #ds_pubrange{
-                    id = {<<>>, 13, 0},
-                    until = 20,
-                    type = ?T_INFLIGHT,
-                    tracks = ?TRACK_FLAG(?COMP)
-                },
-                #ds_pubrange{
-                    id = {<<>>, 20, 0},
-                    until = 42,
-                    type = ?T_INFLIGHT,
-                    tracks = ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)
-                }
-            ])
-        ),
-        ?_assertEqual(
-            {#{ack => 13, comp => 13}, 13},
-            compute_inflight_range([
-                #ds_pubrange{id = {<<>>, 1, 0}, until = 2, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 4, 0}, until = 8, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 11, 0}, until = 12, type = ?T_CHECKPOINT},
-                #ds_pubrange{id = {<<>>, 12, 0}, until = 13, type = ?T_CHECKPOINT}
-            ])
-        )
-    ].
-
--endif.

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 717 - 698
apps/emqx/src/emqx_persistent_session_ds.erl


+ 42 - 66
apps/emqx/src/emqx_persistent_session_ds.hrl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.
@@ -25,75 +25,51 @@
 -define(SESSION_COMMITTED_OFFSET_TAB, emqx_ds_committed_offset_tab).
 -define(DS_MRIA_SHARD, emqx_ds_session_shard).
 
--define(T_INFLIGHT, 1).
--define(T_CHECKPOINT, 2).
+%%%%% Session sequence numbers:
 
--record(ds_sub, {
-    id :: emqx_persistent_session_ds:subscription_id(),
-    start_time :: emqx_ds:time(),
-    props = #{} :: map(),
-    extra = #{} :: map()
-}).
--type ds_sub() :: #ds_sub{}.
-
--record(ds_stream, {
-    session :: emqx_persistent_session_ds:id(),
-    ref :: _StreamRef,
-    stream :: emqx_ds:stream(),
-    rank :: emqx_ds:stream_rank(),
-    beginning :: emqx_ds:iterator()
-}).
--type ds_stream() :: #ds_stream{}.
+%%
+%%   -----|----------|-----|-----|------> seqno
+%%        |          |     |     |
+%%   committed      dup   rec   next
+%%                       (Qos2)
 
--record(ds_pubrange, {
-    id :: {
-        %% What session this range belongs to.
-        _Session :: emqx_persistent_session_ds:id(),
-        %% Where this range starts.
-        _First :: emqx_persistent_message_ds_replayer:seqno(),
-        %% Which stream this range is over.
-        _StreamRef
-    },
-    %% Where this range ends: the first seqno that is not included in the range.
-    until :: emqx_persistent_message_ds_replayer:seqno(),
-    %% Type of a range:
-    %% * Inflight range is a range of yet unacked messages from this stream.
-    %% * Checkpoint range was already acked, its purpose is to keep track of the
-    %%   very last iterator for this stream.
-    type :: ?T_INFLIGHT | ?T_CHECKPOINT,
-    %% What commit tracks this range is part of.
-    tracks = 0 :: non_neg_integer(),
-    %% Meaning of this depends on the type of the range:
-    %% * For inflight range, this is the iterator pointing to the first message in
-    %%   the range.
-    %% * For checkpoint range, this is the iterator pointing right past the last
-    %%   message in the range.
-    iterator :: emqx_ds:iterator(),
-    %% Reserved for future use.
-    misc = #{} :: map()
-}).
--type ds_pubrange() :: #ds_pubrange{}.
+%% Seqno becomes committed after receiving PUBACK for QoS1 or PUBCOMP
+%% for QoS2.
+-define(committed(QOS), QOS).
+%% Seqno becomes dup after broker sends QoS1 or QoS2 message to the
+%% client. Upon session reconnect, messages with seqno in the
+%% committed..dup range are retransmitted with DUP flag.
+%%
+-define(dup(QOS), (10 + QOS)).
+%% Rec flag is specific for the QoS2. It contains seqno of the last
+%% PUBREC received from the client. When the session reconnects,
+%% PUBREL packages for the dup..rec range are retransmitted.
+-define(rec, 22).
+%% Last seqno assigned to a message (it may not be sent yet).
+-define(next(QOS), (30 + QOS)).
 
--record(ds_committed_offset, {
-    id :: {
-        %% What session this marker belongs to.
-        _Session :: emqx_persistent_session_ds:id(),
-        %% Marker name.
-        _CommitType
-    },
-    %% Where this marker is pointing to: the first seqno that is not marked.
-    until :: emqx_persistent_message_ds_replayer:seqno()
+%%%%% Stream Replay State:
+-record(srs, {
+    rank_x :: emqx_ds:rank_x(),
+    rank_y :: emqx_ds:rank_y(),
+    %% Iterators at the beginning and the end of the last batch:
+    it_begin :: emqx_ds:iterator() | undefined,
+    it_end :: emqx_ds:iterator() | end_of_stream,
+    %% Size of the last batch:
+    batch_size = 0 :: non_neg_integer(),
+    %% Session sequence numbers at the time when the batch was fetched:
+    first_seqno_qos1 = 0 :: emqx_persistent_session_ds:seqno(),
+    first_seqno_qos2 = 0 :: emqx_persistent_session_ds:seqno(),
+    %% Sequence numbers that have to be committed for the batch:
+    last_seqno_qos1 = 0 :: emqx_persistent_session_ds:seqno(),
+    last_seqno_qos2 = 0 :: emqx_persistent_session_ds:seqno()
 }).
 
--record(session, {
-    %% same as clientid
-    id :: emqx_persistent_session_ds:id(),
-    %% creation time
-    created_at :: _Millisecond :: non_neg_integer(),
-    last_alive_at :: _Millisecond :: non_neg_integer(),
-    conninfo :: emqx_types:conninfo(),
-    %% for future usage
-    props = #{} :: map()
-}).
+%% Session metadata keys:
+-define(created_at, created_at).
+-define(last_alive_at, last_alive_at).
+-define(expiry_interval, expiry_interval).
+%% Unique integer used to create unique identities
+-define(last_id, last_id).
 
 -endif.

+ 27 - 52
apps/emqx/src/emqx_persistent_session_ds_gc_worker.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.
@@ -69,7 +69,7 @@ handle_info(_Info, State) ->
     {noreply, State}.
 
 %%--------------------------------------------------------------------------------
-%% Internal fns
+%% Internal functions
 %%--------------------------------------------------------------------------------
 
 ensure_gc_timer() ->
@@ -104,58 +104,33 @@ now_ms() ->
     erlang:system_time(millisecond).
 
 start_gc() ->
-    do_gc(more).
-
-zombie_session_ms() ->
-    NowMS = now_ms(),
     GCInterval = emqx_config:get([session_persistence, session_gc_interval]),
     BumpInterval = emqx_config:get([session_persistence, last_alive_update_interval]),
     TimeThreshold = max(GCInterval, BumpInterval) * 3,
-    ets:fun2ms(
-        fun(
-            #session{
-                id = DSSessionId,
-                last_alive_at = LastAliveAt,
-                conninfo = #{expiry_interval := EI}
-            }
-        ) when
-            LastAliveAt + EI + TimeThreshold =< NowMS
-        ->
-            DSSessionId
-        end
-    ).
-
-do_gc(more) ->
+    MinLastAlive = now_ms() - TimeThreshold,
+    gc_loop(MinLastAlive, emqx_persistent_session_ds_state:make_session_iterator()).
+
+gc_loop(MinLastAlive, It0) ->
     GCBatchSize = emqx_config:get([session_persistence, session_gc_batch_size]),
-    MS = zombie_session_ms(),
-    {atomic, Next} = mria:transaction(?DS_MRIA_SHARD, fun() ->
-        Res = mnesia:select(?SESSION_TAB, MS, GCBatchSize, write),
-        case Res of
-            '$end_of_table' ->
-                done;
-            {[], Cont} ->
-                %% since `GCBatchsize' is just a "recommendation" for `select', we try only
-                %% _once_ the continuation and then stop if it yields nothing, to avoid a
-                %% dead loop.
-                case mnesia:select(Cont) of
-                    '$end_of_table' ->
-                        done;
-                    {[], _Cont} ->
-                        done;
-                    {DSSessionIds0, _Cont} ->
-                        do_gc_(DSSessionIds0),
-                        more
-                end;
-            {DSSessionIds0, _Cont} ->
-                do_gc_(DSSessionIds0),
-                more
-        end
-    end),
-    do_gc(Next);
-do_gc(done) ->
-    ok.
+    case emqx_persistent_session_ds_state:session_iterator_next(It0, GCBatchSize) of
+        {[], _It} ->
+            ok;
+        {Sessions, It} ->
+            [do_gc(SessionId, MinLastAlive, Metadata) || {SessionId, Metadata} <- Sessions],
+            gc_loop(MinLastAlive, It)
+    end.
 
-do_gc_(DSSessionIds) ->
-    lists:foreach(fun emqx_persistent_session_ds:destroy_session/1, DSSessionIds),
-    ?tp(ds_session_gc_cleaned, #{session_ids => DSSessionIds}),
-    ok.
+do_gc(SessionId, MinLastAlive, Metadata) ->
+    #{?last_alive_at := LastAliveAt, ?expiry_interval := EI} = Metadata,
+    case LastAliveAt + EI < MinLastAlive of
+        true ->
+            emqx_persistent_session_ds:destroy_session(SessionId),
+            ?tp(debug, ds_session_gc_cleaned, #{
+                session_id => SessionId,
+                last_alive_at => LastAliveAt,
+                expiry_interval => EI,
+                min_last_alive => MinLastAlive
+            });
+        false ->
+            ok
+    end.

+ 132 - 0
apps/emqx/src/emqx_persistent_session_ds_inflight.erl

@@ -0,0 +1,132 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+-module(emqx_persistent_session_ds_inflight).
+
+%% API:
+-export([new/1, push/2, pop/1, n_buffered/2, n_inflight/1, inc_send_quota/1, receive_maximum/1]).
+
+%% internal exports:
+-export([]).
+
+-export_type([t/0]).
+
+-include("emqx.hrl").
+-include("emqx_mqtt.hrl").
+
+%%================================================================================
+%% Type declarations
+%%================================================================================
+
+-record(inflight, {
+    queue :: queue:queue(),
+    receive_maximum :: pos_integer(),
+    n_inflight = 0 :: non_neg_integer(),
+    n_qos0 = 0 :: non_neg_integer(),
+    n_qos1 = 0 :: non_neg_integer(),
+    n_qos2 = 0 :: non_neg_integer()
+}).
+
+-type t() :: #inflight{}.
+
+-type payload() ::
+    {emqx_persistent_session_ds:seqno() | undefined, emqx_types:message()}
+    | {pubrel, emqx_persistent_session_ds:seqno()}.
+
+%%================================================================================
+%% API funcions
+%%================================================================================
+
+-spec new(non_neg_integer()) -> t().
+new(ReceiveMaximum) when ReceiveMaximum > 0 ->
+    #inflight{queue = queue:new(), receive_maximum = ReceiveMaximum}.
+
+-spec receive_maximum(t()) -> pos_integer().
+receive_maximum(#inflight{receive_maximum = ReceiveMaximum}) ->
+    ReceiveMaximum.
+
+-spec push(payload(), t()) -> t().
+push(Payload = {pubrel, _SeqNo}, Rec = #inflight{queue = Q}) ->
+    Rec#inflight{queue = queue:in(Payload, Q)};
+push(Payload = {_, Msg}, Rec) ->
+    #inflight{queue = Q0, n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2} = Rec,
+    Q = queue:in(Payload, Q0),
+    case Msg#message.qos of
+        ?QOS_0 ->
+            Rec#inflight{queue = Q, n_qos0 = NQos0 + 1};
+        ?QOS_1 ->
+            Rec#inflight{queue = Q, n_qos1 = NQos1 + 1};
+        ?QOS_2 ->
+            Rec#inflight{queue = Q, n_qos2 = NQos2 + 1}
+    end.
+
+-spec pop(t()) -> {payload(), t()} | undefined.
+pop(Rec0) ->
+    #inflight{
+        receive_maximum = ReceiveMaximum,
+        n_inflight = NInflight,
+        queue = Q0,
+        n_qos0 = NQos0,
+        n_qos1 = NQos1,
+        n_qos2 = NQos2
+    } = Rec0,
+    case NInflight < ReceiveMaximum andalso queue:out(Q0) of
+        {{value, Payload}, Q} ->
+            Rec =
+                case Payload of
+                    {pubrel, _} ->
+                        Rec0#inflight{queue = Q};
+                    {_, #message{qos = Qos}} ->
+                        case Qos of
+                            ?QOS_0 ->
+                                Rec0#inflight{queue = Q, n_qos0 = NQos0 - 1};
+                            ?QOS_1 ->
+                                Rec0#inflight{
+                                    queue = Q, n_qos1 = NQos1 - 1, n_inflight = NInflight + 1
+                                };
+                            ?QOS_2 ->
+                                Rec0#inflight{
+                                    queue = Q, n_qos2 = NQos2 - 1, n_inflight = NInflight + 1
+                                }
+                        end
+                end,
+            {Payload, Rec};
+        _ ->
+            undefined
+    end.
+
+-spec n_buffered(0..2 | all, t()) -> non_neg_integer().
+n_buffered(?QOS_0, #inflight{n_qos0 = NQos0}) ->
+    NQos0;
+n_buffered(?QOS_1, #inflight{n_qos1 = NQos1}) ->
+    NQos1;
+n_buffered(?QOS_2, #inflight{n_qos2 = NQos2}) ->
+    NQos2;
+n_buffered(all, #inflight{n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2}) ->
+    NQos0 + NQos1 + NQos2.
+
+-spec n_inflight(t()) -> non_neg_integer().
+n_inflight(#inflight{n_inflight = NInflight}) ->
+    NInflight.
+
+%% https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Flow_Control
+-spec inc_send_quota(t()) -> t().
+inc_send_quota(Rec = #inflight{n_inflight = NInflight0}) ->
+    NInflight = max(NInflight0 - 1, 0),
+    Rec#inflight{n_inflight = NInflight}.
+
+%%================================================================================
+%% Internal functions
+%%================================================================================

+ 584 - 0
apps/emqx/src/emqx_persistent_session_ds_state.erl

@@ -0,0 +1,584 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+%% @doc CRUD interface for the persistent session
+%%
+%% This module encapsulates the data related to the state of the
+%% inflight messages for the persistent session based on DS.
+%%
+%% It is responsible for saving, caching, and restoring session state.
+%% It is completely devoid of business logic. Not even the default
+%% values should be set in this module.
+-module(emqx_persistent_session_ds_state).
+
+-export([create_tables/0]).
+
+-export([open/1, create_new/1, delete/1, commit/1, format/1, print_session/1, list_sessions/0]).
+-export([get_created_at/1, set_created_at/2]).
+-export([get_last_alive_at/1, set_last_alive_at/2]).
+-export([get_expiry_interval/1, set_expiry_interval/2]).
+-export([new_id/1]).
+-export([get_stream/2, put_stream/3, del_stream/2, fold_streams/3]).
+-export([get_seqno/2, put_seqno/3]).
+-export([get_rank/2, put_rank/3, del_rank/2, fold_ranks/3]).
+-export([get_subscriptions/1, put_subscription/4, del_subscription/3]).
+
+-export([make_session_iterator/0, session_iterator_next/2]).
+
+-export_type([
+    t/0, metadata/0, subscriptions/0, seqno_type/0, stream_key/0, rank_key/0, session_iterator/0
+]).
+
+-include("emqx_mqtt.hrl").
+-include("emqx_persistent_session_ds.hrl").
+-include_lib("snabbkaffe/include/trace.hrl").
+-include_lib("stdlib/include/qlc.hrl").
+
+%%================================================================================
+%% Type declarations
+%%================================================================================
+
+-type subscriptions() :: emqx_topic_gbt:t(_SubId, emqx_persistent_session_ds:subscription()).
+
+-opaque session_iterator() :: emqx_persistent_session_ds:id() | '$end_of_table'.
+
+%% Generic key-value wrapper that is used for exporting arbitrary
+%% terms to mnesia:
+-record(kv, {k, v}).
+
+%% Persistent map.
+%%
+%% Pmap accumulates the updates in a term stored in the heap of a
+%% process, so they can be committed all at once in a single
+%% transaction.
+%%
+%% It should be possible to make frequent changes to the pmap without
+%% stressing Mria.
+%%
+%% It's implemented as two maps: `cache', and `dirty'. `cache' stores
+%% the data, and `dirty' contains information about dirty and deleted
+%% keys. When `commit/1' is called, dirty keys are dumped to the
+%% tables, and deleted keys are removed from the tables.
+-record(pmap, {table, cache, dirty}).
+
+-type pmap(K, V) ::
+    #pmap{
+        table :: atom(),
+        cache :: #{K => V},
+        dirty :: #{K => dirty | del}
+    }.
+
+-type metadata() ::
+    #{
+        ?created_at => emqx_persistent_session_ds:timestamp(),
+        ?last_alive_at => emqx_persistent_session_ds:timestamp(),
+        ?expiry_interval => non_neg_integer(),
+        ?last_id => integer()
+    }.
+
+-type seqno_type() ::
+    ?next(?QOS_1)
+    | ?dup(?QOS_1)
+    | ?committed(?QOS_1)
+    | ?next(?QOS_2)
+    | ?dup(?QOS_2)
+    | ?rec
+    | ?committed(?QOS_2).
+
+-opaque t() :: #{
+    id := emqx_persistent_session_ds:id(),
+    dirty := boolean(),
+    metadata := metadata(),
+    subscriptions := subscriptions(),
+    seqnos := pmap(seqno_type(), emqx_persistent_session_ds:seqno()),
+    streams := pmap(emqx_ds:stream(), emqx_persistent_session_ds:stream_state()),
+    ranks := pmap(term(), integer())
+}.
+
+-define(session_tab, emqx_ds_session_tab).
+-define(subscription_tab, emqx_ds_session_subscriptions).
+-define(stream_tab, emqx_ds_session_streams).
+-define(seqno_tab, emqx_ds_session_seqnos).
+-define(rank_tab, emqx_ds_session_ranks).
+-define(pmap_tables, [?stream_tab, ?seqno_tab, ?rank_tab, ?subscription_tab]).
+
+%% Enable this flag if you suspect some code breaks the sequence:
+-ifndef(CHECK_SEQNO).
+-define(set_dirty, dirty => true).
+-define(unset_dirty, dirty => false).
+-else.
+-define(set_dirty, dirty => true, '_' => do_seqno()).
+-define(unset_dirty, dirty => false, '_' => do_seqno()).
+-endif.
+
+%%================================================================================
+%% API funcions
+%%================================================================================
+
+-spec create_tables() -> ok.
+create_tables() ->
+    ok = mria:create_table(
+        ?session_tab,
+        [
+            {rlog_shard, ?DS_MRIA_SHARD},
+            {type, ordered_set},
+            {storage, rocksdb_copies},
+            {record_name, kv},
+            {attributes, record_info(fields, kv)}
+        ]
+    ),
+    [create_kv_pmap_table(Table) || Table <- ?pmap_tables],
+    mria:wait_for_tables([?session_tab | ?pmap_tables]).
+
+-spec open(emqx_persistent_session_ds:id()) -> {ok, t()} | undefined.
+open(SessionId) ->
+    ro_transaction(fun() ->
+        case kv_restore(?session_tab, SessionId) of
+            [Metadata] ->
+                Rec = #{
+                    id => SessionId,
+                    metadata => Metadata,
+                    subscriptions => read_subscriptions(SessionId),
+                    streams => pmap_open(?stream_tab, SessionId),
+                    seqnos => pmap_open(?seqno_tab, SessionId),
+                    ranks => pmap_open(?rank_tab, SessionId),
+                    ?unset_dirty
+                },
+                {ok, Rec};
+            [] ->
+                undefined
+        end
+    end).
+
+-spec print_session(emqx_persistent_session_ds:id()) -> map() | undefined.
+print_session(SessionId) ->
+    case open(SessionId) of
+        undefined ->
+            undefined;
+        {ok, Session} ->
+            format(Session)
+    end.
+
+-spec format(t()) -> map().
+format(#{
+    metadata := Metadata,
+    subscriptions := SubsGBT,
+    streams := Streams,
+    seqnos := Seqnos,
+    ranks := Ranks
+}) ->
+    Subs = emqx_topic_gbt:fold(
+        fun(Key, Sub, Acc) -> maps:put(Key, Sub, Acc) end,
+        #{},
+        SubsGBT
+    ),
+    #{
+        metadata => Metadata,
+        subscriptions => Subs,
+        streams => pmap_format(Streams),
+        seqnos => pmap_format(Seqnos),
+        ranks => pmap_format(Ranks)
+    }.
+
+-spec list_sessions() -> [emqx_persistent_session_ds:id()].
+list_sessions() ->
+    mnesia:dirty_all_keys(?session_tab).
+
+-spec delete(emqx_persistent_session_ds:id()) -> ok.
+delete(Id) ->
+    transaction(
+        fun() ->
+            [kv_pmap_delete(Table, Id) || Table <- ?pmap_tables],
+            mnesia:delete(?session_tab, Id, write)
+        end
+    ).
+
+-spec commit(t()) -> t().
+commit(Rec = #{dirty := false}) ->
+    Rec;
+commit(
+    Rec = #{
+        id := SessionId,
+        metadata := Metadata,
+        streams := Streams,
+        seqnos := SeqNos,
+        ranks := Ranks
+    }
+) ->
+    check_sequence(Rec),
+    transaction(fun() ->
+        kv_persist(?session_tab, SessionId, Metadata),
+        Rec#{
+            streams => pmap_commit(SessionId, Streams),
+            seqnos => pmap_commit(SessionId, SeqNos),
+            ranks => pmap_commit(SessionId, Ranks),
+            ?unset_dirty
+        }
+    end).
+
+-spec create_new(emqx_persistent_session_ds:id()) -> t().
+create_new(SessionId) ->
+    transaction(fun() ->
+        delete(SessionId),
+        #{
+            id => SessionId,
+            metadata => #{},
+            subscriptions => emqx_topic_gbt:new(),
+            streams => pmap_open(?stream_tab, SessionId),
+            seqnos => pmap_open(?seqno_tab, SessionId),
+            ranks => pmap_open(?rank_tab, SessionId),
+            ?set_dirty
+        }
+    end).
+
+%%
+
+-spec get_created_at(t()) -> emqx_persistent_session_ds:timestamp() | undefined.
+get_created_at(Rec) ->
+    get_meta(?created_at, Rec).
+
+-spec set_created_at(emqx_persistent_session_ds:timestamp(), t()) -> t().
+set_created_at(Val, Rec) ->
+    set_meta(?created_at, Val, Rec).
+
+-spec get_last_alive_at(t()) -> emqx_persistent_session_ds:timestamp() | undefined.
+get_last_alive_at(Rec) ->
+    get_meta(?last_alive_at, Rec).
+
+-spec set_last_alive_at(emqx_persistent_session_ds:timestamp(), t()) -> t().
+set_last_alive_at(Val, Rec) ->
+    set_meta(?last_alive_at, Val, Rec).
+
+-spec get_expiry_interval(t()) -> non_neg_integer() | undefined.
+get_expiry_interval(Rec) ->
+    get_meta(?expiry_interval, Rec).
+
+-spec set_expiry_interval(non_neg_integer(), t()) -> t().
+set_expiry_interval(Val, Rec) ->
+    set_meta(?expiry_interval, Val, Rec).
+
+-spec new_id(t()) -> {emqx_persistent_session_ds:subscription_id(), t()}.
+new_id(Rec) ->
+    LastId =
+        case get_meta(?last_id, Rec) of
+            undefined -> 0;
+            N when is_integer(N) -> N
+        end,
+    {LastId, set_meta(?last_id, LastId + 1, Rec)}.
+
+%%
+
+-spec get_subscriptions(t()) -> subscriptions().
+get_subscriptions(#{subscriptions := Subs}) ->
+    Subs.
+
+-spec put_subscription(
+    emqx_persistent_session_ds:topic_filter(),
+    _SubId,
+    emqx_persistent_session_ds:subscription(),
+    t()
+) -> t().
+put_subscription(TopicFilter, SubId, Subscription, Rec = #{id := Id, subscriptions := Subs0}) ->
+    %% Note: currently changes to the subscriptions are persisted immediately.
+    Key = {TopicFilter, SubId},
+    transaction(fun() -> kv_pmap_persist(?subscription_tab, Id, Key, Subscription) end),
+    Subs = emqx_topic_gbt:insert(TopicFilter, SubId, Subscription, Subs0),
+    Rec#{subscriptions => Subs}.
+
+-spec del_subscription(emqx_persistent_session_ds:topic_filter(), _SubId, t()) -> t().
+del_subscription(TopicFilter, SubId, Rec = #{id := Id, subscriptions := Subs0}) ->
+    %% Note: currently the subscriptions are persisted immediately.
+    Key = {TopicFilter, SubId},
+    transaction(fun() -> kv_pmap_delete(?subscription_tab, Id, Key) end),
+    Subs = emqx_topic_gbt:delete(TopicFilter, SubId, Subs0),
+    Rec#{subscriptions => Subs}.
+
+%%
+
+-type stream_key() :: {emqx_persistent_session_ds:subscription_id(), _StreamId}.
+
+-spec get_stream(stream_key(), t()) ->
+    emqx_persistent_session_ds:stream_state() | undefined.
+get_stream(Key, Rec) ->
+    gen_get(streams, Key, Rec).
+
+-spec put_stream(stream_key(), emqx_persistent_session_ds:stream_state(), t()) -> t().
+put_stream(Key, Val, Rec) ->
+    gen_put(streams, Key, Val, Rec).
+
+-spec del_stream(stream_key(), t()) -> t().
+del_stream(Key, Rec) ->
+    gen_del(streams, Key, Rec).
+
+-spec fold_streams(fun(), Acc, t()) -> Acc.
+fold_streams(Fun, Acc, Rec) ->
+    gen_fold(streams, Fun, Acc, Rec).
+
+%%
+
+-spec get_seqno(seqno_type(), t()) -> emqx_persistent_session_ds:seqno() | undefined.
+get_seqno(Key, Rec) ->
+    gen_get(seqnos, Key, Rec).
+
+-spec put_seqno(seqno_type(), emqx_persistent_session_ds:seqno(), t()) -> t().
+put_seqno(Key, Val, Rec) ->
+    gen_put(seqnos, Key, Val, Rec).
+
+%%
+
+-type rank_key() :: {emqx_persistent_session_ds:subscription_id(), emqx_ds:rank_x()}.
+
+-spec get_rank(rank_key(), t()) -> integer() | undefined.
+get_rank(Key, Rec) ->
+    gen_get(ranks, Key, Rec).
+
+-spec put_rank(rank_key(), integer(), t()) -> t().
+put_rank(Key, Val, Rec) ->
+    gen_put(ranks, Key, Val, Rec).
+
+-spec del_rank(rank_key(), t()) -> t().
+del_rank(Key, Rec) ->
+    gen_del(ranks, Key, Rec).
+
+-spec fold_ranks(fun(), Acc, t()) -> Acc.
+fold_ranks(Fun, Acc, Rec) ->
+    gen_fold(ranks, Fun, Acc, Rec).
+
+-spec make_session_iterator() -> session_iterator().
+make_session_iterator() ->
+    case mnesia:dirty_first(?session_tab) of
+        '$end_of_table' ->
+            '$end_of_table';
+        Key ->
+            Key
+    end.
+
+-spec session_iterator_next(session_iterator(), pos_integer()) ->
+    {[{emqx_persistent_session_ds:id(), metadata()}], session_iterator()}.
+session_iterator_next(Cursor, 0) ->
+    {[], Cursor};
+session_iterator_next('$end_of_table', _N) ->
+    {[], '$end_of_table'};
+session_iterator_next(Cursor0, N) ->
+    ThisVal = [
+        {Cursor0, Metadata}
+     || #kv{v = Metadata} <- mnesia:dirty_read(?session_tab, Cursor0)
+    ],
+    {NextVals, Cursor} = session_iterator_next(mnesia:dirty_next(?session_tab, Cursor0), N - 1),
+    {ThisVal ++ NextVals, Cursor}.
+
+%%================================================================================
+%% Internal functions
+%%================================================================================
+
+%% All mnesia reads and writes are passed through this function.
+%% Backward compatiblity issues can be handled here.
+encoder(encode, _Table, Term) ->
+    Term;
+encoder(decode, _Table, Term) ->
+    Term.
+
+%%
+
+get_meta(K, #{metadata := Meta}) ->
+    maps:get(K, Meta, undefined).
+
+set_meta(K, V, Rec = #{metadata := Meta}) ->
+    check_sequence(Rec#{metadata => maps:put(K, V, Meta), ?set_dirty}).
+
+%%
+
+gen_get(Field, Key, Rec) ->
+    check_sequence(Rec),
+    pmap_get(Key, maps:get(Field, Rec)).
+
+gen_fold(Field, Fun, Acc, Rec) ->
+    check_sequence(Rec),
+    pmap_fold(Fun, Acc, maps:get(Field, Rec)).
+
+gen_put(Field, Key, Val, Rec) ->
+    check_sequence(Rec),
+    maps:update_with(
+        Field,
+        fun(PMap) -> pmap_put(Key, Val, PMap) end,
+        Rec#{?set_dirty}
+    ).
+
+gen_del(Field, Key, Rec) ->
+    check_sequence(Rec),
+    maps:update_with(
+        Field,
+        fun(PMap) -> pmap_del(Key, PMap) end,
+        Rec#{?set_dirty}
+    ).
+
+%%
+
+read_subscriptions(SessionId) ->
+    Records = kv_pmap_restore(?subscription_tab, SessionId),
+    lists:foldl(
+        fun({{TopicFilter, SubId}, Subscription}, Acc) ->
+            emqx_topic_gbt:insert(TopicFilter, SubId, Subscription, Acc)
+        end,
+        emqx_topic_gbt:new(),
+        Records
+    ).
+
+%%
+
+%% @doc Open a PMAP and fill the clean area with the data from DB.
+%% This functtion should be ran in a transaction.
+-spec pmap_open(atom(), emqx_persistent_session_ds:id()) -> pmap(_K, _V).
+pmap_open(Table, SessionId) ->
+    Clean = maps:from_list(kv_pmap_restore(Table, SessionId)),
+    #pmap{
+        table = Table,
+        cache = Clean,
+        dirty = #{}
+    }.
+
+-spec pmap_get(K, pmap(K, V)) -> V | undefined.
+pmap_get(K, #pmap{cache = Cache}) ->
+    maps:get(K, Cache, undefined).
+
+-spec pmap_put(K, V, pmap(K, V)) -> pmap(K, V).
+pmap_put(K, V, Pmap = #pmap{dirty = Dirty, cache = Cache}) ->
+    Pmap#pmap{
+        cache = maps:put(K, V, Cache),
+        dirty = Dirty#{K => dirty}
+    }.
+
+-spec pmap_del(K, pmap(K, V)) -> pmap(K, V).
+pmap_del(
+    Key,
+    Pmap = #pmap{dirty = Dirty, cache = Cache}
+) ->
+    Pmap#pmap{
+        cache = maps:remove(Key, Cache),
+        dirty = Dirty#{Key => del}
+    }.
+
+-spec pmap_fold(fun((K, V, A) -> A), A, pmap(K, V)) -> A.
+pmap_fold(Fun, Acc, #pmap{cache = Cache}) ->
+    maps:fold(Fun, Acc, Cache).
+
+-spec pmap_commit(emqx_persistent_session_ds:id(), pmap(K, V)) -> pmap(K, V).
+pmap_commit(
+    SessionId, Pmap = #pmap{table = Tab, dirty = Dirty, cache = Cache}
+) ->
+    maps:foreach(
+        fun
+            (K, del) ->
+                kv_pmap_delete(Tab, SessionId, K);
+            (K, dirty) ->
+                V = maps:get(K, Cache),
+                kv_pmap_persist(Tab, SessionId, K, V)
+        end,
+        Dirty
+    ),
+    Pmap#pmap{
+        dirty = #{}
+    }.
+
+-spec pmap_format(pmap(_K, _V)) -> map().
+pmap_format(#pmap{cache = Cache}) ->
+    Cache.
+
+%% Functions dealing with set tables:
+
+kv_persist(Tab, SessionId, Val0) ->
+    Val = encoder(encode, Tab, Val0),
+    mnesia:write(Tab, #kv{k = SessionId, v = Val}, write).
+
+kv_restore(Tab, SessionId) ->
+    [encoder(decode, Tab, V) || #kv{v = V} <- mnesia:read(Tab, SessionId)].
+
+%% Functions dealing with bags:
+
+%% @doc Create a mnesia table for the PMAP:
+-spec create_kv_pmap_table(atom()) -> ok.
+create_kv_pmap_table(Table) ->
+    mria:create_table(Table, [
+        {type, ordered_set},
+        {rlog_shard, ?DS_MRIA_SHARD},
+        {storage, rocksdb_copies},
+        {record_name, kv},
+        {attributes, record_info(fields, kv)}
+    ]).
+
+kv_pmap_persist(Tab, SessionId, Key, Val0) ->
+    %% Write data to mnesia:
+    Val = encoder(encode, Tab, Val0),
+    mnesia:write(Tab, #kv{k = {SessionId, Key}, v = Val}, write).
+
+kv_pmap_restore(Table, SessionId) ->
+    MS = [{#kv{k = {SessionId, '$1'}, v = '$2'}, [], [{{'$1', '$2'}}]}],
+    Objs = mnesia:select(Table, MS, read),
+    [{K, encoder(decode, Table, V)} || {K, V} <- Objs].
+
+kv_pmap_delete(Table, SessionId) ->
+    MS = [{#kv{k = {SessionId, '$1'}, _ = '_'}, [], ['$1']}],
+    Keys = mnesia:select(Table, MS, read),
+    [mnesia:delete(Table, {SessionId, K}, write) || K <- Keys],
+    ok.
+
+kv_pmap_delete(Table, SessionId, Key) ->
+    %% Note: this match spec uses a fixed primary key, so it doesn't
+    %% require a table scan, and the transaction doesn't grab the
+    %% whole table lock:
+    mnesia:delete(Table, {SessionId, Key}, write).
+
+%%
+
+transaction(Fun) ->
+    mria:async_dirty(?DS_MRIA_SHARD, Fun).
+
+ro_transaction(Fun) ->
+    mria:async_dirty(?DS_MRIA_SHARD, Fun).
+
+%% transaction(Fun) ->
+%%     case mnesia:is_transaction() of
+%%         true ->
+%%             Fun();
+%%         false ->
+%%             {atomic, Res} = mria:transaction(?DS_MRIA_SHARD, Fun),
+%%             Res
+%%     end.
+
+%% ro_transaction(Fun) ->
+%%     {atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
+%%     Res.
+
+-compile({inline, check_sequence/1}).
+
+-ifdef(CHECK_SEQNO).
+do_seqno() ->
+    case erlang:get(?MODULE) of
+        undefined ->
+            put(?MODULE, 0),
+            0;
+        N ->
+            put(?MODULE, N + 1),
+            N + 1
+    end.
+
+check_sequence(A = #{'_' := N}) ->
+    N = erlang:get(?MODULE),
+    A.
+-else.
+check_sequence(A) ->
+    A.
+-endif.

+ 312 - 0
apps/emqx/src/emqx_persistent_session_ds_stream_scheduler.erl

@@ -0,0 +1,312 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+-module(emqx_persistent_session_ds_stream_scheduler).
+
+%% API:
+-export([find_new_streams/1, find_replay_streams/1]).
+-export([renew_streams/1, del_subscription/2]).
+
+%% behavior callbacks:
+-export([]).
+
+%% internal exports:
+-export([]).
+
+-export_type([]).
+
+-include_lib("emqx/include/logger.hrl").
+-include("emqx_mqtt.hrl").
+-include("emqx_persistent_session_ds.hrl").
+
+%%================================================================================
+%% Type declarations
+%%================================================================================
+
+%%================================================================================
+%% API functions
+%%================================================================================
+
+%% @doc Find the streams that have uncommitted (in-flight) messages.
+%% Return them in the order they were previously replayed.
+-spec find_replay_streams(emqx_persistent_session_ds_state:t()) ->
+    [{emqx_persistent_session_ds_state:stream_key(), emqx_persistent_session_ds:stream_state()}].
+find_replay_streams(S) ->
+    Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
+    Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
+    %% 1. Find the streams that aren't fully acked
+    Streams = emqx_persistent_session_ds_state:fold_streams(
+        fun(Key, Stream, Acc) ->
+            case is_fully_acked(Comm1, Comm2, Stream) of
+                false ->
+                    [{Key, Stream} | Acc];
+                true ->
+                    Acc
+            end
+        end,
+        [],
+        S
+    ),
+    lists:sort(fun compare_streams/2, Streams).
+
+%% @doc Find streams from which the new messages can be fetched.
+%%
+%% Currently it amounts to the streams that don't have any inflight
+%% messages, since for performance reasons we keep only one record of
+%% in-flight messages per stream, and we don't want to overwrite these
+%% records prematurely.
+%%
+%% This function is non-detereministic: it randomizes the order of
+%% streams to ensure fair replay of different topics.
+-spec find_new_streams(emqx_persistent_session_ds_state:t()) ->
+    [{emqx_persistent_session_ds_state:stream_key(), emqx_persistent_session_ds:stream_state()}].
+find_new_streams(S) ->
+    %% FIXME: this function is currently very sensitive to the
+    %% consistency of the packet IDs on both broker and client side.
+    %%
+    %% If the client fails to properly ack packets due to a bug, or a
+    %% network issue, or if the state of streams and seqno tables ever
+    %% become de-synced, then this function will return an empty list,
+    %% and the replay cannot progress.
+    %%
+    %% In other words, this function is not robust, and we should find
+    %% some way to get the replays un-stuck at the cost of potentially
+    %% losing messages during replay (or just kill the stuck channel
+    %% after timeout?)
+    Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
+    Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
+    shuffle(
+        emqx_persistent_session_ds_state:fold_streams(
+            fun(Key, Stream, Acc) ->
+                case is_fully_acked(Comm1, Comm2, Stream) of
+                    true ->
+                        [{Key, Stream} | Acc];
+                    false ->
+                        Acc
+                end
+            end,
+            [],
+            S
+        )
+    ).
+
+%% @doc This function makes the session aware of the new streams.
+%%
+%% It has the following properties:
+%%
+%% 1. For each RankX, it keeps only the streams with the same RankY.
+%%
+%% 2. For each RankX, it never advances RankY until _all_ streams with
+%% the same RankX are replayed.
+%%
+%% 3. Once all streams with the given rank are replayed, it advances
+%% the RankY to the smallest known RankY that is greater than replayed
+%% RankY.
+%%
+%% 4. If the RankX has never been replayed, it selects the streams
+%% with the smallest RankY.
+%%
+%% This way, messages from the same topic/shard are never reordered.
+-spec renew_streams(emqx_persistent_session_ds_state:t()) -> emqx_persistent_session_ds_state:t().
+renew_streams(S0) ->
+    S1 = remove_fully_replayed_streams(S0),
+    emqx_topic_gbt:fold(
+        fun(Key, _Subscription = #{start_time := StartTime, id := SubId}, S2) ->
+            TopicFilter = emqx_topic:words(emqx_trie_search:get_topic(Key)),
+            Streams = select_streams(
+                SubId,
+                emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
+                S2
+            ),
+            lists:foldl(
+                fun(I, Acc) ->
+                    ensure_iterator(TopicFilter, StartTime, SubId, I, Acc)
+                end,
+                S2,
+                Streams
+            )
+        end,
+        S1,
+        emqx_persistent_session_ds_state:get_subscriptions(S1)
+    ).
+
+-spec del_subscription(
+    emqx_persistent_session_ds:subscription_id(), emqx_persistent_session_ds_state:t()
+) ->
+    emqx_persistent_session_ds_state:t().
+del_subscription(SubId, S0) ->
+    emqx_persistent_session_ds_state:fold_streams(
+        fun(Key, _, Acc) ->
+            case Key of
+                {SubId, _Stream} ->
+                    emqx_persistent_session_ds_state:del_stream(Key, Acc);
+                _ ->
+                    Acc
+            end
+        end,
+        S0,
+        S0
+    ).
+
+%%================================================================================
+%% Internal functions
+%%================================================================================
+
+ensure_iterator(TopicFilter, StartTime, SubId, {{RankX, RankY}, Stream}, S) ->
+    %% TODO: hash collisions
+    Key = {SubId, erlang:phash2(Stream)},
+    case emqx_persistent_session_ds_state:get_stream(Key, S) of
+        undefined ->
+            ?SLOG(debug, #{
+                msg => new_stream, key => Key, stream => Stream
+            }),
+            {ok, Iterator} = emqx_ds:make_iterator(
+                ?PERSISTENT_MESSAGE_DB, Stream, TopicFilter, StartTime
+            ),
+            NewStreamState = #srs{
+                rank_x = RankX,
+                rank_y = RankY,
+                it_begin = Iterator,
+                it_end = Iterator
+            },
+            emqx_persistent_session_ds_state:put_stream(Key, NewStreamState, S);
+        #srs{} ->
+            S
+    end.
+
+select_streams(SubId, Streams0, S) ->
+    TopicStreamGroups = maps:groups_from_list(fun({{X, _}, _}) -> X end, Streams0),
+    maps:fold(
+        fun(RankX, Streams, Acc) ->
+            select_streams(SubId, RankX, Streams, S) ++ Acc
+        end,
+        [],
+        TopicStreamGroups
+    ).
+
+select_streams(SubId, RankX, Streams0, S) ->
+    %% 1. Find the streams with the rank Y greater than the recorded one:
+    Streams1 =
+        case emqx_persistent_session_ds_state:get_rank({SubId, RankX}, S) of
+            undefined ->
+                Streams0;
+            ReplayedY ->
+                [I || I = {{_, Y}, _} <- Streams0, Y > ReplayedY]
+        end,
+    %% 2. Sort streams by rank Y:
+    Streams = lists:sort(
+        fun({{_, Y1}, _}, {{_, Y2}, _}) ->
+            Y1 =< Y2
+        end,
+        Streams1
+    ),
+    %% 3. Select streams with the least rank Y:
+    case Streams of
+        [] ->
+            [];
+        [{{_, MinRankY}, _} | _] ->
+            lists:takewhile(fun({{_, Y}, _}) -> Y =:= MinRankY end, Streams)
+    end.
+
+%% @doc Advance RankY for each RankX that doesn't have any unreplayed
+%% streams.
+%%
+%% Drop streams with the fully replayed rank. This function relies on
+%% the fact that all streams with the same RankX have also the same
+%% RankY.
+-spec remove_fully_replayed_streams(emqx_persistent_session_ds_state:t()) ->
+    emqx_persistent_session_ds_state:t().
+remove_fully_replayed_streams(S0) ->
+    CommQos1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S0),
+    CommQos2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S0),
+    %% 1. For each subscription, find the X ranks that were fully replayed:
+    Groups = emqx_persistent_session_ds_state:fold_streams(
+        fun({SubId, _Stream}, StreamState = #srs{rank_x = RankX, rank_y = RankY}, Acc) ->
+            Key = {SubId, RankX},
+            case is_fully_replayed(CommQos1, CommQos2, StreamState) of
+                true when is_map_key(Key, Acc) ->
+                    Acc#{Key => {true, RankY}};
+                true ->
+                    Acc#{Key => false};
+                _ ->
+                    Acc
+            end
+        end,
+        #{},
+        S0
+    ),
+    %% 2. Advance rank y for each fully replayed set of streams:
+    S1 = maps:fold(
+        fun
+            (Key, {true, RankY}, Acc) ->
+                emqx_persistent_session_ds_state:put_rank(Key, RankY, Acc);
+            (_, _, Acc) ->
+                Acc
+        end,
+        S0,
+        Groups
+    ),
+    %% 3. Remove the fully replayed streams:
+    emqx_persistent_session_ds_state:fold_streams(
+        fun(Key = {SubId, _Stream}, #srs{rank_x = RankX, rank_y = RankY}, Acc) ->
+            case emqx_persistent_session_ds_state:get_rank({SubId, RankX}, Acc) of
+                undefined ->
+                    Acc;
+                MinRankY when RankY < MinRankY ->
+                    ?SLOG(debug, #{
+                        msg => del_fully_preplayed_stream,
+                        key => Key,
+                        rank => {RankX, RankY},
+                        min => MinRankY
+                    }),
+                    emqx_persistent_session_ds_state:del_stream(Key, Acc);
+                _ ->
+                    Acc
+            end
+        end,
+        S1,
+        S1
+    ).
+
+%% @doc Compare the streams by the order in which they were replayed.
+compare_streams(
+    {_KeyA, #srs{first_seqno_qos1 = A1, first_seqno_qos2 = A2}},
+    {_KeyB, #srs{first_seqno_qos1 = B1, first_seqno_qos2 = B2}}
+) ->
+    case A1 =:= B1 of
+        true ->
+            A2 =< B2;
+        false ->
+            A1 < B1
+    end.
+
+is_fully_replayed(Comm1, Comm2, S = #srs{it_end = It}) ->
+    It =:= end_of_stream andalso is_fully_acked(Comm1, Comm2, S).
+
+is_fully_acked(Comm1, Comm2, #srs{last_seqno_qos1 = S1, last_seqno_qos2 = S2}) ->
+    (Comm1 >= S1) andalso (Comm2 >= S2).
+
+-spec shuffle([A]) -> [A].
+shuffle(L0) ->
+    L1 = lists:map(
+        fun(A) ->
+            %% maybe topic/stream prioritization could be introduced here?
+            {rand:uniform(), A}
+        end,
+        L0
+    ),
+    L2 = lists:sort(L1),
+    {_, L} = lists:unzip(L2),
+    L.

+ 1 - 1
apps/emqx/src/emqx_persistent_session_ds_sup.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.

+ 1 - 1
apps/emqx/src/emqx_schema.erl

@@ -1810,7 +1810,7 @@ fields("session_persistence") ->
             sc(
                 pos_integer(),
                 #{
-                    default => 1000,
+                    default => 100,
                     desc => ?DESC(session_ds_max_batch_size)
                 }
             )},

+ 11 - 13
apps/emqx/src/emqx_session.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2017-2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2017-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.
@@ -409,12 +409,8 @@ enrich_delivers(ClientInfo, Delivers, Session) ->
 enrich_delivers(_ClientInfo, [], _UpgradeQoS, _Session) ->
     [];
 enrich_delivers(ClientInfo, [D | Rest], UpgradeQoS, Session) ->
-    case enrich_deliver(ClientInfo, D, UpgradeQoS, Session) of
-        [] ->
-            enrich_delivers(ClientInfo, Rest, UpgradeQoS, Session);
-        Msg ->
-            [Msg | enrich_delivers(ClientInfo, Rest, UpgradeQoS, Session)]
-    end.
+    enrich_deliver(ClientInfo, D, UpgradeQoS, Session) ++
+        enrich_delivers(ClientInfo, Rest, UpgradeQoS, Session).
 
 enrich_deliver(ClientInfo, {deliver, Topic, Msg}, UpgradeQoS, Session) ->
     SubOpts =
@@ -435,13 +431,15 @@ enrich_message(
     _ = emqx_session_events:handle_event(ClientInfo, {dropped, Msg, no_local}),
     [];
 enrich_message(_ClientInfo, MsgIn, SubOpts = #{}, UpgradeQoS) ->
-    maps:fold(
-        fun(SubOpt, V, Msg) -> enrich_subopts(SubOpt, V, Msg, UpgradeQoS) end,
-        MsgIn,
-        SubOpts
-    );
+    [
+        maps:fold(
+            fun(SubOpt, V, Msg) -> enrich_subopts(SubOpt, V, Msg, UpgradeQoS) end,
+            MsgIn,
+            SubOpts
+        )
+    ];
 enrich_message(_ClientInfo, Msg, undefined, _UpgradeQoS) ->
-    Msg.
+    [Msg].
 
 enrich_subopts(nl, 1, Msg, _) ->
     emqx_message:set_flag(nl, Msg);

+ 10 - 10
apps/emqx/src/emqx_topic_gbt.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.
@@ -39,11 +39,11 @@
 -type match(ID) :: key(ID).
 
 -opaque t(ID, Value) :: gb_trees:tree(key(ID), Value).
--opaque t() :: t(_ID, _Value).
+-type t() :: t(_ID, _Value).
 
 %% @doc Create a new gb_tree and store it in the persitent_term with the
 %% given name.
--spec new() -> t().
+-spec new() -> t(_ID, _Value).
 new() ->
     gb_trees:empty().
 
@@ -54,19 +54,19 @@ size(Gbt) ->
 %% @doc Insert a new entry into the index that associates given topic filter to given
 %% record ID, and attaches arbitrary record to the entry. This allows users to choose
 %% between regular and "materialized" indexes, for example.
--spec insert(emqx_types:topic() | words(), _ID, _Record, t()) -> t().
+-spec insert(emqx_types:topic() | words(), ID, Record, t(ID, Record)) -> t(ID, Record).
 insert(Filter, ID, Record, Gbt) ->
     Key = key(Filter, ID),
     gb_trees:enter(Key, Record, Gbt).
 
 %% @doc Delete an entry from the index that associates given topic filter to given
 %% record ID. Deleting non-existing entry is not an error.
--spec delete(emqx_types:topic() | words(), _ID, t()) -> t().
+-spec delete(emqx_types:topic() | words(), ID, t(ID, Record)) -> t(ID, Record).
 delete(Filter, ID, Gbt) ->
     Key = key(Filter, ID),
     gb_trees:delete_any(Key, Gbt).
 
--spec lookup(emqx_types:topic() | words(), _ID, t(), Default) -> _Record | Default.
+-spec lookup(emqx_types:topic() | words(), ID, t(ID, Record), Default) -> Record | Default.
 lookup(Filter, ID, Gbt, Default) ->
     Key = key(Filter, ID),
     case gb_trees:lookup(Key, Gbt) of
@@ -76,7 +76,7 @@ lookup(Filter, ID, Gbt, Default) ->
             Default
     end.
 
--spec fold(fun((key(_ID), _Record, Acc) -> Acc), Acc, t()) -> Acc.
+-spec fold(fun((key(ID), Record, Acc) -> Acc), Acc, t(ID, Record)) -> Acc.
 fold(Fun, Acc, Gbt) ->
     Iter = gb_trees:iterator(Gbt),
     fold_iter(Fun, Acc, Iter).
@@ -91,13 +91,13 @@ fold_iter(Fun, Acc, Iter) ->
 
 %% @doc Match given topic against the index and return the first match, or `false` if
 %% no match is found.
--spec match(emqx_types:topic(), t()) -> match(_ID) | false.
+-spec match(emqx_types:topic(), t(ID, _Record)) -> match(ID) | false.
 match(Topic, Gbt) ->
     emqx_trie_search:match(Topic, make_nextf(Gbt)).
 
 %% @doc Match given topic against the index and return _all_ matches.
 %% If `unique` option is given, return only unique matches by record ID.
--spec matches(emqx_types:topic(), t(), emqx_trie_search:opts()) -> [match(_ID)].
+-spec matches(emqx_types:topic(), t(ID, _Record), emqx_trie_search:opts()) -> [match(ID)].
 matches(Topic, Gbt, Opts) ->
     emqx_trie_search:matches(Topic, make_nextf(Gbt), Opts).
 
@@ -112,7 +112,7 @@ get_topic(Key) ->
     emqx_trie_search:get_topic(Key).
 
 %% @doc Fetch the record associated with the match.
--spec get_record(match(_ID), t()) -> _Record.
+-spec get_record(match(ID), t(ID, Record)) -> Record.
 get_record(Key, Gbt) ->
     gb_trees:get(Key, Gbt).
 

+ 2 - 31
apps/emqx/test/emqx_persistent_messages_SUITE.erl

@@ -216,31 +216,7 @@ t_session_subscription_iterators(Config) ->
                 messages => [Message1, Message2, Message3, Message4]
             }
         end,
-        fun(Trace) ->
-            ct:pal("trace:\n  ~p", [Trace]),
-            case ?of_kind(ds_session_subscription_added, Trace) of
-                [] ->
-                    %% Since `emqx_durable_storage' is a dependency of `emqx', it gets
-                    %% compiled in "prod" mode when running emqx standalone tests.
-                    ok;
-                [_ | _] ->
-                    ?assertMatch(
-                        [
-                            #{?snk_kind := ds_session_subscription_added},
-                            #{?snk_kind := ds_session_subscription_present}
-                        ],
-                        ?of_kind(
-                            [
-                                ds_session_subscription_added,
-                                ds_session_subscription_present
-                            ],
-                            Trace
-                        )
-                    ),
-                    ok
-            end,
-            ok
-        end
+        []
     ),
     ok.
 
@@ -318,11 +294,6 @@ t_qos0_only_many_streams(_Config) ->
             receive_messages(3)
         ),
 
-        ?assertMatch(
-            #{pubranges := [_, _, _]},
-            emqx_persistent_session_ds:print_session(ClientId)
-        ),
-
         Inflight1 = get_session_inflight(ConnPid),
 
         %% TODO: Kinda stupid way to verify that the runtime state is not growing.
@@ -524,7 +495,7 @@ consume(It) ->
     end.
 
 receive_messages(Count) ->
-    receive_messages(Count, 5_000).
+    receive_messages(Count, 10_000).
 
 receive_messages(Count, Timeout) ->
     lists:reverse(receive_messages(Count, [], Timeout)).

+ 89 - 66
apps/emqx/test/emqx_persistent_session_SUITE.erl

@@ -71,7 +71,12 @@ init_per_group(persistence_disabled, Config) ->
     ];
 init_per_group(persistence_enabled, Config) ->
     [
-        {emqx_config, "session_persistence { enable = true }"},
+        {emqx_config,
+            "session_persistence {\n"
+            "  enable = true\n"
+            "  last_alive_update_interval = 100ms\n"
+            "  renew_streams_interval = 100ms\n"
+            "}"},
         {persistence, ds}
         | Config
     ];
@@ -530,42 +535,47 @@ t_process_dies_session_expires(Config) ->
     %% Emulate an error in the connect process,
     %% or that the node of the process goes down.
     %% A persistent session should eventually expire.
-    ConnFun = ?config(conn_fun, Config),
-    ClientId = ?config(client_id, Config),
-    Topic = ?config(topic, Config),
-    STopic = ?config(stopic, Config),
-    Payload = <<"test">>,
-    {ok, Client1} = emqtt:start_link([
-        {proto_ver, v5},
-        {clientid, ClientId},
-        {properties, #{'Session-Expiry-Interval' => 1}},
-        {clean_start, true}
-        | Config
-    ]),
-    {ok, _} = emqtt:ConnFun(Client1),
-    {ok, _, [2]} = emqtt:subscribe(Client1, STopic, qos2),
-    ok = emqtt:disconnect(Client1),
-
-    maybe_kill_connection_process(ClientId, Config),
-
-    ok = publish(Topic, Payload),
-
-    timer:sleep(1100),
-
-    {ok, Client2} = emqtt:start_link([
-        {proto_ver, v5},
-        {clientid, ClientId},
-        {properties, #{'Session-Expiry-Interval' => 30}},
-        {clean_start, false}
-        | Config
-    ]),
-    {ok, _} = emqtt:ConnFun(Client2),
-    ?assertEqual(0, client_info(session_present, Client2)),
-
-    %% We should not receive the pending message
-    ?assertEqual([], receive_messages(1)),
-
-    emqtt:disconnect(Client2).
+    ?check_trace(
+        begin
+            ConnFun = ?config(conn_fun, Config),
+            ClientId = ?config(client_id, Config),
+            Topic = ?config(topic, Config),
+            STopic = ?config(stopic, Config),
+            Payload = <<"test">>,
+            {ok, Client1} = emqtt:start_link([
+                {proto_ver, v5},
+                {clientid, ClientId},
+                {properties, #{'Session-Expiry-Interval' => 1}},
+                {clean_start, true}
+                | Config
+            ]),
+            {ok, _} = emqtt:ConnFun(Client1),
+            {ok, _, [2]} = emqtt:subscribe(Client1, STopic, qos2),
+            ok = emqtt:disconnect(Client1),
+
+            maybe_kill_connection_process(ClientId, Config),
+
+            ok = publish(Topic, Payload),
+
+            timer:sleep(1500),
+
+            {ok, Client2} = emqtt:start_link([
+                {proto_ver, v5},
+                {clientid, ClientId},
+                {properties, #{'Session-Expiry-Interval' => 30}},
+                {clean_start, false}
+                | Config
+            ]),
+            {ok, _} = emqtt:ConnFun(Client2),
+            ?assertEqual(0, client_info(session_present, Client2)),
+
+            %% We should not receive the pending message
+            ?assertEqual([], receive_messages(1)),
+
+            emqtt:disconnect(Client2)
+        end,
+        []
+    ).
 
 t_publish_while_client_is_gone_qos1(Config) ->
     %% A persistent session should receive messages in its
@@ -672,6 +682,7 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
     ),
 
     NAcked = 4,
+    ?assert(NMsgs1 >= NAcked),
     [ok = emqtt:puback(Client1, PktId) || #{packet_id := PktId} <- lists:sublist(Msgs1, NAcked)],
 
     %% Ensure that PUBACKs are propagated to the channel.
@@ -681,7 +692,7 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
     maybe_kill_connection_process(ClientId, Config),
 
     Pubs2 = [
-        #mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M8">>, qos = 1},
+        #mqtt_msg{topic = <<"loc/3/4/6">>, payload = <<"M8">>, qos = 1},
         #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M9">>, qos = 1},
         #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M10">>, qos = 1},
         #mqtt_msg{topic = <<"msg/feed/friend">>, payload = <<"M11">>, qos = 1},
@@ -690,27 +701,30 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
     ok = publish_many(Pubs2),
     NPubs2 = length(Pubs2),
 
+    %% Now reconnect with auto ack to make sure all streams are
+    %% replayed till the end:
     {ok, Client2} = emqtt:start_link([
         {proto_ver, v5},
         {clientid, ClientId},
         {properties, #{'Session-Expiry-Interval' => 30}},
-        {clean_start, false},
-        {auto_ack, false}
+        {clean_start, false}
         | Config
     ]),
+
     {ok, _} = emqtt:ConnFun(Client2),
 
     %% Try to receive _at most_ `NPubs` messages.
     %% There shouldn't be that much unacked messages in the replay anyway,
     %% but it's an easy number to pick.
     NPubs = NPubs1 + NPubs2,
+
     Msgs2 = receive_messages(NPubs, _Timeout = 2000),
     NMsgs2 = length(Msgs2),
 
     ct:pal("Msgs2 = ~p", [Msgs2]),
 
-    ?assert(NMsgs2 < NPubs, Msgs2),
-    ?assert(NMsgs2 > NPubs2, Msgs2),
+    ?assert(NMsgs2 < NPubs, {NMsgs2, '<', NPubs}),
+    ?assert(NMsgs2 > NPubs2, {NMsgs2, '>', NPubs2}),
     ?assert(NMsgs2 >= NPubs - NAcked, Msgs2),
     NSame = NMsgs2 - NPubs2,
     ?assert(
@@ -773,6 +787,11 @@ t_publish_many_while_client_is_gone(Config) ->
     %% for its subscriptions after the client dies or reconnects, in addition
     %% to PUBRELs for the messages it has PUBRECed. While client must send
     %% PUBACKs and PUBRECs in order, those orders are independent of each other.
+    %%
+    %% Developer's note: for simplicity we publish all messages to the
+    %% same topic, since persistent session ds may reorder messages
+    %% that belong to different streams, and this particular test is
+    %% very sensitive the order.
     ClientId = ?config(client_id, Config),
     ConnFun = ?config(conn_fun, Config),
     ClientOpts = [
@@ -785,20 +804,18 @@ t_publish_many_while_client_is_gone(Config) ->
 
     {ok, Client1} = emqtt:start_link([{clean_start, true} | ClientOpts]),
     {ok, _} = emqtt:ConnFun(Client1),
-    {ok, _, [?QOS_1]} = emqtt:subscribe(Client1, <<"t/+/foo">>, ?QOS_1),
-    {ok, _, [?QOS_2]} = emqtt:subscribe(Client1, <<"msg/feed/#">>, ?QOS_2),
-    {ok, _, [?QOS_2]} = emqtt:subscribe(Client1, <<"loc/+/+/+">>, ?QOS_2),
+    {ok, _, [?QOS_2]} = emqtt:subscribe(Client1, <<"t">>, ?QOS_2),
 
     Pubs1 = [
-        #mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M1">>, qos = 1},
-        #mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M2">>, qos = 1},
-        #mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M3">>, qos = 2},
-        #mqtt_msg{topic = <<"loc/1/2/42">>, payload = <<"M4">>, qos = 2},
-        #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M5">>, qos = 2},
-        #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M6">>, qos = 1},
-        #mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M7">>, qos = 2},
-        #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M8">>, qos = 1},
-        #mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M9">>, qos = 2}
+        #mqtt_msg{topic = <<"t">>, payload = <<"M1">>, qos = 1},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M2">>, qos = 1},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M3">>, qos = 2},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M4">>, qos = 2},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M5">>, qos = 2},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M6">>, qos = 1},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M7">>, qos = 2},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M8">>, qos = 1},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M9">>, qos = 2}
     ],
     ok = publish_many(Pubs1),
     NPubs1 = length(Pubs1),
@@ -806,11 +823,12 @@ t_publish_many_while_client_is_gone(Config) ->
     Msgs1 = receive_messages(NPubs1),
     ct:pal("Msgs1 = ~p", [Msgs1]),
     NMsgs1 = length(Msgs1),
-    ?assertEqual(NPubs1, NMsgs1),
+    ?assertEqual(NPubs1, NMsgs1, emqx_persistent_session_ds:print_session(ClientId)),
 
     ?assertEqual(
         get_topicwise_order(Pubs1),
-        get_topicwise_order(Msgs1)
+        get_topicwise_order(Msgs1),
+        emqx_persistent_session_ds:print_session(ClientId)
     ),
 
     %% PUBACK every QoS 1 message.
@@ -819,7 +837,7 @@ t_publish_many_while_client_is_gone(Config) ->
         [PktId || #{qos := 1, packet_id := PktId} <- Msgs1]
     ),
 
-    %% PUBREC first `NRecs` QoS 2 messages.
+    %% PUBREC first `NRecs` QoS 2 messages (up to "M5")
     NRecs = 3,
     PubRecs1 = lists:sublist([PktId || #{qos := 2, packet_id := PktId} <- Msgs1], NRecs),
     lists:foreach(
@@ -843,9 +861,9 @@ t_publish_many_while_client_is_gone(Config) ->
     maybe_kill_connection_process(ClientId, Config),
 
     Pubs2 = [
-        #mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M10">>, qos = 2},
-        #mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M11">>, qos = 1},
-        #mqtt_msg{topic = <<"msg/feed/friend">>, payload = <<"M12">>, qos = 2}
+        #mqtt_msg{topic = <<"t">>, payload = <<"M10">>, qos = 2},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M11">>, qos = 1},
+        #mqtt_msg{topic = <<"t">>, payload = <<"M12">>, qos = 2}
     ],
     ok = publish_many(Pubs2),
     NPubs2 = length(Pubs2),
@@ -878,8 +896,8 @@ t_publish_many_while_client_is_gone(Config) ->
         Msgs2Dups
     ),
 
-    %% Now complete all yet incomplete QoS 2 message flows instead.
-    PubRecs2 = [PktId || #{qos := 2, packet_id := PktId} <- Msgs2],
+    %% Ack more messages:
+    PubRecs2 = lists:sublist([PktId || #{qos := 2, packet_id := PktId} <- Msgs2], 2),
     lists:foreach(
         fun(PktId) -> ok = emqtt:pubrec(Client2, PktId) end,
         PubRecs2
@@ -895,6 +913,7 @@ t_publish_many_while_client_is_gone(Config) ->
 
     %% PUBCOMP every PUBREL.
     PubComps = [PktId || {pubrel, #{packet_id := PktId}} <- PubRels1 ++ PubRels2],
+    ct:pal("PubComps: ~p", [PubComps]),
     lists:foreach(
         fun(PktId) -> ok = emqtt:pubcomp(Client2, PktId) end,
         PubComps
@@ -902,19 +921,19 @@ t_publish_many_while_client_is_gone(Config) ->
 
     %% Ensure that PUBCOMPs are propagated to the channel.
     pong = emqtt:ping(Client2),
-
+    %% Reconnect for the last time
     ok = disconnect_client(Client2),
     maybe_kill_connection_process(ClientId, Config),
 
     {ok, Client3} = emqtt:start_link([{clean_start, false} | ClientOpts]),
     {ok, _} = emqtt:ConnFun(Client3),
 
-    %% Only the last unacked QoS 1 message should be retransmitted.
+    %% Check that we receive the rest of the messages:
     Msgs3 = receive_messages(NPubs, _Timeout = 2000),
     ct:pal("Msgs3 = ~p", [Msgs3]),
     ?assertMatch(
-        [#{topic := <<"t/100/foo">>, payload := <<"M11">>, qos := 1, dup := true}],
-        Msgs3
+        [<<"M10">>, <<"M11">>, <<"M12">>],
+        [I || #{payload := I} <- Msgs3]
     ),
 
     ok = disconnect_client(Client3).
@@ -1080,3 +1099,7 @@ skip_ds_tc(Config) ->
         _ ->
             Config
     end.
+
+debug_info(ClientId) ->
+    Info = emqx_persistent_session_ds:print_session(ClientId),
+    ct:pal("*** State:~n~p", [Info]).

+ 373 - 0
apps/emqx/test/emqx_persistent_session_ds_state_tests.erl

@@ -0,0 +1,373 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+-module(emqx_persistent_session_ds_state_tests).
+
+-compile(nowarn_export_all).
+-compile(export_all).
+
+-include_lib("proper/include/proper.hrl").
+-include_lib("eunit/include/eunit.hrl").
+
+-define(tab, ?MODULE).
+
+%%================================================================================
+%% Type declarations
+%%================================================================================
+
+%% Note: here `committed' != `dirty'. It means "has been committed at
+%% least once since the creation", and it's used by the iteration
+%% test.
+-record(s, {subs = #{}, metadata = #{}, streams = #{}, seqno = #{}, committed = false}).
+
+-type state() :: #{emqx_persistent_session_ds:id() => #s{}}.
+
+%%================================================================================
+%% Properties
+%%================================================================================
+
+seqno_proper_test_() ->
+    Props = [prop_consistency()],
+    Opts = [{numtests, 10}, {to_file, user}, {max_size, 100}],
+    {timeout, 300, [?_assert(proper:quickcheck(Prop, Opts)) || Prop <- Props]}.
+
+prop_consistency() ->
+    ?FORALL(
+        Cmds,
+        commands(?MODULE),
+        begin
+            init(),
+            {_History, State, Result} = run_commands(?MODULE, Cmds),
+            clean(),
+            ?WHENFAIL(
+                io:format(
+                    user,
+                    "Operations: ~p~nState: ~p\nResult: ~p~n",
+                    [Cmds, State, Result]
+                ),
+                aggregate(command_names(Cmds), Result =:= ok)
+            )
+        end
+    ).
+
+%%================================================================================
+%% Generators
+%%================================================================================
+
+-define(n_sessions, 10).
+
+session_id() ->
+    oneof([integer_to_binary(I) || I <- lists:seq(1, ?n_sessions)]).
+
+topic() ->
+    oneof([<<"foo">>, <<"bar">>, <<"foo/#">>, <<"//+/#">>]).
+
+subid() ->
+    oneof([[]]).
+
+subscription() ->
+    oneof([#{}]).
+
+session_id(S) ->
+    oneof(maps:keys(S)).
+
+batch_size() ->
+    range(1, ?n_sessions).
+
+put_metadata() ->
+    oneof([
+        ?LET(
+            Val,
+            range(0, 100),
+            {last_alive_at, set_last_alive_at, Val}
+        ),
+        ?LET(
+            Val,
+            range(0, 100),
+            {created_at, set_created_at, Val}
+        )
+    ]).
+
+get_metadata() ->
+    oneof([
+        {last_alive_at, get_last_alive_at},
+        {created_at, get_created_at}
+    ]).
+
+seqno_track() ->
+    range(0, 1).
+
+seqno() ->
+    range(1, 100).
+
+stream_id() ->
+    range(1, 1).
+
+stream() ->
+    oneof([#{}]).
+
+put_req() ->
+    oneof([
+        ?LET(
+            {Id, Stream},
+            {stream_id(), stream()},
+            {#s.streams, put_stream, Id, Stream}
+        ),
+        ?LET(
+            {Track, Seqno},
+            {seqno_track(), seqno()},
+            {#s.seqno, put_seqno, Track, Seqno}
+        )
+    ]).
+
+get_req() ->
+    oneof([
+        {#s.streams, get_stream, stream_id()},
+        {#s.seqno, get_seqno, seqno_track()}
+    ]).
+
+del_req() ->
+    oneof([
+        {#s.streams, del_stream, stream_id()}
+    ]).
+
+command(S) ->
+    case maps:size(S) > 0 of
+        true ->
+            frequency([
+                %% Global CRUD operations:
+                {1, {call, ?MODULE, create_new, [session_id()]}},
+                {1, {call, ?MODULE, delete, [session_id(S)]}},
+                {2, {call, ?MODULE, reopen, [session_id(S)]}},
+                {2, {call, ?MODULE, commit, [session_id(S)]}},
+
+                %% Subscriptions:
+                {3,
+                    {call, ?MODULE, put_subscription, [
+                        session_id(S), topic(), subid(), subscription()
+                    ]}},
+                {3, {call, ?MODULE, del_subscription, [session_id(S), topic(), subid()]}},
+
+                %% Metadata:
+                {3, {call, ?MODULE, put_metadata, [session_id(S), put_metadata()]}},
+                {3, {call, ?MODULE, get_metadata, [session_id(S), get_metadata()]}},
+
+                %% Key-value:
+                {3, {call, ?MODULE, gen_put, [session_id(S), put_req()]}},
+                {3, {call, ?MODULE, gen_get, [session_id(S), get_req()]}},
+                {3, {call, ?MODULE, gen_del, [session_id(S), del_req()]}},
+
+                %% Getters:
+                {4, {call, ?MODULE, get_subscriptions, [session_id(S)]}},
+                {1, {call, ?MODULE, iterate_sessions, [batch_size()]}}
+            ]);
+        false ->
+            frequency([
+                {1, {call, ?MODULE, create_new, [session_id()]}},
+                {1, {call, ?MODULE, iterate_sessions, [batch_size()]}}
+            ])
+    end.
+
+precondition(_, _) ->
+    true.
+
+postcondition(S, {call, ?MODULE, iterate_sessions, [_]}, Result) ->
+    {Sessions, _} = lists:unzip(Result),
+    %% No lingering sessions:
+    ?assertMatch([], Sessions -- maps:keys(S)),
+    %% All committed sessions are visited by the iterator:
+    CommittedSessions = lists:sort([K || {K, #s{committed = true}} <- maps:to_list(S)]),
+    ?assertMatch([], CommittedSessions -- Sessions),
+    true;
+postcondition(S, {call, ?MODULE, get_metadata, [SessionId, {MetaKey, _Fun}]}, Result) ->
+    #{SessionId := #s{metadata = Meta}} = S,
+    ?assertEqual(
+        maps:get(MetaKey, Meta, undefined),
+        Result,
+        #{session_id => SessionId, meta => MetaKey}
+    ),
+    true;
+postcondition(S, {call, ?MODULE, gen_get, [SessionId, {Idx, Fun, Key}]}, Result) ->
+    #{SessionId := Record} = S,
+    ?assertEqual(
+        maps:get(Key, element(Idx, Record), undefined),
+        Result,
+        #{session_id => SessionId, key => Key, 'fun' => Fun}
+    ),
+    true;
+postcondition(S, {call, ?MODULE, get_subscriptions, [SessionId]}, Result) ->
+    #{SessionId := #s{subs = Subs}} = S,
+    ?assertEqual(maps:size(Subs), emqx_topic_gbt:size(Result)),
+    maps:foreach(
+        fun({TopicFilter, Id}, Expected) ->
+            ?assertEqual(
+                Expected,
+                emqx_topic_gbt:lookup(TopicFilter, Id, Result, default)
+            )
+        end,
+        Subs
+    ),
+    true;
+postcondition(_, _, _) ->
+    true.
+
+next_state(S, _V, {call, ?MODULE, create_new, [SessionId]}) ->
+    S#{SessionId => #s{}};
+next_state(S, _V, {call, ?MODULE, delete, [SessionId]}) ->
+    maps:remove(SessionId, S);
+next_state(S, _V, {call, ?MODULE, put_subscription, [SessionId, TopicFilter, SubId, Subscription]}) ->
+    Key = {TopicFilter, SubId},
+    update(
+        SessionId,
+        #s.subs,
+        fun(Subs) -> Subs#{Key => Subscription} end,
+        S
+    );
+next_state(S, _V, {call, ?MODULE, del_subscription, [SessionId, TopicFilter, SubId]}) ->
+    Key = {TopicFilter, SubId},
+    update(
+        SessionId,
+        #s.subs,
+        fun(Subs) -> maps:remove(Key, Subs) end,
+        S
+    );
+next_state(S, _V, {call, ?MODULE, put_metadata, [SessionId, {Key, _Fun, Val}]}) ->
+    update(
+        SessionId,
+        #s.metadata,
+        fun(Map) -> Map#{Key => Val} end,
+        S
+    );
+next_state(S, _V, {call, ?MODULE, gen_put, [SessionId, {Idx, _Fun, Key, Val}]}) ->
+    update(
+        SessionId,
+        Idx,
+        fun(Map) -> Map#{Key => Val} end,
+        S
+    );
+next_state(S, _V, {call, ?MODULE, gen_del, [SessionId, {Idx, _Fun, Key}]}) ->
+    update(
+        SessionId,
+        Idx,
+        fun(Map) -> maps:remove(Key, Map) end,
+        S
+    );
+next_state(S, _V, {call, ?MODULE, commit, [SessionId]}) ->
+    update(
+        SessionId,
+        #s.committed,
+        fun(_) -> true end,
+        S
+    );
+next_state(S, _V, {call, ?MODULE, _, _}) ->
+    S.
+
+initial_state() ->
+    #{}.
+
+%%================================================================================
+%% Operations
+%%================================================================================
+
+create_new(SessionId) ->
+    put_state(SessionId, emqx_persistent_session_ds_state:create_new(SessionId)).
+
+delete(SessionId) ->
+    emqx_persistent_session_ds_state:delete(SessionId),
+    ets:delete(?tab, SessionId).
+
+commit(SessionId) ->
+    put_state(SessionId, emqx_persistent_session_ds_state:commit(get_state(SessionId))).
+
+reopen(SessionId) ->
+    _ = emqx_persistent_session_ds_state:commit(get_state(SessionId)),
+    {ok, S} = emqx_persistent_session_ds_state:open(SessionId),
+    put_state(SessionId, S).
+
+put_subscription(SessionId, TopicFilter, SubId, Subscription) ->
+    S = emqx_persistent_session_ds_state:put_subscription(
+        TopicFilter, SubId, Subscription, get_state(SessionId)
+    ),
+    put_state(SessionId, S).
+
+del_subscription(SessionId, TopicFilter, SubId) ->
+    S = emqx_persistent_session_ds_state:del_subscription(TopicFilter, SubId, get_state(SessionId)),
+    put_state(SessionId, S).
+
+get_subscriptions(SessionId) ->
+    emqx_persistent_session_ds_state:get_subscriptions(get_state(SessionId)).
+
+put_metadata(SessionId, {_MetaKey, Fun, Value}) ->
+    S = apply(emqx_persistent_session_ds_state, Fun, [Value, get_state(SessionId)]),
+    put_state(SessionId, S).
+
+get_metadata(SessionId, {_MetaKey, Fun}) ->
+    apply(emqx_persistent_session_ds_state, Fun, [get_state(SessionId)]).
+
+gen_put(SessionId, {_Idx, Fun, Key, Value}) ->
+    S = apply(emqx_persistent_session_ds_state, Fun, [Key, Value, get_state(SessionId)]),
+    put_state(SessionId, S).
+
+gen_del(SessionId, {_Idx, Fun, Key}) ->
+    S = apply(emqx_persistent_session_ds_state, Fun, [Key, get_state(SessionId)]),
+    put_state(SessionId, S).
+
+gen_get(SessionId, {_Idx, Fun, Key}) ->
+    apply(emqx_persistent_session_ds_state, Fun, [Key, get_state(SessionId)]).
+
+iterate_sessions(BatchSize) ->
+    Fun = fun F(It0) ->
+        case emqx_persistent_session_ds_state:session_iterator_next(It0, BatchSize) of
+            {[], _} ->
+                [];
+            {Sessions, It} ->
+                Sessions ++ F(It)
+        end
+    end,
+    Fun(emqx_persistent_session_ds_state:make_session_iterator()).
+
+%%================================================================================
+%% Misc.
+%%================================================================================
+
+update(SessionId, Key, Fun, S) ->
+    maps:update_with(
+        SessionId,
+        fun(SS) ->
+            setelement(Key, SS, Fun(erlang:element(Key, SS)))
+        end,
+        S
+    ).
+
+get_state(SessionId) ->
+    case ets:lookup(?tab, SessionId) of
+        [{_, S}] ->
+            S;
+        [] ->
+            error({not_found, SessionId})
+    end.
+
+put_state(SessionId, S) ->
+    ets:insert(?tab, {SessionId, S}).
+
+init() ->
+    _ = ets:new(?tab, [named_table, public, {keypos, 1}]),
+    mria:start(),
+    emqx_persistent_session_ds_state:create_tables().
+
+clean() ->
+    ets:delete(?tab),
+    mria:stop(),
+    mria_mnesia:delete_schema().

+ 1 - 1
apps/emqx_conf/src/emqx_conf_schema.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2021-2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2021-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.

+ 7 - 1
apps/emqx_durable_storage/src/emqx_ds.erl

@@ -47,6 +47,8 @@
     topic_filter/0,
     topic/0,
     stream/0,
+    rank_x/0,
+    rank_y/0,
     stream_rank/0,
     iterator/0,
     iterator_id/0,
@@ -77,7 +79,11 @@
 %% Parsed topic filter.
 -type topic_filter() :: list(binary() | '+' | '#' | '').
 
--type stream_rank() :: {term(), integer()}.
+-type rank_x() :: term().
+
+-type rank_y() :: integer().
+
+-type stream_rank() :: {rank_x(), rank_y()}.
 
 %% TODO: Not implemented
 -type iterator_id() :: term().

+ 5 - 1
apps/emqx_durable_storage/src/emqx_ds_lts.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.
@@ -213,6 +213,10 @@ trie_next(#trie{trie = Trie}, State, ?EOT) ->
         [] -> undefined
     end;
 trie_next(#trie{trie = Trie}, State, Token) ->
+    %% NOTE: it's crucial to return the original (non-wildcard) index
+    %% for the topic, if found. Otherwise messages from the same topic
+    %% will end up in different streams, once the wildcard is learned,
+    %% and their replay order will become undefined:
     case ets:lookup(Trie, {State, Token}) of
         [#trans{next = Next}] ->
             {false, Next};

+ 1 - 1
apps/emqx_exhook/test/emqx_exhook_SUITE.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2020-2023 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2020-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.

+ 7 - 0
changes/ce/feat-12251.en.md

@@ -0,0 +1,7 @@
+Optimize performance of the RocksDB-based persistent session.
+Reduce RAM usage and frequency of database requests.
+
+- Introduce dirty session state to avoid frequent mria transactions
+- Introduce an intermediate buffer for the persistent messages
+- Use separate tracks of PacketIds for QoS1 and QoS2 messages
+- Limit the number of continuous ranges of inflight messages to one per stream