Explorar o código

feat(ds): Add asynchronous poll API

ieQu1 hai 1 ano
pai
achega
cd69c21261
Modificáronse 29 ficheiros con 2522 adicións e 583 borrados
  1. 22 2
      apps/emqx/src/emqx_ds_schema.erl
  2. 266 190
      apps/emqx/src/emqx_persistent_session_ds.erl
  3. 27 27
      apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_inflight.erl
  4. 29 83
      apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_state.erl
  5. 461 111
      apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_stream_scheduler.erl
  6. 8 2
      apps/emqx/src/emqx_schema.erl
  7. 8 4
      apps/emqx/src/emqx_topic_index.erl
  8. 4 0
      apps/emqx/src/emqx_trie_search.erl
  9. 73 11
      apps/emqx_ds_builtin_local/src/emqx_ds_builtin_local.erl
  10. 17 1
      apps/emqx_ds_builtin_local/src/emqx_ds_builtin_local_db_sup.erl
  11. 2 31
      apps/emqx_ds_builtin_local/test/emqx_ds_builtin_local_SUITE.erl
  12. 6 0
      apps/emqx_ds_builtin_raft/src/emqx_ds_replication_layer.erl
  13. 3 2
      apps/emqx_durable_storage/include/emqx_ds.hrl
  14. 13 0
      apps/emqx_durable_storage/include/emqx_ds_metrics.hrl
  15. 54 28
      apps/emqx_durable_storage/src/emqx_ds.erl
  16. 827 0
      apps/emqx_durable_storage/src/emqx_ds_beamformer.erl
  17. 123 0
      apps/emqx_durable_storage/src/emqx_ds_beamformer_sup.erl
  18. 107 0
      apps/emqx_durable_storage/src/emqx_ds_beamformer_waitq.erl
  19. 30 1
      apps/emqx_durable_storage/src/emqx_ds_builtin_metrics.erl
  20. 5 4
      apps/emqx_durable_storage/src/emqx_ds_lib.erl
  21. 48 1
      apps/emqx_durable_storage/src/emqx_ds_storage_bitfield_lts.erl
  22. 104 6
      apps/emqx_durable_storage/src/emqx_ds_storage_layer.erl
  23. 40 1
      apps/emqx_durable_storage/src/emqx_ds_storage_reference.erl
  24. 85 15
      apps/emqx_durable_storage/src/emqx_ds_storage_skipstream_lts.erl
  25. 1 1
      apps/emqx_durable_storage/src/emqx_durable_storage.app.src
  26. 39 0
      apps/emqx_durable_storage/src/proto/emqx_ds_beamsplitter_proto_v1.erl
  27. 114 1
      apps/emqx_durable_storage/test/emqx_ds_storage_layout_SUITE.erl
  28. 0 60
      apps/emqx_durable_storage/test/props/emqx_ds_message_storage_bitmask_shim.erl
  29. 6 1
      apps/emqx_prometheus/src/emqx_prometheus.erl

+ 22 - 2
apps/emqx/src/emqx_ds_schema.erl

@@ -76,13 +76,17 @@ translate_builtin_local(
     #{
         backend := builtin_local,
         n_shards := NShards,
-        layout := Layout
+        layout := Layout,
+        poll_workers_per_shard := NPollers,
+        poll_batch_size := BatchSize
     }
 ) ->
     #{
         backend => builtin_local,
         n_shards => NShards,
-        storage => translate_layout(Layout)
+        storage => translate_layout(Layout),
+        poll_workers_per_shard => NPollers,
+        poll_batch_size => BatchSize
     }.
 
 %%================================================================================
@@ -337,6 +341,22 @@ common_builtin_fields() ->
                             <<"type">> => wildcard_optimized_v2
                         }
                 }
+            )},
+        {poll_workers_per_shard,
+            sc(
+                pos_integer(),
+                #{
+                    default => 10,
+                    importance => ?IMPORTANCE_HIDDEN
+                }
+            )},
+        {poll_batch_size,
+            sc(
+                pos_integer(),
+                #{
+                    default => 100,
+                    importance => ?IMPORTANCE_HIDDEN
+                }
             )}
     ].
 

+ 266 - 190
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -14,6 +14,37 @@
 %% limitations under the License.
 %%--------------------------------------------------------------------
 
+%% @doc This module implements an MQTT session that can survive
+%% restart of EMQX node by backing up its state on disk. It consumes
+%% messages from a shared durable storage. This is in contrast to the
+%% regular "mem" sessions that store all recieved messages in their
+%% own memory queues.
+%%
+%% The main challenge of durable session is to replay sent, but
+%% unacked, messages in case of the client reconnect. This
+%% implementation approaches this problem by storing iterators, batch
+%% sizes and sequence numbers of MQTT packets for the consumed
+%% messages as an array of "stream replay state" records (`#srs'), in
+%% such a way, that messages and their corresponging packet IDs can be
+%% reconstructing by "replaying" the stored SRSes.
+%%
+%% The session logic is implemented as two mostly separate loops
+%% ("circuits") that operate on a transient message queue, serving as
+%% a buffer.
+%%
+%% - *Push circuit* polls durable storage, and pushes messages to the
+%% queue. It's assisted by the `stream_scheduler' module that decides
+%% which streams are eligible for pull. Push circuit is responsible
+%% for maintaining a size of the queue at the configured limit.
+%%
+%% - *Pull circuit* consumes messages from the buffer and publishes
+%% them to the client connection. It's responsible for maintining the
+%% number of inflight packets as close to the negitiated
+%% `Recieve-Maximum' as possible to maximize the throughput.
+%%
+%% These circuites interact simply by notifying each other via
+%% `pull_now' or `push_now' functions.
+
 -module(emqx_persistent_session_ds).
 
 -behaviour(emqx_session).
@@ -167,22 +198,23 @@
 -type shared_sub_state() :: term().
 
 -define(TIMER_PULL, timer_pull).
+-define(TIMER_PUSH, timer_push).
 -define(TIMER_GET_STREAMS, timer_get_streams).
 -define(TIMER_BUMP_LAST_ALIVE_AT, timer_bump_last_alive_at).
 -define(TIMER_RETRY_REPLAY, timer_retry_replay).
 
--type timer() :: ?TIMER_PULL | ?TIMER_GET_STREAMS | ?TIMER_BUMP_LAST_ALIVE_AT | ?TIMER_RETRY_REPLAY.
+-type timer() ::
+    ?TIMER_PULL
+    | ?TIMER_PUSH
+    | ?TIMER_GET_STREAMS
+    | ?TIMER_BUMP_LAST_ALIVE_AT
+    | ?TIMER_RETRY_REPLAY.
+
+-type timer_state() :: reference() | undefined.
 
 %% TODO: Needs configuration?
 -define(TIMEOUT_RETRY_REPLAY, 1000).
 
--record(pending_next, {
-    ref :: reference(),
-    stream_key :: emqx_persistent_session_ds_state:stream_key(),
-    it_begin :: emqx_ds:iterator(),
-    is_replay :: boolean()
-}).
-
 -type session() :: #{
     %% Client ID
     id := id(),
@@ -193,25 +225,27 @@
     %% Shared subscription state:
     shared_sub_s := shared_sub_state(),
     %% Buffer:
-    inflight := emqx_persistent_session_ds_inflight:t(),
-    %% Last fetched stream:
-    %% Used as a continuation point for fair stream scheduling.
-    last_fetched_stream => emqx_persistent_session_ds_state:stream_key(),
+    inflight := emqx_persistent_session_ds_buffer:t(),
+    stream_scheduler_s := emqx_persistent_session_ds_stream_scheduler:t(),
     %% In-progress replay:
     %% List of stream replay states to be added to the inflight buffer.
-    replay => [{_StreamKey, stream_state()}, ...],
+    replay := [{_StreamKey, stream_state()}, ...] | undefined,
     %% Timers:
-    timer() => reference()
+    ?TIMER_PULL := timer_state(),
+    ?TIMER_PUSH := timer_state(),
+    ?TIMER_GET_STREAMS := timer_state(),
+    ?TIMER_BUMP_LAST_ALIVE_AT := timer_state(),
+    ?TIMER_RETRY_REPLAY := timer_state()
 }.
 
--define(IS_REPLAY_ONGOING(SESS), is_map_key(replay, SESS)).
+-define(IS_REPLAY_ONGOING(REPLAY), is_list(REPLAY)).
 
 -record(req_sync, {
     from :: pid(),
     ref :: reference()
 }).
 
--type stream_state() :: #srs{}.
+-type stream_state() :: emqx_persistent_session_ds_stream_scheduler:srs().
 
 -type message() :: emqx_types:message().
 -type timestamp() :: emqx_utils_calendar:epoch_millisecond().
