Просмотр исходного кода

Merge pull request #12998 from ieQu1/dev/improve-latency

Use leader's clock when calculating LTS cutoff timestamp
ieQu1 1 год назад
Родитель
Сommit
8506ca7919

+ 1 - 1
apps/emqx/rebar.config

@@ -34,7 +34,7 @@
     {emqx_http_lib, {git, "https://github.com/emqx/emqx_http_lib.git", {tag, "0.5.3"}}},
     {pbkdf2, {git, "https://github.com/emqx/erlang-pbkdf2.git", {tag, "2.0.4"}}},
     {recon, {git, "https://github.com/ferd/recon", {tag, "2.5.1"}}},
-    {snabbkaffe, {git, "https://github.com/kafka4beam/snabbkaffe.git", {tag, "1.0.8"}}},
+    {snabbkaffe, {git, "https://github.com/kafka4beam/snabbkaffe.git", {tag, "1.0.10"}}},
     {ra, "2.7.3"}
 ]}.
 

+ 1 - 0
apps/emqx_durable_storage/src/emqx_ds_builtin_db_sup.erl

@@ -118,6 +118,7 @@ which_dbs() ->
 init({#?db_sup{db = DB}, DefaultOpts}) ->
     %% Spec for the top-level supervisor for the database:
     logger:notice("Starting DS DB ~p", [DB]),
+    emqx_ds_builtin_sup:clean_gvars(DB),
     emqx_ds_builtin_metrics:init_for_db(DB),
     Opts = emqx_ds_replication_layer_meta:open_db(DB, DefaultOpts),
     ok = start_ra_system(DB, Opts),

+ 1 - 1
apps/emqx_durable_storage/src/emqx_ds_builtin_metrics.erl

@@ -190,7 +190,7 @@ prometheus_per_db(NodeOrAggr) ->
 %%  ...
 %% '''
 %%
-%% If `NodeOrAggr' = `node' then node name is appended to the list of
+%% If `NodeOrAggr' = `aggr' then node name is appended to the list of
 %% labels.
 prometheus_per_db(NodeOrAggr, DB, Acc0) ->
     Labels = [

+ 30 - 1
apps/emqx_durable_storage/src/emqx_ds_builtin_sup.erl

@@ -23,6 +23,7 @@
 
 %% API:
 -export([start_db/2, stop_db/1]).
+-export([set_gvar/3, get_gvar/3, clean_gvars/1]).
 
 %% behavior callbacks:
 -export([init/1]).
@@ -39,6 +40,13 @@
 -define(top, ?MODULE).
 -define(databases, emqx_ds_builtin_databases_sup).
 
+-define(gvar_tab, emqx_ds_builtin_gvar).
+
+-record(gvar, {
+    k :: {emqx_ds:db(), _Key},
+    v :: _Value
+}).
+
 %%================================================================================
 %% API functions
 %%================================================================================
@@ -61,11 +69,31 @@ stop_db(DB) ->
         Pid when is_pid(Pid) ->
             _ = supervisor:terminate_child(?databases, DB),
             _ = supervisor:delete_child(?databases, DB),
-            ok;
+            clean_gvars(DB);
         undefined ->
             ok
     end.
 
+%% @doc Set a DB-global variable. Please don't abuse this API.
+-spec set_gvar(emqx_ds:db(), _Key, _Val) -> ok.
+set_gvar(DB, Key, Val) ->
+    ets:insert(?gvar_tab, #gvar{k = {DB, Key}, v = Val}),
+    ok.
+
+-spec get_gvar(emqx_ds:db(), _Key, Val) -> Val.
+get_gvar(DB, Key, Default) ->
+    case ets:lookup(?gvar_tab, {DB, Key}) of
+        [#gvar{v = Val}] ->
+            Val;
+        [] ->
+            Default
+    end.
+
+-spec clean_gvars(emqx_ds:db()) -> ok.
+clean_gvars(DB) ->
+    ets:match_delete(?gvar_tab, #gvar{k = {DB, '_'}, _ = '_'}),
+    ok.
+
 %%================================================================================
 %% behavior callbacks
 %%================================================================================
@@ -96,6 +124,7 @@ init(?top) ->
         type => supervisor,
         shutdown => infinity
     },
+    _ = ets:new(?gvar_tab, [named_table, set, public, {keypos, #gvar.k}, {read_concurrency, true}]),
     %%
     SupFlags = #{
         strategy => one_for_all,

+ 13 - 1
apps/emqx_durable_storage/src/emqx_ds_lts.erl

@@ -19,7 +19,9 @@
 %% API:
 -export([
     trie_create/1, trie_create/0,
+    destroy/1,
     trie_restore/2,
+    trie_update/2,
     trie_copy_learned_paths/2,
     topic_key/3,
     match_topics/2,
@@ -116,10 +118,20 @@ trie_create(UserOpts) ->
 trie_create() ->
     trie_create(#{}).
 
+-spec destroy(trie()) -> ok.
+destroy(#trie{trie = Trie, stats = Stats}) ->
+    catch ets:delete(Trie),
+    catch ets:delete(Stats),
+    ok.
+
 %% @doc Restore trie from a dump
 -spec trie_restore(options(), [{_Key, _Val}]) -> trie().
 trie_restore(Options, Dump) ->
-    Trie = trie_create(Options),
+    trie_update(trie_create(Options), Dump).
+
+%% @doc Update a trie with a dump of operations (used for replication)
+-spec trie_update(trie(), [{_Key, _Val}]) -> trie().
+trie_update(Trie, Dump) ->
     lists:foreach(
         fun({{StateFrom, Token}, StateTo}) ->
             trie_insert(Trie, StateFrom, Token, StateTo)

+ 80 - 26
apps/emqx_durable_storage/src/emqx_ds_replication_layer.erl

@@ -36,7 +36,8 @@
     update_iterator/3,
     next/3,
     delete_next/4,
-    shard_of_message/3
+    shard_of_message/3,
+    current_timestamp/2
 ]).
 
 %% internal exports:
@@ -65,6 +66,7 @@
 -export([
     init/1,
     apply/3,
+    tick/2,
 
     snapshot_module/0
 ]).
@@ -86,6 +88,7 @@
 ]).
 
 -include_lib("emqx_utils/include/emqx_message.hrl").
+-include_lib("snabbkaffe/include/trace.hrl").
 -include("emqx_ds_replication_layer.hrl").
 
 %%================================================================================
@@ -155,12 +158,14 @@
 
 %% Command. Each command is an entry in the replication log.
 -type ra_command() :: #{
-    ?tag := ?BATCH | add_generation | update_config | drop_generation,
+    ?tag := ?BATCH | add_generation | update_config | drop_generation | storage_event,
     _ => _
 }.
 
 -type timestamp_us() :: non_neg_integer().
 
+-define(gv_timestamp(SHARD), {gv_timestamp, SHARD}).
+
 %%================================================================================
 %% API functions
 %%================================================================================
@@ -363,9 +368,16 @@ shard_of_message(DB, #message{from = From, topic = Topic}, SerializeBy) ->
         end,
     integer_to_binary(Hash).
 
+-spec foreach_shard(emqx_ds:db(), fun((shard_id()) -> _)) -> ok.
 foreach_shard(DB, Fun) ->
     lists:foreach(Fun, list_shards(DB)).
 
+%% @doc Messages have been replicated up to this timestamp on the
+%% local server
+-spec current_timestamp(emqx_ds:db(), emqx_ds_replication_layer:shard_id()) -> emqx_ds:time().
+current_timestamp(DB, Shard) ->
+    emqx_ds_builtin_sup:get_gvar(DB, ?gv_timestamp(Shard), 0).
+
 %%================================================================================
 %% behavior callbacks
 %%================================================================================
@@ -490,7 +502,9 @@ do_next_v1(DB, Shard, Iter, BatchSize) ->
     ShardId = {DB, Shard},
     ?IF_STORAGE_RUNNING(
         ShardId,
-        emqx_ds_storage_layer:next(ShardId, Iter, BatchSize)
+        emqx_ds_storage_layer:next(
+            ShardId, Iter, BatchSize, emqx_ds_replication_layer:current_timestamp(DB, Shard)
+        )
     ).
 
 -spec do_delete_next_v4(
@@ -502,7 +516,13 @@ do_next_v1(DB, Shard, Iter, BatchSize) ->
 ) ->
     emqx_ds:delete_next_result(emqx_ds_storage_layer:delete_iterator()).
 do_delete_next_v4(DB, Shard, Iter, Selector, BatchSize) ->
-    emqx_ds_storage_layer:delete_next({DB, Shard}, Iter, Selector, BatchSize).
+    emqx_ds_storage_layer:delete_next(
+        {DB, Shard},
+        Iter,
+        Selector,
+        BatchSize,
+        emqx_ds_replication_layer:current_timestamp(DB, Shard)
+    ).
 
 -spec do_add_generation_v2(emqx_ds:db()) -> no_return().
 do_add_generation_v2(_DB) ->
@@ -672,50 +692,69 @@ apply(
         ?tag := ?BATCH,
         ?batch_messages := MessagesIn
     },
-    #{db_shard := DBShard, latest := Latest} = State
+    #{db_shard := DBShard = {DB, Shard}, latest := Latest0} = State0
 ) ->
     %% NOTE
     %% Unique timestamp tracking real time closely.
     %% With microsecond granularity it should be nearly impossible for it to run
     %% too far ahead than the real time clock.
-    {NLatest, Messages} = assign_timestamps(Latest, MessagesIn),
-    %% TODO
-    %% Batch is now reversed, but it should not make a lot of difference.
-    %% Even if it would be in order, it's still possible to write messages far away
-    %% in the past, i.e. when replica catches up with the leader. Storage layer
-    %% currently relies on wall clock time to decide if it's safe to iterate over
-    %% next epoch, this is likely wrong. Ideally it should rely on consensus clock
-    %% time instead.
+    ?tp(ds_ra_apply_batch, #{db => DB, shard => Shard, batch => MessagesIn, ts => Latest0}),
+    {Latest, Messages} = assign_timestamps(Latest0, MessagesIn),
     Result = emqx_ds_storage_layer:store_batch(DBShard, Messages, #{}),
-    NState = State#{latest := NLatest},
+    State = State0#{latest := Latest},
+    set_ts(DBShard, Latest),
     %% TODO: Need to measure effects of changing frequency of `release_cursor`.
-    Effect = {release_cursor, RaftIdx, NState},
-    {NState, Result, Effect};
+    Effect = {release_cursor, RaftIdx, State},
+    {State, Result, Effect};
 apply(
     _RaftMeta,
     #{?tag := add_generation, ?since := Since},
-    #{db_shard := DBShard, latest := Latest} = State
+    #{db_shard := DBShard, latest := Latest0} = State0
 ) ->
-    {Timestamp, NLatest} = ensure_monotonic_timestamp(Since, Latest),
+    {Timestamp, Latest} = ensure_monotonic_timestamp(Since, Latest0),
     Result = emqx_ds_storage_layer:add_generation(DBShard, Timestamp),
-    NState = State#{latest := NLatest},
-    {NState, Result};
+    State = State0#{latest := Latest},
+    set_ts(DBShard, Latest),
+    {State, Result};
 apply(
     _RaftMeta,
     #{?tag := update_config, ?since := Since, ?config := Opts},
-    #{db_shard := DBShard, latest := Latest} = State
+    #{db_shard := DBShard, latest := Latest0} = State0
 ) ->
-    {Timestamp, NLatest} = ensure_monotonic_timestamp(Since, Latest),
+    {Timestamp, Latest} = ensure_monotonic_timestamp(Since, Latest0),
     Result = emqx_ds_storage_layer:update_config(DBShard, Timestamp, Opts),
-    NState = State#{latest := NLatest},
-    {NState, Result};
+    State = State0#{latest := Latest},
+    {State, Result};
 apply(
     _RaftMeta,
     #{?tag := drop_generation, ?generation := GenId},
     #{db_shard := DBShard} = State
 ) ->
     Result = emqx_ds_storage_layer:drop_generation(DBShard, GenId),
-    {State, Result}.
+    {State, Result};
+apply(
+    _RaftMeta,
+    #{?tag := storage_event, ?payload := CustomEvent, ?now := Now},
+    #{db_shard := DBShard, latest := Latest0} = State
+) ->
+    Latest = max(Latest0, Now),
+    set_ts(DBShard, Latest),
+    ?tp(
+        debug,
+        emqx_ds_replication_layer_storage_event,
+        #{
+            shard => DBShard, payload => CustomEvent, latest => Latest
+        }
+    ),
+    Effects = handle_custom_event(DBShard, Latest, CustomEvent),
+    {State#{latest => Latest}, ok, Effects}.
+
+-spec tick(integer(), ra_state()) -> ra_machine:effects().
+tick(TimeMs, #{db_shard := DBShard = {DB, Shard}, latest := Latest}) ->
+    %% Leader = emqx_ds_replication_layer_shard:lookup_leader(DB, Shard),
+    {Timestamp, _} = ensure_monotonic_timestamp(timestamp_to_timeus(TimeMs), Latest),
+    ?tp(emqx_ds_replication_layer_tick, #{db => DB, shard => Shard, ts => Timestamp}),
+    handle_custom_event(DBShard, Timestamp, tick).
 
 assign_timestamps(Latest, Messages) ->
     assign_timestamps(Latest, Messages, []).
@@ -730,7 +769,7 @@ assign_timestamps(Latest, [MessageIn | Rest], Acc) ->
             assign_timestamps(Latest + 1, Rest, [Message | Acc])
     end;
 assign_timestamps(Latest, [], Acc) ->
-    {Latest, Acc}.
+    {Latest, lists:reverse(Acc)}.
 
 assign_timestamp(TimestampUs, Message) ->
     {TimestampUs, Message}.
@@ -748,3 +787,18 @@ timeus_to_timestamp(TimestampUs) ->
 
 snapshot_module() ->
     emqx_ds_replication_snapshot.
+
+handle_custom_event(DBShard, Latest, Event) ->
+    try
+        Events = emqx_ds_storage_layer:handle_event(DBShard, Latest, Event),
+        [{append, #{?tag => storage_event, ?payload => I, ?now => Latest}} || I <- Events]
+    catch
+        EC:Err:Stacktrace ->
+            ?tp(error, ds_storage_custom_event_fail, #{
+                EC => Err, stacktrace => Stacktrace, event => Event
+            }),
+            []
+    end.
+
+set_ts({DB, Shard}, TS) ->
+    emqx_ds_builtin_sup:set_gvar(DB, ?gv_timestamp(Shard), TS).

+ 4 - 0
apps/emqx_durable_storage/src/emqx_ds_replication_layer.hrl

@@ -41,4 +41,8 @@
 %% drop_generation
 -define(generation, 2).
 
+%% custom events
+-define(payload, 2).
+-define(now, 3).
+
 -endif.

+ 4 - 2
apps/emqx_durable_storage/src/emqx_ds_replication_layer_shard.erl

@@ -16,6 +16,7 @@
 
 -module(emqx_ds_replication_layer_shard).
 
+%% API:
 -export([start_link/3]).
 
 %% Static server configuration
@@ -325,7 +326,8 @@ start_server(DB, Shard, #{replication_options := ReplicationOpts}) ->
     ClusterName = cluster_name(DB, Shard),
     LocalServer = local_server(DB, Shard),
     Servers = shard_servers(DB, Shard),
-    case ra:restart_server(DB, LocalServer) of
+    MutableConfig = #{tick_timeout => 100},
+    case ra:restart_server(DB, LocalServer, MutableConfig) of
         {error, name_not_registered} ->
             Bootstrap = true,
             Machine = {module, emqx_ds_replication_layer, #{db => DB, shard => Shard}},
@@ -336,7 +338,7 @@ start_server(DB, Shard, #{replication_options := ReplicationOpts}) ->
                 ],
                 ReplicationOpts
             ),
-            ok = ra:start_server(DB, #{
+            ok = ra:start_server(DB, MutableConfig#{
                 id => LocalServer,
                 uid => server_uid(DB, Shard),
                 cluster_name => ClusterName,

+ 161 - 28
apps/emqx_durable_storage/src/emqx_ds_storage_bitfield_lts.erl

@@ -28,15 +28,18 @@
     create/4,
     open/5,
     drop/5,
-    store_batch/4,
+    prepare_batch/4,
+    commit_batch/3,
     get_streams/4,
     get_delete_streams/4,
     make_iterator/5,
     make_delete_iterator/5,
     update_iterator/4,
-    next/4,
-    delete_next/5,
-    post_creation_actions/1
+    next/5,
+    delete_next/6,
+    post_creation_actions/1,
+
+    handle_event/4
 ]).
 
 %% internal exports:
@@ -66,6 +69,9 @@
 -define(start_time, 3).
 -define(storage_key, 4).
 -define(last_seen_key, 5).
+-define(cooked_payloads, 6).
+-define(cooked_lts_ops, 7).
+-define(cooked_ts, 8).
 
 -type options() ::
     #{
@@ -88,16 +94,28 @@
     db :: rocksdb:db_handle(),
     data :: rocksdb:cf_handle(),
     trie :: emqx_ds_lts:trie(),
+    trie_cf :: rocksdb:cf_handle(),
     keymappers :: array:array(emqx_ds_bitmask_keymapper:keymapper()),
-    ts_offset :: non_neg_integer()
+    ts_bits :: non_neg_integer(),
+    ts_offset :: non_neg_integer(),
+    gvars :: ets:table()
 }).
 
+-define(lts_persist_ops, emqx_ds_storage_bitfield_lts_ops).
+
 -type s() :: #s{}.
 
 -type stream() :: emqx_ds_lts:msg_storage_key().
 
 -type delete_stream() :: emqx_ds_lts:msg_storage_key().
 
+-type cooked_batch() ::
+    #{
+        ?cooked_payloads := [{binary(), binary()}],
+        ?cooked_lts_ops := [{binary(), binary()}],
+        ?cooked_ts := integer()
+    }.
+
 -type iterator() ::
     #{
         ?tag := ?IT,
@@ -141,6 +159,10 @@
 
 -define(DS_LTS_COUNTERS, [?DS_LTS_SEEK_COUNTER, ?DS_LTS_NEXT_COUNTER, ?DS_LTS_COLLISION_COUNTER]).
 
+%% GVar used for idle detection:
+-define(IDLE_DETECT, idle_detect).
+-define(EPOCH(S, TS), (TS bsl S#s.ts_bits)).
+
 -ifdef(TEST).
 -include_lib("eunit/include/eunit.hrl").
 -endif.
@@ -212,8 +234,11 @@ open(_Shard, DBHandle, GenId, CFRefs, Schema) ->
         db = DBHandle,
         data = DataCF,
         trie = Trie,
+        trie_cf = TrieCF,
         keymappers = KeymapperCache,
-        ts_offset = TSOffsetBits
+        ts_offset = TSOffsetBits,
+        ts_bits = TSBits,
+        gvars = ets:new(?MODULE, [public, set, {read_concurrency, true}])
     }.
 
 -spec post_creation_actions(emqx_ds_storage_layer:post_creation_context()) ->
@@ -238,32 +263,78 @@ post_creation_actions(
     s()
 ) ->
     ok.
-drop(_Shard, DBHandle, GenId, CFRefs, #s{}) ->
+drop(_Shard, DBHandle, GenId, CFRefs, #s{trie = Trie, gvars = GVars}) ->
+    emqx_ds_lts:destroy(Trie),
+    catch ets:delete(GVars),
     {_, DataCF} = lists:keyfind(data_cf(GenId), 1, CFRefs),
     {_, TrieCF} = lists:keyfind(trie_cf(GenId), 1, CFRefs),
     ok = rocksdb:drop_column_family(DBHandle, DataCF),
     ok = rocksdb:drop_column_family(DBHandle, TrieCF),
     ok.
 
--spec store_batch(
+-spec prepare_batch(
     emqx_ds_storage_layer:shard_id(),
     s(),
-    [{emqx_ds:time(), emqx_types:message()}],
+    [{emqx_ds:time(), emqx_types:message()}, ...],
     emqx_ds:message_store_opts()
 ) ->
-    emqx_ds:store_batch_result().
-store_batch(_ShardId, S = #s{db = DB, data = Data}, Messages, _Options) ->
+    {ok, cooked_batch()}.
+prepare_batch(_ShardId, S, Messages, _Options) ->
+    _ = erase(?lts_persist_ops),
+    {Payloads, MaxTs} =
+        lists:mapfoldl(
+            fun({Timestamp, Msg}, Acc) ->
+                {Key, _} = make_key(S, Timestamp, Msg),
+                Payload = {Key, message_to_value_v1(Msg)},
+                {Payload, max(Acc, Timestamp)}
+            end,
+            0,
+            Messages
+        ),
+    {ok, #{
+        ?cooked_payloads => Payloads,
+        ?cooked_lts_ops => pop_lts_persist_ops(),
+        ?cooked_ts => MaxTs
+    }}.
+
+-spec commit_batch(
+    emqx_ds_storage_layer:shard_id(),
+    s(),
+    cooked_batch()
+) -> ok | emqx_ds:error(_).
+commit_batch(
+    _ShardId,
+    _Data,
+    #{?cooked_payloads := [], ?cooked_lts_ops := LTS}
+) ->
+    %% Assert:
+    [] = LTS,
+    ok;
+commit_batch(
+    _ShardId,
+    #s{db = DB, data = DataCF, trie = Trie, trie_cf = TrieCF, gvars = Gvars},
+    #{?cooked_lts_ops := LtsOps, ?cooked_payloads := Payloads, ?cooked_ts := MaxTs}
+) ->
     {ok, Batch} = rocksdb:batch(),
+    %% Commit LTS trie to the storage:
     lists:foreach(
-        fun({Timestamp, Msg}) ->
-            {Key, _} = make_key(S, Timestamp, Msg),
-            Val = serialize(Msg),
-            rocksdb:put(DB, Data, Key, Val, [])
+        fun({Key, Val}) ->
+            ok = rocksdb:batch_put(Batch, TrieCF, term_to_binary(Key), term_to_binary(Val))
         end,
-        Messages
+        LtsOps
+    ),
+    %% Apply LTS ops to the memory cache:
+    _ = emqx_ds_lts:trie_update(Trie, LtsOps),
+    %% Commit payloads:
+    lists:foreach(
+        fun({Key, Val}) ->
+            ok = rocksdb:batch_put(Batch, DataCF, Key, term_to_binary(Val))
+        end,
+        Payloads
     ),
     Result = rocksdb:write_batch(DB, Batch, []),
     rocksdb:release_batch(Batch),
+    ets:insert(Gvars, {?IDLE_DETECT, false, MaxTs}),
     %% NOTE
     %% Strictly speaking, `{error, incomplete}` is a valid result but should be impossible to
     %% observe until there's `{no_slowdown, true}` in write options.
@@ -348,13 +419,39 @@ update_iterator(
 ) ->
     {ok, OldIter#{?last_seen_key => DSKey}}.
 
-next(Shard, Schema = #s{ts_offset = TSOffset}, It, BatchSize) ->
-    %% Compute safe cutoff time.
-    %% It's the point in time where the last complete epoch ends, so we need to know
-    %% the current time to compute it.
+next(
+    Shard,
+    Schema = #s{ts_offset = TSOffset, ts_bits = TSBits},
+    It = #{?storage_key := Stream},
+    BatchSize,
+    Now
+) ->
     init_counters(),
-    Now = emqx_ds:timestamp_us(),
-    SafeCutoffTime = (Now bsr TSOffset) bsl TSOffset,
+    %% Compute safe cutoff time. It's the point in time where the last
+    %% complete epoch ends, so we need to know the current time to
+    %% compute it. This is needed because new keys can be added before
+    %% the iterator.
+    IsWildcard =
+        case Stream of
+            {_StaticKey, []} -> false;
+            _ -> true
+        end,
+    SafeCutoffTime =
+        case IsWildcard of
+            true ->
+                (Now bsr TSOffset) bsl TSOffset;
+            false ->
+                %% Iterators scanning streams without varying topic
+                %% levels can operate on incomplete epochs, since new
+                %% matching keys for the single topic are added in
+                %% lexicographic order.
+                %%
+                %% Note: this DOES NOT apply to non-wildcard topic
+                %% filters operating on streams with varying parts:
+                %% iterator can jump to the next topic and then it
+                %% won't backtrack.
+                1 bsl TSBits - 1
+        end,
     try
         next_until(Schema, It, SafeCutoffTime, BatchSize)
     after
@@ -386,12 +483,11 @@ next_until(#s{db = DB, data = CF, keymappers = Keymappers}, It, SafeCutoffTime,
         rocksdb:iterator_close(ITHandle)
     end.
 
-delete_next(Shard, Schema = #s{ts_offset = TSOffset}, It, Selector, BatchSize) ->
+delete_next(Shard, Schema = #s{ts_offset = TSOffset}, It, Selector, BatchSize, Now) ->
     %% Compute safe cutoff time.
     %% It's the point in time where the last complete epoch ends, so we need to know
     %% the current time to compute it.
     init_counters(),
-    Now = emqx_message:timestamp_now(),
     SafeCutoffTime = (Now bsr TSOffset) bsl TSOffset,
     try
         delete_next_until(Schema, It, SafeCutoffTime, Selector, BatchSize)
@@ -441,6 +537,24 @@ delete_next_until(
         rocksdb:iterator_close(ITHandle)
     end.
 
+handle_event(_ShardId, State = #s{gvars = Gvars}, Time, tick) ->
+    case ets:lookup(Gvars, ?IDLE_DETECT) of
+        [{?IDLE_DETECT, Latch, LastWrittenTs}] ->
+            ok;
+        [] ->
+            Latch = false,
+            LastWrittenTs = 0
+    end,
+    case Latch of
+        false when ?EPOCH(State, Time) > ?EPOCH(State, LastWrittenTs) ->
+            ets:insert(Gvars, {?IDLE_DETECT, true, LastWrittenTs}),
+            [dummy_event];
+        _ ->
+            []
+    end;
+handle_event(_ShardId, _Data, _Time, _Event) ->
+    [].
+
 %%================================================================================
 %% Internal functions
 %%================================================================================
@@ -722,9 +836,6 @@ value_v1_to_message({Id, Qos, From, Flags, Headers, Topic, Payload, Timestamp, E
         extra = Extra
     }.
 
-serialize(Msg) ->
-    term_to_binary(message_to_value_v1(Msg)).
-
 deserialize(Blob) ->
     value_v1_to_message(binary_to_term(Blob)).
 
@@ -752,7 +863,8 @@ make_keymapper(TopicIndexBytes, BitsPerTopicLevel, TSBits, TSOffsetBits, N) ->
 -spec restore_trie(pos_integer(), rocksdb:db_handle(), rocksdb:cf_handle()) -> emqx_ds_lts:trie().
 restore_trie(TopicIndexBytes, DB, CF) ->
     PersistCallback = fun(Key, Val) ->
-        rocksdb:put(DB, CF, term_to_binary(Key), term_to_binary(Val), [])
+        push_lts_persist_op(Key, Val),
+        ok
     end,
     {ok, IT} = rocksdb:iterator(DB, CF, []),
     try
@@ -800,8 +912,29 @@ data_cf(GenId) ->
 trie_cf(GenId) ->
     "emqx_ds_storage_bitfield_lts_trie" ++ integer_to_list(GenId).
 
+-spec push_lts_persist_op(_Key, _Val) -> ok.
+push_lts_persist_op(Key, Val) ->
+    case erlang:get(?lts_persist_ops) of
+        undefined ->
+            erlang:put(?lts_persist_ops, [{Key, Val}]);
+        L when is_list(L) ->
+            erlang:put(?lts_persist_ops, [{Key, Val} | L])
+    end.
+
+-spec pop_lts_persist_ops() -> [{_Key, _Val}].
+pop_lts_persist_ops() ->
+    case erlang:erase(?lts_persist_ops) of
+        undefined ->
+            [];
+        L when is_list(L) ->
+            L
+    end.
+
 -ifdef(TEST).
 
+serialize(Msg) ->
+    term_to_binary(message_to_value_v1(Msg)).
+
 serialize_deserialize_test() ->
     Msg = #message{
         id = <<"message_id_val">>,

+ 114 - 24
apps/emqx_durable_storage/src/emqx_ds_storage_layer.erl

@@ -26,13 +26,16 @@
 
     %% Data
     store_batch/3,
+    prepare_batch/3,
+    commit_batch/2,
+
     get_streams/3,
     get_delete_streams/3,
     make_iterator/4,
     make_delete_iterator/4,
     update_iterator/3,
-    next/3,
-    delete_next/4,
+    next/4,
+    delete_next/5,
 
     %% Generations
     update_config/3,
@@ -42,7 +45,10 @@
 
     %% Snapshotting
     take_snapshot/1,
-    accept_snapshot/1
+    accept_snapshot/1,
+
+    %% Custom events
+    handle_event/3
 ]).
 
 %% gen_server
@@ -63,7 +69,8 @@
     shard_id/0,
     options/0,
     prototype/0,
-    post_creation_context/0
+    post_creation_context/0,
+    cooked_batch/0
 ]).
 
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
@@ -79,11 +86,11 @@
 
 %% # "Record" integer keys.  We use maps with integer keys to avoid persisting and sending
 %% records over the wire.
-
 %% tags:
 -define(STREAM, 1).
 -define(IT, 2).
 -define(DELETE_IT, 3).
+-define(COOKED_BATCH, 4).
 
 %% keys:
 -define(tag, 1).
@@ -130,6 +137,13 @@
         ?enc := term()
     }.
 
+-opaque cooked_batch() ::
+    #{
+        ?tag := ?COOKED_BATCH,
+        ?generation := gen_id(),
+        ?enc := term()
+    }.
+
 %%%% Generation:
 
 -define(GEN_KEY(GEN_ID), {generation, GEN_ID}).
@@ -201,16 +215,23 @@
 -callback open(shard_id(), rocksdb:db_handle(), gen_id(), cf_refs(), _Schema) ->
     _Data.
 
+%% Delete the schema and data
 -callback drop(shard_id(), rocksdb:db_handle(), gen_id(), cf_refs(), _RuntimeData) ->
     ok | {error, _Reason}.
 
--callback store_batch(
+-callback prepare_batch(
     shard_id(),
     _Data,
-    [{emqx_ds:time(), emqx_types:message()}],
+    [{emqx_ds:time(), emqx_types:message()}, ...],
     emqx_ds:message_store_opts()
 ) ->
-    emqx_ds:store_batch_result().
+    {ok, term()} | emqx_ds:error(_).
+
+-callback commit_batch(
+    shard_id(),
+    _Data,
+    _CookedBatch
+) -> ok | emqx_ds:error(_).
 
 -callback get_streams(shard_id(), _Data, emqx_ds:topic_filter(), emqx_ds:time()) ->
     [_Stream].
@@ -223,12 +244,19 @@
 ) ->
     emqx_ds:make_delete_iterator_result(_Iterator).
 
--callback next(shard_id(), _Data, Iter, pos_integer()) ->
+-callback next(shard_id(), _Data, Iter, pos_integer(), emqx_ds:time()) ->
     {ok, Iter, [emqx_types:message()]} | {error, _}.
 
+-callback delete_next(
+    shard_id(), _Data, DeleteIterator, emqx_ds:delete_selector(), pos_integer(), emqx_ds:time()
+) ->
+    {ok, DeleteIterator, _NDeleted :: non_neg_integer(), _IteratedOver :: non_neg_integer()}.
+
+-callback handle_event(shard_id(), _Data, emqx_ds:time(), CustomEvent | tick) -> [CustomEvent].
+
 -callback post_creation_actions(post_creation_context()) -> _Data.
 
--optional_callbacks([post_creation_actions/1]).
+-optional_callbacks([post_creation_actions/1, handle_event/4]).
 
 %%================================================================================
 %% API for the replication layer
@@ -251,20 +279,54 @@ drop_shard(Shard) ->
     emqx_ds:message_store_opts()
 ) ->
     emqx_ds:store_batch_result().
-store_batch(Shard, Messages = [{Time, _Msg} | _], Options) ->
+store_batch(Shard, Messages, Options) ->
+    ?tp(emqx_ds_storage_layer_store_batch, #{
+        shard => Shard, messages => Messages, options => Options
+    }),
+    case prepare_batch(Shard, Messages, Options) of
+        {ok, CookedBatch} ->
+            commit_batch(Shard, CookedBatch);
+        ignore ->
+            ok;
+        Error = {error, _, _} ->
+            Error
+    end.
+
+-spec prepare_batch(
+    shard_id(),
+    [{emqx_ds:time(), emqx_types:message()}],
+    emqx_ds:message_store_opts()
+) -> {ok, cooked_batch()} | ignore | emqx_ds:error(_).
+prepare_batch(Shard, Messages = [{Time, _Msg} | _], Options) ->
     %% NOTE
     %% We assume that batches do not span generations. Callers should enforce this.
-    ?tp(emqx_ds_storage_layer_store_batch, #{
+    ?tp(emqx_ds_storage_layer_prepare_batch, #{
         shard => Shard, messages => Messages, options => Options
     }),
-    #{module := Mod, data := GenData} = generation_at(Shard, Time),
+    {GenId, #{module := Mod, data := GenData}} = generation_at(Shard, Time),
     T0 = erlang:monotonic_time(microsecond),
-    Result = Mod:store_batch(Shard, GenData, Messages, Options),
+    Result =
+        case Mod:prepare_batch(Shard, GenData, Messages, Options) of
+            {ok, CookedBatch} ->
+                {ok, #{?tag => ?COOKED_BATCH, ?generation => GenId, ?enc => CookedBatch}};
+            Error = {error, _, _} ->
+                Error
+        end,
     T1 = erlang:monotonic_time(microsecond),
+    %% TODO store->prepare
     emqx_ds_builtin_metrics:observe_store_batch_time(Shard, T1 - T0),
     Result;
-store_batch(_Shard, [], _Options) ->
-    ok.
+prepare_batch(_Shard, [], _Options) ->
+    ignore.
+
+-spec commit_batch(shard_id(), cooked_batch()) -> emqx_ds:store_batch_result().
+commit_batch(Shard, #{?tag := ?COOKED_BATCH, ?generation := GenId, ?enc := CookedBatch}) ->
+    #{?GEN_KEY(GenId) := #{module := Mod, data := GenData}} = get_schema_runtime(Shard),
+    T0 = erlang:monotonic_time(microsecond),
+    Result = Mod:commit_batch(Shard, GenData, CookedBatch),
+    T1 = erlang:monotonic_time(microsecond),
+    emqx_ds_builtin_metrics:observe_store_batch_time(Shard, T1 - T0),
+    Result.
 
 -spec get_streams(shard_id(), emqx_ds:topic_filter(), emqx_ds:time()) ->
     [{integer(), stream()}].
@@ -277,6 +339,13 @@ get_streams(Shard, TopicFilter, StartTime) ->
             case generation_get(Shard, GenId) of
                 #{module := Mod, data := GenData} ->
                     Streams = Mod:get_streams(Shard, GenData, TopicFilter, StartTime),
+                    ?tp(get_streams_get_gen_topic, #{
+                        gen_id => GenId,
+                        topic => TopicFilter,
+                        start_time => StartTime,
+                        streams => Streams,
+                        gen_data => GenData
+                    }),
                     [
                         {GenId, ?stream_v2(GenId, InnerStream)}
                      || InnerStream <- Streams
@@ -377,13 +446,13 @@ update_iterator(
             {error, unrecoverable, generation_not_found}
     end.
 
--spec next(shard_id(), iterator(), pos_integer()) ->
+-spec next(shard_id(), iterator(), pos_integer(), emqx_ds:time()) ->
     emqx_ds:next_result(iterator()).
-next(Shard, Iter = #{?tag := ?IT, ?generation := GenId, ?enc := GenIter0}, BatchSize) ->
+next(Shard, Iter = #{?tag := ?IT, ?generation := GenId, ?enc := GenIter0}, BatchSize, Now) ->
     case generation_get(Shard, GenId) of
         #{module := Mod, data := GenData} ->
             Current = generation_current(Shard),
-            case Mod:next(Shard, GenData, GenIter0, BatchSize) of
+            case Mod:next(Shard, GenData, GenIter0, BatchSize, Now) of
                 {ok, _GenIter, []} when GenId < Current ->
                     %% This is a past generation. Storage layer won't write
                     %% any more messages here. The iterator reached the end:
@@ -399,18 +468,21 @@ next(Shard, Iter = #{?tag := ?IT, ?generation := GenId, ?enc := GenIter0}, Batch
             {error, unrecoverable, generation_not_found}
     end.
 
--spec delete_next(shard_id(), delete_iterator(), emqx_ds:delete_selector(), pos_integer()) ->
+-spec delete_next(
+    shard_id(), delete_iterator(), emqx_ds:delete_selector(), pos_integer(), emqx_ds:time()
+) ->
     emqx_ds:delete_next_result(delete_iterator()).
 delete_next(
     Shard,
     Iter = #{?tag := ?DELETE_IT, ?generation := GenId, ?enc := GenIter0},
     Selector,
-    BatchSize
+    BatchSize,
+    Now
 ) ->
     case generation_get(Shard, GenId) of
         #{module := Mod, data := GenData} ->
             Current = generation_current(Shard),
-            case Mod:delete_next(Shard, GenData, GenIter0, Selector, BatchSize) of
+            case Mod:delete_next(Shard, GenData, GenIter0, Selector, BatchSize, Now) of
                 {ok, _GenIter, _Deleted = 0, _IteratedOver = 0} when GenId < Current ->
                     %% This is a past generation. Storage layer won't write
                     %% any more messages here. The iterator reached the end:
@@ -849,6 +921,24 @@ handle_accept_snapshot(ShardId) ->
     Dir = db_dir(ShardId),
     emqx_ds_storage_snapshot:new_writer(Dir).
 
+%% FIXME: currently this interface is a hack to handle safe cutoff
+%% timestamp in LTS. It has many shortcomings (can lead to infinite
+%% loops if the CBM is not careful; events from one generation may be
+%% sent to the next one, etc.) and the API is not well thought out in
+%% general.
+%%
+%% The mechanism of storage layer events should be refined later.
+-spec handle_event(shard_id(), emqx_ds:time(), CustomEvent | tick) -> [CustomEvent].
+handle_event(Shard, Time, Event) ->
+    {_GenId, #{module := Mod, data := GenData}} = generation_at(Shard, Time),
+    ?tp(emqx_ds_storage_layer_event, #{mod => Mod, time => Time, event => Event}),
+    case erlang:function_exported(Mod, handle_event, 4) of
+        true ->
+            Mod:handle_event(Shard, GenData, Time, Event);
+        false ->
+            []
+    end.
+
 %%--------------------------------------------------------------------------------
 %% Schema access
 %%--------------------------------------------------------------------------------
@@ -881,7 +971,7 @@ generations_since(Shard, Since) ->
         Schema
     ).
 
--spec generation_at(shard_id(), emqx_ds:time()) -> generation().
+-spec generation_at(shard_id(), emqx_ds:time()) -> {gen_id(), generation()}.
 generation_at(Shard, Time) ->
     Schema = #{current_generation := Current} = get_schema_runtime(Shard),
     generation_at(Time, Current, Schema).
@@ -892,7 +982,7 @@ generation_at(Time, GenId, Schema) ->
         #{since := Since} when Time < Since andalso GenId > 0 ->
             generation_at(Time, prev_generation_id(GenId), Schema);
         _ ->
-            Gen
+            {GenId, Gen}
     end.
 
 -define(PERSISTENT_TERM(SHARD), {emqx_ds_storage_layer, SHARD}).

+ 13 - 19
apps/emqx_durable_storage/src/emqx_ds_storage_reference.erl

@@ -31,14 +31,15 @@
     create/4,
     open/5,
     drop/5,
-    store_batch/4,
+    prepare_batch/4,
+    commit_batch/3,
     get_streams/4,
     get_delete_streams/4,
     make_iterator/5,
     make_delete_iterator/5,
     update_iterator/4,
-    next/4,
-    delete_next/5
+    next/5,
+    delete_next/6
 ]).
 
 %% internal exports:
@@ -101,12 +102,14 @@ drop(_ShardId, DBHandle, _GenId, _CFRefs, #s{cf = CFHandle}) ->
     ok = rocksdb:drop_column_family(DBHandle, CFHandle),
     ok.
 
-store_batch(_ShardId, #s{db = DB, cf = CF}, Messages, _Options = #{atomic := true}) ->
+prepare_batch(_ShardId, _Data, Messages, _Options) ->
+    {ok, Messages}.
+
+commit_batch(_ShardId, #s{db = DB, cf = CF}, Messages) ->
     {ok, Batch} = rocksdb:batch(),
     lists:foreach(
-        fun(Msg) ->
-            Id = erlang:unique_integer([monotonic]),
-            Key = <<Id:64>>,
+        fun({TS, Msg}) ->
+            Key = <<TS:64>>,
             Val = term_to_binary(Msg),
             rocksdb:batch_put(Batch, CF, Key, Val)
         end,
@@ -114,16 +117,7 @@ store_batch(_ShardId, #s{db = DB, cf = CF}, Messages, _Options = #{atomic := tru
     ),
     Res = rocksdb:write_batch(DB, Batch, _WriteOptions = []),
     rocksdb:release_batch(Batch),
-    Res;
-store_batch(_ShardId, #s{db = DB, cf = CF}, Messages, _Options) ->
-    lists:foreach(
-        fun({Timestamp, Msg}) ->
-            Key = <<Timestamp:64>>,
-            Val = term_to_binary(Msg),
-            rocksdb:put(DB, CF, Key, Val, [])
-        end,
-        Messages
-    ).
+    Res.
 
 get_streams(_Shard, _Data, _TopicFilter, _StartTime) ->
     [#stream{}].
@@ -154,7 +148,7 @@ update_iterator(_Shard, _Data, OldIter, DSKey) ->
         last_seen_message_key = DSKey
     }}.
 
-next(_Shard, #s{db = DB, cf = CF}, It0, BatchSize) ->
+next(_Shard, #s{db = DB, cf = CF}, It0, BatchSize, _Now) ->
     #it{topic_filter = TopicFilter, start_time = StartTime, last_seen_message_key = Key0} = It0,
     {ok, ITHandle} = rocksdb:iterator(DB, CF, []),
     Action =
@@ -170,7 +164,7 @@ next(_Shard, #s{db = DB, cf = CF}, It0, BatchSize) ->
     It = It0#it{last_seen_message_key = Key},
     {ok, It, lists:reverse(Messages)}.
 
-delete_next(_Shard, #s{db = DB, cf = CF}, It0, Selector, BatchSize) ->
+delete_next(_Shard, #s{db = DB, cf = CF}, It0, Selector, BatchSize, _Now) ->
     #delete_it{
         topic_filter = TopicFilter,
         start_time = StartTime,

+ 264 - 263
apps/emqx_durable_storage/test/emqx_ds_replication_SUITE.erl

@@ -21,10 +21,14 @@
 -include_lib("emqx/include/emqx.hrl").
 -include_lib("common_test/include/ct.hrl").
 -include_lib("stdlib/include/assert.hrl").
--include_lib("snabbkaffe/include/test_macros.hrl").
+-include_lib("snabbkaffe/include/snabbkaffe.hrl").
 
 -define(DB, testdb).
 
+-define(ON(NODE, BODY),
+    erpc:call(NODE, erlang, apply, [fun() -> BODY end, []])
+).
+
 opts() ->
     opts(#{}).
 
@@ -32,12 +36,13 @@ opts(Overrides) ->
     maps:merge(
         #{
             backend => builtin,
-            storage => {emqx_ds_storage_bitfield_lts, #{}},
+            %% storage => {emqx_ds_storage_reference, #{}},
+            storage => {emqx_ds_storage_bitfield_lts, #{epoch_bits => 10}},
             n_shards => 16,
             n_sites => 1,
             replication_factor => 3,
             replication_options => #{
-                wal_max_size_bytes => 64 * 1024,
+                wal_max_size_bytes => 64,
                 wal_max_batch_size => 1024,
                 snapshot_interval => 128
             }
@@ -67,64 +72,61 @@ t_replication_transfers_snapshots('end', Config) ->
     ok = emqx_cth_cluster:stop(?config(nodes, Config)).
 
 t_replication_transfers_snapshots(Config) ->
-    NMsgs = 4000,
-    Nodes = [Node, NodeOffline | _] = ?config(nodes, Config),
-    _Specs = [_, SpecOffline | _] = ?config(specs, Config),
-
-    %% Initialize DB on all nodes and wait for it to be online.
-    Opts = opts(#{n_shards => 1, n_sites => 3}),
-    ?assertEqual(
-        [{ok, ok} || _ <- Nodes],
-        erpc:multicall(Nodes, emqx_ds, open_db, [?DB, Opts])
-    ),
-    ?retry(
-        500,
-        10,
-        ?assertMatch([[_], [_], [_]], [shards_online(N, ?DB) || N <- Nodes])
-    ),
-
-    %% Stop the DB on the "offline" node.
-    ok = emqx_cth_cluster:stop_node(NodeOffline),
-
-    %% Fill the storage with messages and few additional generations.
-    Messages = fill_storage(Node, ?DB, NMsgs, #{p_addgen => 0.01}),
-
-    %% Restart the node.
-    [NodeOffline] = emqx_cth_cluster:restart(SpecOffline),
-    {ok, SRef} = snabbkaffe:subscribe(
-        ?match_event(#{
-            ?snk_kind := dsrepl_snapshot_accepted,
-            ?snk_meta := #{node := NodeOffline}
-        })
-    ),
-    ?assertEqual(
-        ok,
-        erpc:call(NodeOffline, emqx_ds, open_db, [?DB, opts()])
-    ),
-
-    %% Trigger storage operation and wait the replica to be restored.
-    _ = add_generation(Node, ?DB),
-    ?assertMatch(
-        {ok, _},
-        snabbkaffe:receive_events(SRef)
+    NMsgs = 400,
+    NClients = 5,
+    {Stream, TopicStreams} = emqx_ds_test_helpers:interleaved_topic_messages(
+        ?FUNCTION_NAME, NClients, NMsgs
     ),
 
-    %% Wait until any pending replication activities are finished (e.g. Raft log entries).
-    ok = timer:sleep(3_000),
-
-    %% Check that the DB has been restored.
-    Shard = hd(shards(NodeOffline, ?DB)),
-    MessagesOffline = lists:keysort(
-        #message.timestamp,
-        consume_shard(NodeOffline, ?DB, Shard, ['#'], 0)
-    ),
-    ?assertEqual(
-        sample(40, Messages),
-        sample(40, MessagesOffline)
-    ),
-    ?assertEqual(
-        Messages,
-        MessagesOffline
+    Nodes = [Node, NodeOffline | _] = ?config(nodes, Config),
+    _Specs = [_, SpecOffline | _] = ?config(specs, Config),
+    ?check_trace(
+        begin
+            %% Initialize DB on all nodes and wait for it to be online.
+            Opts = opts(#{n_shards => 1, n_sites => 3}),
+            ?assertEqual(
+                [{ok, ok} || _ <- Nodes],
+                erpc:multicall(Nodes, emqx_ds, open_db, [?DB, Opts])
+            ),
+            ?retry(
+                500,
+                10,
+                ?assertMatch([[_], [_], [_]], [shards_online(N, ?DB) || N <- Nodes])
+            ),
+
+            %% Stop the DB on the "offline" node.
+            ok = emqx_cth_cluster:stop_node(NodeOffline),
+
+            %% Fill the storage with messages and few additional generations.
+            emqx_ds_test_helpers:apply_stream(?DB, Nodes -- [NodeOffline], Stream),
+
+            %% Restart the node.
+            [NodeOffline] = emqx_cth_cluster:restart(SpecOffline),
+            {ok, SRef} = snabbkaffe:subscribe(
+                ?match_event(#{
+                    ?snk_kind := dsrepl_snapshot_accepted,
+                    ?snk_meta := #{node := NodeOffline}
+                })
+            ),
+            ?assertEqual(
+                ok,
+                erpc:call(NodeOffline, emqx_ds, open_db, [?DB, opts()])
+            ),
+
+            %% Trigger storage operation and wait the replica to be restored.
+            _ = add_generation(Node, ?DB),
+            ?assertMatch(
+                {ok, _},
+                snabbkaffe:receive_events(SRef)
+            ),
+
+            %% Wait until any pending replication activities are finished (e.g. Raft log entries).
+            ok = timer:sleep(3_000),
+
+            %% Check that the DB has been restored:
+            emqx_ds_test_helpers:verify_stream_effects(?DB, ?FUNCTION_NAME, Nodes, TopicStreams)
+        end,
+        []
     ).
 
 t_rebalance(init, Config) ->
@@ -142,112 +144,120 @@ t_rebalance(init, Config) ->
 t_rebalance('end', Config) ->
     ok = emqx_cth_cluster:stop(?config(nodes, Config)).
 
+%% This testcase verifies that the storage rebalancing works correctly:
+%% 1. Join/leave operations are applied successfully.
+%% 2. Message data survives the rebalancing.
+%% 3. Shard cluster membership converges to the target replica allocation.
+%% 4. Replication factor is respected.
 t_rebalance(Config) ->
-    %% This testcase verifies that the storage rebalancing works correctly:
-    %% 1. Join/leave operations are applied successfully.
-    %% 2. Message data survives the rebalancing.
-    %% 3. Shard cluster membership converges to the target replica allocation.
-    %% 4. Replication factor is respected.
-
-    NMsgs = 800,
+    NMsgs = 50,
     NClients = 5,
-    Nodes = [N1, N2, N3, N4] = ?config(nodes, Config),
-
-    %% Initialize DB on the first node.
-    Opts = opts(#{n_shards => 16, n_sites => 1, replication_factor => 3}),
-    ?assertEqual(ok, erpc:call(N1, emqx_ds, open_db, [?DB, Opts])),
-    ?assertMatch(
-        Shards when length(Shards) == 16,
-        shards_online(N1, ?DB)
-    ),
-
-    %% Open DB on the rest of the nodes.
-    ?assertEqual(
-        [{ok, ok} || _ <- [N2, N3, N4]],
-        erpc:multicall([N2, N3, N4], emqx_ds, open_db, [?DB, Opts])
+    {Stream0, TopicStreams} = emqx_ds_test_helpers:interleaved_topic_messages(
+        ?FUNCTION_NAME, NClients, NMsgs
     ),
-
-    Sites = [S1, S2 | _Rest] = [ds_repl_meta(N, this_site) || N <- Nodes],
-    ct:pal("Sites: ~p~n", [Sites]),
-
-    %% Only N1 should be responsible for all shards initially.
-    ?assertEqual(
-        [[S1] || _ <- Nodes],
-        [ds_repl_meta(N, db_sites, [?DB]) || N <- Nodes]
-    ),
-
-    %% Fill the storage with messages and few additional generations.
-    %% This will force shards to trigger snapshot transfers during rebalance.
-    ClientMessages = emqx_utils:pmap(
-        fun(CID) ->
-            N = lists:nth(1 + (CID rem length(Nodes)), Nodes),
-            fill_storage(N, ?DB, NMsgs, #{client_id => integer_to_binary(CID)})
+    Nodes = [N1, N2 | _] = ?config(nodes, Config),
+    ?check_trace(
+        #{timetrap => 30_000},
+        begin
+            %% 1. Initialize DB on the first node.
+            Opts = opts(#{n_shards => 16, n_sites => 1, replication_factor => 3}),
+            ?assertEqual(ok, ?ON(N1, emqx_ds:open_db(?DB, Opts))),
+            ?assertMatch(Shards when length(Shards) == 16, shards_online(N1, ?DB)),
+
+            %% 1.1 Open DB on the rest of the nodes:
+            [
+                ?assertEqual(ok, ?ON(Node, emqx_ds:open_db(?DB, Opts)))
+             || Node <- Nodes
+            ],
+
+            Sites = [S1, S2 | _] = [ds_repl_meta(N, this_site) || N <- Nodes],
+            ct:pal("Sites: ~p~n", [Sites]),
+
+            Sequence = [
+                %% Join the second site to the DB replication sites:
+                {N1, join_db_site, S2},
+                %% Should be a no-op:
+                {N2, join_db_site, S2},
+                %% Now join the rest of the sites:
+                {N2, assign_db_sites, Sites}
+            ],
+            Stream1 = emqx_utils_stream:interleave(
+                [
+                    {50, Stream0},
+                    emqx_utils_stream:const(add_generation)
+                ],
+                false
+            ),
+            Stream = emqx_utils_stream:interleave(
+                [
+                    {50, Stream0},
+                    emqx_utils_stream:list(Sequence)
+                ],
+                true
+            ),
+
+            %% 1.2 Verify that all nodes have the same view of metadata storage:
+            [
+                ?defer_assert(
+                    ?assertEqual(
+                        [S1],
+                        ?ON(Node, emqx_ds_replication_layer_meta:db_sites(?DB)),
+                        #{
+                            msg => "Initially, only S1 should be responsible for all shards",
+                            node => Node
+                        }
+                    )
+                )
+             || Node <- Nodes
+            ],
+
+            %% 2. Start filling the storage:
+            emqx_ds_test_helpers:apply_stream(?DB, Nodes, Stream),
+            timer:sleep(5000),
+            emqx_ds_test_helpers:verify_stream_effects(?DB, ?FUNCTION_NAME, Nodes, TopicStreams),
+            [
+                ?defer_assert(
+                    ?assertEqual(
+                        16 * 3 div length(Nodes),
+                        n_shards_online(Node, ?DB),
+                        "Each node is now responsible for 3/4 of the shards"
+                    )
+                )
+             || Node <- Nodes
+            ],
+
+            %% Verify that the set of shard servers matches the target allocation.
+            Allocation = [ds_repl_meta(N, my_shards, [?DB]) || N <- Nodes],
+            ShardServers = [
+                shard_server_info(N, ?DB, Shard, Site, readiness)
+             || {N, Site, Shards} <- lists:zip3(Nodes, Sites, Allocation),
+                Shard <- Shards
+            ],
+            ?assert(
+                lists:all(fun({_Server, Status}) -> Status == ready end, ShardServers),
+                ShardServers
+            ),
+
+            %% Scale down the cluster by removing the first node.
+            ?assertEqual(ok, ds_repl_meta(N1, leave_db_site, [?DB, S1])),
+            ct:pal("Transitions (~p -> ~p): ~p~n", [
+                Sites, tl(Sites), emqx_ds_test_helpers:transitions(N1, ?DB)
+            ]),
+            ?retry(1000, 10, ?assertEqual([], emqx_ds_test_helpers:transitions(N2, ?DB))),
+
+            %% Verify that at the end each node is now responsible for each shard.
+            ?defer_assert(
+                ?assertEqual(
+                    [0, 16, 16, 16],
+                    [n_shards_online(N, ?DB) || N <- Nodes]
+                )
+            ),
+
+            %% Verify that the messages are once again preserved after the rebalance:
+            emqx_ds_test_helpers:verify_stream_effects(?DB, ?FUNCTION_NAME, Nodes, TopicStreams)
         end,
-        lists:seq(1, NClients),
-        infinity
-    ),
-    Messages1 = lists:sort(fun compare_message/2, lists:append(ClientMessages)),
-
-    %% Join the second site to the DB replication sites.
-    ?assertEqual(ok, ds_repl_meta(N1, join_db_site, [?DB, S2])),
-    %% Should be no-op.
-    ?assertEqual(ok, ds_repl_meta(N2, join_db_site, [?DB, S2])),
-    ct:pal("Transitions (~p -> ~p): ~p~n", [[S1], [S1, S2], transitions(N1, ?DB)]),
-
-    %% Fill in some more messages *during* the rebalance.
-    MessagesRB1 = fill_storage(N4, ?DB, NMsgs, #{client_id => <<"RB1">>}),
-
-    ?retry(1000, 10, ?assertEqual([], transitions(N1, ?DB))),
-
-    %% Now join the rest of the sites.
-    ?assertEqual(ok, ds_repl_meta(N2, assign_db_sites, [?DB, Sites])),
-    ct:pal("Transitions (~p -> ~p): ~p~n", [[S1, S2], Sites, transitions(N1, ?DB)]),
-
-    %% Fill in some more messages *during* the rebalance.
-    MessagesRB2 = fill_storage(N4, ?DB, NMsgs, #{client_id => <<"RB2">>}),
-
-    ?retry(1000, 10, ?assertEqual([], transitions(N2, ?DB))),
-
-    %% Verify that each node is now responsible for 3/4 of the shards.
-    ?assertEqual(
-        [(16 * 3) div length(Nodes) || _ <- Nodes],
-        [n_shards_online(N, ?DB) || N <- Nodes]
-    ),
-
-    %% Verify that the set of shard servers matches the target allocation.
-    Allocation = [ds_repl_meta(N, my_shards, [?DB]) || N <- Nodes],
-    ShardServers = [
-        shard_server_info(N, ?DB, Shard, Site, readiness)
-     || {N, Site, Shards} <- lists:zip3(Nodes, Sites, Allocation),
-        Shard <- Shards
-    ],
-    ?assert(
-        lists:all(fun({_Server, Status}) -> Status == ready end, ShardServers),
-        ShardServers
-    ),
-
-    %% Verify that the messages are preserved after the rebalance.
-    Messages = Messages1 ++ MessagesRB1 ++ MessagesRB2,
-    MessagesN4 = lists:sort(fun compare_message/2, consume(N4, ?DB, ['#'], 0)),
-    ?assertEqual(sample(20, Messages), sample(20, MessagesN4)),
-    ?assertEqual(Messages, MessagesN4),
-
-    %% Scale down the cluster by removing the first node.
-    ?assertEqual(ok, ds_repl_meta(N1, leave_db_site, [?DB, S1])),
-    ct:pal("Transitions (~p -> ~p): ~p~n", [Sites, tl(Sites), transitions(N1, ?DB)]),
-
-    ?retry(1000, 10, ?assertEqual([], transitions(N2, ?DB))),
-
-    %% Verify that each node is now responsible for each shard.
-    ?assertEqual(
-        [0, 16, 16, 16],
-        [n_shards_online(N, ?DB) || N <- Nodes]
-    ),
-
-    %% Verify that the messages are once again preserved after the rebalance.
-    MessagesN3 = lists:sort(fun compare_message/2, consume(N3, ?DB, ['#'], 0)),
-    ?assertEqual(sample(20, Messages), sample(20, MessagesN3)),
-    ?assertEqual(Messages, MessagesN3).
+        []
+    ).
 
 t_join_leave_errors(init, Config) ->
     Apps = [appspec(emqx_durable_storage)],
@@ -293,7 +303,7 @@ t_join_leave_errors(Config) ->
 
     %% Should be no-op.
     ?assertEqual(ok, ds_repl_meta(N1, join_db_site, [?DB, S1])),
-    ?assertEqual([], transitions(N1, ?DB)),
+    ?assertEqual([], emqx_ds_test_helpers:transitions(N1, ?DB)),
 
     %% Impossible to leave the last site.
     ?assertEqual(
@@ -304,12 +314,12 @@ t_join_leave_errors(Config) ->
     %% "Move" the DB to the other node.
     ?assertEqual(ok, ds_repl_meta(N1, join_db_site, [?DB, S2])),
     ?assertEqual(ok, ds_repl_meta(N2, leave_db_site, [?DB, S1])),
-    ?assertMatch([_ | _], transitions(N1, ?DB)),
-    ?retry(1000, 10, ?assertEqual([], transitions(N1, ?DB))),
+    ?assertMatch([_ | _], emqx_ds_test_helpers:transitions(N1, ?DB)),
+    ?retry(1000, 10, ?assertEqual([], emqx_ds_test_helpers:transitions(N1, ?DB))),
 
     %% Should be no-op.
     ?assertEqual(ok, ds_repl_meta(N2, leave_db_site, [?DB, S1])),
-    ?assertEqual([], transitions(N1, ?DB)).
+    ?assertEqual([], emqx_ds_test_helpers:transitions(N1, ?DB)).
 
 t_rebalance_chaotic_converges(init, Config) ->
     Apps = [appspec(emqx_durable_storage)],
@@ -333,78 +343,79 @@ t_rebalance_chaotic_converges(Config) ->
     NMsgs = 500,
     Nodes = [N1, N2, N3] = ?config(nodes, Config),
 
-    %% Initialize DB on first two nodes.
-    Opts = opts(#{n_shards => 16, n_sites => 2, replication_factor => 3}),
-    ?assertEqual(
-        [{ok, ok}, {ok, ok}],
-        erpc:multicall([N1, N2], emqx_ds, open_db, [?DB, Opts])
-    ),
-
-    %% Open DB on the last node.
-    ?assertEqual(
-        ok,
-        erpc:call(N3, emqx_ds, open_db, [?DB, Opts])
-    ),
-
-    %% Find out which sites there are.
-    Sites = [S1, S2, S3] = [ds_repl_meta(N, this_site) || N <- Nodes],
-    ct:pal("Sites: ~p~n", [Sites]),
-
-    %% Initially, the DB is assigned to [S1, S2].
-    ?retry(500, 10, ?assertEqual([16, 16], [n_shards_online(N, ?DB) || N <- [N1, N2]])),
-    ?assertEqual(
-        lists:sort([S1, S2]),
-        ds_repl_meta(N1, db_sites, [?DB])
+    NClients = 5,
+    {Stream0, TopicStreams} = emqx_ds_test_helpers:interleaved_topic_messages(
+        ?FUNCTION_NAME, NClients, NMsgs
     ),
 
-    %% Fill the storage with messages and few additional generations.
-    Messages0 = lists:append([
-        fill_storage(N1, ?DB, NMsgs, #{client_id => <<"C1">>}),
-        fill_storage(N2, ?DB, NMsgs, #{client_id => <<"C2">>}),
-        fill_storage(N3, ?DB, NMsgs, #{client_id => <<"C3">>})
-    ]),
-
-    %% Construct a chaotic transition sequence that changes assignment to [S2, S3].
-    Sequence = [
-        {N1, join_db_site, S3},
-        {N2, leave_db_site, S2},
-        {N3, leave_db_site, S1},
-        {N1, join_db_site, S2},
-        {N2, join_db_site, S1},
-        {N3, leave_db_site, S3},
-        {N1, leave_db_site, S1},
-        {N2, join_db_site, S3}
-    ],
-
-    %% Apply the sequence while also filling the storage with messages.
-    TransitionMessages = lists:map(
-        fun({N, Operation, Site}) ->
-            %% Apply the transition.
-            ?assertEqual(ok, ds_repl_meta(N, Operation, [?DB, Site])),
-            %% Give some time for at least one transition to complete.
-            Transitions = transitions(N, ?DB),
-            ct:pal("Transitions after ~p: ~p", [Operation, Transitions]),
-            ?retry(200, 10, ?assertNotEqual(Transitions, transitions(N, ?DB))),
-            %% Fill the storage with messages.
-            CID = integer_to_binary(erlang:system_time()),
-            fill_storage(N, ?DB, NMsgs, #{client_id => CID})
+    ?check_trace(
+        #{},
+        begin
+            %% Initialize DB on first two nodes.
+            Opts = opts(#{n_shards => 16, n_sites => 2, replication_factor => 3}),
+
+            ?assertEqual(
+                [{ok, ok}, {ok, ok}],
+                erpc:multicall([N1, N2], emqx_ds, open_db, [?DB, Opts])
+            ),
+
+            %% Open DB on the last node.
+            ?assertEqual(
+                ok,
+                erpc:call(N3, emqx_ds, open_db, [?DB, Opts])
+            ),
+
+            %% Find out which sites there are.
+            Sites = [S1, S2, S3] = [ds_repl_meta(N, this_site) || N <- Nodes],
+            ct:pal("Sites: ~p~n", [Sites]),
+
+            Sequence = [
+                {N1, join_db_site, S3},
+                {N2, leave_db_site, S2},
+                {N3, leave_db_site, S1},
+                {N1, join_db_site, S2},
+                {N2, join_db_site, S1},
+                {N3, leave_db_site, S3},
+                {N1, leave_db_site, S1},
+                {N2, join_db_site, S3}
+            ],
+
+            %% Interleaved list of events:
+            Stream = emqx_utils_stream:interleave(
+                [
+                    {50, Stream0},
+                    emqx_utils_stream:list(Sequence)
+                ],
+                true
+            ),
+
+            ?retry(500, 10, ?assertEqual([16, 16], [n_shards_online(N, ?DB) || N <- [N1, N2]])),
+            ?assertEqual(
+                lists:sort([S1, S2]),
+                ds_repl_meta(N1, db_sites, [?DB]),
+                "Initially, the DB is assigned to [S1, S2]"
+            ),
+
+            emqx_ds_test_helpers:apply_stream(?DB, Nodes, Stream),
+
+            %% Wait for the last transition to complete.
+            ?retry(500, 20, ?assertEqual([], emqx_ds_test_helpers:transitions(N1, ?DB))),
+
+            ?defer_assert(
+                ?assertEqual(
+                    lists:sort([S2, S3]),
+                    ds_repl_meta(N1, db_sites, [?DB])
+                )
+            ),
+
+            %% Wait until the LTS timestamp is updated:
+            timer:sleep(5000),
+
+            %% Check that all messages are still there.
+            emqx_ds_test_helpers:verify_stream_effects(?DB, ?FUNCTION_NAME, Nodes, TopicStreams)
         end,
-        Sequence
-    ),
-
-    %% Wait for the last transition to complete.
-    ?retry(500, 20, ?assertEqual([], transitions(N1, ?DB))),
-
-    ?assertEqual(
-        lists:sort([S2, S3]),
-        ds_repl_meta(N1, db_sites, [?DB])
-    ),
-
-    %% Check that all messages are still there.
-    Messages = lists:append(TransitionMessages) ++ Messages0,
-    MessagesDB = lists:sort(fun compare_message/2, consume(N1, ?DB, ['#'], 0)),
-    ?assertEqual(sample(20, Messages), sample(20, MessagesDB)),
-    ?assertEqual(Messages, MessagesDB).
+        []
+    ).
 
 t_rebalance_offline_restarts(init, Config) ->
     Apps = [appspec(emqx_durable_storage)],
@@ -447,7 +458,7 @@ t_rebalance_offline_restarts(Config) ->
     %% Shut down N3 and then remove it from the DB.
     ok = emqx_cth_cluster:stop_node(N3),
     ?assertEqual(ok, ds_repl_meta(N1, leave_db_site, [?DB, S3])),
-    Transitions = transitions(N1, ?DB),
+    Transitions = emqx_ds_test_helpers:transitions(N1, ?DB),
     ct:pal("Transitions: ~p~n", [Transitions]),
 
     %% Wait until at least one transition completes.
@@ -462,7 +473,7 @@ t_rebalance_offline_restarts(Config) ->
     ),
 
     %% Target state should still be reached eventually.
-    ?retry(1000, 20, ?assertEqual([], transitions(N1, ?DB))),
+    ?retry(1000, 20, ?assertEqual([], emqx_ds_test_helpers:transitions(N1, ?DB))),
     ?assertEqual(lists:sort([S1, S2]), ds_repl_meta(N1, db_sites, [?DB])).
 
 %%
@@ -478,15 +489,19 @@ ds_repl_meta(Node, Fun) ->
     ds_repl_meta(Node, Fun, []).
 
 ds_repl_meta(Node, Fun, Args) ->
-    erpc:call(Node, emqx_ds_replication_layer_meta, Fun, Args).
+    try
+        erpc:call(Node, emqx_ds_replication_layer_meta, Fun, Args)
+    catch
+        EC:Err:Stack ->
+            ct:pal("emqx_ds_replication_layer_meta:~p(~p) @~p failed:~n~p:~p~nStack: ~p", [
+                Fun, Args, Node, EC, Err, Stack
+            ]),
+            error(meta_op_failed)
+    end.
 
 ds_repl_shard(Node, Fun, Args) ->
     erpc:call(Node, emqx_ds_replication_layer_shard, Fun, Args).
 
-transitions(Node, DB) ->
-    Shards = shards(Node, DB),
-    [{S, T} || S <- Shards, T <- ds_repl_meta(Node, replica_set_transitions, [DB, S])].
-
 shards(Node, DB) ->
     erpc:call(Node, emqx_ds_replication_layer_meta, shards, [DB]).
 
@@ -496,25 +511,6 @@ shards_online(Node, DB) ->
 n_shards_online(Node, DB) ->
     length(shards_online(Node, DB)).
 
-fill_storage(Node, DB, NMsgs, Opts) ->
-    fill_storage(Node, DB, NMsgs, 0, Opts).
-
-fill_storage(Node, DB, NMsgs, I, Opts) when I < NMsgs ->
-    PAddGen = maps:get(p_addgen, Opts, 0.001),
-    R1 = push_message(Node, DB, I, Opts),
-    R2 = probably(PAddGen, fun() -> add_generation(Node, DB) end),
-    R1 ++ R2 ++ fill_storage(Node, DB, NMsgs, I + 1, Opts);
-fill_storage(_Node, _DB, NMsgs, NMsgs, _Opts) ->
-    [].
-
-push_message(Node, DB, I, Opts) ->
-    Topic = emqx_topic:join([<<"topic">>, <<"foo">>, integer_to_binary(I)]),
-    {Bytes, _} = rand:bytes_s(120, rand:seed_s(default, I)),
-    ClientId = maps:get(client_id, Opts, <<?MODULE_STRING>>),
-    Message = message(ClientId, Topic, Bytes, I * 100),
-    ok = erpc:call(Node, emqx_ds, store_batch, [DB, [Message], #{sync => true}]),
-    [Message].
-
 add_generation(Node, DB) ->
     ok = erpc:call(Node, emqx_ds, add_generation, [DB]),
     [].
@@ -545,9 +541,14 @@ probably(P, Fun) ->
 
 sample(N, List) ->
     L = length(List),
-    H = N div 2,
-    Filler = integer_to_list(L - N) ++ " more",
-    lists:sublist(List, H) ++ [Filler] ++ lists:sublist(List, L - H, L).
+    case L =< N of
+        true ->
+            L;
+        false ->
+            H = N div 2,
+            Filler = integer_to_list(L - N) ++ " more",
+            lists:sublist(List, H) ++ [Filler] ++ lists:sublist(List, L - H, L)
+    end.
 
 %%
 

+ 1 - 1
apps/emqx_durable_storage/test/emqx_ds_storage_SUITE.erl

@@ -23,7 +23,7 @@
 -include_lib("stdlib/include/assert.hrl").
 
 opts() ->
-    #{storage => {emqx_ds_storage_bitfield_lts, #{}}}.
+    #{storage => {emqx_ds_storage_reference, #{}}}.
 
 %%
 

+ 11 - 5
apps/emqx_durable_storage/test/emqx_ds_storage_bitfield_lts_SUITE.erl

@@ -73,13 +73,15 @@ t_iterate(_Config) ->
         begin
             [{_Rank, Stream}] = emqx_ds_storage_layer:get_streams(?SHARD, parse_topic(Topic), 0),
             {ok, It} = emqx_ds_storage_layer:make_iterator(?SHARD, Stream, parse_topic(Topic), 0),
-            {ok, NextIt, MessagesAndKeys} = emqx_ds_storage_layer:next(?SHARD, It, 100),
+            {ok, NextIt, MessagesAndKeys} = emqx_ds_storage_layer:next(
+                ?SHARD, It, 100, emqx_ds:timestamp_us()
+            ),
             Messages = [Msg || {_DSKey, Msg} <- MessagesAndKeys],
             ?assertEqual(
                 lists:map(fun integer_to_binary/1, Timestamps),
                 payloads(Messages)
             ),
-            {ok, _, []} = emqx_ds_storage_layer:next(?SHARD, NextIt, 100)
+            {ok, _, []} = emqx_ds_storage_layer:next(?SHARD, NextIt, 100, emqx_ds:timestamp_us())
         end
      || Topic <- Topics
     ],
@@ -370,7 +372,7 @@ dump_stream(Shard, Stream, TopicFilter, StartTime) ->
         F(It, 0) ->
             error({too_many_iterations, It});
         F(It, N) ->
-            case emqx_ds_storage_layer:next(Shard, It, BatchSize) of
+            case emqx_ds_storage_layer:next(Shard, It, BatchSize, emqx_ds:timestamp_us()) of
                 end_of_stream ->
                     [];
                 {ok, _NextIt, []} ->
@@ -542,7 +544,11 @@ delete(_Shard, [], _Selector) ->
 delete(Shard, Iterators, Selector) ->
     {NewIterators0, N} = lists:foldl(
         fun(Iterator0, {AccIterators, NAcc}) ->
-            case emqx_ds_storage_layer:delete_next(Shard, Iterator0, Selector, 10) of
+            case
+                emqx_ds_storage_layer:delete_next(
+                    Shard, Iterator0, Selector, 10, emqx_ds:timestamp_us()
+                )
+            of
                 {ok, end_of_stream} ->
                     {AccIterators, NAcc};
                 {ok, _Iterator1, 0} ->
@@ -573,7 +579,7 @@ replay(_Shard, []) ->
 replay(Shard, Iterators) ->
     {NewIterators0, Messages0} = lists:foldl(
         fun(Iterator0, {AccIterators, AccMessages}) ->
-            case emqx_ds_storage_layer:next(Shard, Iterator0, 10) of
+            case emqx_ds_storage_layer:next(Shard, Iterator0, 10, emqx_ds:timestamp_us()) of
                 {ok, end_of_stream} ->
                     {AccIterators, AccMessages};
                 {ok, _Iterator1, []} ->

+ 240 - 7
apps/emqx_durable_storage/test/emqx_ds_test_helpers.erl

@@ -18,6 +18,14 @@
 -compile(export_all).
 -compile(nowarn_export_all).
 
+-include_lib("emqx_utils/include/emqx_message.hrl").
+-include_lib("snabbkaffe/include/snabbkaffe.hrl").
+-include_lib("stdlib/include/assert.hrl").
+
+-define(ON(NODE, BODY),
+    erpc:call(NODE, erlang, apply, [fun() -> BODY end, []])
+).
+
 %% RPC mocking
 
 mock_rpc() ->
@@ -57,8 +65,221 @@ mock_rpc_result(gen_rpc, ExpectFun) ->
         end
     end).
 
+%% Consume data from the DS storage on a given node as a stream:
+-type ds_stream() :: emqx_utils_stream:stream({emqx_ds:message_key(), emqx_types:message()}).
+
+%% @doc Create an infinite list of messages from a given client:
+interleaved_topic_messages(TestCase, NClients, NMsgs) ->
+    %% List of fake client IDs:
+    Clients = [integer_to_binary(I) || I <- lists:seq(1, NClients)],
+    TopicStreams = [
+        {ClientId, emqx_utils_stream:limit_length(NMsgs, topic_messages(TestCase, ClientId))}
+     || ClientId <- Clients
+    ],
+    %% Interleaved stream of messages:
+    Stream = emqx_utils_stream:interleave(
+        [{2, Stream} || {_ClientId, Stream} <- TopicStreams], true
+    ),
+    {Stream, TopicStreams}.
+
+topic_messages(TestCase, ClientId) ->
+    topic_messages(TestCase, ClientId, 0).
+
+topic_messages(TestCase, ClientId, N) ->
+    fun() ->
+        NBin = integer_to_binary(N),
+        Msg = #message{
+            from = ClientId,
+            topic = client_topic(TestCase, ClientId),
+            timestamp = N * 100,
+            payload = <<NBin/binary, "                                                       ">>
+        },
+        [Msg | topic_messages(TestCase, ClientId, N + 1)]
+    end.
+
+client_topic(TestCase, ClientId) when is_atom(TestCase) ->
+    client_topic(atom_to_binary(TestCase, utf8), ClientId);
+client_topic(TestCase, ClientId) when is_binary(TestCase) ->
+    <<TestCase/binary, "/", ClientId/binary>>.
+
+ds_topic_generation_stream(DB, Node, Shard, Topic, Stream) ->
+    {ok, Iterator} = ?ON(
+        Node,
+        emqx_ds_storage_layer:make_iterator(Shard, Stream, Topic, 0)
+    ),
+    do_ds_topic_generation_stream(DB, Node, Shard, Iterator).
+
+do_ds_topic_generation_stream(DB, Node, Shard, It0) ->
+    fun() ->
+        case
+            ?ON(
+                Node,
+                begin
+                    Now = emqx_ds_replication_layer:current_timestamp(DB, Shard),
+                    emqx_ds_storage_layer:next(Shard, It0, 1, Now)
+                end
+            )
+        of
+            {ok, _It, []} ->
+                [];
+            {ok, end_of_stream} ->
+                [];
+            {ok, It, [KeyMsg]} ->
+                [KeyMsg | do_ds_topic_generation_stream(DB, Node, Shard, It)]
+        end
+    end.
+
+%% Payload generation:
+
+apply_stream(DB, Nodes, Stream) ->
+    apply_stream(
+        DB,
+        emqx_utils_stream:repeat(emqx_utils_stream:list(Nodes)),
+        Stream,
+        0
+    ).
+
+apply_stream(DB, NodeStream0, Stream0, N) ->
+    case emqx_utils_stream:next(Stream0) of
+        [] ->
+            ?tp(all_done, #{});
+        [Msg = #message{} | Stream] ->
+            [Node | NodeStream] = emqx_utils_stream:next(NodeStream0),
+            ?tp(
+                test_push_message,
+                maps:merge(
+                    emqx_message:to_map(Msg),
+                    #{n => N}
+                )
+            ),
+            ?ON(Node, emqx_ds:store_batch(DB, [Msg], #{sync => true})),
+            apply_stream(DB, NodeStream, Stream, N + 1);
+        [add_generation | Stream] ->
+            %% FIXME:
+            [Node | NodeStream] = emqx_utils_stream:next(NodeStream0),
+            ?ON(Node, emqx_ds:add_generation(DB)),
+            apply_stream(DB, NodeStream, Stream, N);
+        [{Node, Operation, Arg} | Stream] when
+            Operation =:= join_db_site; Operation =:= leave_db_site; Operation =:= assign_db_sites
+        ->
+            ?tp(notice, test_apply_operation, #{node => Node, operation => Operation, arg => Arg}),
+            %% Apply the transition.
+            ?assertEqual(
+                ok,
+                ?ON(
+                    Node,
+                    emqx_ds_replication_layer_meta:Operation(DB, Arg)
+                )
+            ),
+            %% Give some time for at least one transition to complete.
+            Transitions = transitions(Node, DB),
+            ct:pal("Transitions after ~p: ~p", [Operation, Transitions]),
+            ?retry(200, 10, ?assertNotEqual(Transitions, transitions(Node, DB))),
+            apply_stream(DB, NodeStream0, Stream, N);
+        [Fun | Stream] when is_function(Fun) ->
+            Fun(),
+            apply_stream(DB, NodeStream0, Stream, N)
+    end.
+
+transitions(Node, DB) ->
+    ?ON(
+        Node,
+        begin
+            Shards = emqx_ds_replication_layer_meta:shards(DB),
+            [
+                {S, T}
+             || S <- Shards, T <- emqx_ds_replication_layer_meta:replica_set_transitions(DB, S)
+            ]
+        end
+    ).
+
+%% Stream comparison
+
+message_eq(Msg1, {_Key, Msg2}) ->
+    %% Timestamps can be modified by the replication layer, ignore them:
+    Msg1#message{timestamp = 0} =:= Msg2#message{timestamp = 0}.
+
 %% Consuming streams and iterators
 
+-spec verify_stream_effects(atom(), binary(), [node()], [{emqx_types:clientid(), ds_stream()}]) ->
+    ok.
+verify_stream_effects(DB, TestCase, Nodes0, L) ->
+    Checked = lists:flatmap(
+        fun({ClientId, Stream}) ->
+            Nodes = nodes_of_clientid(DB, ClientId, Nodes0),
+            ct:pal("Nodes allocated for client ~p: ~p", [ClientId, Nodes]),
+            ?defer_assert(
+                ?assertMatch([_ | _], Nodes, ["No nodes have been allocated for ", ClientId])
+            ),
+            [verify_stream_effects(DB, TestCase, Node, ClientId, Stream) || Node <- Nodes]
+        end,
+        L
+    ),
+    ?defer_assert(?assertMatch([_ | _], Checked, "Some messages have been verified")).
+
+-spec verify_stream_effects(atom(), binary(), node(), emqx_types:clientid(), ds_stream()) -> ok.
+verify_stream_effects(DB, TestCase, Node, ClientId, ExpectedStream) ->
+    ct:pal("Checking consistency of effects for ~p on ~p", [ClientId, Node]),
+    DiffOpts = #{context => 20, window => 1000, compare_fun => fun message_eq/2},
+    ?defer_assert(
+        begin
+            snabbkaffe_diff:assert_lists_eq(
+                ExpectedStream,
+                ds_topic_stream(DB, ClientId, client_topic(TestCase, ClientId), Node),
+                DiffOpts
+            ),
+            ct:pal("Data for client ~p on ~p is consistent.", [ClientId, Node])
+        end
+    ).
+
+%% Create a stream from the topic (wildcards are NOT supported for a
+%% good reason: order of messages is implementation-dependent!).
+%%
+%% Note: stream produces messages with keys
+-spec ds_topic_stream(atom(), binary(), binary(), node()) -> ds_stream().
+ds_topic_stream(DB, ClientId, TopicBin, Node) ->
+    Topic = emqx_topic:words(TopicBin),
+    Shard = shard_of_clientid(DB, Node, ClientId),
+    {ShardId, DSStreams} =
+        ?ON(
+            Node,
+            begin
+                DBShard = {DB, Shard},
+                {DBShard, emqx_ds_storage_layer:get_streams(DBShard, Topic, 0)}
+            end
+        ),
+    %% Sort streams by their rank Y, and chain them together:
+    emqx_utils_stream:chain([
+        ds_topic_generation_stream(DB, Node, ShardId, Topic, S)
+     || {_RankY, S} <- lists:sort(DSStreams)
+    ]).
+
+%% Find which nodes from the list contain the shards for the given
+%% client ID:
+nodes_of_clientid(DB, ClientId, Nodes = [N0 | _]) ->
+    Shard = shard_of_clientid(DB, N0, ClientId),
+    SiteNodes = ?ON(
+        N0,
+        begin
+            Sites = emqx_ds_replication_layer_meta:replica_set(DB, Shard),
+            lists:map(fun emqx_ds_replication_layer_meta:node/1, Sites)
+        end
+    ),
+    lists:filter(
+        fun(N) ->
+            lists:member(N, SiteNodes)
+        end,
+        Nodes
+    ).
+
+shard_of_clientid(DB, Node, ClientId) ->
+    ?ON(
+        Node,
+        emqx_ds_replication_layer:shard_of_message(DB, #message{from = ClientId}, clientid)
+    ).
+
+%% Consume eagerly:
+
 consume(DB, TopicFilter) ->
     consume(DB, TopicFilter, 0).
 
@@ -85,8 +306,14 @@ consume_stream(DB, Stream, TopicFilter, StartTime) ->
 consume_iter(DB, It) ->
     consume_iter(DB, It, #{}).
 
-consume_iter(DB, It, Opts) ->
-    consume_iter_with(fun emqx_ds:next/3, [DB], It, Opts).
+consume_iter(DB, It0, Opts) ->
+    consume_iter_with(
+        fun(It, BatchSize) ->
+            emqx_ds:next(DB, It, BatchSize)
+        end,
+        It0,
+        Opts
+    ).
 
 storage_consume(ShardId, TopicFilter) ->
     storage_consume(ShardId, TopicFilter, 0).
@@ -108,16 +335,22 @@ storage_consume_stream(ShardId, Stream, TopicFilter, StartTime) ->
 storage_consume_iter(ShardId, It) ->
     storage_consume_iter(ShardId, It, #{}).
 
-storage_consume_iter(ShardId, It, Opts) ->
-    consume_iter_with(fun emqx_ds_storage_layer:next/3, [ShardId], It, Opts).
+storage_consume_iter(ShardId, It0, Opts) ->
+    consume_iter_with(
+        fun(It, BatchSize) ->
+            emqx_ds_storage_layer:next(ShardId, It, BatchSize, emqx_ds:timestamp_us())
+        end,
+        It0,
+        Opts
+    ).
 
-consume_iter_with(NextFun, Args, It0, Opts) ->
+consume_iter_with(NextFun, It0, Opts) ->
     BatchSize = maps:get(batch_size, Opts, 5),
-    case erlang:apply(NextFun, Args ++ [It0, BatchSize]) of
+    case NextFun(It0, BatchSize) of
         {ok, It, _Msgs = []} ->
             {ok, It, []};
         {ok, It1, Batch} ->
-            {ok, It, Msgs} = consume_iter_with(NextFun, Args, It1, Opts),
+            {ok, It, Msgs} = consume_iter_with(NextFun, It1, Opts),
             {ok, It, [Msg || {_DSKey, Msg} <- Batch] ++ Msgs};
         {ok, Eos = end_of_stream} ->
             {ok, Eos, []};

+ 75 - 1
apps/emqx_utils/src/emqx_utils_stream.erl

@@ -20,11 +20,15 @@
 -export([
     empty/0,
     list/1,
+    const/1,
     mqueue/1,
     map/2,
     transpose/1,
+    chain/1,
     chain/2,
-    repeat/1
+    repeat/1,
+    interleave/2,
+    limit_length/2
 ]).
 
 %% Evaluating
@@ -69,6 +73,11 @@ list([]) ->
 list([X | Rest]) ->
     fun() -> [X | list(Rest)] end.
 
+%% @doc Make a stream with a single element infinitely repeated
+-spec const(T) -> stream(T).
+const(T) ->
+    fun() -> [T | const(T)] end.
+
 %% @doc Make a stream out of process message queue.
 -spec mqueue(timeout()) -> stream(any()).
 mqueue(Timeout) ->
@@ -118,6 +127,11 @@ transpose_tail(S, Tail) ->
         end
     end.
 
+%% @doc Make a stream by concatenating multiple streams.
+-spec chain([stream(X)]) -> stream(X).
+chain(L) ->
+    lists:foldl(fun chain/2, empty(), L).
+
 %% @doc Make a stream by chaining (concatenating) two streams.
 %% The second stream begins to produce values only after the first one is exhausted.
 -spec chain(stream(X), stream(Y)) -> stream(X | Y).
@@ -144,6 +158,45 @@ repeat(S) ->
         end
     end.
 
+%% @doc Interleave the elements of the streams.
+%%
+%% This function accepts a list of tuples where the first element
+%% specifies size of the "batch" to be consumed from the stream at a
+%% time (stream is the second tuple element). If element of the list
+%% is a plain stream, then the batch size is assumed to be 1.
+%%
+%% If `ContinueAtEmpty' is `false', and one of the streams returns
+%% `[]', then the function will return `[]' as well. Otherwise, it
+%% will continue consuming data from the remaining streams.
+-spec interleave([stream(X) | {non_neg_integer(), stream(X)}], boolean()) -> stream(X).
+interleave(L0, ContinueAtEmpty) ->
+    L = lists:map(
+        fun
+            (Stream) when is_function(Stream) ->
+                {1, Stream};
+            (A = {N, _}) when N >= 0 ->
+                A
+        end,
+        L0
+    ),
+    fun() ->
+        do_interleave(ContinueAtEmpty, 0, L, [])
+    end.
+
+%% @doc Truncate list to the given length
+-spec limit_length(non_neg_integer(), stream(X)) -> stream(X).
+limit_length(0, _) ->
+    fun() -> [] end;
+limit_length(N, S) when N >= 0 ->
+    fun() ->
+        case next(S) of
+            [] ->
+                [];
+            [X | S1] ->
+                [X | limit_length(N - 1, S1)]
+        end
+    end.
+
 %%
 
 %% @doc Produce the next value from the stream.
@@ -237,3 +290,24 @@ csv_read_line([Line | Lines]) ->
     {Fields, Lines};
 csv_read_line([]) ->
     eof.
+
+do_interleave(_Cont, _, [], []) ->
+    [];
+do_interleave(Cont, N, [{N, S} | Rest], Rev) ->
+    do_interleave(Cont, 0, Rest, [{N, S} | Rev]);
+do_interleave(Cont, _, [], Rev) ->
+    do_interleave(Cont, 0, lists:reverse(Rev), []);
+do_interleave(Cont, I, [{N, S} | Rest], Rev) when I < N ->
+    case next(S) of
+        [] when Cont ->
+            do_interleave(Cont, 0, Rest, Rev);
+        [] ->
+            [];
+        [X | S1] ->
+            [
+                X
+                | fun() ->
+                    do_interleave(Cont, I + 1, [{N, S1} | Rest], Rev)
+                end
+            ]
+    end.

+ 16 - 0
apps/emqx_utils/test/emqx_utils_stream_tests.erl

@@ -157,6 +157,22 @@ mqueue_test() ->
         emqx_utils_stream:consume(emqx_utils_stream:mqueue(400))
     ).
 
+interleave_test() ->
+    S1 = emqx_utils_stream:list([1, 2, 3]),
+    S2 = emqx_utils_stream:list([a, b, c, d]),
+    ?assertEqual(
+        [1, 2, a, b, 3, c, d],
+        emqx_utils_stream:consume(emqx_utils_stream:interleave([{2, S1}, {2, S2}], true))
+    ).
+
+interleave_stop_test() ->
+    S1 = emqx_utils_stream:const(1),
+    S2 = emqx_utils_stream:list([a, b, c, d]),
+    ?assertEqual(
+        [1, 1, a, b, 1, 1, c, d, 1, 1],
+        emqx_utils_stream:consume(emqx_utils_stream:interleave([{2, S1}, {2, S2}], false))
+    ).
+
 csv_test() ->
     Data1 = <<"h1,h2,h3\r\nvv1,vv2,vv3\r\nvv4,vv5,vv6">>,
     ?assertEqual(

+ 1 - 1
mix.exs

@@ -71,7 +71,7 @@ defmodule EMQXUmbrella.MixProject do
       {:telemetry, "1.1.0"},
       # in conflict by emqtt and hocon
       {:getopt, "1.0.2", override: true},
-      {:snabbkaffe, github: "kafka4beam/snabbkaffe", tag: "1.0.8", override: true},
+      {:snabbkaffe, github: "kafka4beam/snabbkaffe", tag: "1.0.10", override: true},
       {:hocon, github: "emqx/hocon", tag: "0.42.2", override: true},
       {:emqx_http_lib, github: "emqx/emqx_http_lib", tag: "0.5.3", override: true},
       {:esasl, github: "emqx/esasl", tag: "0.2.1"},

+ 1 - 1
rebar.config

@@ -96,7 +96,7 @@
     {observer_cli, "1.7.1"},
     {system_monitor, {git, "https://github.com/ieQu1/system_monitor", {tag, "3.0.3"}}},
     {getopt, "1.0.2"},
-    {snabbkaffe, {git, "https://github.com/kafka4beam/snabbkaffe.git", {tag, "1.0.8"}}},
+    {snabbkaffe, {git, "https://github.com/kafka4beam/snabbkaffe.git", {tag, "1.0.10"}}},
     {hocon, {git, "https://github.com/emqx/hocon.git", {tag, "0.42.2"}}},
     {emqx_http_lib, {git, "https://github.com/emqx/emqx_http_lib.git", {tag, "0.5.3"}}},
     {esasl, {git, "https://github.com/emqx/esasl", {tag, "0.2.1"}}},