Explorar el Código

Merge pull request #12624 from keynslug/fix/EMQX-11901/ds-error-class

feat(sessds): handle recoverable errors during replay
Andrew Mayorov hace 1 año
padre
commit
d725206bcb

+ 71 - 1
apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl

@@ -118,7 +118,6 @@ app_specs() ->
 app_specs(Opts) ->
     ExtraEMQXConf = maps:get(extra_emqx_conf, Opts, ""),
     [
-        emqx_durable_storage,
         {emqx, "session_persistence = {enable = true}" ++ ExtraEMQXConf}
     ].
 
@@ -154,6 +153,14 @@ start_client(Opts0 = #{}) ->
     on_exit(fun() -> catch emqtt:stop(Client) end),
     Client.
 
+start_connect_client(Opts = #{}) ->
+    Client = start_client(Opts),
+    ?assertMatch({ok, _}, emqtt:connect(Client)),
+    Client.
+
+mk_clientid(Prefix, ID) ->
+    iolist_to_binary(io_lib:format("~p/~p", [Prefix, ID])).
+
 restart_node(Node, NodeSpec) ->
     ?tp(will_restart_node, #{}),
     emqx_cth_cluster:restart(Node, NodeSpec),
@@ -599,3 +606,66 @@ t_session_gc(Config) ->
         []
     ),
     ok.
+
+t_session_replay_retry(_Config) ->
+    %% Verify that the session recovers smoothly from transient errors during
+    %% replay.
+
+    ok = emqx_ds_test_helpers:mock_rpc(),
+
+    NClients = 10,
+    ClientSubOpts = #{
+        clientid => mk_clientid(?FUNCTION_NAME, sub),
+        auto_ack => never
+    },
+    ClientSub = start_connect_client(ClientSubOpts),
+    ?assertMatch(
+        {ok, _, [?RC_GRANTED_QOS_1]},
+        emqtt:subscribe(ClientSub, <<"t/#">>, ?QOS_1)
+    ),
+
+    ClientsPub = [
+        start_connect_client(#{
+            clientid => mk_clientid(?FUNCTION_NAME, I),
+            properties => #{'Session-Expiry-Interval' => 0}
+        })
+     || I <- lists:seq(1, NClients)
+    ],
+    lists:foreach(
+        fun(Client) ->
+            Index = integer_to_binary(rand:uniform(NClients)),
+            Topic = <<"t/", Index/binary>>,
+            ?assertMatch({ok, #{}}, emqtt:publish(Client, Topic, Index, 1))
+        end,
+        ClientsPub
+    ),
+
+    Pubs0 = emqx_common_test_helpers:wait_publishes(NClients, 5_000),
+    NPubs = length(Pubs0),
+    ?assertEqual(NClients, NPubs, ?drainMailbox()),
+
+    ok = emqtt:stop(ClientSub),
+
+    %% Make `emqx_ds` believe that roughly half of the shards are unavailable.
+    ok = emqx_ds_test_helpers:mock_rpc_result(
+        fun(_Node, emqx_ds_replication_layer, _Function, [_DB, Shard | _]) ->
+            case erlang:phash2(Shard) rem 2 of
+                0 -> unavailable;
+                1 -> passthrough
+            end
+        end
+    ),
+
+    _ClientSub = start_connect_client(ClientSubOpts#{clean_start => false}),
+
+    Pubs1 = emqx_common_test_helpers:wait_publishes(NPubs, 5_000),
+    ?assert(length(Pubs1) < length(Pubs0), Pubs1),
+
+    %% "Recover" the shards.
+    emqx_ds_test_helpers:unmock_rpc(),
+
+    Pubs2 = emqx_common_test_helpers:wait_publishes(NPubs - length(Pubs1), 5_000),
+    ?assertEqual(
+        [maps:with([topic, payload, qos], P) || P <- Pubs0],
+        [maps:with([topic, payload, qos], P) || P <- Pubs1 ++ Pubs2]
+    ).

+ 88 - 37
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -123,7 +123,12 @@
 -define(TIMER_PULL, timer_pull).
 -define(TIMER_GET_STREAMS, timer_get_streams).
 -define(TIMER_BUMP_LAST_ALIVE_AT, timer_bump_last_alive_at).
--type timer() :: ?TIMER_PULL | ?TIMER_GET_STREAMS | ?TIMER_BUMP_LAST_ALIVE_AT.
+-define(TIMER_RETRY_REPLAY, timer_retry_replay).
+
+-type timer() :: ?TIMER_PULL | ?TIMER_GET_STREAMS | ?TIMER_BUMP_LAST_ALIVE_AT | ?TIMER_RETRY_REPLAY.
+
+%% TODO: Needs configuration?
+-define(TIMEOUT_RETRY_REPLAY, 1000).
 
 -type session() :: #{
     %% Client ID
@@ -134,10 +139,15 @@
     s := emqx_persistent_session_ds_state:t(),
     %% Buffer:
     inflight := emqx_persistent_session_ds_inflight:t(),
+    %% In-progress replay:
+    %% List of stream replay states to be added to the inflight buffer.
+    replay => [{_StreamKey, stream_state()}, ...],
     %% Timers:
     timer() => reference()
 }.
 
+-define(IS_REPLAY_ONGOING(SESS), is_map_key(replay, SESS)).
+
 -record(req_sync, {
     from :: pid(),
     ref :: reference()
@@ -450,12 +460,14 @@ deliver(ClientInfo, Delivers, Session0) ->
 
 -spec handle_timeout(clientinfo(), _Timeout, session()) ->
     {ok, replies(), session()} | {ok, replies(), timeout(), session()}.
-handle_timeout(
-    ClientInfo,
-    ?TIMER_PULL,
-    Session0
-) ->
-    {Publishes, Session1} = drain_buffer(fetch_new_messages(Session0, ClientInfo)),
+handle_timeout(ClientInfo, ?TIMER_PULL, Session0) ->
+    {Publishes, Session1} =
+        case ?IS_REPLAY_ONGOING(Session0) of
+            false ->
+                drain_buffer(fetch_new_messages(Session0, ClientInfo));
+            true ->
+                {[], Session0}
+        end,
     Timeout =
         case Publishes of
             [] ->
@@ -465,6 +477,9 @@ handle_timeout(
         end,
     Session = emqx_session:ensure_timer(?TIMER_PULL, Timeout, Session1),
     {ok, Publishes, Session};
+handle_timeout(ClientInfo, ?TIMER_RETRY_REPLAY, Session0) ->
+    Session = replay_streams(Session0, ClientInfo),
+    {ok, [], Session};
 handle_timeout(_ClientInfo, ?TIMER_GET_STREAMS, Session0 = #{s := S0}) ->
     S1 = emqx_persistent_session_ds_subs:gc(S0),
     S = emqx_persistent_session_ds_stream_scheduler:renew_streams(S1),
@@ -503,30 +518,47 @@ bump_last_alive(S0) ->
     {ok, replies(), session()}.
 replay(ClientInfo, [], Session0 = #{s := S0}) ->
     Streams = emqx_persistent_session_ds_stream_scheduler:find_replay_streams(S0),
-    Session = lists:foldl(
-        fun({_StreamKey, Stream}, SessionAcc) ->
-            replay_batch(Stream, SessionAcc, ClientInfo)
-        end,
-        Session0,
-        Streams
-    ),
+    Session = replay_streams(Session0#{replay => Streams}, ClientInfo),
+    {ok, [], Session}.
+
+replay_streams(Session0 = #{replay := [{_StreamKey, Srs0} | Rest]}, ClientInfo) ->
+    case replay_batch(Srs0, Session0, ClientInfo) of
+        Session = #{} ->
+            replay_streams(Session#{replay := Rest}, ClientInfo);
+        {error, recoverable, Reason} ->
+            RetryTimeout = ?TIMEOUT_RETRY_REPLAY,
+            ?SLOG(warning, #{
+                msg => "failed_to_fetch_replay_batch",
+                stream => Srs0,
+                reason => Reason,
+                class => recoverable,
+                retry_in_ms => RetryTimeout
+            }),
+            emqx_session:ensure_timer(?TIMER_RETRY_REPLAY, RetryTimeout, Session0)
+        %% TODO: Handle unrecoverable errors.
+    end;
+replay_streams(Session0 = #{replay := []}, _ClientInfo) ->
+    Session = maps:remove(replay, Session0),
     %% Note: we filled the buffer with the historical messages, and
     %% from now on we'll rely on the normal inflight/flow control
     %% mechanisms to replay them:
-    {ok, [], pull_now(Session)}.
+    pull_now(Session).
 
--spec replay_batch(stream_state(), session(), clientinfo()) -> session().
-replay_batch(Srs0, Session, ClientInfo) ->
+-spec replay_batch(stream_state(), session(), clientinfo()) -> session() | emqx_ds:error(_).
+replay_batch(Srs0, Session0, ClientInfo) ->
     #srs{batch_size = BatchSize} = Srs0,
-    %% TODO: retry on errors:
-    {Srs, Inflight} = enqueue_batch(true, BatchSize, Srs0, Session, ClientInfo),
-    %% Assert:
-    Srs =:= Srs0 orelse
-        ?tp(warning, emqx_persistent_session_ds_replay_inconsistency, #{
-            expected => Srs0,
-            got => Srs
-        }),
-    Session#{inflight => Inflight}.
+    case enqueue_batch(true, BatchSize, Srs0, Session0, ClientInfo) of
+        {ok, Srs, Session} ->
+            %% Assert:
+            Srs =:= Srs0 orelse
+                ?tp(warning, emqx_persistent_session_ds_replay_inconsistency, #{
+                    expected => Srs0,
+                    got => Srs
+                }),
+            Session;
+        {error, _, _} = Error ->
+            Error
+    end.
 
 %%--------------------------------------------------------------------
 
@@ -743,7 +775,7 @@ fetch_new_messages([I | Streams], Session0 = #{inflight := Inflight}, ClientInfo
             fetch_new_messages(Streams, Session, ClientInfo)
     end.
 
-new_batch({StreamKey, Srs0}, BatchSize, Session = #{s := S0}, ClientInfo) ->
+new_batch({StreamKey, Srs0}, BatchSize, Session0 = #{s := S0}, ClientInfo) ->
     SN1 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_1), S0),
     SN2 = emqx_persistent_session_ds_state:get_seqno(?next(?QOS_2), S0),
     Srs1 = Srs0#srs{
@@ -753,11 +785,30 @@ new_batch({StreamKey, Srs0}, BatchSize, Session = #{s := S0}, ClientInfo) ->
         last_seqno_qos1 = SN1,
         last_seqno_qos2 = SN2
     },
-    {Srs, Inflight} = enqueue_batch(false, BatchSize, Srs1, Session, ClientInfo),
-    S1 = emqx_persistent_session_ds_state:put_seqno(?next(?QOS_1), Srs#srs.last_seqno_qos1, S0),
-    S2 = emqx_persistent_session_ds_state:put_seqno(?next(?QOS_2), Srs#srs.last_seqno_qos2, S1),
-    S = emqx_persistent_session_ds_state:put_stream(StreamKey, Srs, S2),
-    Session#{s => S, inflight => Inflight}.
+    case enqueue_batch(false, BatchSize, Srs1, Session0, ClientInfo) of
+        {ok, Srs, Session} ->
+            S1 = emqx_persistent_session_ds_state:put_seqno(
+                ?next(?QOS_1),
+                Srs#srs.last_seqno_qos1,
+                S0
+            ),
+            S2 = emqx_persistent_session_ds_state:put_seqno(
+                ?next(?QOS_2),
+                Srs#srs.last_seqno_qos2,
+                S1
+            ),
+            S = emqx_persistent_session_ds_state:put_stream(StreamKey, Srs, S2),
+            Session#{s => S};
+        {error, Class, Reason} ->
+            %% TODO: Handle unrecoverable error.
+            ?SLOG(info, #{
+                msg => "failed_to_fetch_batch",
+                stream => Srs1,
+                reason => Reason,
+                class => Class
+            }),
+            Session0
+    end.
 
 enqueue_batch(IsReplay, BatchSize, Srs0, Session = #{inflight := Inflight0}, ClientInfo) ->
     #srs{
@@ -786,13 +837,13 @@ enqueue_batch(IsReplay, BatchSize, Srs0, Session = #{inflight := Inflight0}, Cli
                 last_seqno_qos1 = LastSeqnoQos1,
                 last_seqno_qos2 = LastSeqnoQos2
             },
-            {Srs, Inflight};
+            {ok, Srs, Session#{inflight := Inflight}};
         {ok, end_of_stream} ->
             %% No new messages; just update the end iterator:
-            {Srs0#srs{it_begin = ItBegin, it_end = end_of_stream, batch_size = 0}, Inflight0};
-        {error, _} when not IsReplay ->
-            ?SLOG(info, #{msg => "failed_to_fetch_batch", iterator => ItBegin}),
-            {Srs0, Inflight0}
+            Srs = Srs0#srs{it_begin = ItBegin, it_end = end_of_stream, batch_size = 0},
+            {ok, Srs, Session#{inflight := Inflight0}};
+        {error, _, _} = Error ->
+            Error
     end.
 
 %% key_of_iter(#{3 := #{3 := #{5 := K}}}) ->

+ 1 - 0
apps/emqx/src/emqx_rpc.erl

@@ -35,6 +35,7 @@
 
 -export_type([
     badrpc/0,
+    call_result/1,
     call_result/0,
     cast_result/0,
     multicall_result/1,

+ 11 - 0
apps/emqx/test/emqx_common_test_helpers.erl

@@ -61,6 +61,7 @@
     read_schema_configs/2,
     render_config_file/2,
     wait_for/4,
+    wait_publishes/2,
     wait_mqtt_payload/1,
     select_free_port/1
 ]).
@@ -426,6 +427,16 @@ wait_for(Fn, Ln, F, Timeout) ->
     {Pid, Mref} = erlang:spawn_monitor(fun() -> wait_loop(F, catch_call(F)) end),
     wait_for_down(Fn, Ln, Timeout, Pid, Mref, false).
 
+wait_publishes(0, _Timeout) ->
+    [];
+wait_publishes(Count, Timeout) ->
+    receive
+        {publish, Msg} ->
+            [Msg | wait_publishes(Count - 1, Timeout)]
+    after Timeout ->
+        []
+    end.
+
 flush() ->
     flush([]).
 

+ 7 - 3
apps/emqx_durable_storage/src/emqx_ds.erl

@@ -68,6 +68,8 @@
     make_iterator_result/1, make_iterator_result/0,
     make_delete_iterator_result/1, make_delete_iterator_result/0,
 
+    error/1,
+
     ds_specific_stream/0,
     ds_specific_iterator/0,
     ds_specific_generation_rank/0,
@@ -118,14 +120,14 @@
 
 -type message_key() :: binary().
 
--type store_batch_result() :: ok | {error, _}.
+-type store_batch_result() :: ok | error(_).
 
--type make_iterator_result(Iterator) :: {ok, Iterator} | {error, _}.
+-type make_iterator_result(Iterator) :: {ok, Iterator} | error(_).
 
 -type make_iterator_result() :: make_iterator_result(iterator()).
 
 -type next_result(Iterator) ::
-    {ok, Iterator, [{message_key(), emqx_types:message()}]} | {ok, end_of_stream} | {error, _}.
+    {ok, Iterator, [{message_key(), emqx_types:message()}]} | {ok, end_of_stream} | error(_).
 
 -type next_result() :: next_result(iterator()).
 
@@ -142,6 +144,8 @@
 
 -type delete_next_result() :: delete_next_result(delete_iterator()).
 
+-type error(Reason) :: {error, recoverable | unrecoverable, Reason}.
+
 %% Timestamp
 %% Earliest possible timestamp is 0.
 %% TODO granularity?  Currently, we should always use milliseconds, as that's the unit we

+ 35 - 25
apps/emqx_durable_storage/src/emqx_ds_replication_layer.erl

@@ -171,7 +171,12 @@ drop_db(DB) ->
 -spec store_batch(emqx_ds:db(), [emqx_types:message(), ...], emqx_ds:message_store_opts()) ->
     emqx_ds:store_batch_result().
 store_batch(DB, Messages, Opts) ->
-    emqx_ds_replication_layer_egress:store_batch(DB, Messages, Opts).
+    try
+        emqx_ds_replication_layer_egress:store_batch(DB, Messages, Opts)
+    catch
+        error:{Reason, _Call} when Reason == timeout; Reason == noproc ->
+            {error, recoverable, Reason}
+    end.
 
 -spec get_streams(emqx_ds:db(), emqx_ds:topic_filter(), emqx_ds:time()) ->
     [{emqx_ds:stream_rank(), stream()}].
@@ -180,7 +185,14 @@ get_streams(DB, TopicFilter, StartTime) ->
     lists:flatmap(
         fun(Shard) ->
             Node = node_of_shard(DB, Shard),
-            Streams = emqx_ds_proto_v4:get_streams(Node, DB, Shard, TopicFilter, StartTime),
+            Streams =
+                try
+                    emqx_ds_proto_v4:get_streams(Node, DB, Shard, TopicFilter, StartTime)
+                catch
+                    error:{erpc, _} ->
+                        %% TODO: log?
+                        []
+                end,
             lists:map(
                 fun({RankY, StorageLayerStream}) ->
                     RankX = Shard,
@@ -198,35 +210,29 @@ get_streams(DB, TopicFilter, StartTime) ->
 make_iterator(DB, Stream, TopicFilter, StartTime) ->
     ?stream_v2(Shard, StorageStream) = Stream,
     Node = node_of_shard(DB, Shard),
-    case emqx_ds_proto_v4:make_iterator(Node, DB, Shard, StorageStream, TopicFilter, StartTime) of
+    try emqx_ds_proto_v4:make_iterator(Node, DB, Shard, StorageStream, TopicFilter, StartTime) of
         {ok, Iter} ->
             {ok, #{?tag => ?IT, ?shard => Shard, ?enc => Iter}};
-        Err = {error, _} ->
-            Err
+        Error = {error, _, _} ->
+            Error
+    catch
+        error:RPCError = {erpc, _} ->
+            {error, recoverable, RPCError}
     end.
 
--spec update_iterator(
-    emqx_ds:db(),
-    iterator(),
-    emqx_ds:message_key()
-) ->
+-spec update_iterator(emqx_ds:db(), iterator(), emqx_ds:message_key()) ->
     emqx_ds:make_iterator_result(iterator()).
 update_iterator(DB, OldIter, DSKey) ->
     #{?tag := ?IT, ?shard := Shard, ?enc := StorageIter} = OldIter,
     Node = node_of_shard(DB, Shard),
-    case
-        emqx_ds_proto_v4:update_iterator(
-            Node,
-            DB,
-            Shard,
-            StorageIter,
-            DSKey
-        )
-    of
+    try emqx_ds_proto_v4:update_iterator(Node, DB, Shard, StorageIter, DSKey) of
         {ok, Iter} ->
             {ok, #{?tag => ?IT, ?shard => Shard, ?enc => Iter}};
-        Err = {error, _} ->
-            Err
+        Error = {error, _, _} ->
+            Error
+    catch
+        error:RPCError = {erpc, _} ->
+            {error, recoverable, RPCError}
     end.
 
 -spec next(emqx_ds:db(), iterator(), pos_integer()) -> emqx_ds:next_result(iterator()).
@@ -245,8 +251,12 @@ next(DB, Iter0, BatchSize) ->
         {ok, StorageIter, Batch} ->
             Iter = Iter0#{?enc := StorageIter},
             {ok, Iter, Batch};
-        Other ->
-            Other
+        Ok = {ok, _} ->
+            Ok;
+        Error = {error, _, _} ->
+            Error;
+        RPCError = {badrpc, _} ->
+            {error, recoverable, RPCError}
     end.
 
 -spec node_of_shard(emqx_ds:db(), shard_id()) -> node().
@@ -337,7 +347,7 @@ do_get_streams_v2(DB, Shard, TopicFilter, StartTime) ->
     emqx_ds:topic_filter(),
     emqx_ds:time()
 ) ->
-    {ok, emqx_ds_storage_layer:iterator()} | {error, _}.
+    emqx_ds:make_iterator_result(emqx_ds_storage_layer:iterator()).
 do_make_iterator_v1(_DB, _Shard, _Stream, _TopicFilter, _StartTime) ->
     error(obsolete_api).
 
@@ -348,7 +358,7 @@ do_make_iterator_v1(_DB, _Shard, _Stream, _TopicFilter, _StartTime) ->
     emqx_ds:topic_filter(),
     emqx_ds:time()
 ) ->
-    {ok, emqx_ds_storage_layer:iterator()} | {error, _}.
+    emqx_ds:make_iterator_result(emqx_ds_storage_layer:iterator()).
 do_make_iterator_v2(DB, Shard, Stream, TopicFilter, StartTime) ->
     emqx_ds_storage_layer:make_iterator({DB, Shard}, Stream, TopicFilter, StartTime).
 

+ 11 - 12
apps/emqx_durable_storage/src/emqx_ds_storage_bitfield_lts.erl

@@ -230,7 +230,7 @@ drop(_Shard, DBHandle, GenId, CFRefs, #s{}) ->
     emqx_ds_storage_layer:shard_id(), s(), [emqx_types:message()], emqx_ds:message_store_opts()
 ) ->
     emqx_ds:store_batch_result().
-store_batch(_ShardId, S = #s{db = DB, data = Data}, Messages, _Options = #{atomic := true}) ->
+store_batch(_ShardId, S = #s{db = DB, data = Data}, Messages, _Options) ->
     {ok, Batch} = rocksdb:batch(),
     lists:foreach(
         fun(Msg) ->
@@ -240,18 +240,17 @@ store_batch(_ShardId, S = #s{db = DB, data = Data}, Messages, _Options = #{atomi
         end,
         Messages
     ),
-    Res = rocksdb:write_batch(DB, Batch, _WriteOptions = []),
+    Result = rocksdb:write_batch(DB, Batch, []),
     rocksdb:release_batch(Batch),
-    Res;
-store_batch(_ShardId, S = #s{db = DB, data = Data}, Messages, _Options) ->
-    lists:foreach(
-        fun(Msg) ->
-            {Key, _} = make_key(S, Msg),
-            Val = serialize(Msg),
-            rocksdb:put(DB, Data, Key, Val, [])
-        end,
-        Messages
-    ).
+    %% NOTE
+    %% Strictly speaking, `{error, incomplete}` is a valid result but should be impossible to
+    %% observe until there's `{no_slowdown, true}` in write options.
+    case Result of
+        ok ->
+            ok;
+        {error, {error, Reason}} ->
+            {error, unrecoverable, {rocksdb, Reason}}
+    end.
 
 -spec get_streams(
     emqx_ds_storage_layer:shard_id(),

+ 5 - 7
apps/emqx_durable_storage/src/emqx_ds_storage_layer.erl

@@ -256,12 +256,10 @@ make_iterator(
                     Err
             end;
         {error, not_found} ->
-            {error, end_of_stream}
+            {error, unrecoverable, generation_not_found}
     end.
 
--spec update_iterator(
-    shard_id(), iterator(), emqx_ds:message_key()
-) ->
+-spec update_iterator(shard_id(), iterator(), emqx_ds:message_key()) ->
     emqx_ds:make_iterator_result(iterator()).
 update_iterator(
     Shard,
@@ -281,7 +279,7 @@ update_iterator(
                     Err
             end;
         {error, not_found} ->
-            {error, end_of_stream}
+            {error, unrecoverable, generation_not_found}
     end.
 
 -spec next(shard_id(), iterator(), pos_integer()) ->
@@ -298,12 +296,12 @@ next(Shard, Iter = #{?tag := ?IT, ?generation := GenId, ?enc := GenIter0}, Batch
                     {ok, end_of_stream};
                 {ok, GenIter, Batch} ->
                     {ok, Iter#{?enc := GenIter}, Batch};
-                Error = {error, _} ->
+                Error = {error, _, _} ->
                     Error
             end;
         {error, not_found} ->
             %% generation was possibly dropped by GC
-            {ok, end_of_stream}
+            {error, unrecoverable, generation_not_found}
     end.
 
 -spec update_config(shard_id(), emqx_ds:create_db_opts()) -> ok.

+ 3 - 5
apps/emqx_durable_storage/src/proto/emqx_ds_proto_v4.erl

@@ -64,7 +64,7 @@ get_streams(Node, DB, Shard, TopicFilter, Time) ->
     emqx_ds:topic_filter(),
     emqx_ds:time()
 ) ->
-    {ok, emqx_ds_storage_layer:iterator()} | {error, _}.
+    emqx_ds:make_iterator_result().
 make_iterator(Node, DB, Shard, Stream, TopicFilter, StartTime) ->
     erpc:call(Node, emqx_ds_replication_layer, do_make_iterator_v2, [
         DB, Shard, Stream, TopicFilter, StartTime
@@ -77,9 +77,7 @@ make_iterator(Node, DB, Shard, Stream, TopicFilter, StartTime) ->
     emqx_ds_storage_layer:iterator(),
     pos_integer()
 ) ->
-    {ok, emqx_ds_storage_layer:iterator(), [{emqx_ds:message_key(), [emqx_types:message()]}]}
-    | {ok, end_of_stream}
-    | {error, _}.
+    emqx_rpc:call_result(emqx_ds:next_result()).
 next(Node, DB, Shard, Iter, BatchSize) ->
     emqx_rpc:call(Shard, Node, emqx_ds_replication_layer, do_next_v1, [DB, Shard, Iter, BatchSize]).
 
@@ -103,7 +101,7 @@ store_batch(Node, DB, Shard, Batch, Options) ->
     emqx_ds_storage_layer:iterator(),
     emqx_ds:message_key()
 ) ->
-    {ok, emqx_ds_storage_layer:iterator()} | {error, _}.
+    emqx_ds:make_iterator_result().
 update_iterator(Node, DB, Shard, OldIter, DSKey) ->
     erpc:call(Node, emqx_ds_replication_layer, do_update_iterator_v2, [
         DB, Shard, OldIter, DSKey

+ 111 - 11
apps/emqx_durable_storage/test/emqx_ds_SUITE.erl

@@ -21,6 +21,7 @@
 -include_lib("emqx/include/emqx.hrl").
 -include_lib("common_test/include/ct.hrl").
 -include_lib("stdlib/include/assert.hrl").
+-include_lib("emqx/include/asserts.hrl").
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
 
 -define(N_SHARDS, 1).
@@ -404,7 +405,10 @@ t_drop_generation_with_never_used_iterator(_Config) ->
     ],
     ?assertMatch(ok, emqx_ds:store_batch(DB, Msgs1)),
 
-    ?assertMatch({ok, end_of_stream, []}, iterate(DB, Iter0, 1)),
+    ?assertMatch(
+        {error, unrecoverable, generation_not_found, []},
+        iterate(DB, Iter0, 1)
+    ),
 
     %% New iterator for the new stream will only see the later messages.
     [{_, Stream1}] = emqx_ds:get_streams(DB, TopicFilter, StartTime),
@@ -453,9 +457,10 @@ t_drop_generation_with_used_once_iterator(_Config) ->
     ],
     ?assertMatch(ok, emqx_ds:store_batch(DB, Msgs1)),
 
-    ?assertMatch({ok, end_of_stream, []}, iterate(DB, Iter1, 1)),
-
-    ok.
+    ?assertMatch(
+        {error, unrecoverable, generation_not_found, []},
+        iterate(DB, Iter1, 1)
+    ).
 
 t_drop_generation_update_iterator(_Config) ->
     %% This checks the behavior of `emqx_ds:update_iterator' after the generation
@@ -481,9 +486,10 @@ t_drop_generation_update_iterator(_Config) ->
     ok = emqx_ds:add_generation(DB),
     ok = emqx_ds:drop_generation(DB, GenId0),
 
-    ?assertEqual({error, end_of_stream}, emqx_ds:update_iterator(DB, Iter1, Key2)),
-
-    ok.
+    ?assertEqual(
+        {error, unrecoverable, generation_not_found},
+        emqx_ds:update_iterator(DB, Iter1, Key2)
+    ).
 
 t_make_iterator_stale_stream(_Config) ->
     %% This checks the behavior of `emqx_ds:make_iterator' after the generation underlying
@@ -507,7 +513,7 @@ t_make_iterator_stale_stream(_Config) ->
     ok = emqx_ds:drop_generation(DB, GenId0),
 
     ?assertEqual(
-        {error, end_of_stream},
+        {error, unrecoverable, generation_not_found},
         emqx_ds:make_iterator(DB, Stream0, TopicFilter, StartTime)
     ),
 
@@ -548,9 +554,99 @@ t_get_streams_concurrently_with_drop_generation(_Config) ->
             ok
         end,
         []
+    ).
+
+t_error_mapping_replication_layer(_Config) ->
+    %% This checks that the replication layer maps recoverable errors correctly.
+
+    ok = emqx_ds_test_helpers:mock_rpc(),
+    ok = snabbkaffe:start_trace(),
+
+    DB = ?FUNCTION_NAME,
+    ?assertMatch(ok, emqx_ds:open_db(DB, (opts())#{n_shards => 2})),
+    [Shard1, Shard2] = emqx_ds_replication_layer_meta:shards(DB),
+
+    TopicFilter = emqx_topic:words(<<"foo/#">>),
+    Msgs = [
+        message(<<"C1">>, <<"foo/bar">>, <<"1">>, 0),
+        message(<<"C1">>, <<"foo/baz">>, <<"2">>, 1),
+        message(<<"C2">>, <<"foo/foo">>, <<"3">>, 2),
+        message(<<"C3">>, <<"foo/xyz">>, <<"4">>, 3),
+        message(<<"C4">>, <<"foo/bar">>, <<"5">>, 4),
+        message(<<"C5">>, <<"foo/oof">>, <<"6">>, 5)
+    ],
+
+    ?assertMatch(ok, emqx_ds:store_batch(DB, Msgs)),
+
+    ?block_until(#{?snk_kind := emqx_ds_replication_layer_egress_flush, shard := Shard1}),
+    ?block_until(#{?snk_kind := emqx_ds_replication_layer_egress_flush, shard := Shard2}),
+
+    Streams0 = emqx_ds:get_streams(DB, TopicFilter, 0),
+    Iterators0 = lists:map(
+        fun({_Rank, S}) ->
+            {ok, Iter} = emqx_ds:make_iterator(DB, S, TopicFilter, 0),
+            Iter
+        end,
+        Streams0
     ),
 
-    ok.
+    %% Disrupt the link to the second shard.
+    ok = emqx_ds_test_helpers:mock_rpc_result(
+        fun(_Node, emqx_ds_replication_layer, _Function, Args) ->
+            case Args of
+                [DB, Shard1 | _] -> passthrough;
+                [DB, Shard2 | _] -> unavailable
+            end
+        end
+    ),
+
+    %% Result of `emqx_ds:get_streams/3` will just contain partial results, not an error.
+    Streams1 = emqx_ds:get_streams(DB, TopicFilter, 0),
+    ?assert(
+        length(Streams1) > 0 andalso length(Streams1) =< length(Streams0),
+        Streams1
+    ),
+
+    %% At least one of `emqx_ds:make_iterator/4` will end in an error.
+    Results1 = lists:map(
+        fun({_Rank, S}) ->
+            case emqx_ds:make_iterator(DB, S, TopicFilter, 0) of
+                Ok = {ok, _Iter} ->
+                    Ok;
+                Error = {error, recoverable, {erpc, _}} ->
+                    Error;
+                Other ->
+                    ct:fail({unexpected_result, Other})
+            end
+        end,
+        Streams0
+    ),
+    ?assert(
+        length([error || {error, _, _} <- Results1]) > 0,
+        Results1
+    ),
+
+    %% At least one of `emqx_ds:next/3` over initial set of iterators will end in an error.
+    Results2 = lists:map(
+        fun(Iter) ->
+            case emqx_ds:next(DB, Iter, _BatchSize = 42) of
+                Ok = {ok, _Iter, [_ | _]} ->
+                    Ok;
+                Error = {error, recoverable, {badrpc, _}} ->
+                    Error;
+                Other ->
+                    ct:fail({unexpected_result, Other})
+            end
+        end,
+        Iterators0
+    ),
+    ?assert(
+        length([error || {error, _, _} <- Results2]) > 0,
+        Results2
+    ),
+
+    snabbkaffe:stop(),
+    meck:unload().
 
 update_data_set() ->
     [
@@ -586,6 +682,10 @@ fetch_all(DB, TopicFilter, StartTime) ->
         Streams
     ).
 
+message(ClientId, Topic, Payload, PublishedAt) ->
+    Msg = message(Topic, Payload, PublishedAt),
+    Msg#message{from = ClientId}.
+
 message(Topic, Payload, PublishedAt) ->
     #message{
         topic = Topic,
@@ -605,8 +705,8 @@ iterate(DB, It0, BatchSize, Acc) ->
             iterate(DB, It, BatchSize, Acc ++ Msgs);
         {ok, end_of_stream} ->
             {ok, end_of_stream, Acc};
-        Ret ->
-            Ret
+        {error, Class, Reason} ->
+            {error, Class, Reason, Acc}
     end.
 
 %% CT callbacks

+ 58 - 0
apps/emqx_durable_storage/test/emqx_ds_test_helpers.erl

@@ -0,0 +1,58 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+-module(emqx_ds_test_helpers).
+
+-compile(export_all).
+-compile(nowarn_export_all).
+
+%% RPC mocking
+
+mock_rpc() ->
+    ok = meck:new(erpc, [passthrough, no_history, unstick]),
+    ok = meck:new(gen_rpc, [passthrough, no_history]).
+
+unmock_rpc() ->
+    catch meck:unload(erpc),
+    catch meck:unload(gen_rpc).
+
+mock_rpc_result(ExpectFun) ->
+    mock_rpc_result(erpc, ExpectFun),
+    mock_rpc_result(gen_rpc, ExpectFun).
+
+mock_rpc_result(erpc, ExpectFun) ->
+    ok = meck:expect(erpc, call, fun(Node, Mod, Function, Args) ->
+        case ExpectFun(Node, Mod, Function, Args) of
+            passthrough ->
+                meck:passthrough([Node, Mod, Function, Args]);
+            unavailable ->
+                meck:exception(error, {erpc, noconnection});
+            {timeout, Timeout} ->
+                ok = timer:sleep(Timeout),
+                meck:exception(error, {erpc, timeout})
+        end
+    end);
+mock_rpc_result(gen_rpc, ExpectFun) ->
+    ok = meck:expect(gen_rpc, call, fun(Dest = {Node, _}, Mod, Function, Args) ->
+        case ExpectFun(Node, Mod, Function, Args) of
+            passthrough ->
+                meck:passthrough([Dest, Mod, Function, Args]);
+            unavailable ->
+                {badtcp, econnrefused};
+            {timeout, Timeout} ->
+                ok = timer:sleep(Timeout),
+                {badrpc, timeout}
+        end
+    end).