@@ -259,7 +293,7 @@ open(#{clientid := ClientID} = ClientInfo, ConnInfo, MaybeWillMsg, Conf) ->
     ok = emqx_cm:takeover_kick(ClientID),
     case session_open(ClientID, ClientInfo, ConnInfo, MaybeWillMsg) of
         Session0 = #{} ->
-            Session1 = Session0#{props => Conf},
+            Session1 = Session0#{props := Conf},
             Session = do_expire(ClientInfo, Session1),
             {true, ensure_timers(Session), []};
         false ->
@@ -314,13 +348,13 @@ info(upgrade_qos, #{props := Conf}) ->
 info(inflight, #{inflight := Inflight}) ->
     Inflight;
 info(inflight_cnt, #{inflight := Inflight}) ->
-    emqx_persistent_session_ds_inflight:n_inflight(Inflight);
+    emqx_persistent_session_ds_buffer:n_inflight(Inflight);
 info(inflight_max, #{inflight := Inflight}) ->
-    emqx_persistent_session_ds_inflight:receive_maximum(Inflight);
+    emqx_persistent_session_ds_buffer:receive_maximum(Inflight);
 info(retry_interval, #{props := Conf}) ->
     maps:get(retry_interval, Conf);
 info(mqueue_len, #{inflight := Inflight}) ->
-    emqx_persistent_session_ds_inflight:n_buffered(all, Inflight);
+    emqx_persistent_session_ds_buffer:n_buffered(all, Inflight);
 info(mqueue_dropped, _Session) ->
     0;
 %% info(next_pkt_id, #{s := S}) ->
@@ -395,7 +429,7 @@ subscribe(
     case emqx_persistent_session_ds_shared_subs:on_subscribe(TopicFilter, SubOpts, Session) of
         {ok, S0, SharedSubS} ->
             S = emqx_persistent_session_ds_state:commit(S0),
-            {ok, Session#{s => S, shared_sub_s => SharedSubS}};
+            {ok, Session#{s := S, shared_sub_s := SharedSubS}};
         Error = {error, _} ->
             Error
     end;
@@ -407,7 +441,7 @@ subscribe(
     case emqx_persistent_session_ds_subs:on_subscribe(TopicFilter, SubOpts, Session) of
         {ok, S1} ->
             S = emqx_persistent_session_ds_state:commit(S1),
-            {ok, Session#{s => S}};
+            {ok, Session#{s := S}};
         Error = {error, _} ->
             Error
     end.
@@ -419,7 +453,9 @@ subscribe(
     {ok, session(), emqx_types:subopts()} | {error, emqx_types:reason_code()}.
 unsubscribe(
     #share{} = TopicFilter,
-    Session = #{id := SessionId, s := S0, shared_sub_s := SharedSubS0}
+    Session0 = #{
+        id := SessionId, s := S0, shared_sub_s := SharedSubS0, stream_scheduler_s := SchedS0
+    }
 ) ->
     case
         emqx_persistent_session_ds_shared_subs:on_unsubscribe(
@@ -427,21 +463,27 @@ unsubscribe(
         )
     of
         {ok, S1, SharedSubS1, #{id := SubId, subopts := SubOpts}} ->
-            S2 = emqx_persistent_session_ds_stream_scheduler:on_unsubscribe(SubId, S1),
+            {S2, SchedS} = emqx_persistent_session_ds_stream_scheduler:on_unsubscribe(
+                SubId, S1, SchedS0
+            ),
             S = emqx_persistent_session_ds_state:commit(S2),
-            {ok, Session#{s => S, shared_sub_s => SharedSubS1}, SubOpts};
+            Session = Session0#{s := S, shared_sub_s := SharedSubS1, stream_scheduler_s := SchedS},
+            {ok, Session, SubOpts};
         Error = {error, _} ->
             Error
     end;
 unsubscribe(
     TopicFilter,
-    Session = #{id := SessionId, s := S0}
+    Session0 = #{id := SessionId, s := S0, stream_scheduler_s := SchedS0}
 ) ->
     case emqx_persistent_session_ds_subs:on_unsubscribe(SessionId, TopicFilter, S0) of
         {ok, S1, #{id := SubId, subopts := SubOpts}} ->
-            S2 = emqx_persistent_session_ds_stream_scheduler:on_unsubscribe(SubId, S1),
+            {S2, SchedS} = emqx_persistent_session_ds_stream_scheduler:on_unsubscribe(
+                SubId, S1, SchedS0
+            ),
             S = emqx_persistent_session_ds_state:commit(S2),
-            {ok, Session#{s => S}, SubOpts};
+            Session = Session0#{s := S, stream_scheduler_s := SchedS},
+            {ok, Session, SubOpts};
         Error = {error, _} ->
             Error
     end.
@@ -477,7 +519,7 @@ publish(
                 undefined ->
                     Results = emqx_broker:publish(Msg),
                     S = emqx_persistent_session_ds_state:put_awaiting_rel(PacketId, Ts, S0),
-                    {ok, Results, Session#{s => S}};
+                    {ok, Results, Session#{s := S}};
                 _Ts ->
                     {error, ?RC_PACKET_IDENTIFIER_IN_USE}
             end;
@@ -530,7 +572,7 @@ do_expire(ClientInfo, Session = #{s := S0, props := Props}) ->
         S0,
         ExpiredPacketIds
     ),
-    Session#{s => S}.
+    Session#{s := S}.
 
 %%--------------------------------------------------------------------
 %% Client -> Broker: PUBACK
@@ -574,7 +616,7 @@ pubrel(PacketId, Session = #{s := S0}) ->
             {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND};
         _TS ->
             S = emqx_persistent_session_ds_state:del_awaiting_rel(PacketId, S0),
-            {ok, Session#{s => S}}
+            {ok, Session#{s := S}}
     end.
 
 %%--------------------------------------------------------------------
@@ -612,53 +654,63 @@ deliver(ClientInfo, Delivers, Session0) ->
 
 -spec handle_timeout(clientinfo(), _Timeout, session()) ->
     {ok, replies(), session()} | {ok, replies(), timeout(), session()}.
-handle_timeout(ClientInfo, ?TIMER_PULL, Session0) ->
-    {Publishes, Session1} =
-        case ?IS_REPLAY_ONGOING(Session0) of
-            false ->
-                drain_buffer(fetch_new_messages(Session0, ClientInfo));
-            true ->
-                {[], Session0}
-        end,
-    Timeout =
-        case Publishes of
-            [] ->
-                get_config(ClientInfo, [idle_poll_interval]);
-            [_ | _] ->
-                0
-        end,
-    Session = emqx_session:ensure_timer(?TIMER_PULL, Timeout, Session1),
-    {ok, Publishes, Session};
+handle_timeout(_ClientInfo, ?TIMER_PULL, Session0) ->
+    %% Pull circuit loop:
+    ?tp(debug, sessds_pull, #{}),
+    Session1 = Session0#{?TIMER_PULL := undefined},
+    {Publishes, Session} = drain_buffer(Session1),
+    {ok, Publishes, push_now(Session)};
+handle_timeout(ClientInfo, ?TIMER_PUSH, Session0) ->
+    %% Push circuit loop:
+    ?tp(debug, sessds_push, #{}),
+    Session1 = Session0#{?TIMER_PUSH := undefined},
+    #{s := S, stream_scheduler_s := SchedS0, inflight := Inflight, replay := Replay} = Session0,
+    BatchSize = get_config(ClientInfo, [batch_size]),
+    IsFull = emqx_persistent_session_ds_buffer:n_buffered(all, Inflight) >= BatchSize,
+    case ?IS_REPLAY_ONGOING(Replay) orelse IsFull of
+        true ->
+            {ok, [], Session1};
+        false ->
+            Timeout = get_config(ClientInfo, [idle_poll_interval]),
+            PollOpts = #{timeout => Timeout},
+            SchedS = emqx_persistent_session_ds_stream_scheduler:poll(PollOpts, SchedS0, S),
+            {ok, [], Session1#{stream_scheduler_s := SchedS}}
+    end;
 handle_timeout(ClientInfo, ?TIMER_RETRY_REPLAY, Session0) ->
     Session = replay_streams(Session0, ClientInfo),
     {ok, [], Session};
-handle_timeout(ClientInfo, ?TIMER_GET_STREAMS, Session0 = #{s := S0, shared_sub_s := SharedSubS0}) ->
+handle_timeout(
+    ClientInfo,
+    ?TIMER_GET_STREAMS,
+    Session0 = #{s := S0, shared_sub_s := SharedSubS0, stream_scheduler_s := SchedS0}
+) ->
+    ?tp(debug, sessds_renew_streams, #{}),
     %% `gc` and `renew_streams` methods may drop unsubscribed streams.
     %% Shared subscription handler must have a chance to see unsubscribed streams
     %% in the fully replayed state.
     {S1, SharedSubS1} = emqx_persistent_session_ds_shared_subs:pre_renew_streams(S0, SharedSubS0),
     S2 = emqx_persistent_session_ds_subs:gc(S1),
-    S3 = emqx_persistent_session_ds_stream_scheduler:renew_streams(S2),
+    {S3, SchedS} = emqx_persistent_session_ds_stream_scheduler:renew_streams(S2, SchedS0),
     {S, SharedSubS} = emqx_persistent_session_ds_shared_subs:renew_streams(S3, SharedSubS1),
     Interval = get_config(ClientInfo, [renew_streams_interval]),
-    Session = emqx_session:ensure_timer(
+    Session = set_timer(
         ?TIMER_GET_STREAMS,
         Interval,
-        Session0#{s => S, shared_sub_s => SharedSubS}
+        Session0#{s := S, shared_sub_s := SharedSubS, stream_scheduler_s := SchedS}
     ),
-    {ok, [], Session};
+    {ok, [], push_now(Session)};
 handle_timeout(_ClientInfo, ?TIMER_BUMP_LAST_ALIVE_AT, Session0 = #{s := S0}) ->
     S = emqx_persistent_session_ds_state:commit(bump_last_alive(S0)),
-    Session = emqx_session:ensure_timer(
+    Session = set_timer(
         ?TIMER_BUMP_LAST_ALIVE_AT,
         bump_interval(),
-        Session0#{s => S}
+        Session0#{s := S}
     ),
     {ok, [], Session};
 handle_timeout(_ClientInfo, #req_sync{from = From, ref = Ref}, Session = #{s := S0}) ->
     S = emqx_persistent_session_ds_state:commit(S0),
     From ! Ref,
-    {ok, [], Session#{s => S}};
+    {ok, [], Session#{s := S}};
 handle_timeout(ClientInfo, expire_awaiting_rel, Session) ->
     expire(ClientInfo, Session);
 handle_timeout(_ClientInfo, Timeout, Session) ->
@@ -674,7 +726,9 @@ handle_info(
     ?shared_sub_message(Msg), Session = #{s := S0, shared_sub_s := SharedSubS0}, _ClientInfo
 ) ->
     {S, SharedSubS} = emqx_persistent_session_ds_shared_subs:on_info(S0, SharedSubS0, Msg),
-    Session#{s => S, shared_sub_s => SharedSubS};
+    Session#{s := S, shared_sub_s := SharedSubS};
+handle_info(AsyncReply = #poll_reply{}, Session, ClientInfo) ->
+    push_now(handle_ds_reply(AsyncReply, Session, ClientInfo));
 handle_info(Msg, Session, _ClientInfo) ->
     ?SLOG(warning, #{msg => emqx_session_ds_unknown_message, message => Msg}),
     Session.
@@ -698,7 +752,7 @@ bump_last_alive(S0) ->
     {ok, replies(), session()}.
 replay(ClientInfo, [], Session0 = #{s := S0}) ->
     Streams = emqx_persistent_session_ds_stream_scheduler:find_replay_streams(S0),
-    Session = replay_streams(Session0#{replay => Streams}, ClientInfo),
+    Session = replay_streams(Session0#{replay := Streams}, ClientInfo),
     {ok, [], Session}.
 
 replay_streams(Session0 = #{replay := [{StreamKey, Srs0} | Rest]}, ClientInfo) ->
@@ -714,22 +768,24 @@ replay_streams(Session0 = #{replay := [{StreamKey, Srs0} | Rest]}, ClientInfo) -
                 class => recoverable,
                 retry_in_ms => RetryTimeout
             }),
-            emqx_session:ensure_timer(?TIMER_RETRY_REPLAY, RetryTimeout, Session0);
+            set_timer(?TIMER_RETRY_REPLAY, RetryTimeout, Session0);
         {error, unrecoverable, Reason} ->
             Session1 = skip_batch(StreamKey, Srs0, Session0, ClientInfo, Reason),
             replay_streams(Session1#{replay := Rest}, ClientInfo)
     end;
-replay_streams(Session0 = #{replay := []}, _ClientInfo) ->
-    Session = maps:remove(replay, Session0),
+replay_streams(Session = #{replay := []}, _ClientInfo) ->
     %% Note: we filled the buffer with the historical messages, and
     %% from now on we'll rely on the normal inflight/flow control
     %% mechanisms to replay them:
-    pull_now(Session).
+    pull_now(Session#{replay := undefined}).
 
 -spec replay_batch(
-    emqx_persistent_session_ds_state:stream_key(), stream_state(), session(), clientinfo()
+    emqx_persistent_session_ds_stream_scheduler:stream_key(),
+    stream_state(),
+    session(),
+    clientinfo()
 ) ->
-    session() | emqx_ds:error(_).
+    {ok, stream_state(), session()} | emqx_ds:error(_).
 replay_batch(StreamKey, Srs0, Session0, ClientInfo) ->
     #srs{it_begin = ItBegin, batch_size = BatchSize} = Srs0,
     FetchResult = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, BatchSize),
@@ -794,7 +850,7 @@ disconnect(Session = #{id := Id, s := S0, shared_sub_s := SharedSubS0}, ConnInfo
         end,
     {S4, SharedSubS} = emqx_persistent_session_ds_shared_subs:on_disconnect(S3, SharedSubS0),
     S = emqx_persistent_session_ds_state:commit(S4),
-    {shutdown, Session#{s => S, shared_sub_s => SharedSubS}}.
+    {shutdown, Session#{s := S, shared_sub_s := SharedSubS}}.
 
 -spec terminate(Reason :: term(), session()) -> ok.
 terminate(_Reason, Session = #{id := Id, s := S}) ->
@@ -913,15 +969,23 @@ session_open(
                         S6, shared_sub_opts(SessionId)
                     ),
                     S = emqx_persistent_session_ds_state:commit(S7),
-                    Inflight = emqx_persistent_session_ds_inflight:new(
+                    Inflight = emqx_persistent_session_ds_buffer:new(
                         receive_maximum(NewConnInfo)
                     ),
+                    SSS = emqx_persistent_session_ds_stream_scheduler:init(S),
                     #{
                         id => SessionId,
                         s => S,
                         shared_sub_s => SharedSubS,
                         inflight => Inflight,
-                        props => #{}
+                        props => #{},
+                        stream_scheduler_s => SSS,
+                        replay => undefined,
+                        ?TIMER_PULL => undefined,
+                        ?TIMER_PUSH => undefined,
+                        ?TIMER_GET_STREAMS => undefined,
+                        ?TIMER_BUMP_LAST_ALIVE_AT => undefined,
+                        ?TIMER_RETRY_REPLAY => undefined
                     }
             end;
         undefined ->
@@ -969,7 +1033,14 @@ session_ensure_new(
         props => Conf,
         s => S,
         shared_sub_s => emqx_persistent_session_ds_shared_subs:new(shared_sub_opts(Id)),
-        inflight => emqx_persistent_session_ds_inflight:new(receive_maximum(ConnInfo))
+        inflight => emqx_persistent_session_ds_buffer:new(receive_maximum(ConnInfo)),
+        stream_scheduler_s => emqx_persistent_session_ds_stream_scheduler:init(S),
+        replay => undefined,
+        ?TIMER_PULL => undefined,
+        ?TIMER_PUSH => undefined,
+        ?TIMER_GET_STREAMS => undefined,
+        ?TIMER_BUMP_LAST_ALIVE_AT => undefined,
+        ?TIMER_RETRY_REPLAY => undefined
     }.
 
 %% @doc Called when a client reconnects with `clean session=true' or
@@ -1017,102 +1088,58 @@ do_ensure_all_iterators_closed(_DSSessionID) ->
 %% Normal replay:
 %%--------------------------------------------------------------------
 
-fetch_new_messages(Session0 = #{s := S0, shared_sub_s := SharedSubS0}, ClientInfo) ->
-    {S1, SharedSubS1} = emqx_persistent_session_ds_shared_subs:on_streams_replay(S0, SharedSubS0),
-    Session1 = Session0#{s => S1, shared_sub_s => SharedSubS1},
-    LFS = maps:get(last_fetched_stream, Session1, beginning),
-    ItStream = emqx_persistent_session_ds_stream_scheduler:iter_next_streams(LFS, S1),
-    BatchSize = get_config(ClientInfo, [batch_size]),
-    Session2 = fetch_new_messages(ItStream, BatchSize, Session1, ClientInfo),
-    Session2#{shared_sub_s => SharedSubS1}.
-
-fetch_new_messages(ItStream0, BatchSize, Session0, ClientInfo) ->
-    #{inflight := Inflight} = Session0,
-    case emqx_persistent_session_ds_inflight:n_buffered(all, Inflight) >= BatchSize of
-        true ->
-            %% Buffer is full:
-            Session0;
-        false ->
-            case emqx_persistent_session_ds_stream_scheduler:next_stream(ItStream0) of
-                {StreamKey, Srs, ItStream} ->
-                    Session1 = new_batch(StreamKey, Srs, BatchSize, Session0, ClientInfo),
-                    Session = Session1#{last_fetched_stream => StreamKey},
-                    fetch_new_messages(ItStream, BatchSize, Session, ClientInfo);
-                none ->
-                    Session0
+push_now(Session) ->
+    ensure_timer(?TIMER_PUSH, 0, Session).
+
+handle_ds_reply(AsyncReply, Session0 = #{s := S0, stream_scheduler_s := SchedS0}, ClientInfo) ->
+    case emqx_persistent_session_ds_stream_scheduler:on_ds_reply(AsyncReply, S0, SchedS0) of
+        {undefined, SchedS} ->
+            Session0#{stream_scheduler_s := SchedS};
+        {{StreamKey, ItBegin, FetchResult}, SchedS} ->
+            Session1 = Session0#{stream_scheduler_s := SchedS},
+            case enqueue_batch(false, Session1, ClientInfo, StreamKey, ItBegin, FetchResult) of
+                {ignore, _, Session} ->
+                    Session;
+                {ok, Srs, Session = #{s := S1}} ->
+                    S2 = emqx_persistent_session_ds_state:put_seqno(
+                        ?next(?QOS_1),
+                        Srs#srs.last_seqno_qos1,
+                        S1
+                    ),
+                    S3 = emqx_persistent_session_ds_state:put_seqno(
+                        ?next(?QOS_2),
+                        Srs#srs.last_seqno_qos2,
+                        S2
+                    ),
+                    S = emqx_persistent_session_ds_state:put_stream(StreamKey, Srs, S3),
+                    pull_now(Session#{s := S});
+                {{error, recoverable, Reason}, _Srs, Session} ->
+                    ?SLOG(debug, #{
+                        msg => "failed_to_fetch_batch",
+                        stream => StreamKey,
+                        reason => Reason,
+                        class => recoverable
+                    }),
+                    Session;
+                {{error, unrecoverable, Reason}, Srs, Session} ->
+                    skip_batch(StreamKey, Srs, Session, ClientInfo, Reason)
             end
     end.
 
-new_batch(StreamKey, Srs0, BatchSize, Session0 = #{s := S0}, ClientInfo) ->
-    Pending = fetch(false, StreamKey, Srs0, BatchSize),
-    case enqueue_batch(Session0, ClientInfo, Pending, receive_pending(Pending)) of
-        {ok, Srs, Session} ->
-            S1 = emqx_persistent_session_ds_state:put_seqno(
-                ?next(?QOS_1),
-                Srs#srs.last_seqno_qos1,
-                S0
-            ),
-            S2 = emqx_persistent_session_ds_state:put_seqno(
-                ?next(?QOS_2),
-                Srs#srs.last_seqno_qos2,
-                S1
-            ),
-            S = emqx_persistent_session_ds_state:put_stream(StreamKey, Srs, S2),
-            Session#{s => S};
-        {error, recoverable, Reason} ->
-            ?SLOG(debug, #{
-                msg => "failed_to_fetch_batch",
-                stream => StreamKey,
-                reason => Reason,
-                class => recoverable
-            }),
-            Session0;
-        {error, unrecoverable, Reason} ->
-            skip_batch(StreamKey, Srs0, Session0, ClientInfo, Reason)
-    end.
-
 %%--------------------------------------------------------------------
 %% Generic functions for fetching messages (during replay or normal
 %% operation):
 %% --------------------------------------------------------------------
 
-fetch(IsReplay, StreamKey, Srs0, DefaultBatchSize) ->
-    case IsReplay of
-        true ->
-            %% When we do replay we must use the same starting point
-            %% and batch size as initially:
-            BatchSize = Srs0#srs.batch_size,
-            ItBegin = Srs0#srs.it_begin;
-        false ->
-            BatchSize = DefaultBatchSize,
-            ItBegin = Srs0#srs.it_end
-    end,
-    {ok, Ref} = emqx_ds:anext(?PERSISTENT_MESSAGE_DB, ItBegin, BatchSize),
-    #pending_next{
-        ref = Ref,
-        is_replay = IsReplay,
-        it_begin = ItBegin,
-        stream_key = StreamKey
-    }.
-
-receive_pending(#pending_next{ref = Ref}) ->
-    receive
-        #ds_async_result{ref = Ref, data = Data} -> Data
-    end.
-
-enqueue_batch(Session, ClientInfo, Pending, FetchResult) ->
-    #pending_next{is_replay = IsReplay, stream_key = StreamKey, it_begin = ItBegin} = Pending,
-    enqueue_batch(IsReplay, Session, ClientInfo, StreamKey, ItBegin, FetchResult).
-
 -spec enqueue_batch(
     boolean(),
     session(),
     clientinfo(),
-    emqx_persistent_session_ds_state:stream_key(),
+    emqx_persistent_session_ds_stream_scheduler:stream_key(),
     emqx_ds:iterator(),
     emqx_ds:next_result()
 ) ->
-    {ok | emqx_ds:error(), #srs{}, session()}
+    {ok | emqx_ds:error(_), stream_state(), session()}
     | {ignore, undefined, session()}.
 enqueue_batch(IsReplay, Session = #{s := S}, ClientInfo, StreamKey, ItBegin, FetchResult) ->
     case emqx_persistent_session_ds_state:get_stream(StreamKey, S) of
@@ -1132,38 +1159,45 @@ enqueue_batch(IsReplay, Session = #{s := S}, ClientInfo, StreamKey, ItBegin, Fet
             }),
             {ignore, undefined, Session};
         Srs ->
-            do_enqueue_batch(IsReplay, Session, ClientInfo, Srs, ItBegin, FetchResult)
+            do_enqueue_batch(IsReplay, Session, ClientInfo, StreamKey, Srs, ItBegin, FetchResult)
     end.
 
-do_enqueue_batch(IsReplay, Session, ClientInfo, Srs0, ItBegin, FetchResult) ->
-    #{s := S0, inflight := Inflight0} = Session,
+do_enqueue_batch(IsReplay, Session, ClientInfo, StreamKey, Srs0, ItBegin, FetchResult) ->
+    #{s := S0, inflight := Inflight0, stream_scheduler_s := SchedS0} = Session,
     #srs{sub_state_id = SubStateId} = Srs0,
+    case IsReplay of
+        false ->
+            %% Normally we assign a new set of sequence
+            %% numbers to messages in the batch:
+            FirstSeqnoQos1 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_1), S0),
+            FirstSeqnoQos2 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_2), S0);
+        true ->
+            %% During replay we reuse the original sequence
+            %% numbers:
+            #srs{
+                first_seqno_qos1 = FirstSeqnoQos1,
+                first_seqno_qos2 = FirstSeqnoQos2
+            } = Srs0
+    end,
     case FetchResult of
         {error, _, _} = Error ->
             {Error, Srs0, Session};
         {ok, end_of_stream} ->
             %% No new messages; just update the end iterator:
             Srs = Srs0#srs{
+                first_seqno_qos1 = FirstSeqnoQos1,
+                first_seqno_qos2 = FirstSeqnoQos2,
+                last_seqno_qos1 = FirstSeqnoQos1,
+                last_seqno_qos2 = FirstSeqnoQos2,
                 it_begin = ItBegin,
                 it_end = end_of_stream,
                 batch_size = 0
             },
-            {ok, Srs, Session};
+            SchedS = emqx_persistent_session_ds_stream_scheduler:on_enqueue(
+                IsReplay, StreamKey, Srs, S0, SchedS0
+            ),
+            {ok, Srs, Session#{stream_scheduler_s := SchedS}};
         {ok, ItEnd, Messages} ->
-            case IsReplay of
-                false ->
-                    %% Normally we assign a new set of sequence
-                    %% numbers to messages in the batch:
-                    FirstSeqnoQos1 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_1), S0),
-                    FirstSeqnoQos2 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_2), S0);
-                true ->
-                    %% During replay we reuse the original sequence
-                    %% numbers:
-                    #srs{
-                        first_seqno_qos1 = FirstSeqnoQos1,
-                        first_seqno_qos2 = FirstSeqnoQos2
-                    } = Srs0
-            end,
             SubState = emqx_persistent_session_ds_state:get_subscription_state(SubStateId, S0),
             {Inflight, LastSeqnoQos1, LastSeqnoQos2} = process_batch(
                 IsReplay,
@@ -1184,7 +1218,10 @@ do_enqueue_batch(IsReplay, Session, ClientInfo, Srs0, ItBegin, FetchResult) ->
                 last_seqno_qos1 = LastSeqnoQos1,
                 last_seqno_qos2 = LastSeqnoQos2
             },
-            {ok, Srs, Session#{inflight := Inflight}}
+            SchedS = emqx_persistent_session_ds_stream_scheduler:on_enqueue(
+                IsReplay, StreamKey, Srs, S0, SchedS0
+            ),
+            {ok, Srs, Session#{inflight := Inflight, stream_scheduler_s := SchedS}}
     end.
 
 %% key_of_iter(#{3 := #{3 := #{5 := K}}}) ->
@@ -1235,7 +1272,7 @@ process_batch(
                         %% We ignore QoS 0 messages during replay:
                         Acc;
                     ?QOS_0 ->
-                        emqx_persistent_session_ds_inflight:push({undefined, Msg}, Acc);
+                        emqx_persistent_session_ds_buffer:push({undefined, Msg}, Acc);
                     ?QOS_1 when SeqNoQos1 =< Comm1 ->
                         %% QoS1 message has been acked by the client, ignore:
                         Acc;
@@ -1243,15 +1280,15 @@ process_batch(
                         %% QoS1 message has been sent but not
                         %% acked. Retransmit:
                         Msg1 = emqx_message:set_flag(dup, true, Msg),
-                        emqx_persistent_session_ds_inflight:push({SeqNoQos1, Msg1}, Acc);
+                        emqx_persistent_session_ds_buffer:push({SeqNoQos1, Msg1}, Acc);
                     ?QOS_1 ->
-                        emqx_persistent_session_ds_inflight:push({SeqNoQos1, Msg}, Acc);
+                        emqx_persistent_session_ds_buffer:push({SeqNoQos1, Msg}, Acc);
                     ?QOS_2 when SeqNoQos2 =< Comm2 ->
                         %% QoS2 message has been PUBCOMP'ed by the client, ignore:
                         Acc;
                     ?QOS_2 when SeqNoQos2 =< Rec ->
                         %% QoS2 message has been PUBREC'ed by the client, resend PUBREL:
-                        emqx_persistent_session_ds_inflight:push({pubrel, SeqNoQos2}, Acc);
+                        emqx_persistent_session_ds_buffer:push({pubrel, SeqNoQos2}, Acc);
                     ?QOS_2 when SeqNoQos2 =< Dup2 ->
                         %% QoS2 message has been sent, but we haven't received PUBREC.
                         %%
@@ -1259,9 +1296,9 @@ process_batch(
                         %% DUP flag is never set for QoS2 messages? We
                         %% do so for mem sessions, though.
                         Msg1 = emqx_message:set_flag(dup, true, Msg),
-                        emqx_persistent_session_ds_inflight:push({SeqNoQos2, Msg1}, Acc);
+                        emqx_persistent_session_ds_buffer:push({SeqNoQos2, Msg1}, Acc);
                     ?QOS_2 ->
-                        emqx_persistent_session_ds_inflight:push({SeqNoQos2, Msg}, Acc)
+                        emqx_persistent_session_ds_buffer:push({SeqNoQos2, Msg}, Acc)
                 end,
                 SeqNoQos1,
                 SeqNoQos2
@@ -1293,17 +1330,17 @@ enqueue_transient(
     case Qos of
         ?QOS_0 ->
             S = S0,
-            Inflight = emqx_persistent_session_ds_inflight:push({undefined, Msg}, Inflight0);
+            Inflight = emqx_persistent_session_ds_buffer:push({undefined, Msg}, Inflight0);
         QoS when QoS =:= ?QOS_1; QoS =:= ?QOS_2 ->
             SeqNo = inc_seqno(
                 QoS, emqx_persistent_session_ds_state:get_seqno(?next(QoS), S0)
             ),
             S = emqx_persistent_session_ds_state:put_seqno(?next(QoS), SeqNo, S0),
-            Inflight = emqx_persistent_session_ds_inflight:push({SeqNo, Msg}, Inflight0)
+            Inflight = emqx_persistent_session_ds_buffer:push({SeqNo, Msg}, Inflight0)
     end,
     Session#{
-        inflight => Inflight,
-        s => S
+        inflight := Inflight,
+        s := S
     }.
 
 %%--------------------------------------------------------------------
@@ -1312,10 +1349,10 @@ enqueue_transient(
 
 drain_buffer(Session = #{inflight := Inflight0, s := S0}) ->
     {Publishes, Inflight, S} = do_drain_buffer(Inflight0, S0, []),
-    {Publishes, Session#{inflight => Inflight, s := S}}.
+    {Publishes, Session#{inflight := Inflight, s := S}}.
 
 do_drain_buffer(Inflight0, S0, Acc) ->
-    case emqx_persistent_session_ds_inflight:pop(Inflight0) of
+    case emqx_persistent_session_ds_buffer:pop(Inflight0) of
         undefined ->
             {lists:reverse(Acc), Inflight0, S0};
         {{pubrel, SeqNo}, Inflight} ->
@@ -1338,13 +1375,20 @@ do_drain_buffer(Inflight0, S0, Acc) ->
 %% effects. Add `CBM:init' callback to the session behavior?
 -spec ensure_timers(session()) -> session().
 ensure_timers(Session0) ->
-    Session1 = emqx_session:ensure_timer(?TIMER_PULL, 100, Session0),
-    Session2 = emqx_session:ensure_timer(?TIMER_GET_STREAMS, 100, Session1),
-    emqx_session:ensure_timer(?TIMER_BUMP_LAST_ALIVE_AT, 100, Session2).
+    Session1 = set_timer(?TIMER_GET_STREAMS, 100, Session0),
+    set_timer(?TIMER_BUMP_LAST_ALIVE_AT, 100, Session1).
 
+%% This function triggers sending buffered packets to the client
+%% (provided there is something to send and the number of in-flight
+%% packets is less than `Recieve-Maximum'). Normally, pull is
+%% triggered when:
+%%
+%% - New messages (durable or transient) are enqueued
+%%
+%% - When the client releases a packet ID (via PUBACK or PUBCOMP)
 -spec pull_now(session()) -> session().
 pull_now(Session) ->
-    emqx_session:reset_timer(?TIMER_PULL, 0, Session).
+    ensure_timer(?TIMER_PULL, 0, Session).
 
 -spec receive_maximum(conninfo()) -> pos_integer().
 receive_maximum(ConnInfo) ->
@@ -1411,26 +1455,44 @@ maybe_set_offline_info(S, Id) ->
 
 -spec update_seqno(puback | pubrec | pubcomp, emqx_types:packet_id(), session()) ->
     {ok, emqx_types:message(), session()} | {error, _}.
-update_seqno(Track, PacketId, Session = #{id := SessionId, s := S, inflight := Inflight0}) ->
+update_seqno(
+    Track,
+    PacketId,
+    Session = #{id := SessionId, s := S, inflight := Inflight0, stream_scheduler_s := SchedS0}
+) ->
     SeqNo = packet_id_to_seqno(PacketId, S),
     case Track of
         puback ->
             SeqNoKey = ?committed(?QOS_1),
-            Result = emqx_persistent_session_ds_inflight:puback(SeqNo, Inflight0);
+            Result = emqx_persistent_session_ds_buffer:puback(SeqNo, Inflight0);
         pubrec ->
             SeqNoKey = ?rec,
-            Result = emqx_persistent_session_ds_inflight:pubrec(SeqNo, Inflight0);
+            Result = emqx_persistent_session_ds_buffer:pubrec(SeqNo, Inflight0);
         pubcomp ->
             SeqNoKey = ?committed(?QOS_2),
-            Result = emqx_persistent_session_ds_inflight:pubcomp(SeqNo, Inflight0)
+            Result = emqx_persistent_session_ds_buffer:pubcomp(SeqNo, Inflight0)
     end,
     case Result of
         {ok, Inflight} ->
             %% TODO: we pass a bogus message into the hook:
             Msg = emqx_message:make(SessionId, <<>>, <<>>),
+            SchedS =
+                case Track of
+                    puback ->
+                        emqx_persistent_session_ds_stream_scheduler:on_seqno_release(
+                            ?QOS_1, SeqNo, SchedS0
+                        );
+                    pubcomp ->
+                        emqx_persistent_session_ds_stream_scheduler:on_seqno_release(
+                            ?QOS_2, SeqNo, SchedS0
+                        );
+                    _ ->
+                        SchedS0
+                end,
             {ok, Msg, Session#{
-                s => emqx_persistent_session_ds_state:put_seqno(SeqNoKey, SeqNo, S),
-                inflight => Inflight
+                s := emqx_persistent_session_ds_state:put_seqno(SeqNoKey, SeqNo, S),
+                inflight := Inflight,
+                stream_scheduler_s := SchedS
             }};
         {error, Expected} ->
             ?SLOG(warning, #{
@@ -1539,6 +1601,20 @@ maybe_set_will_message_timer(#{id := SessionId, s := S}) ->
             ok
     end.
 
+-spec ensure_timer(timer(), non_neg_integer(), session()) -> session().
+ensure_timer(Timer, Time, Session) ->
+    case Session of
+        #{Timer := undefined} ->
+            set_timer(Timer, Time, Session);
+        #{Timer := TRef} when is_reference(TRef) ->
+            Session
+    end.
+
+-spec set_timer(timer(), non_neg_integer(), session()) -> session().
+set_timer(Timer, Time, Session) ->
+    TRef = emqx_utils:start_timer(Time, {emqx_session, Timer}),
+    Session#{Timer := TRef}.
+
 %%--------------------------------------------------------------------
 %% Tests
 %%--------------------------------------------------------------------

+ 27 - 27
apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_inflight.erl

@@ -13,7 +13,7 @@
 %% See the License for the specific language governing permissions and
 %% limitations under the License.
 %%--------------------------------------------------------------------
--module(emqx_persistent_session_ds_inflight).
+-module(emqx_persistent_session_ds_buffer).
 
 %% API:
 -export([
@@ -49,7 +49,7 @@
     {emqx_persistent_session_ds:seqno() | undefined, emqx_types:message()}
     | {pubrel, emqx_persistent_session_ds:seqno()}.
 
--record(inflight, {
+-record(ds_buffer, {
     receive_maximum :: pos_integer(),
     %% Main queue:
     queue :: queue:queue(payload()),
@@ -64,7 +64,7 @@
     n_qos2 = 0 :: non_neg_integer()
 }).
 
--type t() :: #inflight{}.
+-type t() :: #ds_buffer{}.
 
 %%================================================================================
 %% API functions
@@ -72,7 +72,7 @@
 
 -spec new(non_neg_integer()) -> t().
 new(ReceiveMaximum) when ReceiveMaximum > 0 ->
-    #inflight{
+    #ds_buffer{
         receive_maximum = ReceiveMaximum,
         queue = queue:new(),
         puback_queue = iqueue_new(),
@@ -81,27 +81,27 @@ new(ReceiveMaximum) when ReceiveMaximum > 0 ->
     }.
 
 -spec receive_maximum(t()) -> pos_integer().
-receive_maximum(#inflight{receive_maximum = ReceiveMaximum}) ->
+receive_maximum(#ds_buffer{receive_maximum = ReceiveMaximum}) ->
     ReceiveMaximum.
 
 -spec push(payload(), t()) -> t().
-push(Payload = {pubrel, _SeqNo}, Rec = #inflight{queue = Q}) ->
-    Rec#inflight{queue = queue:in(Payload, Q)};
+push(Payload = {pubrel, _SeqNo}, Rec = #ds_buffer{queue = Q}) ->
+    Rec#ds_buffer{queue = queue:in(Payload, Q)};
 push(Payload = {_, Msg}, Rec) ->
-    #inflight{queue = Q0, n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2} = Rec,
+    #ds_buffer{queue = Q0, n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2} = Rec,
     Q = queue:in(Payload, Q0),
     case Msg#message.qos of
         ?QOS_0 ->
-            Rec#inflight{queue = Q, n_qos0 = NQos0 + 1};
+            Rec#ds_buffer{queue = Q, n_qos0 = NQos0 + 1};
         ?QOS_1 ->
-            Rec#inflight{queue = Q, n_qos1 = NQos1 + 1};
+            Rec#ds_buffer{queue = Q, n_qos1 = NQos1 + 1};
         ?QOS_2 ->
-            Rec#inflight{queue = Q, n_qos2 = NQos2 + 1}
+            Rec#ds_buffer{queue = Q, n_qos2 = NQos2 + 1}
     end.
 
 -spec pop(t()) -> {payload(), t()} | undefined.
 pop(Rec0) ->
-    #inflight{
+    #ds_buffer{
         receive_maximum = ReceiveMaximum,
         n_inflight = NInflight,
         queue = Q0,
@@ -117,20 +117,20 @@ pop(Rec0) ->
             Rec =
                 case Payload of
                     {pubrel, _} ->
-                        Rec0#inflight{queue = Q};
+                        Rec0#ds_buffer{queue = Q};
                     {SeqNo, #message{qos = Qos}} ->
                         case Qos of
                             ?QOS_0 ->
-                                Rec0#inflight{queue = Q, n_qos0 = NQos0 - 1};
+                                Rec0#ds_buffer{queue = Q, n_qos0 = NQos0 - 1};
                             ?QOS_1 ->
-                                Rec0#inflight{
+                                Rec0#ds_buffer{
                                     queue = Q,
                                     n_qos1 = NQos1 - 1,
                                     n_inflight = NInflight + 1,
                                     puback_queue = ipush(SeqNo, QAck)
                                 };
                             ?QOS_2 ->
-                                Rec0#inflight{
+                                Rec0#ds_buffer{
                                     queue = Q,
                                     n_qos2 = NQos2 - 1,
                                     n_inflight = NInflight + 1,
@@ -145,25 +145,25 @@ pop(Rec0) ->
     end.
 
 -spec n_buffered(?QOS_0..?QOS_2 | all, t()) -> non_neg_integer().
-n_buffered(?QOS_0, #inflight{n_qos0 = NQos0}) ->
+n_buffered(?QOS_0, #ds_buffer{n_qos0 = NQos0}) ->
     NQos0;
-n_buffered(?QOS_1, #inflight{n_qos1 = NQos1}) ->
+n_buffered(?QOS_1, #ds_buffer{n_qos1 = NQos1}) ->
     NQos1;
-n_buffered(?QOS_2, #inflight{n_qos2 = NQos2}) ->
+n_buffered(?QOS_2, #ds_buffer{n_qos2 = NQos2}) ->
     NQos2;
-n_buffered(all, #inflight{n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2}) ->
+n_buffered(all, #ds_buffer{n_qos0 = NQos0, n_qos1 = NQos1, n_qos2 = NQos2}) ->
     NQos0 + NQos1 + NQos2.
 
 -spec n_inflight(t()) -> non_neg_integer().
-n_inflight(#inflight{n_inflight = NInflight}) ->
+n_inflight(#ds_buffer{n_inflight = NInflight}) ->
     NInflight.
 
 -spec puback(emqx_persistent_session_ds:seqno(), t()) -> {ok, t()} | {error, Expected} when
     Expected :: emqx_persistent_session_ds:seqno() | undefined.
-puback(SeqNo, Rec = #inflight{puback_queue = Q0, n_inflight = N}) ->
+puback(SeqNo, Rec = #ds_buffer{puback_queue = Q0, n_inflight = N}) ->
     case ipop(Q0) of
         {{value, SeqNo}, Q} ->
-            {ok, Rec#inflight{
+            {ok, Rec#ds_buffer{
                 puback_queue = Q,
                 n_inflight = max(0, N - 1)
             }};
@@ -175,10 +175,10 @@ puback(SeqNo, Rec = #inflight{puback_queue = Q0, n_inflight = N}) ->
 
 -spec pubcomp(emqx_persistent_session_ds:seqno(), t()) -> {ok, t()} | {error, Expected} when
     Expected :: emqx_persistent_session_ds:seqno() | undefined.
-pubcomp(SeqNo, Rec = #inflight{pubcomp_queue = Q0, n_inflight = N}) ->
+pubcomp(SeqNo, Rec = #ds_buffer{pubcomp_queue = Q0, n_inflight = N}) ->
     case ipop(Q0) of
         {{value, SeqNo}, Q} ->
-            {ok, Rec#inflight{
+            {ok, Rec#ds_buffer{
                 pubcomp_queue = Q,
                 n_inflight = max(0, N - 1)
             }};
@@ -192,10 +192,10 @@ pubcomp(SeqNo, Rec = #inflight{pubcomp_queue = Q0, n_inflight = N}) ->
 %% https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Flow_Control
 -spec pubrec(emqx_persistent_session_ds:seqno(), t()) -> {ok, t()} | {error, Expected} when
     Expected :: emqx_persistent_session_ds:seqno() | undefined.
-pubrec(SeqNo, Rec = #inflight{pubrec_queue = Q0}) ->
+pubrec(SeqNo, Rec = #ds_buffer{pubrec_queue = Q0}) ->
     case ipop(Q0) of
         {{value, SeqNo}, Q} ->
-            {ok, Rec#inflight{
+            {ok, Rec#ds_buffer{
                 pubrec_queue = Q
             }};
         {{value, Expected}, _} ->

+ 29 - 83
apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_state.erl

@@ -51,7 +51,7 @@
 -export([get_peername/1, set_peername/2]).
 -export([get_protocol/1, set_protocol/2]).
 -export([new_id/1]).
--export([get_stream/2, put_stream/3, del_stream/2, fold_streams/3, iter_streams/2, n_streams/1]).
+-export([get_stream/2, put_stream/3, del_stream/2, fold_streams/3, n_streams/1]).
 -export([get_seqno/2, put_seqno/3]).
 -export([get_rank/2, put_rank/3, del_rank/2, fold_ranks/3]).
 -export([
@@ -78,16 +78,12 @@
     n_awaiting_rel/1
 ]).
 
--export([iter_next/1]).
-
 -export([make_session_iterator/0, session_iterator_next/2]).
 
 -export_type([
     t/0,
     metadata/0,
-    iter/2,
     seqno_type/0,
-    stream_key/0,
     rank_key/0,
     session_iterator/0,
     protocol/0
@@ -120,9 +116,6 @@
 -opaque iter(K, V) :: #{
     it := gb_trees:iter(internal_key(K), V), inv_key_mapping := #{internal_key(K) => K}
 }.
-%% ELSE ifdef(STORE_STATE_IN_DS).
--else.
--opaque iter(K, V) :: gb_trees:iter(K, V).
 %% END ifdef(STORE_STATE_IN_DS).
 -endif.
 
@@ -840,18 +833,20 @@ del_subscription_state(SStateId, Rec) ->
 
 %%
 
--type stream_key() :: {emqx_persistent_session_ds:subscription_id(), _StreamId}.
-
--spec get_stream(stream_key(), t()) ->
+-spec get_stream(emqx_persistent_session_ds_stream_scheduler:stream_key(), t()) ->
     emqx_persistent_session_ds:stream_state() | undefined.
 get_stream(Key, Rec) ->
     gen_get(?streams, Key, Rec).
 
--spec put_stream(stream_key(), emqx_persistent_session_ds:stream_state(), t()) -> t().
+-spec put_stream(
+    emqx_persistent_session_ds_stream_scheduler:stream_key(),
+    emqx_persistent_session_ds:stream_state(),
+    t()
+) -> t().
 put_stream(Key, Val, Rec) ->
     gen_put(?streams, Key, Val, Rec).
 
--spec del_stream(stream_key(), t()) -> t().
+-spec del_stream(emqx_persistent_session_ds_stream_scheduler:stream_key(), t()) -> t().
 del_stream(Key, Rec) ->
     gen_del(?streams, Key, Rec).
 
@@ -859,18 +854,11 @@ del_stream(Key, Rec) ->
 fold_streams(Fun, Acc, Rec) ->
     gen_fold(?streams, Fun, Acc, Rec).
 
+-ifdef(STORE_STATE_IN_DS).
 -spec iter_streams(_StartAfter :: stream_key() | beginning, t()) ->
     iter(stream_key(), emqx_persistent_session_ds:stream_state()).
--ifdef(STORE_STATE_IN_DS).
 iter_streams(After, Rec) ->
     gen_iter_after(?streams, After, Rec).
-%% ELSE ifdef(STORE_STATE_IN_DS).
--else.
-iter_streams(After, Rec) ->
-    %% NOTE
-    %% No special handling for `beginning', as it always compares less
-    %% than any `stream_key()'.
-    gen_iter_after(?streams, After, Rec).
 %% END ifdef(STORE_STATE_IN_DS).
 -endif.
 
@@ -932,8 +920,8 @@ n_awaiting_rel(Rec) ->
 
 %%
 
--spec iter_next(iter(K, V)) -> {K, V, iter(K, V)} | none.
 -ifdef(STORE_STATE_IN_DS).
+-spec iter_next(iter(K, V)) -> {K, V, iter(K, V)} | none.
 iter_next(#{it := InnerIt0, inv_key_mapping := InvKeyMapping} = It0) ->
     case gen_iter_next(InnerIt0) of
         none ->
@@ -942,10 +930,6 @@ iter_next(#{it := InnerIt0, inv_key_mapping := InvKeyMapping} = It0) ->
             Key = maps:get(IntKey, InvKeyMapping),
             {Key, Value, It0#{it := InnerIt}}
     end.
-%% ELSE ifdef(STORE_STATE_IN_DS).
--else.
-iter_next(It0) ->
-    gen_iter_next(It0).
 %% END ifdef(STORE_STATE_IN_DS).
 -endif.
 
@@ -1120,14 +1104,6 @@ gen_size(Field, Rec) ->
     check_sequence(Rec),
     pmap_size(maps:get(Field, Rec)).
 
-gen_iter_after(Field, After, Rec) ->
-    check_sequence(Rec),
-    pmap_iter_after(After, maps:get(Field, Rec)).
-
-gen_iter_next(It) ->
-    %% NOTE: Currently, gbt iterators is the only type of iterators.
-    gbt_iter_next(It).
-
 -spec update_pmaps(fun((pmap(_K, _V) | undefined, atom()) -> term()), map()) -> map().
 update_pmaps(Fun, Map) ->
     lists:foldl(
@@ -1360,12 +1336,6 @@ pmap_iter_after(AfterExt, #pmap{table = Table, key_mapping = KeyMapping, cache =
     It = gbt_iter_after(AfterInt, Cache),
     InvKeyMapping = invert_key_mapping(KeyMapping),
     #{it => It, inv_key_mapping => InvKeyMapping}.
-%% ELSE ifdef(STORE_STATE_IN_DS).
--else.
-pmap_iter_after(After, #pmap{table = Table, cache = Cache}) ->
-    %% NOTE: Only valid for gbt-backed PMAPs.
-    gbt = cache_data_type(Table),
-    gbt_iter_after(After, Cache).
 %% END ifdef(STORE_STATE_IN_DS).
 -endif.
 
@@ -1373,7 +1343,6 @@ pmap_iter_after(After, #pmap{table = Table, cache = Cache}) ->
 
 -ifdef(STORE_STATE_IN_DS).
 -define(stream_tab, ?stream_domain).
--endif.
 
 cache_data_type(?stream_tab) -> gbt;
 cache_data_type(_Table) -> map.
@@ -1382,23 +1351,14 @@ cache_from_list(?stream_tab, L) ->
     gbt_from_list(L);
 cache_from_list(_Table, L) ->
     maps:from_list(L).
-
-cache_get(?stream_tab, K, Cache) ->
-    gbt_get(K, Cache, undefined);
-cache_get(_Table, K, Cache) ->
-    maps:get(K, Cache, undefined).
-
 cache_put(?stream_tab, K, V, Cache) ->
     gbt_put(K, V, Cache);
 cache_put(_Table, K, V, Cache) ->
     maps:put(K, V, Cache).
-
 cache_remove(?stream_tab, K, Cache) ->
     gbt_remove(K, Cache);
 cache_remove(_Table, K, Cache) ->
     maps:remove(K, Cache).
-
--ifdef(STORE_STATE_IN_DS).
 cache_fold(?stream_tab, Fun, Acc, KeyMapping, Cache) ->
     gbt_fold(Fun, Acc, KeyMapping, Cache);
 cache_fold(_Table, FunIn, Acc, KeyMapping, Cache) ->
@@ -1413,16 +1373,6 @@ cache_has_key(?stream_tab, Key, Cache) ->
     gb_trees:is_defined(Key, Cache);
 cache_has_key(_Domain, Key, Cache) ->
     is_map_key(Key, Cache).
-%% ELSE ifdef(STORE_STATE_IN_DS).
--else.
-cache_fold(?stream_tab, Fun, Acc, Cache) ->
-    gbt_fold(Fun, Acc, Cache);
-cache_fold(_Table, Fun, Acc, Cache) ->
-    maps:fold(Fun, Acc, Cache).
-%% END ifdef(STORE_STATE_IN_DS).
--endif.
-
--ifdef(STORE_STATE_IN_DS).
 cache_format(?stream_tab, InvKeyMapping, Cache) ->
     lists:map(
         fun({IntK, V}) ->
@@ -1440,20 +1390,31 @@ cache_format(_Table, InvKeyMapping, Cache) ->
         #{},
         Cache
     ).
-%% ELSE ifdef(STORE_STATE_IN_DS).
+cache_size(?stream_tab, Cache) ->
+    gbt_size(Cache);
+cache_size(_Table, Cache) ->
+    maps:size(Cache).
+%% Below ndef(STORE_STATE_IN_DS)
 -else.
-cache_format(?stream_tab, Cache) ->
-    gbt_format(Cache);
+cache_from_list(_Table, L) ->
+    maps:from_list(L).
+cache_put(_Table, K, V, Cache) ->
+    maps:put(K, V, Cache).
+cache_remove(_Table, K, Cache) ->
+    maps:remove(K, Cache).
+cache_fold(_Table, Fun, Acc, Cache) ->
+    maps:fold(Fun, Acc, Cache).
 cache_format(_Table, Cache) ->
     Cache.
+cache_size(_Table, Cache) ->
+    maps:size(Cache).
 %% END ifdef(STORE_STATE_IN_DS).
 -endif.
 
-cache_size(?stream_tab, Cache) ->
-    gbt_size(Cache);
-cache_size(_Table, Cache) ->
-    maps:size(Cache).
+cache_get(_Table, K, Cache) ->
+    maps:get(K, Cache, undefined).
 
+-ifdef(STORE_STATE_IN_DS).
 %% PMAP Cache implementation backed by `gb_trees'.
 %% Supports iteration starting from specific key.
 
@@ -1479,7 +1440,6 @@ gbt_remove(K, Cache) ->
 gbt_format(Cache) ->
     gb_trees:to_list(Cache).
 
--ifdef(STORE_STATE_IN_DS).
 gbt_fold(Fun, Acc, KeyMapping, Cache) ->
     InvKeyMapping = invert_key_mapping(KeyMapping),
     It = gb_trees:iterator(Cache),
@@ -1493,22 +1453,10 @@ gbt_fold_iter(Fun, Acc, InvKeyMapping, It0) ->
         _ ->
             Acc
     end.
-%% ELSE ifdef(STORE_STATE_IN_DS).
--else.
-gbt_fold(Fun, Acc, Cache) ->
-    It = gb_trees:iterator(Cache),
-    gbt_fold_iter(Fun, Acc, It).
-
-gbt_fold_iter(Fun, Acc, It0) ->
-    case gb_trees:next(It0) of
-        {K, V, It} ->
-            gbt_fold_iter(Fun, Fun(K, V, Acc), It);
-        _ ->
-            Acc
-    end.
 %% END ifdef(STORE_STATE_IN_DS).
 -endif.
 
+-ifdef(STORE_STATE_IN_DS).
 gbt_size(Cache) ->
     gb_trees:size(Cache).
 
@@ -1524,7 +1472,6 @@ gbt_iter_after(After, Cache) ->
 gbt_iter_next(It) ->
     gb_trees:next(It).
 
--ifdef(STORE_STATE_IN_DS).
 session_restore(SessionId) ->
     Empty = maps:from_keys(
         [
@@ -1549,7 +1496,6 @@ session_restore(SessionId) ->
 -else.
 
 %% Functions dealing with set tables:
-
 kv_persist(Tab, SessionId, Val0) ->
     Val = encoder(encode, Tab, Val0),
     mnesia:write(Tab, #kv{k = SessionId, v = Val}, write).

+ 461 - 111
apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_stream_scheduler.erl

@@ -13,12 +13,124 @@
 %% See the License for the specific language governing permissions and
 %% limitations under the License.
 %%--------------------------------------------------------------------
+
+%% @doc Stream scheduler is a helper module used by durable sessions
+%% to track states of the DS streams (Stream Replay States, or SRS for
+%% short). It has two main duties:
+%%
+%% - During normal operation, it polls DS iterators that are eligible
+%% for poll.
+%%
+%% - During session reconnect, it returns the list of SRS that must be
+%% replayed in order.
+%%
+%% ** Blocked streams
+%%
+%% For performance reasons we keep only one record of in-flight
+%% messages per stream, and we don't want to overwrite these records
+%% prematurely. So scheduler makes sure that streams that have
+%% un-acked QoS1 or QoS2 messages are not polled.
+%%
+%% ** Stream state machine
+%%
+%% During normal operation, state of each iterator can be described as
+%% a FSM. Implementation detail: unconventially, iterators' states are
+%% tracked implicitly, by moving SRS ID between different buckets.
+%% This facilitates faster processing of iterators that have a certain
+%% state.
+%%
+%% There are the following stream replay states:
+%%
+%% - *(R)eady*: stream iterator can be polled. Ready SRS are stored in
+%% `#s.ready' bucket.
+%%
+%% - *(P)ending*: poll request for the iterator has been sent to DS,
+%% and we're awaiting the response. Such iterators are stored in
+%% `#s.pending' bucket.
+%%
+%% - *(S)erved*: poll reply has been received, and ownership over SRS
+%% has been handed over to the parent session. This state is implicit:
+%% *served* streams are not tracked by the scheduler. It's assumed
+%% that the session will process the batch and immediately hand SRS
+%% back via `on_enqueue' call.
+%%
+%% - *BQ1*, *BQ2* and *BQ12*: these three states correspond to the
+%% situations when stream cannot be polled, because it is blocked by
+%% un-acked QoS1, QoS2 or QoS1&2 messages respectively. Such streams
+%% are stored in `#s.bq1' or `#s.bq2' buckets (or both).
+%%
+%% - *(U)nsubscribed*: streams for unsubcribed topics can linger in
+%% the session state for a while until all queued messages are acked.
+%% This state is implicit: unsubscribed streams are simply removed
+%% from all buckets. Unsubscribed streams are ignored by the scheduler
+%% until the moment they can be garbage-collected. So this is a
+%% terminal state. Even if the client resubscribes, it will produce a
+%% new, totally separate set of SRS.
+%%
+%% *** State transitions
+%%
+%% New streams start in the *Ready* state, from which they follow one
+%% of these paths:
+%%
+%%      .--(`?MODULE:poll')--> *P* --(Poll reply)--> *S* --> ...
+%%     /
+%% *R* --(`?MODULE:on_unsubscribe')--> *U*
+%%  ^  \
+%%  |   `--(`?MODULE:poll')--->---.
+%%  |                              \
+%%  \        Idle longpoll loop    *P*
+%%   \                             ,
+%%    `---<--(Poll timeout)---<---'
+%%
+%% *Served* streams are returned to the parent session, which assigns
+%% QoS and sequence numbers to the batch messages according to its own
+%% logic, and enqueues batch to the buffer. Then it returns the
+%% updated SRS back to the scheduler, where it can undergo the
+%% following transitions:
+%%
+%%          .--(buffer is full)--> *R* --> ...
+%%         /
+%%        /--(`on_unsubscribe')--> *U*
+%%       /
+%%      /--(only QoS0 messages in the batch)--> *R* --> ...
+%%     /
+%% *S* --([QoS0] & QoS1 messages)--> *BQ1* --> ...
+%%    \
+%%     \--([QoS0] & QoS2 messages)--> *BQ2* --> ...
+%%      \
+%%       `--([QoS0] & QoS1 & QoS2)--> *BQ12* --> ...
+%%
+%% *BQ1* and *BQ2* are handled similarly. They transition to *Ready*
+%% once session calls `?MODULE:on_seqno_release' for the corresponding
+%% QoS track and sequence number equal to the SRS's last sequence
+%% number for the track:
+%%
+%% *BQX* --(`?MODULE:on_seqno_release(?QOS_X, LastSeqNoX)')--> *R* --> ...
+%%      \
+%%       `--(`on_unsubscribe')--> *U*
+%%
+%% *BQ12* is handled like this:
+%%
+%%        .--(`on_seqno_release(?QOS_1, LastSeqNo1)')--> *BQ2* --> ...
+%%       /
+%% *BQ12*--(`on_unsubscribe')--> *U*
+%%       \
+%%        `--(`on_seqno_release(?QOS_2, LastSeqNo2)')--> *BQ1* --> ...
+%%
+%%
 -module(emqx_persistent_session_ds_stream_scheduler).
 
 %% API:
--export([iter_next_streams/2, next_stream/1]).
--export([find_replay_streams/1, is_fully_acked/2]).
--export([renew_streams/1, on_unsubscribe/2]).
+-export([
+    init/1,
+    poll/3,
+    on_ds_reply/3,
+    on_enqueue/5,
+    on_seqno_release/3,
+    find_replay_streams/1,
+    is_fully_acked/2
+]).
+-export([renew_streams/2, on_unsubscribe/3]).
 
 %% behavior callbacks:
 -export([]).
@@ -26,9 +138,11 @@
 %% internal exports:
 -export([]).
 
--export_type([]).
+-export_type([t/0, stream_key/0, srs/0]).
 
 -include_lib("emqx/include/logger.hrl").
+-include_lib("snabbkaffe/include/trace.hrl").
+-include_lib("emqx_durable_storage/include/emqx_ds.hrl").
 -include("emqx_mqtt.hrl").
 -include("session_internals.hrl").
 
@@ -36,37 +150,76 @@
 %% Type declarations
 %%================================================================================
 
--type stream_key() :: emqx_persistent_session_ds_state:stream_key().
--type stream_state() :: emqx_persistent_session_ds:stream_state().
+-type stream_key() :: {emqx_persistent_session_ds:subscription_id(), _StreamId}.
+
+-type srs() :: #srs{}.
+
+%%%%%% Pending poll for the iterator:
+-record(pending_poll, {
+    %% Poll reference:
+    ref :: reference(),
+    %% Iterator at the beginning of poll:
+    it_begin :: emqx_ds:iterator()
+}).
+
+-type pending() :: #pending_poll{}.
+
+-record(block, {
+    id :: stream_key(),
+    last_seqno_qos1 :: emqx_persistent_session_ds:seqno(),
+    last_seqno_qos2 :: emqx_persistent_session_ds:seqno()
+}).
+
+-type block() :: #block{}.
+
+-type blocklist() :: gb_trees:tree(emqx_persistent_session_ds:seqno(), block()).
 
-%% Restartable iterator with a filter and an iteration limit.
--record(iter, {
-    limit :: non_neg_integer(),
-    filter,
-    it,
-    it_cont
+-type ready() :: [stream_key()].
+
+-record(s, {
+    %% Buckets:
+    ready :: ready(),
+    pending = #{} :: #{stream_key() => #pending_poll{}},
+    bq1 :: blocklist(),
+    bq2 :: blocklist()
 }).
 
--type iter(K, V, IterInner) :: #iter{
-    filter :: fun((K, V) -> boolean()),
-    it :: IterInner,
-    it_cont :: IterInner
-}.
+-opaque t() :: #s{}.
 
--type iter_stream() :: iter(
-    stream_key(),
-    stream_state(),
-    emqx_persistent_session_ds_state:iter(stream_key(), stream_state())
-).
+-type state() :: r | p | s | bq1 | bq2 | bq12 | u.
 
 %%================================================================================
 %% API functions
 %%================================================================================
 
+-spec init(emqx_persistent_session_ds_state:t()) -> t().
+init(S) ->
+    SchedS0 = #s{
+        ready = empty_ready(),
+        bq1 = gb_trees:empty(),
+        bq2 = gb_trees:empty()
+    },
+    Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
+    Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
+    %% Restore stream states:
+    emqx_persistent_session_ds_state:fold_streams(
+        fun(Key, Srs, Acc) ->
+            case derive_state(Comm1, Comm2, Srs) of
+                r -> to_R(Key, Acc);
+                u -> to_U(Key, Srs, Acc);
+                bq1 -> to_BQ1(Key, Srs, Acc);
+                bq2 -> to_BQ2(Key, Srs, Acc);
+                bq12 -> to_BQ12(Key, Srs, Acc)
+            end
+        end,
+        SchedS0,
+        S
+    ).
+
 %% @doc Find the streams that have uncommitted (in-flight) messages.
 %% Return them in the order they were previously replayed.
 -spec find_replay_streams(emqx_persistent_session_ds_state:t()) ->
-    [{emqx_persistent_session_ds_state:stream_key(), emqx_persistent_session_ds:stream_state()}].
+    [{stream_key(), emqx_persistent_session_ds:stream_state()}].
 find_replay_streams(S) ->
     Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
     Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
@@ -85,70 +238,147 @@ find_replay_streams(S) ->
     ),
     lists:sort(fun compare_streams/2, Streams).
 
-%% @doc Find streams from which the new messages can be fetched.
-%%
-%% Currently it amounts to the streams that don't have any inflight
-%% messages, since for performance reasons we keep only one record of
-%% in-flight messages per stream, and we don't want to overwrite these
-%% records prematurely.
-%%
-%% This function is non-detereministic: it randomizes the order of
-%% streams to ensure fair replay of different topics.
--spec iter_next_streams(_LastVisited :: stream_key(), emqx_persistent_session_ds_state:t()) ->
-    iter_stream().
-iter_next_streams(LastVisited, S) ->
-    %% FIXME: this function is currently very sensitive to the
-    %% consistency of the packet IDs on both broker and client side.
-    %%
-    %% If the client fails to properly ack packets due to a bug, or a
-    %% network issue, or if the state of streams and seqno tables ever
-    %% become de-synced, then this function will return an empty list,
-    %% and the replay cannot progress.
-    %%
-    %% In other words, this function is not robust, and we should find
-    %% some way to get the replays un-stuck at the cost of potentially
-    %% losing messages during replay (or just kill the stuck channel
-    %% after timeout?)
-    Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
-    Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
-    Filter = fun(_Key, Stream) -> is_fetchable(Comm1, Comm2, Stream) end,
-    #iter{
-        %% Limit the iteration to one round over all streams:
-        limit = emqx_persistent_session_ds_state:n_streams(S),
-        %% Filter out the streams not eligible for fetching:
-        filter = Filter,
-        %% Start the iteration right after the last visited stream:
-        it = emqx_persistent_session_ds_state:iter_streams(LastVisited, S),
-        %% Restart the iteration from the beginning:
-        it_cont = emqx_persistent_session_ds_state:iter_streams(beginning, S)
-    }.
+%% @doc Send poll request to DS for all iterators that are currently
+%% in ready state.
+-spec poll(emqx_ds:poll_opts(), t(), emqx_persistent_session_ds_state:t()) -> t().
+poll(PollOpts0, SchedS0 = #s{ready = Ready}, S) ->
+    %% Create an alias for replies:
+    Ref = alias([explicit_unalias]),
+    %% Scan ready streams and create poll requests:
+    {Iterators, SchedS} = fold_ready(
+        fun(StreamKey, {AccIt, SchedS1}) ->
+            SRS = emqx_persistent_session_ds_state:get_stream(StreamKey, S),
+            It = {StreamKey, SRS#srs.it_end},
+            Pending = #pending_poll{ref = Ref, it_begin = SRS#srs.it_begin},
+            {
+                [It | AccIt],
+                to_P(StreamKey, Pending, SchedS1)
+            }
+        end,
+        {[], SchedS0},
+        Ready
+    ),
+    case Iterators of
+        [] ->
+            %% Nothing to poll:
+            unalias(Ref),
+            ok;
+        _ ->
+            %% Send poll request:
+            PollOpts = PollOpts0#{reply_to => Ref},
+            {ok, Ref} = emqx_ds:poll(?PERSISTENT_MESSAGE_DB, Iterators, PollOpts),
+            ok
+    end,
+    %% Clean ready bucket at once, since we poll all ready streams at once:
+    SchedS#s{ready = empty_ready()}.
 
--spec next_stream(iter_stream()) -> {stream_key(), stream_state(), iter_stream()} | none.
-next_stream(#iter{limit = 0}) ->
-    none;
-next_stream(ItStream0 = #iter{limit = N, filter = Filter, it = It0, it_cont = ItCont}) ->
-    case emqx_persistent_session_ds_state:iter_next(It0) of
-        {Key, Stream, It} ->
-            ItStream = ItStream0#iter{it = It, limit = N - 1},
-            case Filter(Key, Stream) of
+on_ds_reply(#poll_reply{ref = Ref, payload = poll_timeout}, S, SchedS0 = #s{pending = P0}) ->
+    %% Poll request has timed out. All pending streams that match poll
+    %% reference can be moved to R state:
+    ?SLOG(debug, #{msg => sess_poll_timeout, ref => Ref}),
+    unalias(Ref),
+    SchedS = maps:fold(
+        fun(Key, #pending_poll{ref = R}, SchedS1 = #s{pending = P}) ->
+            case R =:= Ref of
                 true ->
-                    {Key, Stream, ItStream};
+                    SchedS2 = SchedS1#s{pending = maps:remove(Key, P)},
+                    case emqx_persistent_session_ds_state:get_stream(Key, S) of
+                        undefined ->
+                            SchedS2;
+                        #srs{unsubscribed = true} ->
+                            SchedS2;
+                        _ ->
+                            to_R(Key, SchedS2)
+                    end;
                 false ->
-                    next_stream(ItStream)
-            end;
-        none when It0 =/= ItCont ->
-            %% Restart the iteration from the beginning:
-            ItStream = ItStream0#iter{it = ItCont},
-            next_stream(ItStream);
-        none ->
-            %% No point in restarting the iteration, `ItCont` is empty:
-            none
+                    SchedS1
+            end
+        end,
+        SchedS0,
+        P0
+    ),
+    {undefined, SchedS};
+on_ds_reply(
+    #poll_reply{ref = Ref, userdata = StreamKey, payload = Payload},
+    _S,
+    SchedS0 = #s{pending = Pending0}
+) ->
+    case maps:take(StreamKey, Pending0) of
+        {#pending_poll{ref = Ref, it_begin = ItBegin}, Pending} ->
+            ?tp(debug, sess_poll_reply, #{ref => Ref, stream_key => StreamKey}),
+            SchedS = SchedS0#s{pending = Pending},
+            {{StreamKey, ItBegin, Payload}, to_S(StreamKey, SchedS)};
+        _ ->
+            ?SLOG(
+                info,
+                #{
+                    msg => "sessds_unexpected_msg",
+                    userdata => StreamKey,
+                    ref => Ref
+                }
+            ),
+            {undefined, SchedS0}
     end.
 
-is_fetchable(_Comm1, _Comm2, #srs{it_end = end_of_stream}) ->
-    false;
-is_fetchable(Comm1, Comm2, #srs{unsubscribed = Unsubscribed} = Stream) ->
-    is_fully_acked(Comm1, Comm2, Stream) andalso not Unsubscribed.
+on_enqueue(true, _Key, _Srs, _S, SchedS) ->
+    SchedS;
+on_enqueue(false, Key, Srs, S, SchedS) ->
+    Comm1 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_1), S),
+    Comm2 = emqx_persistent_session_ds_state:get_seqno(?committed(?QOS_2), S),
+    case derive_state(Comm1, Comm2, Srs) of
+        r ->
+            to_R(Key, SchedS);
+        u ->
+            to_U(Key, Srs, SchedS);
+        bq1 ->
+            to_BQ1(Key, Srs, SchedS);
+        bq2 ->
+            to_BQ2(Key, Srs, SchedS);
+        bq12 ->
+            to_BQ12(Key, Srs, SchedS)
+    end.
+
+on_seqno_release(?QOS_1, SnQ1, SchedS0 = #s{bq1 = PrimaryTab0, bq2 = SecondaryTab}) ->
+    case check_block_status(PrimaryTab0, SecondaryTab, SnQ1, #block.last_seqno_qos2) of
+        false ->
+            %% This seqno doesn't unlock anything:
+            SchedS0;
+        {false, Key, PrimaryTab} ->
+            %% It was BQ1:
+            to_R(Key, SchedS0#s{bq1 = PrimaryTab});
+        {true, Key, PrimaryTab} ->
+            %% It was BQ12:
+            ?tp(sessds_stream_state_trans, #{
+                key => Key,
+                to => bq2
+            }),
+            SchedS0#s{bq1 = PrimaryTab}
+    end;
+on_seqno_release(?QOS_2, SnQ2, SchedS0 = #s{bq2 = PrimaryTab0, bq1 = SecondaryTab}) ->
+    case check_block_status(PrimaryTab0, SecondaryTab, SnQ2, #block.last_seqno_qos1) of
+        false ->
+            %% This seqno doesn't unlock anything:
+            SchedS0;
+        {false, Key, PrimaryTab} ->
+            %% It was BQ2:
+            to_R(Key, SchedS0#s{bq2 = PrimaryTab});
+        {true, Key, PrimaryTab} ->
+            %% It was BQ12:
+            ?tp(sessds_stream_state_trans, #{
+                key => Key,
+                to => bq1
+            }),
+            SchedS0#s{bq2 = PrimaryTab}
+    end.
+
+check_block_status(PrimaryTab0, SecondaryTab, PrimaryKey, SecondaryIdx) ->
+    case gb_trees:take_any(PrimaryKey, PrimaryTab0) of
+        error ->
+            false;
+        {Block = #block{id = StreamKey}, PrimaryTab} ->
+            StillBlocked = gb_trees:is_defined(element(SecondaryIdx, Block), SecondaryTab),
+            {StillBlocked, StreamKey, PrimaryTab}
+    end.
 
 %% @doc This function makes the session aware of the new streams.
 %%
@@ -167,8 +397,9 @@ is_fetchable(Comm1, Comm2, #srs{unsubscribed = Unsubscribed} = Stream) ->
 %% with the smallest RankY.
 %%
 %% This way, messages from the same topic/shard are never reordered.
--spec renew_streams(emqx_persistent_session_ds_state:t()) -> emqx_persistent_session_ds_state:t().
-renew_streams(S0) ->
+-spec renew_streams(emqx_persistent_session_ds_state:t(), t()) ->
+    {emqx_persistent_session_ds_state:t(), t()}.
+renew_streams(S0, SchedS0) ->
     S1 = remove_unsubscribed_streams(S0),
     S2 = remove_fully_replayed_streams(S1),
     S3 = update_stream_subscription_state_ids(S2),
@@ -179,12 +410,16 @@ renew_streams(S0) ->
     %% out of the scheduler for complete symmetry?
     fold_proper_subscriptions(
         fun
-            (Key, #{start_time := StartTime, id := SubId, current_state := SStateId}, Acc) ->
+            (
+                Key,
+                #{start_time := StartTime, id := SubId, current_state := SStateId},
+                Acc = {S4, _}
+            ) ->
                 TopicFilter = emqx_topic:words(Key),
                 Streams = select_streams(
                     SubId,
                     emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
-                    Acc
+                    S4
                 ),
                 lists:foldl(
                     fun(I, Acc1) ->
@@ -196,15 +431,15 @@ renew_streams(S0) ->
             (_Key, _DeletedSubscription, Acc) ->
                 Acc
         end,
-        S3,
+        {S3, SchedS0},
         S3
     ).
 
 -spec on_unsubscribe(
-    emqx_persistent_session_ds:subscription_id(), emqx_persistent_session_ds_state:t()
+    emqx_persistent_session_ds:subscription_id(), emqx_persistent_session_ds_state:t(), t()
 ) ->
-    emqx_persistent_session_ds_state:t().
-on_unsubscribe(SubId, S0) ->
+    {emqx_persistent_session_ds_state:t(), t()}.
+on_unsubscribe(SubId, S0, SchedS0) ->
     %% NOTE: this function only marks the streams for deletion,
     %% instead of outright deleting them.
     %%
@@ -224,19 +459,21 @@ on_unsubscribe(SubId, S0) ->
     %% `renew_streams', when it detects that all in-flight messages
     %% from the stream have been acked by the client.
     emqx_persistent_session_ds_state:fold_streams(
-        fun(Key, Srs, Acc) ->
+        fun(Key, Srs0, {S1, SchedS1}) ->
             case Key of
                 {SubId, _Stream} ->
                     %% This stream belongs to a deleted subscription.
                     %% Mark for deletion:
-                    emqx_persistent_session_ds_state:put_stream(
-                        Key, Srs#srs{unsubscribed = true}, Acc
-                    );
+                    Srs = Srs0#srs{unsubscribed = true},
+                    {
+                        emqx_persistent_session_ds_state:put_stream(Key, Srs, S1),
+                        to_U(Key, Srs, SchedS1)
+                    };
                 _ ->
-                    Acc
+                    {S1, SchedS1}
             end
         end,
-        S0,
+        {S0, SchedS0},
         S0
     ).
 
@@ -252,10 +489,120 @@ is_fully_acked(Srs, S) ->
 %% Internal functions
 %%================================================================================
 
-ensure_iterator(TopicFilter, StartTime, SubId, SStateId, {{RankX, RankY}, Stream}, S) ->
+%%--------------------------------------------------------------------------------
+%% SRS FSM
+%%--------------------------------------------------------------------------------
+
+-spec derive_state(
+    emqx_persistent_session_ds:seqno(), emqx_persistent_session_ds:seqno(), srs()
+) -> state().
+derive_state(_, _, #srs{unsubscribed = true}) ->
+    u;
+derive_state(Comm1, Comm2, SRS) ->
+    case {is_track_acked(?QOS_1, Comm1, SRS), is_track_acked(?QOS_2, Comm2, SRS)} of
+        {true, true} -> r;
+        {false, true} -> bq1;
+        {true, false} -> bq2;
+        {false, false} -> bq12
+    end.
+
+%% Note: `to_State' functions must be called from a correct state.
+%% They are NOT idempotent, and they don't do full cleanup.
+
+-spec to_R(stream_key(), t()) -> t().
+to_R(Key, S = #s{ready = R}) ->
+    ?tp(sessds_stream_state_trans, #{
+        key => Key,
+        to => r
+    }),
+    S#s{ready = push_to_ready(Key, R)}.
+
+-spec to_P(stream_key(), pending(), t()) -> t().
+to_P(Key, Pending, S = #s{pending = P}) ->
+    ?tp(sessds_stream_state_trans, #{
+        key => Key,
+        to => p
+    }),
+    S#s{pending = P#{Key => Pending}}.
+
+-spec to_BQ1(stream_key(), srs(), t()) -> t().
+to_BQ1(Key, SRS, S = #s{bq1 = BQ1}) ->
+    ?tp(sessds_stream_state_trans, #{
+        key => Key,
+        to => bq1
+    }),
+    Block = #block{last_seqno_qos1 = SN1} = block_of_srs(Key, SRS),
+    S#s{bq1 = gb_trees:insert(SN1, Block, BQ1)}.
+
+-spec to_BQ2(stream_key(), srs(), t()) -> t().
+to_BQ2(Key, SRS, S = #s{bq2 = BQ2}) ->
+    ?tp(sessds_stream_state_trans, #{
+        key => Key,
+        to => bq1
+    }),
+    Block = #block{last_seqno_qos2 = SN2} = block_of_srs(Key, SRS),
+    S#s{bq2 = gb_trees:insert(SN2, Block, BQ2)}.
+
+-spec to_BQ12(stream_key(), srs(), t()) -> t().
+to_BQ12(Key, SRS, S = #s{bq1 = BQ1, bq2 = BQ2}) ->
+    ?tp(sessds_stream_state_trans, #{
+        key => Key,
+        to => bq12
+    }),
+    Block = #block{last_seqno_qos1 = SN1, last_seqno_qos2 = SN2} = block_of_srs(Key, SRS),
+    S#s{bq1 = gb_trees:insert(SN1, Block, BQ1), bq2 = gb_trees:insert(SN2, Block, BQ2)}.
+
+-spec to_U(stream_key(), srs(), t()) -> t().
+to_U(
+    Key,
+    #srs{last_seqno_qos1 = SN1, last_seqno_qos2 = SN2},
+    S = #s{ready = R, pending = P, bq1 = BQ1, bq2 = BQ2}
+) ->
+    ?tp(sessds_stream_state_trans, #{
+        key => Key,
+        to => u
+    }),
+    S#s{
+        ready = del_ready(Key, R),
+        pending = maps:remove(Key, P),
+        bq1 = gb_trees:delete_any(SN1, BQ1),
+        bq2 = gb_trees:delete_any(SN2, BQ2)
+    }.
+
+-spec to_S(stream_key(), t()) -> t().
+to_S(Key, S) ->
+    ?tp(sessds_stream_state_trans, #{
+        key => Key,
+        to => s
+    }),
+    S.
+
+-spec block_of_srs(stream_key(), srs()) -> block().
+block_of_srs(Key, #srs{last_seqno_qos1 = SN1, last_seqno_qos2 = SN2}) ->
+    #block{id = Key, last_seqno_qos1 = SN1, last_seqno_qos2 = SN2}.
+
+%%--------------------------------------------------------------------------------
+%% Misc.
+%%--------------------------------------------------------------------------------
+
+fold_ready(Fun, Acc, Ready) ->
+    lists:foldl(Fun, Acc, Ready).
+
+empty_ready() ->
+    [].
+
+push_to_ready(K, Ready) ->
+    [K | Ready].
+
+del_ready(K, Ready) ->
+    Ready -- [K].
+
+ensure_iterator(TopicFilter, StartTime, SubId, SStateId, {{RankX, RankY}, Stream}, {S, SchedS}) ->
     Key = {SubId, Stream},
     case emqx_persistent_session_ds_state:get_stream(Key, S) of
         undefined ->
+            %% This is a newly discovered stream. Create an iterator
+            %% for it, and mark it as ready:
             case emqx_ds:make_iterator(?PERSISTENT_MESSAGE_DB, Stream, TopicFilter, StartTime) of
                 {ok, Iterator} ->
                     NewStreamState = #srs{
@@ -265,18 +612,21 @@ ensure_iterator(TopicFilter, StartTime, SubId, SStateId, {{RankX, RankY}, Stream
                         it_end = Iterator,
                         sub_state_id = SStateId
                     },
-                    emqx_persistent_session_ds_state:put_stream(Key, NewStreamState, S);
-                {error, recoverable, Reason} ->
-                    ?SLOG(debug, #{
+                    {
+                        emqx_persistent_session_ds_state:put_stream(Key, NewStreamState, S),
+                        to_R(Key, SchedS)
+                    };
+                {error, Class, Reason} ->
+                    ?SLOG(info, #{
                         msg => "failed_to_initialize_stream_iterator",
                         stream => Stream,
-                        class => recoverable,
+                        class => Class,
                         reason => Reason
                     }),
-                    S
+                    {S, SchedS}
             end;
         #srs{} ->
-            S
+            {S, SchedS}
     end.
 
 select_streams(SubId, Streams0, S) ->
@@ -446,14 +796,14 @@ compare_streams(
 is_fully_replayed(Comm1, Comm2, S = #srs{it_end = It}) ->
     It =:= end_of_stream andalso is_fully_acked(Comm1, Comm2, S).
 
-is_fully_acked(_, _, #srs{
-    first_seqno_qos1 = Q1, last_seqno_qos1 = Q1, first_seqno_qos2 = Q2, last_seqno_qos2 = Q2
-}) ->
-    %% Streams where the last chunk doesn't contain any QoS1 and 2
-    %% messages are considered fully acked:
-    true;
-is_fully_acked(Comm1, Comm2, #srs{last_seqno_qos1 = S1, last_seqno_qos2 = S2}) ->
-    (Comm1 >= S1) andalso (Comm2 >= S2).
+is_fully_acked(Comm1, Comm2, SRS) ->
+    is_track_acked(?QOS_1, Comm1, SRS) andalso
+        is_track_acked(?QOS_2, Comm2, SRS).
+
+is_track_acked(?QOS_1, Committed, #srs{first_seqno_qos1 = First, last_seqno_qos1 = Last}) ->
+    First =:= Last orelse Committed >= Last;
+is_track_acked(?QOS_2, Committed, #srs{first_seqno_qos2 = First, last_seqno_qos2 = Last}) ->
+    First =:= Last orelse Committed >= Last.
 
 fold_proper_subscriptions(Fun, Acc, S) ->
     emqx_persistent_session_ds_state:fold_subscriptions(

+ 8 - 2
apps/emqx/src/emqx_schema.erl

@@ -1696,14 +1696,20 @@ fields("durable_sessions") ->
                 #{
                     default => 100,
                     desc => ?DESC(session_ds_batch_size),
-                    importance => ?IMPORTANCE_MEDIUM
+                    importance => ?IMPORTANCE_MEDIUM,
+                    %% Note: the same value is used for both sync
+                    %% `next' request and async polls. Since poll
+                    %% workers are global for the DS DB, this value is
+                    %% global and it cannot be overridden per
+                    %% listener:
+                    mapping => "emqx_durable_session.poll_batch_size"
                 }
             )},
         {"idle_poll_interval",
             sc(
                 timeout_duration(),
                 #{
-                    default => <<"100ms">>,
+                    default => <<"5s">>,
                     desc => ?DESC(session_ds_idle_poll_interval)
                 }
             )},

+ 8 - 4
apps/emqx/src/emqx_topic_index.erl

@@ -18,7 +18,7 @@
 
 -module(emqx_topic_index).
 
--export([new/0]).
+-export([new/0, new/1]).
 -export([insert/4]).
 -export([delete/3]).
 -export([match/2]).
@@ -37,11 +37,15 @@
 -type match(ID) :: key(ID).
 -type words() :: emqx_trie_search:words().
 
-%% @doc Create a new ETS table suitable for topic index.
-%% Usable mostly for testing purposes.
 -spec new() -> ets:table().
 new() ->
-    ets:new(?MODULE, [public, ordered_set, {read_concurrency, true}]).
+    new([public, {read_concurrency, true}]).
+
+%% @doc Create a new ETS table suitable for topic index.
+%% Usable mostly for testing purposes.
+-spec new(list()) -> ets:table().
+new(Options) ->
+    ets:new(?MODULE, [ordered_set | Options]).
 
 %% @doc Insert a new entry into the index that associates given topic filter to given
 %% record ID, and attaches arbitrary record to the entry. This allows users to choose

+ 4 - 0
apps/emqx/src/emqx_trie_search.erl

@@ -356,6 +356,8 @@ match_add(K, first) ->
     throw({first, K}).
 
 -spec filter_words(emqx_types:topic()) -> [word()].
+filter_words(Words) when is_list(Words) ->
+    Words;
 filter_words(Topic) when is_binary(Topic) ->
     % NOTE
     % This is almost identical to `emqx_topic:words/1`, but it doesn't convert empty
@@ -364,6 +366,8 @@ filter_words(Topic) when is_binary(Topic) ->
     [word(W, filter) || W <- emqx_topic:tokens(Topic)].
 
 -spec topic_words(emqx_types:topic()) -> [binary()].
+topic_words(Words) when is_list(Words) ->
+    Words;
 topic_words(Topic) when is_binary(Topic) ->
     [word(W, topic) || W <- emqx_topic:tokens(Topic)].
 

+ 73 - 11
apps/emqx_ds_builtin_local/src/emqx_ds_builtin_local.erl

@@ -17,6 +17,7 @@
 
 -behaviour(emqx_ds).
 -behaviour(emqx_ds_buffer).
+-behaviour(emqx_ds_beamformer).
 
 %% API:
 -export([]).
@@ -38,8 +39,13 @@
     make_delete_iterator/4,
     update_iterator/3,
     next/3,
+    poll/3,
     delete_next/4,
 
+    %% `beamformer':
+    unpack_iterator/2,
+    scan_stream/5,
+
     %% `emqx_ds_buffer':
     init_buffer/3,
     flush_buffer/4,
@@ -91,6 +97,7 @@
         backend := builtin_local,
         storage := emqx_ds_storage_layer:prototype(),
         n_shards := pos_integer(),
+        poll_workers_per_shard => pos_integer(),
         %% Inherited from `emqx_ds:generic_db_opts()`.
         force_monotonic_timestamps => boolean(),
         atomic_batches => boolean()
@@ -265,7 +272,10 @@ flush_buffer(DB, Shard, Messages, S0 = #bs{options = Options}) ->
     ShardId = {DB, Shard},
     ForceMonotonic = maps:get(force_monotonic_timestamps, Options),
     {Latest, Batch} = make_batch(ForceMonotonic, current_timestamp(ShardId), Messages),
-    Result = emqx_ds_storage_layer:store_batch(ShardId, Batch, _Options = #{}),
+    DispatchF = fun(Events) ->
+        emqx_ds_beamformer:shard_event({DB, Shard}, Events)
+    end,
+    Result = emqx_ds_storage_layer:store_batch(ShardId, Batch, _Options = #{}, DispatchF),
     emqx_ds_builtin_local_meta:set_current_timestamp(ShardId, Latest),
     {S0, Result}.
 
@@ -359,10 +369,10 @@ make_iterator(DB, ?stream(Shard, InnerStream), TopicFilter, StartTime) ->
             Error
     end.
 
--spec update_iterator(emqx_ds:db(), iterator(), emqx_ds:message_key()) ->
+-spec update_iterator(_Shard, emqx_ds:ds_specific_iterator(), emqx_ds:message_key()) ->
     emqx_ds:make_iterator_result(iterator()).
-update_iterator(DB, Iter0 = #{?tag := ?IT, ?shard := Shard, ?enc := StorageIter0}, Key) ->
-    case emqx_ds_storage_layer:update_iterator({DB, Shard}, StorageIter0, Key) of
+update_iterator(ShardId, Iter0 = #{?tag := ?IT, ?enc := StorageIter0}, Key) ->
+    case emqx_ds_storage_layer:update_iterator(ShardId, StorageIter0, Key) of
         {ok, StorageIter} ->
             {ok, Iter0#{?enc => StorageIter}};
         Err = {error, _, _} ->
@@ -371,15 +381,60 @@ update_iterator(DB, Iter0 = #{?tag := ?IT, ?shard := Shard, ?enc := StorageIter0
 
 -spec next(emqx_ds:db(), iterator(), pos_integer()) -> emqx_ds:next_result(iterator()).
 next(DB, Iter, N) ->
-    {ok, Ref} = anext(DB, Iter, N),
+    {ok, Ref} = emqx_ds_lib:with_worker(undefined, ?MODULE, do_next, [DB, Iter, N]),
     receive
-        #ds_async_result{ref = Ref, data = Data} ->
+        #poll_reply{ref = Ref, payload = Data} ->
             Data
     end.
 
--spec anext(emqx_ds:db(), iterator(), pos_integer()) -> {ok, reference()}.
-anext(DB, Iter, N) ->
-    emqx_ds_lib:anext_helper(?MODULE, do_next, [DB, Iter, N]).
+-spec poll(emqx_ds:db(), emqx_ds:poll_iterators(), emqx_ds:poll_opts()) -> {ok, reference()}.
+poll(DB, Iterators, PollOpts = #{timeout := Timeout}) ->
+    %% Create a new alias, if not already provided:
+    case PollOpts of
+        #{reply_to := ReplyTo} ->
+            ok;
+        _ ->
+            ReplyTo = alias([explicit_unalias])
+    end,
+    %% Spawn a helper process that will notify the caller when the
+    %% poll times out:
+    _Completion = spawn_link(
+        fun() ->
+            send_poll_timeout(ReplyTo, Timeout)
+        end
+    ),
+    %% Submit poll jobs:
+    lists:foreach(
+        fun({ItKey, It = #{?tag := ?IT, ?shard := Shard}}) ->
+            ShardId = {DB, Shard},
+            ReturnAddr = {ReplyTo, ItKey},
+            emqx_ds_beamformer:poll(node(), ReturnAddr, ShardId, It, PollOpts)
+        end,
+        Iterators
+    ),
+    {ok, ReplyTo}.
+
+unpack_iterator(Shard, #{?tag := ?IT, ?enc := Iterator}) ->
+    {Stream, TopicFilter, DSKey, TS} = emqx_ds_storage_layer:unpack_iterator(Shard, Iterator),
+    MsgMatcher = emqx_ds_storage_layer:message_matcher(Shard, Iterator),
+    #{
+        stream => Stream,
+        topic_filter => TopicFilter,
+        last_seen_key => DSKey,
+        timestamp => TS,
+        message_matcher => MsgMatcher
+    }.
+
+scan_stream(ShardId, Stream, TopicFilter, StartMsg, BatchSize) ->
+    {DB, _} = ShardId,
+    Now = current_timestamp(ShardId),
+    T0 = erlang:monotonic_time(microsecond),
+    Result = emqx_ds_storage_layer:scan_stream(
+        ShardId, Stream, TopicFilter, Now, StartMsg, BatchSize
+    ),
+    T1 = erlang:monotonic_time(microsecond),
+    emqx_ds_builtin_metrics:observe_next_time(DB, T1 - T0),
+    Result.
 
 -spec get_delete_streams(emqx_ds:db(), emqx_ds:topic_filter(), emqx_ds:time()) ->
     [emqx_ds:ds_specific_delete_stream()].
@@ -420,9 +475,9 @@ make_delete_iterator(DB, ?delete_stream(Shard, InnerStream), TopicFilter, StartT
 -spec delete_next(emqx_ds:db(), delete_iterator(), emqx_ds:delete_selector(), pos_integer()) ->
     emqx_ds:delete_next_result(emqx_ds:delete_iterator()).
 delete_next(DB, Iter, Selector, N) ->
-    {ok, Ref} = emqx_ds_lib:anext_helper(?MODULE, do_delete_next, [DB, Iter, Selector, N]),
+    {ok, Ref} = emqx_ds_lib:with_worker(undefined, ?MODULE, do_delete_next, [DB, Iter, Selector, N]),
     receive
-        #ds_async_result{ref = Ref, data = Data} -> Data
+        #poll_reply{ref = Ref, payload = Data} -> Data
     end.
 
 %%================================================================================
@@ -477,3 +532,10 @@ timeus_to_timestamp(undefined) ->
     undefined;
 timeus_to_timestamp(TimestampUs) ->
     TimestampUs div 1000.
+
+send_poll_timeout(ReplyTo, Timeout) ->
+    receive
+    after Timeout + 10 ->
+        logger:debug("Timeout for poll ~p", [ReplyTo]),
+        ReplyTo ! #poll_reply{ref = ReplyTo, payload = poll_timeout}
+    end.

+ 17 - 1
apps/emqx_ds_builtin_local/src/emqx_ds_builtin_local_db_sup.erl

@@ -160,7 +160,8 @@ init({#?shard_sup{db = DB, shard = Shard}, _}) ->
     Children = [
         shard_storage_spec(DB, Shard, Opts),
         shard_buffer_spec(DB, Shard, Opts),
-        shard_batch_serializer_spec(DB, Shard, Opts)
+        shard_batch_serializer_spec(DB, Shard, Opts),
+        shard_beamformers_spec(DB, Shard, Opts)
     ],
     {ok, {SupFlags, Children}}.
 
@@ -228,6 +229,21 @@ shard_batch_serializer_spec(DB, Shard, Opts) ->
         type => worker
     }.
 
+shard_beamformers_spec(DB, Shard, _Options) ->
+    %% TODO: don't hardcode value
+    BeamformerOpts = #{
+        n_workers => 5
+    },
+    #{
+        id => {Shard, beamformers},
+        type => supervisor,
+        shutdown => infinity,
+        start =>
+            {emqx_ds_beamformer_sup, start_link, [
+                emqx_ds_builtin_local, {DB, Shard}, BeamformerOpts
+            ]}
+    }.
+
 ensure_started(Res) ->
     case Res of
         {ok, _Pid} ->

+ 2 - 31
apps/emqx_ds_builtin_local/test/emqx_ds_builtin_local_SUITE.erl

@@ -122,35 +122,6 @@ t_drop_generation_with_used_once_iterator(Config) ->
         emqx_ds_test_helpers:consume_iter(DB, Iter1)
     ).
 
-t_drop_generation_update_iterator(Config) ->
-    %% This checks the behavior of `emqx_ds:update_iterator' after the generation
-    %% underlying the iterator has been dropped.
-
-    DB = ?FUNCTION_NAME,
-    ?assertMatch(ok, emqx_ds:open_db(DB, opts(Config))),
-    [GenId0] = maps:keys(emqx_ds:list_generations_with_lifetimes(DB)),
-
-    TopicFilter = emqx_topic:words(<<"foo/+">>),
-    StartTime = 0,
-    Msgs0 = [
-        message(<<"foo/bar">>, <<"1">>, 0),
-        message(<<"foo/baz">>, <<"2">>, 1)
-    ],
-    ?assertMatch(ok, emqx_ds:store_batch(DB, Msgs0)),
-
-    [{_, Stream0}] = emqx_ds:get_streams(DB, TopicFilter, StartTime),
-    {ok, Iter0} = emqx_ds:make_iterator(DB, Stream0, TopicFilter, StartTime),
-    {ok, Iter1, _Batch1} = emqx_ds:next(DB, Iter0, 1),
-    {ok, _Iter2, [{Key2, _Msg}]} = emqx_ds:next(DB, Iter1, 1),
-
-    ok = emqx_ds:add_generation(DB),
-    ok = emqx_ds:drop_generation(DB, GenId0),
-
-    ?assertEqual(
-        {error, unrecoverable, generation_not_found},
-        emqx_ds:update_iterator(DB, Iter1, Key2)
-    ).
-
 t_make_iterator_stale_stream(Config) ->
     %% This checks the behavior of `emqx_ds:make_iterator' after the generation underlying
     %% the stream has been dropped.
@@ -233,7 +204,7 @@ t_store_batch_fail(Config) ->
             ],
             ?assertMatch(ok, emqx_ds:store_batch(DB, Batch1, #{sync => true})),
             %% Inject unrecoverable error:
-            meck:expect(emqx_ds_storage_layer, store_batch, fun(_DB, _Shard, _Messages) ->
+            meck:expect(emqx_ds_storage_layer, store_batch, fun(_DB, _Shard, _Messages, _DispatchF) ->
                 {error, unrecoverable, mock}
             end),
             Batch2 = [
@@ -244,7 +215,7 @@ t_store_batch_fail(Config) ->
                 {error, unrecoverable, mock}, emqx_ds:store_batch(DB, Batch2, #{sync => true})
             ),
             %% Inject a recoveralbe error:
-            meck:expect(emqx_ds_storage_layer, store_batch, fun(_DB, _Shard, _Messages) ->
+            meck:expect(emqx_ds_storage_layer, store_batch, fun(_DB, _Shard, _Messages, _DispatchF) ->
                 {error, recoverable, mock}
             end),
             Batch3 = [

+ 6 - 0
apps/emqx_ds_builtin_raft/src/emqx_ds_replication_layer.erl

@@ -25,6 +25,7 @@
     make_delete_iterator/4,
     update_iterator/3,
     next/3,
+    poll/3,
     delete_next/4,
 
     current_timestamp/2,
@@ -382,6 +383,11 @@ next(DB, Iter0, BatchSize) ->
             Other
     end.
 
+-spec poll(emqx_ds:db(), emqx_ds:poll_iterators(), emqx_ds:poll_opts()) ->
+    {ok, reference()}.
+poll(_DB, _Iterators, _PollOpts) ->
+    error(not_implemented).
+
 -spec delete_next(emqx_ds:db(), delete_iterator(), emqx_ds:delete_selector(), pos_integer()) ->
     emqx_ds:delete_next_result(delete_iterator()).
 delete_next(DB, Iter0, Selector, BatchSize) ->

+ 3 - 2
apps/emqx_durable_storage/include/emqx_ds.hrl

@@ -40,9 +40,10 @@
     filters = #{}
 }).
 
--record(ds_async_result, {
+-record(poll_reply, {
     ref :: reference(),
-    data :: emqx_ds:next_result()
+    userdata,
+    payload :: emqx_ds:next_result() | poll_timeout
 }).
 
 -endif.

+ 13 - 0
apps/emqx_durable_storage/include/emqx_ds_metrics.hrl

@@ -55,4 +55,17 @@
 -define(DS_SKIPSTREAM_LTS_FUTURE, emqx_ds_storage_skipstream_lts_future).
 -define(DS_SKIPSTREAM_LTS_EOS, emqx_ds_storage_skipstream_lts_end_of_stream).
 
+%%%% Poll metrics:
+%% Total number of incoming poll requests:
+-define(DS_POLL_REQUESTS, emqx_ds_poll_requests).
+%% Number of fulfilled requests:
+-define(DS_POLL_REQUESTS_FULFILLED, emqx_ds_poll_requests_fulfilled).
+%% Number of requests dropped due to OLP:
+-define(DS_POLL_REQUESTS_DROPPED, emqx_ds_poll_requests_dropped).
+%% Number of requests that expired while waiting for new messages:
+-define(DS_POLL_REQUESTS_EXPIRED, emqx_ds_poll_requests_expired).
+%% Measure of "beam coherence": average number of requests fulfilled
+%% by a single beam:
+-define(DS_POLL_REQUEST_SHARING, emqx_ds_poll_request_sharing).
+
 -endif.

+ 54 - 28
apps/emqx_durable_storage/src/emqx_ds.erl

@@ -39,7 +39,7 @@
 -export([store_batch/2, store_batch/3]).
 
 %% Message replay API:
--export([get_streams/3, make_iterator/4, update_iterator/3, next/3, anext/3]).
+-export([get_streams/3, make_iterator/4, next/3, poll/3]).
 
 %% Message delete API:
 -export([get_delete_streams/3, make_delete_iterator/4, delete_next/4]).
@@ -84,7 +84,10 @@
     ds_specific_delete_stream/0,
     ds_specific_delete_iterator/0,
     generation_rank/0,
-    generation_info/0
+    generation_info/0,
+
+    poll_iterators/0,
+    poll_opts/0
 ]).
 
 %%================================================================================
@@ -180,6 +183,8 @@
 
 -type delete_next_result() :: delete_next_result(delete_iterator()).
 
+-type poll_iterators() :: [{_UserData, iterator()}].
+
 -type error(Reason) :: {error, recoverable | unrecoverable, Reason}.
 
 %% Timestamp
@@ -217,6 +222,17 @@
 
 -type create_db_opts() :: generic_db_opts().
 
+-type poll_opts() ::
+    #{
+        %% Expire poll request after this timeout
+        timeout := pos_integer(),
+        %% (Optional) Provide an explicit process alias for receiving
+        %% replies. It must be created with `explicit_unalias' flag,
+        %% otherwise replies will get lost. If not specified, DS will
+        %% create a new alias.
+        reply_to => reference()
+    }.
+
 %% An opaque term identifying a generation.  Each implementation will possibly add
 %% information to this term to match its inner structure (e.g.: by embedding the shard id,
 %% in the case of `emqx_ds_replication_layer').
@@ -258,18 +274,9 @@
 -callback make_iterator(db(), ds_specific_stream(), topic_filter(), time()) ->
     make_iterator_result(ds_specific_iterator()).
 
--callback update_iterator(db(), ds_specific_iterator(), message_key()) ->
-    make_iterator_result(ds_specific_iterator()).
-
 -callback next(db(), Iterator, pos_integer()) -> next_result(Iterator).
 
-%% Asynchronous next. Backend must reply to the calling process with
-%% `#ds_async_result{}' message, where `ref' field is equal to the
-%% returned reference.
-%%
-%% Reference is a process alias that can be unalised to ignore the
-%% result.
--callback anext(db(), _Iterator, pos_integer()) -> {ok, reference()}.
+-callback poll(db(), poll_iterators(), poll_opts()) -> {ok, reference()}.
 
 -callback get_delete_streams(db(), topic_filter(), time()) -> [ds_specific_delete_stream()].
 
@@ -289,8 +296,6 @@
     make_delete_iterator/4,
     delete_next/4,
 
-    anext/3,
-
     count/1
 ]).
 
@@ -415,24 +420,45 @@ get_streams(DB, TopicFilter, StartTime) ->
 make_iterator(DB, Stream, TopicFilter, StartTime) ->
     ?module(DB):make_iterator(DB, Stream, TopicFilter, StartTime).
 
--spec update_iterator(db(), iterator(), message_key()) ->
-    make_iterator_result().
-update_iterator(DB, OldIter, DSKey) ->
-    ?module(DB):update_iterator(DB, OldIter, DSKey).
-
 -spec next(db(), iterator(), pos_integer()) -> next_result().
 next(DB, Iter, BatchSize) ->
     ?module(DB):next(DB, Iter, BatchSize).
 
--spec anext(db(), iterator(), pos_integer()) -> {ok, reference()}.
-anext(DB, Iter, BatchSize) ->
-    Mod = ?module(DB),
-    case erlang:function_exported(Mod, anext, 3) of
-        true ->
-            Mod:anext(DB, Iter, BatchSize);
-        false ->
-            emqx_ds_lib:anext_helper(Mod, next, [DB, Iter, BatchSize])
-    end.
+%% @doc Schedule asynchrounous long poll of the iterators and return
+%% immediately.
+%%
+%% Arguments:
+%% 1. Name of DS DB
+%% 2. List of tuples, where first element is an arbitrary tag that can
+%%    be used to identify replies, and the second one is iterator.
+%% 3. Poll options
+%%
+%% Return value: process alias that identifies the replies.
+%%
+%% Data will be sent to the caller process as messages wrapped in
+%% `#poll_reply' record:
+%% - `ref' field will be equal to the returned reference.
+%% - `userdata' field will be equal to the iterator tag.
+%% - `payload' will be of type `next_result()' or `poll_timeout' atom
+%%
+%% There are some important caveats:
+%%
+%% - Replies are sent on a best-effort basis. They may be lost for any
+%% reason. Caller must be designed to tolerate and retry missed poll
+%% replies.
+%%
+%% - There is no explicit lifetime management for poll workers. When
+%% caller dies, its poll requests survive. It's assumed that orphaned
+%% requests will naturally clean themselves out by timeout alone.
+%% Therefore, timeout must not be too long.
+%%
+%% - But not too short either: if no data arrives to the stream before
+%% timeout, the request is usually retried. This should not create a
+%% busy loop. Also DS may silently drop requests due to overload. So
+%% they should not be retried too early.
+-spec poll(db(), poll_iterators(), poll_opts()) -> {ok, reference()}.
+poll(DB, Iterators, PollOpts = #{timeout := Timeout}) when is_integer(Timeout), Timeout > 0 ->
+    ?module(DB):poll(DB, Iterators, PollOpts).
 
 -spec get_delete_streams(db(), topic_filter(), time()) -> [delete_stream()].
 get_delete_streams(DB, TopicFilter, StartTime) ->

+ 827 - 0
apps/emqx_durable_storage/src/emqx_ds_beamformer.erl

@@ -0,0 +1,827 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+%% @doc This process is responsible for processing async poll requests
+%% from the consumers.
+%%
+%% It serves as a pool for such requests, limiting the number of
+%% queries running in parallel. In addition, it tries to group
+%% "coherent" poll requests together, so they can be fulfilled as a
+%% group ("coherent beam").
+%%
+%% By "coherent" we mean requests to scan overlapping key ranges of
+%% the same DS stream. Grouping requests helps to reduce the number of
+%% storage queries and conserve throughput of the EMQX backplane
+%% network.
+%%
+%% Beamformer works as following:
+%%
+%% - Initially, requests are added to the "pending" queue.
+%%
+%% - Beamformer process spins in a "fulfill loop" that takes requests
+%% from the pending queue one at a time, and tries to fulfill them
+%% normally by quering the storage.
+%%
+%% - If storage returns a non-empty batch, beamformer then searches
+%% for pending poll requests that may be coherent with the current
+%% one. All matching requests are then packed into "beams" (one per
+%% destination node) and sent out accodingly.
+%%
+%% - If the query returns an empty batch, beamformer moves the request
+%% to the "wait" queue. Poll requests just linger there until they
+%% time out, or until beamformer receives a matching stream event from
+%% the storage. The storage backend can send requests to the
+%% beamformer by calling `shard_event' function.
+%%
+%% Storage event processing logic is following: if beamformer finds
+%% waiting poll requests matching the event, it queries the storage
+%% for a batch of data. If batch is non-empty, requests are served
+%% exactly as described above. If batch is empty again, request is
+%% moved back to the wait queue.
+
+%% WARNING: beamformer makes some implicit assumptions about the
+%% storage layout:
+%%
+%% - There's a bijection between iterator position and the message key
+%%
+%% - Message keys in the stream are monotonic
+%%
+%% - Quering a stream with non-wildcard topic-filter is equivalent to
+%% quering it with a wildcard topic filter and dropping messages in
+%% postprocessing, e.g.:
+%%
+%% ```
+%% next("foo/bar", StartTime) ==
+%%   filter(λ msg. msg.topic == "foo/bar",
+%%         next("#", StartTime))
+%% '''
+-module(emqx_ds_beamformer).
+
+-behaviour(gen_server).
+
+%% API:
+-export([poll/5, shard_event/2]).
+
+%% behavior callbacks:
+-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2]).
+
+%% internal exports:
+-export([start_link/4, do_dispatch/1]).
+
+-export_type([opts/0, beam/2, beam/0, return_addr/1, unpack_iterator_result/1]).
+
+-include_lib("snabbkaffe/include/trace.hrl").
+-include_lib("emqx_utils/include/emqx_message.hrl").
+
+-include("emqx_ds.hrl").
+
+-ifdef(TEST).
+-include_lib("eunit/include/eunit.hrl").
+-endif.
+
+%%================================================================================
+%% Type declarations
+%%================================================================================
+
+-type opts() :: #{
+    n_workers := non_neg_integer()
+}.
+
+%% Request:
+
+-type return_addr(ItKey) :: {reference(), ItKey}.
+
+-record(poll_req, {
+    key,
+    %% Node from which the poll request originates:
+    node,
+    %% Information about the process that created the request:
+    return_addr,
+    %% Iterator:
+    it,
+    %% Callback that filters messages that belong to the request:
+    msg_matcher,
+    opts,
+    deadline
+}).
+
+-type poll_req(ItKey, Iterator) ::
+    #poll_req{
+        key :: {_Stream, _TopicFilter, emqx_ds:message_key()},
+        node :: node(),
+        return_addr :: return_addr(ItKey),
+        it :: Iterator,
+        msg_matcher :: match_messagef(),
+        opts :: emqx_ds:poll_opts(),
+        deadline :: integer()
+    }.
+
+%% Response:
+
+-type dispatch_mask() :: bitstring().
+
+-record(beam, {iterators, pack, misc = #{}}).
+
+-opaque beam(ItKey, Iterator) ::
+    #beam{
+        iterators :: [{return_addr(ItKey), Iterator}],
+        pack ::
+            [{emqx_ds:message_key(), dispatch_mask(), emqx_types:message()}]
+            | end_of_stream
+            | emqx_ds:error(_),
+        misc :: #{}
+    }.
+
+-type beam() :: beam(_ItKey, _Iterator).
+
+-type stream_scan_return() ::
+    {ok, emqx_ds:message_key(), [{emqx_ds:message_key(), emqx_types:message()}]}
+    | {ok, end_of_stream}
+    | emqx_ds:error(_).
+
+-record(s, {
+    module :: module(),
+    metrics_id,
+    shard,
+    name,
+    pending_queue :: ets:tid(),
+    pending_request_limit :: non_neg_integer(),
+    wait_queue :: ets:tid(),
+    is_spinning = false :: boolean(),
+    batch_size :: non_neg_integer()
+}).
+
+-type s() :: #s{}.
+
+-record(shard_event, {
+    events :: [{_Stream, _Topic}]
+}).
+
+-define(fulfill_loop, fulfill_loop).
+-define(housekeeping_loop, housekeeping_loop).
+
+%%================================================================================
+%% Callbacks
+%%================================================================================
+
+-type match_messagef() :: fun((emqx_ds:message_key(), emqx_types:message()) -> boolean()).
+
+-type unpack_iterator_result(Stream) :: #{
+    stream := Stream,
+    topic_filter := _,
+    last_seen_key := emqx_ds:message_key(),
+    timestamp := emqx_ds:time(),
+    message_matcher := match_messagef()
+}.
+
+-callback unpack_iterator(_Shard, _Iterator) ->
+    unpack_iterator_result(_Stream) | emqx_ds:error(_).
+
+-callback update_iterator(_Shard, Iterator, emqx_ds:message_key()) ->
+    emqx_ds:make_iterator_result(Iterator).
+
+-callback scan_stream(_Shard, _Stream, _TopicFilter, _StartKey, _BatchSize :: non_neg_integer()) ->
+    stream_scan_return().
+
+%%================================================================================
+%% API functions
+%%================================================================================
+
+-spec poll(node(), return_addr(_ItKey), _Shard, _Iterator, emqx_ds:poll_opts()) ->
+    ok.
+poll(Node, ReturnAddr, Shard, Iterator, Opts = #{timeout := Timeout}) ->
+    ?tp(emqx_ds_beamformer_poll, #{
+        node => Node, return_addr => ReturnAddr, shard => Shard, it => Iterator, timeout => Timeout
+    }),
+    CBM = emqx_ds_beamformer_sup:cbm(Shard),
+    #{
+        stream := Stream,
+        topic_filter := TF,
+        last_seen_key := DSKey,
+        timestamp := Timestamp,
+        message_matcher := MsgMatcher
+    } = CBM:unpack_iterator(Shard, Iterator),
+    Deadline = erlang:monotonic_time(millisecond) + Timeout,
+    logger:debug(#{
+        msg => poll, shard => Shard, key => DSKey, timeout => Timeout, deadline => Deadline
+    }),
+    %% Try to maximize likelyhood of sending similar iterators to the
+    %% same worker:
+    Token = {Stream, Timestamp div 10_000_000},
+    Worker = gproc_pool:pick_worker(
+        emqx_ds_beamformer_sup:pool(Shard),
+        Token
+    ),
+    %% Make request:
+    Req = #poll_req{
+        key = {Stream, TF, DSKey},
+        node = Node,
+        return_addr = ReturnAddr,
+        it = Iterator,
+        opts = Opts,
+        deadline = Deadline,
+        msg_matcher = MsgMatcher
+    },
+    emqx_ds_builtin_metrics:inc_poll_requests(shard_metrics_id(Shard), 1),
+    %% Currently we implement backpressure by ignoring transient
+    %% errors (gen_server timeouts, `too_many_requests'), and just
+    %% letting poll requests expire at the higher level. This should
+    %% hold back the caller.
+    try gen_server:call(Worker, Req, Timeout) of
+        ok -> ok;
+        {error, recoverable, too_many_requests} -> ok
+    catch
+        exit:timeout ->
+            ok
+    end.
+
+shard_event(Shard, Events) ->
+    Workers = gproc_pool:active_workers(emqx_ds_beamformer_sup:pool(Shard)),
+    lists:foreach(
+        fun({_, Pid}) ->
+            Pid ! #shard_event{events = Events}
+        end,
+        Workers
+    ).
+
+%%================================================================================
+%% behavior callbacks
+%%================================================================================
+
+init([CBM, ShardId, Name, _Opts]) ->
+    process_flag(trap_exit, true),
+    Pool = emqx_ds_beamformer_sup:pool(ShardId),
+    gproc_pool:add_worker(Pool, Name),
+    gproc_pool:connect_worker(Pool, Name),
+    PendingTab = ets:new(pending_polls, [duplicate_bag, private, {keypos, #poll_req.key}]),
+    WaitingTab = emqx_ds_beamformer_waitq:new(),
+    S = #s{
+        module = CBM,
+        shard = ShardId,
+        metrics_id = shard_metrics_id(ShardId),
+        name = Name,
+        pending_queue = PendingTab,
+        wait_queue = WaitingTab,
+        pending_request_limit = cfg_pending_request_limit(),
+        batch_size = cfg_batch_size()
+    },
+    self() ! ?housekeeping_loop,
+    {ok, S}.
+
+handle_call(
+    Req = #poll_req{},
+    _From,
+    S = #s{pending_queue = PendingTab, wait_queue = WaitingTab, metrics_id = Metrics}
+) ->
+    NQueued = ets:info(PendingTab, size) + ets:info(WaitingTab, size),
+    case NQueued >= S#s.pending_request_limit of
+        true ->
+            emqx_ds_builtin_metrics:inc_poll_requests_dropped(Metrics, 1),
+            Reply = {error, recoverable, too_many_requests},
+            {reply, Reply, S};
+        false ->
+            ets:insert(S#s.pending_queue, Req),
+            {reply, ok, start_fulfill_loop(S)}
+    end;
+handle_call(_Call, _From, S) ->
+    {reply, {error, unknown_call}, S}.
+
+handle_cast(_Cast, S) ->
+    {noreply, S}.
+
+handle_info(#shard_event{events = Events}, S) ->
+    ?tp(debug, emqx_ds_beamformer_event, #{events => Events}),
+    {noreply, maybe_fulfill_waiting(S, Events)};
+handle_info(?fulfill_loop, S0) ->
+    S1 = S0#s{is_spinning = false},
+    S = fulfill_pending(S1),
+    {noreply, S};
+handle_info(?housekeeping_loop, S0) ->
+    %% Reload configuration according from environment variables:
+    S1 = S0#s{
+        batch_size = cfg_batch_size(),
+        pending_request_limit = cfg_pending_request_limit()
+    },
+    S = cleanup(S1),
+    erlang:send_after(cfg_housekeeping_interval(), self(), ?housekeeping_loop),
+    {noreply, S};
+handle_info(_Info, S) ->
+    {noreply, S}.
+
+terminate(_Reason, #s{shard = ShardId, name = Name}) ->
+    Pool = emqx_ds_beamformer_sup:pool(ShardId),
+    gproc_pool:disconnect_worker(Pool, Name),
+    gproc_pool:remove_worker(Pool, Name),
+    ok.
+
+%%================================================================================
+%% Internal exports
+%%================================================================================
+
+-spec start_link(module(), _Shard, integer(), opts()) -> {ok, pid()}.
+start_link(Mod, ShardId, Name, Opts) ->
+    gen_server:start_link(?MODULE, [Mod, ShardId, Name, Opts], []).
+
+%% @doc RPC target: split the beam and dispatch replies to local
+%% consumers.
+-spec do_dispatch(beam()) -> ok.
+do_dispatch(Beam = #beam{}) ->
+    lists:foreach(
+        fun({{Alias, ItKey}, Result}) ->
+            Alias ! #poll_reply{ref = Alias, userdata = ItKey, payload = Result}
+        end,
+        split(Beam)
+    ).
+
+%%================================================================================
+%% Internal functions
+%%================================================================================
+
+-spec start_fulfill_loop(s()) -> s().
+start_fulfill_loop(S = #s{is_spinning = true}) ->
+    S;
+start_fulfill_loop(S = #s{is_spinning = false}) ->
+    self() ! ?fulfill_loop,
+    S#s{is_spinning = true}.
+
+-spec cleanup(s()) -> s().
+cleanup(S = #s{pending_queue = PendingTab, wait_queue = WaitingTab, metrics_id = Metrics}) ->
+    do_cleanup(Metrics, PendingTab),
+    do_cleanup(Metrics, WaitingTab),
+    %% erlang:garbage_collect(),
+    S.
+
+do_cleanup(Metrics, Tab) ->
+    Now = erlang:monotonic_time(millisecond),
+    MS = {#poll_req{_ = '_', deadline = '$1'}, [{'<', '$1', Now}], [true]},
+    NDeleted = ets:select_delete(Tab, [MS]),
+    emqx_ds_builtin_metrics:inc_poll_requests_expired(Metrics, NDeleted).
+
+-spec fulfill_pending(s()) -> s().
+fulfill_pending(S = #s{pending_queue = PendingTab}) ->
+    %% debug_pending(S),
+    case find_older_request(PendingTab, 100) of
+        undefined ->
+            S;
+        Req ->
+            ?tp(emqx_ds_beamformer_fulfill_pending, #{req => Req}),
+            %% The function MUST destructively consume all requests
+            %% matching stream and MsgKey to avoid infinite loop:
+            do_fulfill_pending(S, Req),
+            start_fulfill_loop(S)
+    end.
+
+do_fulfill_pending(
+    S = #s{
+        shard = Shard,
+        module = CBM,
+        pending_queue = PendingTab,
+        batch_size = BatchSize
+    },
+    #poll_req{key = {Stream, TopicFilter, StartKey}}
+) ->
+    OnMatch = fun(_) -> ok end,
+    %% Here we only group requests with exact match of the topic
+    %% filter:
+    GetF = fun(MsgKey) ->
+        ets:take(PendingTab, {Stream, TopicFilter, MsgKey})
+    end,
+    Result = CBM:scan_stream(Shard, Stream, TopicFilter, StartKey, BatchSize),
+    form_beams(S, GetF, OnMatch, move_to_waiting(S), StartKey, Result).
+
+maybe_fulfill_waiting(S, []) ->
+    S;
+maybe_fulfill_waiting(
+    S = #s{wait_queue = WaitingTab, module = CBM, shard = Shard, batch_size = BatchSize},
+    [{Stream, UpdatedTopic} | Rest]
+) ->
+    case find_waiting(Stream, UpdatedTopic, WaitingTab) of
+        undefined ->
+            ?tp(emqx_ds_beamformer_fulfill_waiting, #{
+                stream => Stream, topic => UpdatedTopic, candidates => undefined
+            }),
+            maybe_fulfill_waiting(S, Rest);
+        {Candidates, TopicFilter, StartKey} ->
+            ?tp(emqx_ds_beamformer_fulfill_waiting, #{
+                stream => Stream,
+                topic => UpdatedTopic,
+                candidates => Candidates,
+                start_time => StartKey
+            }),
+            GetF = fun(Key) -> maps:get(Key, Candidates, []) end,
+            OnNomatch = fun(_) -> ok end,
+            OnMatch = fun(Reqs) ->
+                lists:foreach(
+                    fun(#poll_req{key = {Str, TF, _}, return_addr = Id}) ->
+                        emqx_ds_beamformer_waitq:delete(Str, TF, Id, WaitingTab)
+                    end,
+                    Reqs
+                )
+            end,
+            Result = CBM:scan_stream(Shard, Stream, TopicFilter, StartKey, BatchSize),
+            case form_beams(S, GetF, OnMatch, OnNomatch, StartKey, Result) of
+                true -> maybe_fulfill_waiting(S, [{Stream, UpdatedTopic} | Rest]);
+                false -> maybe_fulfill_waiting(S, Rest)
+            end
+    end.
+
+move_to_waiting(#s{wait_queue = WaitingTab}) ->
+    fun(NoMatch) ->
+        lists:foreach(
+            fun(Req = #poll_req{key = {Stream, TopicFilter, _Key}, return_addr = Id}) ->
+                emqx_ds_beamformer_waitq:insert(Stream, TopicFilter, Id, Req, WaitingTab)
+            end,
+            NoMatch
+        )
+    end.
+
+find_waiting(Stream, Topic, Tab) ->
+    case emqx_ds_beamformer_waitq:matches(Stream, Topic, Tab) of
+        [] ->
+            undefined;
+        [Fst | _] = Matches ->
+            %% 1. Find all poll requests that match the topic of the
+            %% event
+            %%
+            %% 2. Find most common topic filter for all these events
+            %%
+            %% 3. Find the smallest DS key
+            lists:foldl(
+                fun(Req, {Acc, AccTopic, AccKey}) ->
+                    ReqKey = ds_key_of_poll(Req),
+                    {
+                        map_pushl(ReqKey, Req, Acc),
+                        common_topic_filter(AccTopic, topic_of_poll(Req)),
+                        min(AccKey, ReqKey)
+                    }
+                end,
+                {#{}, topic_of_poll(Fst), ds_key_of_poll(Fst)},
+                Matches
+            )
+    end.
+
+ds_key_of_poll(#poll_req{key = {_, _, Key}}) -> Key.
+
+topic_of_poll(#poll_req{key = {_, Topic, _}}) -> Topic.
+
+common_topic_filter([], []) ->
+    [];
+common_topic_filter(['#'], _) ->
+    ['#'];
+common_topic_filter(_, ['#']) ->
+    ['#'];
+common_topic_filter(['+' | L1], [_ | L2]) ->
+    ['+' | common_topic_filter(L1, L2)];
+common_topic_filter([_ | L1], ['+' | L2]) ->
+    ['+' | common_topic_filter(L1, L2)];
+common_topic_filter([A | L1], [A | L2]) ->
+    [A | common_topic_filter(L1, L2)].
+
+map_pushl(Key, Elem, Map) ->
+    maps:update_with(Key, fun(L) -> [Elem | L] end, [Elem], Map).
+
+%% It's always worth trying to fulfill the oldest requests first,
+%% because they have a better chance of producing a batch that
+%% overlaps with other pending requests.
+%%
+%% This function implements a heuristic that tries to find such poll
+%% request. It simply compares the keys (and nothing else) within a
+%% small sample of pending polls, and picks request with the smallest
+%% key as the starting point.
+find_older_request(Tab, SampleSize) ->
+    MS = {'_', [], ['$_']},
+    case ets:select(Tab, [MS], SampleSize) of
+        '$end_of_table' ->
+            undefined;
+        {[Fst | Rest], _Cont} ->
+            %% Find poll request with the minimal key:
+            lists:foldl(
+                fun(E, Acc) ->
+                    case ds_key_of_poll(E) < ds_key_of_poll(Acc) of
+                        true -> E;
+                        false -> Acc
+                    end
+                end,
+                Fst,
+                Rest
+            )
+    end.
+
+%% @doc Split beam into individual batches
+-spec split(beam(ItKey, Iterator)) -> [{ItKey, emqx_ds:next_result(Iterator)}].
+split(#beam{iterators = Its, pack = end_of_stream}) ->
+    [{ItKey, {ok, end_of_stream}} || {ItKey, _Iter} <- Its];
+split(#beam{iterators = Its, pack = {error, _, _} = Err}) ->
+    [{ItKey, Err} || {ItKey, _Iter} <- Its];
+split(#beam{iterators = Its, pack = Pack}) ->
+    split(Its, Pack, 0, []).
+
+%% This function checks pending requests in the `FromTab' and either
+%% dispatches them as a beam or passes them to `OnNomatch' callback if
+%% there's nothing to dispatch.
+-spec form_beams(
+    s(),
+    fun((emqx_ds:message_key()) -> [Req]),
+    fun(([Req]) -> _),
+    fun(([Req]) -> _),
+    emqx_ds:message_key(),
+    stream_scan_return()
+) -> boolean() when Req :: poll_req(_ItKey, _It).
+form_beams(S, GetF, OnMatch, OnNomatch, StartKey, {ok, EndKey, Batch}) ->
+    do_form_beams(S, GetF, OnMatch, OnNomatch, StartKey, EndKey, Batch);
+form_beams(#s{metrics_id = Metrics}, GetF, OnMatch, _OnNomatch, StartKey, Result) ->
+    Pack =
+        case Result of
+            Err = {error, _, _} -> Err;
+            {ok, end_of_stream} -> end_of_stream
+        end,
+    MatchReqs = GetF(StartKey),
+    %% Report metrics:
+    NFulfilled = length(MatchReqs),
+    NFulfilled > 0 andalso
+        begin
+            emqx_ds_builtin_metrics:inc_poll_requests_fulfilled(Metrics, NFulfilled),
+            emqx_ds_builtin_metrics:observe_sharing(Metrics, NFulfilled)
+        end,
+    %% Execute callbacks:
+    OnMatch(MatchReqs),
+    %% Split matched requests by destination node:
+    ReqsByNode = maps:groups_from_list(
+        fun(#poll_req{node = Node}) -> Node end,
+        fun(#poll_req{return_addr = RAddr, it = It}) ->
+            {RAddr, It}
+        end,
+        MatchReqs
+    ),
+    %% Pack requests into beams and serve them:
+    maps:foreach(
+        fun(Node, Its) ->
+            Beam = #beam{
+                pack = Pack,
+                iterators = Its
+            },
+            send_out(Node, Beam)
+        end,
+        ReqsByNode
+    ),
+    NFulfilled > 0.
+
+-spec do_form_beams(
+    s(),
+    fun((emqx_ds:message_key()) -> [Req]),
+    fun(([Req]) -> _),
+    fun(([Req]) -> _),
+    emqx_ds:message_key(),
+    emqx_ds:message_key(),
+    [{emqx_ds:message_key(), emqx_types:message()}]
+) -> boolean() when Req :: poll_req(_ItKey, _It).
+do_form_beams(
+    #s{metrics_id = Metrics, module = CBM, shard = Shard},
+    GetF,
+    OnMatch,
+    OnNomatch,
+    StartKey,
+    EndKey,
+    Batch
+) ->
+    %% Find iterators that match the start message of the batch (to
+    %% handle iterators freshly created by `emqx_ds:make_iterator'):
+    Candidates0 = GetF(StartKey),
+    %% Search for iterators where `last_seen_key' is equal to key of
+    %% any message in the batch:
+    Candidates = lists:foldl(
+        fun({Key, _Msg}, Acc) ->
+            GetF(Key) ++ Acc
+        end,
+        Candidates0,
+        Batch
+    ),
+    %% Find what poll requests _actually_ have data in the batch. It's
+    %% important not to send empty batches to the consumers, so they
+    %% don't come back immediately, creating a busy loop:
+    {MatchReqs, NoMatchReqs} = filter_candidates(Candidates, Batch),
+    ?tp(emqx_ds_beamformer_form_beams, #{match => MatchReqs, no_match => NoMatchReqs}),
+    %% Report metrics:
+    NFulfilled = length(MatchReqs),
+    NFulfilled > 0 andalso
+        begin
+            emqx_ds_builtin_metrics:inc_poll_requests_fulfilled(Metrics, NFulfilled),
+            emqx_ds_builtin_metrics:observe_sharing(Metrics, NFulfilled)
+        end,
+    %% Execute callbacks:
+    OnMatch(MatchReqs),
+    OnNomatch(NoMatchReqs),
+    %% Split matched requests by destination node:
+    ReqsByNode = maps:groups_from_list(fun(#poll_req{node = Node}) -> Node end, MatchReqs),
+    %% Pack requests into beams and serve them:
+    UpdateIterator = fun(Iterator, NextKey) ->
+        CBM:update_iterator(Shard, Iterator, NextKey)
+    end,
+    maps:foreach(
+        fun(Node, Reqs) ->
+            Beam = pack(UpdateIterator, EndKey, Reqs, Batch),
+            send_out(Node, Beam)
+        end,
+        ReqsByNode
+    ),
+    NFulfilled > 0.
+
+-spec pack(
+    fun((Iterator, emqx_ds:message_key()) -> Iterator),
+    emqx_ds:message_key(),
+    [{ItKey, Iterator}],
+    [{emqx_ds:message_key(), emqx_types:message()}]
+) -> beam(ItKey, Iterator).
+pack(UpdateIterator, NextKey, Reqs, Batch) ->
+    Pack = [{Key, mk_mask(Reqs, Elem), Msg} || Elem = {Key, Msg} <- Batch],
+    UpdatedIterators =
+        lists:map(
+            fun(#poll_req{it = It0, return_addr = RAddr}) ->
+                {ok, It} = UpdateIterator(It0, NextKey),
+                {RAddr, It}
+            end,
+            Reqs
+        ),
+    #beam{
+        iterators = UpdatedIterators,
+        pack = Pack
+    }.
+
+split([], _Pack, _N, Acc) ->
+    Acc;
+split([{ItKey, It} | Rest], Pack, N, Acc0) ->
+    Msgs = [
+        {MsgKey, Msg}
+     || {MsgKey, Mask, Msg} <- Pack,
+        is_member(N, Mask)
+    ],
+    case Msgs of
+        [] -> logger:warning("Empty batch ~p", [ItKey]);
+        _ -> ok
+    end,
+    Acc = [{ItKey, {ok, It, Msgs}} | Acc0],
+    split(Rest, Pack, N + 1, Acc).
+
+-spec is_member(non_neg_integer(), dispatch_mask()) -> boolean().
+is_member(N, Mask) ->
+    <<_:N, Val:1, _/bitstring>> = Mask,
+    Val =:= 1.
+
+-spec mk_mask([poll_req(_ItKey, _Iterator)], {emqx_ds:message_key(), emqx_types:message()}) ->
+    dispatch_mask().
+mk_mask(Reqs, Elem) ->
+    mk_mask(Reqs, Elem, <<>>).
+
+mk_mask([], _Elem, Acc) ->
+    Acc;
+mk_mask([#poll_req{msg_matcher = Matcher} | Rest], {Key, Message} = Elem, Acc) ->
+    Val =
+        case Matcher(Key, Message) of
+            true -> 1;
+            false -> 0
+        end,
+    mk_mask(Rest, Elem, <<Acc/bitstring, Val:1>>).
+
+filter_candidates(Reqs, Messages) ->
+    lists:partition(
+        fun(#poll_req{msg_matcher = Matcher}) ->
+            lists:any(
+                fun({MsgKey, Msg}) -> Matcher(MsgKey, Msg) end,
+                Messages
+            )
+        end,
+        Reqs
+    ).
+
+send_out(Node, Beam) ->
+    ?tp(debug, beamformer_out, #{
+        dest_node => Node,
+        beam => Beam
+    }),
+    emqx_ds_beamsplitter_proto_v1:dispatch(Node, Beam),
+    ok.
+
+shard_metrics_id({DB, Shard}) ->
+    emqx_ds_builtin_metrics:shard_metric_id(DB, Shard).
+
+%% Dynamic config (currently it's global for all DBs):
+
+cfg_pending_request_limit() ->
+    application:get_env(emqx_durable_storage, poll_pending_request_limit, 100_000).
+
+cfg_batch_size() ->
+    application:get_env(emqx_durable_storage, poll_batch_size, 100).
+
+cfg_housekeeping_interval() ->
+    application:get_env(emqx_durable_storage, beamformer_housekeeping_interval, 1000).
+
+%%================================================================================
+%% Tests
+%%================================================================================
+
+-ifdef(TEST).
+
+is_member_test_() ->
+    [
+        ?_assert(is_member(0, <<1:1>>)),
+        ?_assertNot(is_member(0, <<0:1>>)),
+        ?_assertNot(is_member(5, <<0:10>>)),
+
+        ?_assert(is_member(0, <<255:8>>)),
+        ?_assert(is_member(7, <<255:8>>)),
+
+        ?_assertNot(is_member(7, <<0:8, 1:1, 0:10>>)),
+        ?_assert(is_member(8, <<0:8, 1:1, 0:10>>)),
+        ?_assertNot(is_member(9, <<0:8, 1:1, 0:10>>))
+    ].
+
+pack_test_() ->
+    UpdateIterator = fun(It0, NextKey) ->
+        {ok, setelement(2, It0, NextKey)}
+    end,
+    Raddr = raddr,
+    NextKey = <<"42">>,
+    Req1 = #poll_req{
+        return_addr = Raddr,
+        msg_matcher = fun(_, _) -> true end,
+        it = {it1, <<"0">>}
+    },
+    Req2 = #poll_req{
+        return_addr = Raddr,
+        msg_matcher = fun(_, _) -> false end,
+        it = {it2, <<"1">>}
+    },
+    Req3 = #poll_req{
+        return_addr = Raddr,
+        msg_matcher = fun(_, _) -> true end,
+        it = {it3, <<"2">>}
+    },
+    %% Messages:
+    M1 = {<<"1">>, #message{id = <<"1">>}},
+    M2 = {NextKey, #message{id = <<"2">>}},
+    Reqs = [Req1, Req2, Req3],
+    [
+        ?_assertMatch(
+            #beam{
+                iterators = [],
+                pack = [
+                    {<<"1">>, <<>>, #message{id = <<"1">>}},
+                    {Next, <<>>, #message{id = <<"2">>}}
+                ]
+            },
+            pack(UpdateIterator, NextKey, [], [M1, M2])
+        ),
+        ?_assertMatch(
+            #beam{
+                iterators = [
+                    {Raddr, {it1, NextKey}}, {Raddr, {it2, NextKey}}, {Raddr, {it3, NextKey}}
+                ],
+                pack = [
+                    {<<"1">>, <<1:1, 0:1, 1:1>>, #message{id = <<"1">>}},
+                    {NextKey, <<1:1, 0:1, 1:1>>, #message{id = <<"2">>}}
+                ]
+            },
+            pack(UpdateIterator, NextKey, Reqs, [M1, M2])
+        )
+    ].
+
+split_test_() ->
+    Always = fun(_It, _Msg) -> true end,
+    M1 = {<<"1">>, #message{id = <<"1">>}},
+    M2 = {<<"2">>, #message{id = <<"2">>}},
+    M3 = {<<"3">>, #message{id = <<"3">>}},
+    Its = [{<<"it1">>, it1}, {<<"it2">>, it2}],
+    Beam1 = #beam{
+        iterators = Its,
+        pack = [
+            {<<"1">>, <<1:1, 0:1>>, element(2, M1)},
+            {<<"2">>, <<0:1, 1:1>>, element(2, M2)},
+            {<<"3">>, <<1:1, 1:1>>, element(2, M3)}
+        ]
+    },
+    [{<<"it1">>, Result1}, {<<"it2">>, Result2}] = lists:sort(split(Beam1)),
+    [
+        ?_assertMatch(
+            {ok, it1, [M1, M3]},
+            Result1
+        ),
+        ?_assertMatch(
+            {ok, it2, [M2, M3]},
+            Result2
+        )
+    ].
+
+-endif.

+ 123 - 0
apps/emqx_durable_storage/src/emqx_ds_beamformer_sup.erl

@@ -0,0 +1,123 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+-module(emqx_ds_beamformer_sup).
+
+-behaviour(supervisor).
+
+%% API:
+-export([start_link/3, pool/1, cbm/1]).
+
+%% behavior callbacks:
+-export([init/1]).
+
+%% internal exports:
+-export([start_workers/3, init_pool_owner/3]).
+
+-export_type([]).
+
+%%================================================================================
+%% Type declarations
+%%================================================================================
+
+-define(SUP(SHARD), {n, l, {?MODULE, SHARD}}).
+
+-define(cbm(DB), {?MODULE, DB}).
+
+%%================================================================================
+%% API functions
+%%================================================================================
+
+-spec cbm(_Shard) -> module().
+cbm(DB) ->
+    persistent_term:get(?cbm(DB)).
+
+pool(Shard) ->
+    {?MODULE, Shard}.
+
+-spec start_link(module(), _Shard, emqx_ds_beamformer:opts()) -> supervisor:startlink_ret().
+start_link(CBM, ShardId, Opts) ->
+    supervisor:start_link(
+        {via, gproc, ?SUP(ShardId)}, ?MODULE, {top, CBM, ShardId, Opts}
+    ).
+
+%%================================================================================
+%% behavior callbacks
+%%================================================================================
+
+init({top, Module, ShardId, Opts}) ->
+    Children = [
+        #{
+            id => pool_owner,
+            type => worker,
+            start => {proc_lib, start_link, [?MODULE, init_pool_owner, [self(), ShardId, Module]]}
+        },
+        #{
+            id => workers,
+            type => supervisor,
+            shutdown => infinity,
+            start => {?MODULE, start_workers, [Module, ShardId, Opts]}
+        }
+    ],
+    SupFlags = #{
+        strategy => one_for_all,
+        intensity => 1,
+        period => 1
+    },
+    {ok, {SupFlags, Children}};
+init({workers, Module, ShardId, Opts}) ->
+    #{n_workers := InitialNWorkers} = Opts,
+    Children = [
+        #{
+            id => I,
+            type => worker,
+            shutdown => 5000,
+            start => {emqx_ds_beamformer, start_link, [Module, ShardId, I, Opts]}
+        }
+     || I <- lists:seq(1, InitialNWorkers)
+    ],
+    SupFlags = #{
+        strategy => one_for_one,
+        intensity => 10,
+        period => 10
+    },
+    {ok, {SupFlags, Children}}.
+
+%%================================================================================
+%% Internal exports
+%%================================================================================
+
+start_workers(Module, ShardId, InitialNWorkers) ->
+    supervisor:start_link(?MODULE, {workers, Module, ShardId, InitialNWorkers}).
+
+%% Helper process that automatically destroys gproc pool when
+%% supervisor is stopped:
+-spec init_pool_owner(pid(), _Shard, module()) -> no_return().
+init_pool_owner(Parent, ShardId, Module) ->
+    process_flag(trap_exit, true),
+    gproc_pool:new(pool(ShardId), hash, [{auto_size, true}]),
+    persistent_term:put(?cbm(ShardId), Module),
+    proc_lib:init_ack(Parent, {ok, self()}),
+    %% Automatic cleanup:
+    receive
+        {'EXIT', _Pid, Reason} ->
+            gproc_pool:force_delete(pool(ShardId)),
+            persistent_term:erase(?cbm(ShardId)),
+            exit(Reason)
+    end.
+
+%%================================================================================
+%% Internal functions
+%%================================================================================

+ 107 - 0
apps/emqx_durable_storage/src/emqx_ds_beamformer_waitq.erl

@@ -0,0 +1,107 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+%% @doc This helper module matches stream events to waiting poll
+%% requests in the beamformer.
+-module(emqx_ds_beamformer_waitq).
+
+%% API:
+-export([new/0, insert/5, delete/4, matches/3]).
+
+-ifdef(TEST).
+-include_lib("eunit/include/eunit.hrl").
+-endif.
+
+%%================================================================================
+%% API functions
+%%================================================================================
+
+new() ->
+    ets:new(?MODULE, [ordered_set, private]).
+
+insert(Stream, Filter, ID, Record, Tab) ->
+    Key = make_key(Stream, Filter, ID),
+    true = ets:insert(Tab, {Key, Record}).
+
+delete(Stream, Filter, ID, Tab) ->
+    ets:delete(Tab, make_key(Stream, Filter, ID)).
+
+matches(Stream, Topic, Tab) ->
+    Ids = emqx_trie_search:matches(Topic, make_nextf(Stream, Tab), []),
+    [Val || Id <- Ids, {_, Val} <- ets:lookup(Tab, {Stream, Id})].
+
+%%================================================================================
+%% Internal functions
+%%================================================================================
+
+make_key(Stream, TopicFilter, ID) ->
+    {Stream, emqx_trie_search:make_key(TopicFilter, ID)}.
+
+make_nextf(Stream, Tab) ->
+    fun(Key0) ->
+        case ets:next(Tab, {Stream, Key0}) of
+            '$end_of_table' -> '$end_of_table';
+            {Stream, Key} -> Key;
+            {_OtherStream, _Key} -> '$end_of_table'
+        end
+    end.
+
+%%================================================================================
+%% Tests
+%%================================================================================
+
+-ifdef(TEST).
+
+topic_match_test() ->
+    Tab = new(),
+    insert(s1, [<<"foo">>, '+'], 1, {val, 1}, Tab),
+    insert(s1, [<<"foo">>, <<"bar">>], 2, {val, 2}, Tab),
+    insert(s1, [<<"1">>, <<"2">>], 3, {val, 3}, Tab),
+
+    insert(s2, [<<"foo">>, '+'], 4, {val, 4}, Tab),
+    insert(s2, [<<"foo">>, <<"bar">>], 5, {val, 5}, Tab),
+    insert(s2, [<<"1">>, <<"2">>], 6, {val, 6}, Tab),
+
+    ?assertEqual(
+        [{val, 1}],
+        lists:sort(matches(s1, [<<"foo">>, <<"2">>], Tab))
+    ),
+    ?assertEqual(
+        [{val, 4}],
+        lists:sort(matches(s2, [<<"foo">>, <<"2">>], Tab))
+    ),
+    ?assertEqual(
+        [{val, 1}, {val, 2}],
+        lists:sort(matches(s1, [<<"foo">>, <<"bar">>], Tab))
+    ),
+    ?assertEqual(
+        [{val, 4}, {val, 5}],
+        lists:sort(matches(s2, [<<"foo">>, <<"bar">>], Tab))
+    ),
+    ?assertEqual(
+        [],
+        matches(s3, [<<"foo">>, <<"bar">>], Tab)
+    ),
+    ?assertEqual(
+        [{val, 3}],
+        matches(s1, [<<"1">>, <<"2">>], Tab)
+    ),
+    ?assertEqual(
+        [{val, 6}],
+        matches(s2, [<<"1">>, <<"2">>], Tab)
+    ).
+
+-endif.

+ 30 - 1
apps/emqx_durable_storage/src/emqx_ds_builtin_metrics.erl

@@ -34,6 +34,12 @@
 
     observe_next_time/2,
 
+    observe_sharing/2,
+    inc_poll_requests/2,
+    inc_poll_requests_fulfilled/2,
+    inc_poll_requests_dropped/2,
+    inc_poll_requests_expired/2,
+
     inc_lts_seek_counter/2,
     inc_lts_next_counter/2,
     inc_lts_collision_counter/2,
@@ -86,7 +92,15 @@
     {slide, ?DS_BUFFER_FLUSH_TIME}
 ]).
 
--define(SHARD_METRICS, ?BUFFER_METRICS).
+-define(BEAMFORMER_METRICS, [
+    {counter, ?DS_POLL_REQUESTS},
+    {counter, ?DS_POLL_REQUESTS_FULFILLED},
+    {counter, ?DS_POLL_REQUESTS_DROPPED},
+    {counter, ?DS_POLL_REQUESTS_EXPIRED},
+    {slide, ?DS_POLL_REQUEST_SHARING}
+]).
+
+-define(SHARD_METRICS, ?BEAMFORMER_METRICS ++ ?BUFFER_METRICS).
 
 -type shard_metrics_id() :: binary().
 
@@ -157,6 +171,21 @@ observe_store_batch_time({DB, _}, StoreTime) ->
 observe_next_time(DB, NextTime) ->
     catch emqx_metrics_worker:observe(?WORKER, DB, ?DS_BUILTIN_NEXT_TIME, NextTime).
 
+observe_sharing(Id, Sharing) ->
+    catch emqx_metrics_worker:observe(?WORKER, Id, ?DS_POLL_REQUEST_SHARING, Sharing).
+
+inc_poll_requests(Id, NPolls) ->
+    catch emqx_metrics_worker:inc(?WORKER, Id, ?DS_POLL_REQUESTS, NPolls).
+
+inc_poll_requests_fulfilled(Id, NPolls) ->
+    catch emqx_metrics_worker:inc(?WORKER, Id, ?DS_POLL_REQUESTS_FULFILLED, NPolls).
+
+inc_poll_requests_expired(Id, NPolls) ->
+    catch emqx_metrics_worker:inc(?WORKER, Id, ?DS_POLL_REQUESTS_EXPIRED, NPolls).
+
+inc_poll_requests_dropped(Id, N) ->
+    catch emqx_metrics_worker:inc(?WORKER, Id, ?DS_POLL_REQUESTS_DROPPED, N).
+
 -spec inc_lts_seek_counter(emqx_ds_storage_layer:shard_id(), non_neg_integer()) -> ok.
 inc_lts_seek_counter({DB, _}, Inc) ->
     catch emqx_metrics_worker:inc(?WORKER, DB, ?DS_BITFIELD_LTS_SEEK_COUNTER, Inc).

+ 5 - 4
apps/emqx_durable_storage/src/emqx_ds_lib.erl

@@ -18,7 +18,7 @@
 -include("emqx_ds.hrl").
 
 %% API:
--export([anext_helper/3]).
+-export([with_worker/4]).
 
 %% internal exports:
 -export([]).
@@ -33,9 +33,10 @@
 %% API functions
 %%================================================================================
 
-anext_helper(Mod, Function, Args) ->
+-spec with_worker(_UserData, module(), atom(), list()) -> {ok, reference()}.
+with_worker(UserData, Mod, Function, Args) ->
     ReplyTo = alias([reply]),
-    spawn_opt(
+    _ = spawn_opt(
         fun() ->
             Result =
                 try
@@ -48,7 +49,7 @@ anext_helper(Mod, Function, Args) ->
                             stacktrace => Stack
                         }}
                 end,
-            ReplyTo ! #ds_async_result{ref = ReplyTo, data = Result}
+            ReplyTo ! #poll_reply{userdata = UserData, ref = ReplyTo, payload = Result}
         end,
         [link, {min_heap_size, 10000}]
     ),

+ 48 - 1
apps/emqx_durable_storage/src/emqx_ds_storage_bitfield_lts.erl

@@ -39,7 +39,12 @@
     delete_next/7,
     lookup_message/3,
 
-    handle_event/4
+    handle_event/4,
+
+    unpack_iterator/3,
+    scan_stream/8,
+    message_matcher/3,
+    batch_events/2
 ]).
 
 %% internal exports:
@@ -628,6 +633,48 @@ handle_event(_ShardId, _Data, _Time, _Event) ->
     %% of `Time' in the replication layer.
     [].
 
+unpack_iterator(_Shard, _S, #{
+    ?tag := ?IT,
+    ?storage_key := Stream,
+    ?topic_filter := TF,
+    ?last_seen_key := <<>>,
+    ?start_time := StartTime
+}) ->
+    {Stream, TF, <<>>, StartTime};
+unpack_iterator(_Shard, #s{keymappers = Keymappers}, #{
+    ?tag := ?IT, ?storage_key := Stream, ?topic_filter := TF, ?last_seen_key := LSK
+}) ->
+    {_, Varying} = Stream,
+    NVarying = length(Varying),
+    Keymapper = array:get(NVarying, Keymappers),
+    Timestamp = emqx_ds_bitmask_keymapper:bin_key_to_coord(Keymapper, LSK, ?DIM_TS),
+    {Stream, TF, LSK, Timestamp}.
+
+scan_stream(Shard, S, Stream, TopicFilter, LastSeenKey, BatchSize, TMax, IsCurrent) ->
+    It = #{
+        ?tag => ?IT,
+        ?topic_filter => TopicFilter,
+        ?start_time => 0,
+        ?storage_key => Stream,
+        ?last_seen_key => LastSeenKey
+    },
+    case next(Shard, S, It, BatchSize, TMax, IsCurrent) of
+        {ok, #{?last_seen_key := LSK}, Batch} ->
+            {ok, LSK, Batch};
+        Other ->
+            Other
+    end.
+
+message_matcher(_Shard, #s{}, #{?tag := ?IT, ?last_seen_key := LSK, ?topic_filter := TF}) ->
+    fun(MsgKey, #message{topic = Topic}) ->
+        MsgKey > LSK andalso emqx_topic:match(emqx_topic:tokens(Topic), TF)
+    end.
+
+batch_events(_S, _CookedBatch) ->
+    %% FIXME: here we rely on the fact that bitfield_lts layout is
+    %% deprecated.
+    [].
+
 %%================================================================================
 %% Internal functions
 %%================================================================================

+ 104 - 6
apps/emqx_durable_storage/src/emqx_ds_storage_layer.erl

@@ -26,8 +26,10 @@
 
     %% Data
     store_batch/3,
+    store_batch/4,
     prepare_batch/3,
     commit_batch/3,
+    dispatch_events/3,
 
     get_streams/3,
     get_delete_streams/3,
@@ -35,6 +37,12 @@
     make_delete_iterator/4,
     update_iterator/3,
     next/4,
+
+    generation/1,
+    unpack_iterator/2,
+    scan_stream/6,
+    message_matcher/2,
+
     delete_next/5,
 
     %% Preconditions
@@ -75,7 +83,9 @@
     options/0,
     prototype/0,
     cooked_batch/0,
-    batch_store_opts/0
+    batch_store_opts/0,
+    poll_iterators/0,
+    event_dispatch_f/0
 ]).
 
 -include("emqx_ds.hrl").
@@ -179,6 +189,8 @@
         ?enc := term()
     }.
 
+-type event_dispatch_f() :: fun(([stream()]) -> ok).
+
 %%%% Generation:
 
 -define(GEN_KEY(GEN_ID), {generation, GEN_ID}).
@@ -228,6 +240,11 @@
 
 -type options() :: map().
 
+-type poll_iterators() :: [{_UserData, iterator()}].
+
+-define(ERR_GEN_GONE, {error, unrecoverable, generation_not_found}).
+-define(ERR_BUFF_FULL, {error, recoverable, reached_max}).
+
 %%================================================================================
 %% Generation callbacks
 %%================================================================================
@@ -311,7 +328,18 @@
 -callback handle_event(shard_id(), generation_data(), emqx_ds:time(), CustomEvent | tick) ->
     [CustomEvent].
 
--optional_callbacks([handle_event/4]).
+%% Stream event API:
+
+-callback batch_events(
+    generation_data(),
+    _CookedBatch
+) -> [_Stream].
+
+-optional_callbacks([
+    handle_event/4,
+    %% FIXME: should be mandatory:
+    batch_events/2
+]).
 
 %%================================================================================
 %% API for the replication layer
@@ -333,10 +361,25 @@ drop_shard(Shard) ->
 %% `commit' operations.
 -spec store_batch(shard_id(), batch(), batch_store_opts()) ->
     emqx_ds:store_batch_result().
-store_batch(Shard, Batch, Options) ->
+store_batch(Shard, Messages, Options) ->
+    DispatchF = fun(_) -> ok end,
+    store_batch(Shard, Messages, Options, DispatchF).
+
+%% @doc This is a convenicence wrapper that combines `prepare',
+%% `commit' and `dispatch_events' operations.
+-spec store_batch(
+    shard_id(),
+    batch(),
+    batch_store_opts(),
+    event_dispatch_f()
+) ->
+    emqx_ds:store_batch_result().
+store_batch(Shard, Batch, Options, DispatchF) ->
     case prepare_batch(Shard, Batch, #{}) of
         {ok, CookedBatch} ->
-            commit_batch(Shard, CookedBatch, Options);
+            Result = commit_batch(Shard, CookedBatch, Options),
+            dispatch_events(Shard, CookedBatch, DispatchF),
+            Result;
         ignore ->
             ok;
         Error = {error, _, _} ->
@@ -409,6 +452,18 @@ commit_batch(Shard, #{?tag := ?COOKED_BATCH, ?generation := GenId, ?enc := Cooke
     emqx_ds_builtin_metrics:observe_store_batch_time(Shard, T1 - T0),
     Result.
 
+-spec dispatch_events(
+    shard_id(),
+    cooked_batch(),
+    event_dispatch_f()
+) -> ok.
+dispatch_events(
+    Shard, #{?tag := ?COOKED_BATCH, ?generation := GenId, ?enc := CookedBatch}, DispatchF
+) ->
+    #{?GEN_KEY(GenId) := #{module := Mod, data := GenData}} = get_schema_runtime(Shard),
+    Events = Mod:batch_events(GenData, CookedBatch),
+    DispatchF([{?stream_v2(GenId, InnerStream), Topic} || {InnerStream, Topic} <- Events]).
+
 -spec get_streams(shard_id(), emqx_ds:topic_filter(), emqx_ds:time()) ->
     [{integer(), stream()}].
 get_streams(Shard, TopicFilter, StartTime) ->
@@ -517,9 +572,13 @@ update_iterator(
                     {error, unrecoverable, Err}
             end;
         not_found ->
-            {error, unrecoverable, generation_not_found}
+            ?ERR_GEN_GONE
     end.
 
+-spec generation(iterator()) -> gen_id().
+generation(#{?tag := ?IT, ?generation := GenId}) ->
+    GenId.
+
 -spec next(shard_id(), iterator(), pos_integer(), emqx_ds:time()) ->
     emqx_ds:next_result(iterator()).
 next(Shard, Iter = #{?tag := ?IT, ?generation := GenId, ?enc := GenIter0}, BatchSize, Now) ->
@@ -536,7 +595,46 @@ next(Shard, Iter = #{?tag := ?IT, ?generation := GenId, ?enc := GenIter0}, Batch
             end;
         not_found ->
             %% generation was possibly dropped by GC
-            {error, unrecoverable, generation_not_found}
+            ?ERR_GEN_GONE
+    end.
+
+%% Internal API for fetching data with multiple iterators in one
+%% sweep. This API does not suppose precise batch size.
+
+%%    When doing multi-next, we group iterators by stream:
+unpack_iterator(Shard, #{?tag := ?IT, ?generation := GenId, ?enc := Inner}) ->
+    case generation_get(Shard, GenId) of
+        #{module := Mod, data := GenData} ->
+            {InnerStream, TopicFilter, Key, TS} = Mod:unpack_iterator(Shard, GenData, Inner),
+            {?stream_v2(GenId, InnerStream), TopicFilter, Key, TS};
+        not_found ->
+            %% generation was possibly dropped by GC
+            ?ERR_GEN_GONE
+    end.
+
+%% @doc This callback is similar in nature to `next'. It is used by
+%% the beamformer module, and it allows to fetch data for multiple
+%% iterators at once.
+scan_stream(
+    Shard, ?stream_v2(GenId, Inner), TopicFilter, Now, StartMsg, BatchSize
+) ->
+    case generation_get(Shard, GenId) of
+        #{module := Mod, data := GenData} ->
+            IsCurrent = GenId =:= generation_current(Shard),
+            Mod:scan_stream(
+                Shard, GenData, Inner, TopicFilter, StartMsg, BatchSize, Now, IsCurrent
+            );
+        not_found ->
+            ?ERR_GEN_GONE
+    end.
+
+message_matcher(Shard, #{?tag := ?IT, ?generation := GenId, ?enc := Inner}) ->
+    %% logger:warning(?MODULE_STRING ++ ":match_message(~p, ~p, ~p)", [Shard, GenId, Inner]),
+    case generation_get(Shard, GenId) of
+        #{module := Mod, data := GenData} ->
+            Mod:message_matcher(Shard, GenData, Inner);
+        not_found ->
+            false
     end.
 
 -spec delete_next(

+ 40 - 1
apps/emqx_durable_storage/src/emqx_ds_storage_reference.erl

@@ -42,7 +42,12 @@
     update_iterator/4,
     next/6,
     delete_next/7,
-    lookup_message/3
+    lookup_message/3,
+
+    unpack_iterator/3,
+    scan_stream/8,
+    message_matcher/3,
+    batch_events/2
 ]).
 
 %% internal exports:
@@ -219,6 +224,40 @@ lookup_message(_ShardId, #s{db = DB, cf = CF}, #message_matcher{timestamp = TS})
             {error, unrecoverable, Reason}
     end.
 
+unpack_iterator(_Shard, _S, #it{topic_filter = TopicFilter, last_seen_message_key = LSK}) ->
+    Stream = #stream{},
+    case LSK of
+        first -> Timestamp = 0;
+        <<Timestamp:64>> -> ok
+    end,
+    {Stream, TopicFilter, LSK, Timestamp}.
+
+scan_stream(Shard, S, _Stream, TopicFilter, LastSeenKey, BatchSize, TMax, IsCurrent) ->
+    It0 = #it{topic_filter = TopicFilter, start_time = 0, last_seen_message_key = LastSeenKey},
+    case next(Shard, S, It0, BatchSize, TMax, IsCurrent) of
+        {ok, #it{last_seen_message_key = LSK}, Batch} ->
+            {ok, LSK, Batch};
+        Other ->
+            Other
+    end.
+
+message_matcher(_Shard, _S, #it{
+    start_time = StartTime, topic_filter = TF, last_seen_message_key = LSK
+}) ->
+    fun(MsgKey = <<TS:64>>, #message{topic = Topic}) ->
+        MsgKey > LSK andalso TS >= StartTime andalso emqx_topic:match(Topic, TF)
+    end.
+
+batch_events(_, Messages) ->
+    Topics = lists:foldl(
+        fun({_TS, #message{topic = Topic}}, Acc) ->
+            Acc#{Topic => 1}
+        end,
+        #{},
+        Messages
+    ),
+    [{#stream{}, T} || T <- maps:keys(Topics)].
+
 %%================================================================================
 %% Internal functions
 %%================================================================================

+ 85 - 15
apps/emqx_durable_storage/src/emqx_ds_storage_skipstream_lts.erl

@@ -34,7 +34,13 @@
     update_iterator/4,
     next/6,
     delete_next/7,
-    lookup_message/3
+    lookup_message/3,
+
+    unpack_iterator/3,
+    scan_stream/8,
+    message_matcher/3,
+
+    batch_events/2
 ]).
 
 %% internal exports:
@@ -269,6 +275,19 @@ commit_batch(
         rocksdb:release_batch(Batch)
     end.
 
+batch_events(#s{trie = _Trie}, #{?cooked_msg_ops := Payloads}) ->
+    EventMap = lists:foldl(
+        fun
+            (?cooked_msg_op(_Timestamp, _Static, _Varying, ?cooked_delete), Acc) ->
+                Acc;
+            (?cooked_msg_op(_Timestamp, Static, Varying, _ValBlob), Acc) ->
+                maps:put({Static, Varying}, 1, Acc)
+        end,
+        #{},
+        Payloads
+    ),
+    maps:keys(EventMap).
+
 get_streams(_Shard, #s{trie = Trie}, TopicFilter, _StartTime) ->
     get_streams(Trie, TopicFilter).
 
@@ -278,6 +297,9 @@ get_delete_streams(_Shard, #s{trie = Trie}, TopicFilter, _StartTime) ->
 make_iterator(_Shard, _State, _Stream, _TopicFilter, TS) when TS >= ?max_ts ->
     {error, unrecoverable, "Timestamp is too large"};
 make_iterator(_Shard, #s{trie = Trie}, #stream{static_index = StaticIdx}, TopicFilter, StartTime) ->
+    ?tp_ignore_side_effects_in_prod(emqx_ds_storage_skipstream_lts_make_iterator, #{
+        static_index => StaticIdx, topic_filter => TopicFilter, start_time => StartTime
+    }),
     {ok, TopicStructure} = emqx_ds_lts:reverse_lookup(Trie, StaticIdx),
     CompressedTF = emqx_ds_lts:compress_topic(StaticIdx, TopicStructure, TopicFilter),
     {ok, #it{
@@ -286,6 +308,53 @@ make_iterator(_Shard, #s{trie = Trie}, #stream{static_index = StaticIdx}, TopicF
         compressed_tf = emqx_topic:join(CompressedTF)
     }}.
 
+message_matcher(_Shard, #s{trie = Trie}, #it{
+    static_index = StaticIdx, ts = LastSeenTS, compressed_tf = CTF
+}) ->
+    {ok, TopicStructure} = emqx_ds_lts:reverse_lookup(Trie, StaticIdx),
+    TF = emqx_ds_lts:decompress_topic(TopicStructure, words(CTF)),
+    fun(MsgKey, #message{topic = Topic}) ->
+        case match_ds_key(StaticIdx, MsgKey) of
+            false ->
+                ?tp_ignore_side_effects_in_prod(emqx_ds_storage_skipstream_lts_matcher, #{
+                    static_index => StaticIdx,
+                    last_seen_ts => LastSeenTS,
+                    topic_filter => TF,
+                    its => false
+                }),
+                false;
+            TS ->
+                ?tp_ignore_side_effects_in_prod(emqx_ds_storage_skipstream_lts_matcher, #{
+                    static_index => StaticIdx,
+                    last_seen_ts => LastSeenTS,
+                    topic_filter => TF,
+                    its => TS
+                }),
+                %% Timestamp stored in the iterator follows modulo
+                %% 2^64 arithmetic, so in this context `?max_ts' means
+                %% 0.
+                (LastSeenTS =:= ?max_ts orelse TS > LastSeenTS) andalso
+                    emqx_topic:match(words(Topic), TF)
+        end
+    end.
+
+unpack_iterator(_Shard, #s{trie = _Trie}, #it{
+    static_index = StaticIdx, compressed_tf = CTF, ts = TS
+}) ->
+    StartKey = mk_key(StaticIdx, 0, <<>>, TS),
+    %% Structure = get_topic_structure(Trie, StaticIdx),
+    {StaticIdx, words(CTF), StartKey, TS}.
+
+scan_stream(Shard, S, StaticIdx, Varying, LastSeenKey, BatchSize, TMax, IsCurrent) ->
+    LastSeenTS = match_ds_key(StaticIdx, LastSeenKey),
+    It = #it{static_index = StaticIdx, compressed_tf = emqx_topic:join(Varying), ts = LastSeenTS},
+    case next(Shard, S, It, BatchSize, TMax, IsCurrent) of
+        {ok, #it{ts = TS, static_index = StaticIdx}, Batch} ->
+            {ok, mk_key(StaticIdx, 0, <<>>, TS), Batch};
+        Other ->
+            Other
+    end.
+
 make_delete_iterator(Shard, Data, Stream, TopicFilter, StartTime) ->
     make_iterator(Shard, Data, Stream, TopicFilter, StartTime).
 
@@ -509,16 +578,7 @@ next_loop(
     BatchSize,
     TMax
 ) ->
-    TopicStructure =
-        case emqx_ds_lts:reverse_lookup(Trie, StaticIdx) of
-            {ok, Rev} ->
-                Rev;
-            undefined ->
-                throw(#{
-                    msg => "LTS trie missing key",
-                    key => StaticIdx
-                })
-        end,
+    TopicStructure = get_topic_structure(Trie, StaticIdx),
     Ctx = #ctx{
         shard = Shard,
         s = S,
@@ -529,6 +589,17 @@ next_loop(
     },
     next_loop(Ctx, It, BatchSize, {seek, inc_ts(LastTS)}, []).
 
+get_topic_structure(Trie, StaticIdx) ->
+    case emqx_ds_lts:reverse_lookup(Trie, StaticIdx) of
+        {ok, Rev} ->
+            Rev;
+        undefined ->
+            throw(#{
+                msg => "LTS trie missing key",
+                key => StaticIdx
+            })
+    end.
+
 next_loop(_Ctx, It, 0, _Op, Acc) ->
     finalize_loop(It, Acc);
 next_loop(Ctx, It0, BatchSize, Op, Acc) ->
@@ -777,10 +848,9 @@ trie_cf(GenId) ->
 
 %%%%%%%% Topic encoding %%%%%%%%%%
 
-words(<<>>) ->
-    [];
-words(Bin) ->
-    emqx_topic:words(Bin).
+% words(L) when is_list(L) -> L;
+words(<<>>) -> [];
+words(Bin) -> emqx_topic:words(Bin).
 
 %%%%%%%% Counters %%%%%%%%%%
 

+ 1 - 1
apps/emqx_durable_storage/src/emqx_durable_storage.app.src

@@ -5,7 +5,7 @@
     {vsn, "0.3.0"},
     {modules, []},
     {registered, []},
-    {applications, [kernel, stdlib, rocksdb, gproc, mria, emqx_utils]},
+    {applications, [kernel, stdlib, rocksdb, gproc, mria, emqx_utils, gen_rpc]},
     {mod, {emqx_ds_app, []}},
     {env, []}
 ]}.

+ 39 - 0
apps/emqx_durable_storage/src/proto/emqx_ds_beamsplitter_proto_v1.erl

@@ -0,0 +1,39 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+-module(emqx_ds_beamsplitter_proto_v1).
+
+%-behavior(emqx_bpapi).
+
+%% API:
+-export([dispatch/2]).
+
+%% behavior callbacks:
+-export([introduced_in/0]).
+
+%%================================================================================
+%% API functions
+%%================================================================================
+
+-spec dispatch(node(), emqx_ds_beamformer:beam()) -> true.
+dispatch(Node, Beam) ->
+    emqx_rpc:cast(Node, emqx_ds_beamformer, do_dispatch, [Beam]).
+
+%%================================================================================
+%% behavior callbacks
+%%================================================================================
+
+introduced_in() ->
+    "5.8.0".

+ 114 - 1
apps/emqx_durable_storage/test/emqx_ds_storage_layout_SUITE.erl

@@ -23,6 +23,7 @@
 -include_lib("common_test/include/ct.hrl").
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
 -include_lib("stdlib/include/assert.hrl").
+-include("emqx_ds.hrl").
 
 -define(assertSameSet(A, B), ?assertEqual(lists:sort(A), lists:sort(B))).
 
@@ -45,6 +46,8 @@ all() ->
 init_per_group(Group, Config) ->
     LayoutConf =
         case Group of
+            reference ->
+                {emqx_ds_storage_reference, #{}};
             skipstream_lts ->
                 {emqx_ds_storage_skipstream_lts, #{with_guid => true}};
             bitfield_lts ->
@@ -289,6 +292,111 @@ t_replay(Config) ->
     ?assert(check(?SHARD, <<"#">>, 0, Messages)),
     ok.
 
+%% This testcase verifies poll functionality that doesn't involve events:
+t_poll(Config) ->
+    ?check_trace(
+        begin
+            Topics = [list_to_binary("t/" ++ integer_to_list(I)) || I <- lists:seq(1, 21)],
+            Values = lists:seq(1, 1_000, 500),
+            Batch1 = [
+                make_message(Val, Topic, bin(Val))
+             || Topic <- Topics, Val <- Values
+            ],
+            BatchSize = 1000,
+            Timeout = 1_000,
+            PollOpts = #{timeout => Timeout},
+            %% 1. Store a batch of data to create streams:
+            ok = emqx_ds:store_batch(?FUNCTION_NAME, Batch1),
+            timer:sleep(1000),
+            %% 2. Create a number of iterators for different topic
+            %% subscriptions. These iterators overlap, so the
+            %% beamformer is likely to group them.
+            Iterators0 =
+                [
+                    begin
+                        {ok, It} = emqx_ds:make_iterator(?FUNCTION_NAME, Stream, TopicFilter, 0),
+                        %% Create a reference to identify poll reply:
+                        {make_ref(), It}
+                    end
+                 || TopicFilter <- lists:map(fun emqx_topic:words/1, [<<"#">> | Topics]),
+                    {_Rank, Stream} <- emqx_ds:get_streams(?FUNCTION_NAME, TopicFilter, 0)
+                ],
+            ?assertMatch([_ | _], Iterators0, "List of iterators should be non-empty"),
+            %% 2. Fetch values via `next' API for reference:
+            Reference1 = [
+                {Ref, emqx_ds:next(?FUNCTION_NAME, It, BatchSize)}
+             || {Ref, It} <- Iterators0
+            ],
+            %% 3. Fetch the same data via poll API (we use initial values of
+            %% the iterator as tags):
+            {ok, Alias1} = emqx_ds:poll(?FUNCTION_NAME, Iterators0, PollOpts),
+            %% Collect the replies:
+            Got1 = collect_poll_replies(Alias1, Timeout),
+            unalias(Alias1),
+            %% 4. Compare data. Everything (batch contents and iterators) should be the same:
+            compare_poll_with_reference(Reference1, Got1),
+            %% 5. Create a new poll request with the new iterators,
+            %% these should be resolved via events:
+            Iterators1 = lists:map(
+                fun({ItRef, {ok, It, _}}) ->
+                    {ItRef, It}
+                end,
+                Got1
+            ),
+            {ok, Alias2} = emqx_ds:poll(?FUNCTION_NAME, Iterators1, PollOpts),
+            %% 5.1 Sleep to make sure poll requests are enqueued
+            %% _before_ the batch is published:
+            timer:sleep(10),
+            %% 6. Add new data and receive results:
+            emqx_ds:store_batch(?FUNCTION_NAME, Batch1),
+            case ?config(layout, Config) of
+                {emqx_ds_storage_bitfield_lts, _} ->
+                    %% Currenty this layout doesn't support events:
+                    ok;
+                _ ->
+                    ?assertMatch(
+                        [{_, {ok, _, [_ | _]}} | _],
+                        collect_poll_replies(Alias2, Timeout),
+                        "Poll reply with non-empty batch should be received after "
+                        "data was published to the topic."
+                    )
+            end
+        end,
+        []
+    ).
+
+compare_poll_with_reference(Reference, PollRepliesL) ->
+    PollReplies = maps:from_list(PollRepliesL),
+    lists:foreach(
+        fun({ItRef, ReferenceReply}) ->
+            case ReferenceReply of
+                {ok, _, []} ->
+                    %% DS doesn't send empty replies back, so skip
+                    %% check here:
+                    ok;
+                _ ->
+                    compare_poll_reply(ReferenceReply, maps:get(ItRef, PollReplies, undefined))
+            end
+        end,
+        Reference
+    ).
+
+compare_poll_reply({ok, ReferenceIterator, BatchRef}, {ok, ReplyIterator, Batch}) ->
+    ?defer_assert(?assertEqual(ReferenceIterator, ReplyIterator, "Iterators should be equal")),
+    ?defer_assert(snabbkaffe_diff:assert_lists_eq(BatchRef, Batch));
+compare_poll_reply(A, B) ->
+    ?defer_assert(?assertEqual(A, B)).
+
+collect_poll_replies(Alias, Timeout) ->
+    receive
+        #poll_reply{payload = poll_timeout, ref = Alias} ->
+            [];
+        #poll_reply{userdata = ItRef, payload = Reply, ref = Alias} ->
+            [{ItRef, Reply} | collect_poll_replies(Alias, Timeout)]
+    after Timeout ->
+        []
+    end.
+
 t_atomic_store_batch(_Config) ->
     DB = ?FUNCTION_NAME,
     ?check_trace(
@@ -486,6 +594,7 @@ bin(X) ->
 groups() ->
     TCs = emqx_common_test_helpers:all(?MODULE),
     [
+        {reference, TCs},
         {bitfield_lts, TCs},
         {skipstream_lts, TCs}
     ].
@@ -494,8 +603,12 @@ suite() -> [{timetrap, {seconds, 20}}].
 
 init_per_suite(Config) ->
     WorkDir = emqx_cth_suite:work_dir(Config),
+    DSEnv = [{poll_batch_size, 1000}],
     Apps = emqx_cth_suite:start(
-        [emqx_ds_builtin_local],
+        [
+            {emqx_durable_storage, #{override_env => DSEnv}},
+            emqx_ds_builtin_local
+        ],
         #{work_dir => WorkDir}
     ),
     [{apps, Apps}, {work_dir, WorkDir} | Config].

+ 0 - 60
apps/emqx_durable_storage/test/props/emqx_ds_message_storage_bitmask_shim.erl

@@ -1,60 +0,0 @@
-%%--------------------------------------------------------------------
-%% Copyright (c) 2020-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
-%%
-%% Licensed under the Apache License, Version 2.0 (the "License");
-%% you may not use this file except in compliance with the License.
-%% You may obtain a copy of the License at
-%%
-%%     http://www.apache.org/licenses/LICENSE-2.0
-%%
-%% Unless required by applicable law or agreed to in writing, software
-%% distributed under the License is distributed on an "AS IS" BASIS,
-%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-%% See the License for the specific language governing permissions and
-%% limitations under the License.
-%%--------------------------------------------------------------------
-
--module(emqx_ds_message_storage_bitmask_shim).
-
--include("../../../emqx/include/emqx.hrl").
-
--export([open/0]).
--export([close/1]).
--export([store/2]).
--export([iterate/2]).
-
--opaque t() :: ets:tid().
-
--export_type([t/0]).
-
--spec open() -> t().
-open() ->
-    ets:new(?MODULE, [ordered_set, {keypos, 1}]).
-
--spec close(t()) -> ok.
-close(Tab) ->
-    true = ets:delete(Tab),
-    ok.
-
--spec store(t(), emqx_types:message()) ->
-    ok | {error, _TODO}.
-store(Tab, Msg = #message{id = MessageID, timestamp = PublishedAt}) ->
-    true = ets:insert(Tab, {{PublishedAt, MessageID}, Msg}),
-    ok.
-
--spec iterate(t(), emqx_ds:replay()) ->
-    [binary()].
-iterate(Tab, {TopicFilter0, StartTime}) ->
-    TopicFilter = iolist_to_binary(lists:join("/", TopicFilter0)),
-    ets:foldr(
-        fun({{PublishedAt, _}, Msg = #message{topic = Topic}}, Acc) ->
-            case emqx_topic:match(Topic, TopicFilter) of
-                true when PublishedAt >= StartTime ->
-                    [Msg | Acc];
-                _ ->
-                    Acc
-            end
-        end,
-        [],
-        Tab
-    ).

+ 6 - 1
apps/emqx_prometheus/src/emqx_prometheus.erl

@@ -525,7 +525,12 @@ emqx_collect(K = ?DS_SKIPSTREAM_LTS_HASH_COLLISION, D) -> counter_metrics(?MG(K,
 emqx_collect(K = ?DS_SKIPSTREAM_LTS_HIT, D) -> counter_metrics(?MG(K, D, []));
 emqx_collect(K = ?DS_SKIPSTREAM_LTS_MISS, D) -> counter_metrics(?MG(K, D, []));
 emqx_collect(K = ?DS_SKIPSTREAM_LTS_FUTURE, D) -> counter_metrics(?MG(K, D, []));
-emqx_collect(K = ?DS_SKIPSTREAM_LTS_EOS, D) -> counter_metrics(?MG(K, D, [])).
+emqx_collect(K = ?DS_SKIPSTREAM_LTS_EOS, D) -> counter_metrics(?MG(K, D, []));
+emqx_collect(K = ?DS_POLL_REQUESTS, D) -> counter_metrics(?MG(K, D, []));
+emqx_collect(K = ?DS_POLL_REQUESTS_FULFILLED, D) -> counter_metrics(?MG(K, D, []));
+emqx_collect(K = ?DS_POLL_REQUESTS_DROPPED, D) -> counter_metrics(?MG(K, D, []));
+emqx_collect(K = ?DS_POLL_REQUESTS_EXPIRED, D) -> counter_metrics(?MG(K, D, []));
+emqx_collect(K = ?DS_POLL_REQUEST_SHARING, D) -> gauge_metrics(?MG(K, D, [])).
 
 %%--------------------------------------------------------------------
 %% Indicators