Explorar o código

feat(delayed): check if the source client is banned when publishing a delayed message

firest %!s(int64=3) %!d(string=hai) anos
pai
achega
dd7d4224ce

+ 17 - 2
apps/emqx_modules/src/emqx_delayed.erl

@@ -373,8 +373,23 @@ do_publish({Ts, _Id}, Now, Acc) when Ts > Now ->
     Acc;
     Acc;
 do_publish(Key = {Ts, _Id}, Now, Acc) when Ts =< Now ->
 do_publish(Key = {Ts, _Id}, Now, Acc) when Ts =< Now ->
     case mnesia:dirty_read(?TAB, Key) of
     case mnesia:dirty_read(?TAB, Key) of
-        [] -> ok;
-        [#delayed_message{msg = Msg}] -> emqx_pool:async_submit(fun emqx:publish/1, [Msg])
+        [] ->
+            ok;
+        [#delayed_message{msg = Msg}] ->
+            case emqx_banned:look_up({clientid, Msg#message.from}) of
+                [] ->
+                    emqx_pool:async_submit(fun emqx:publish/1, [Msg]);
+                _ ->
+                    ?tp(
+                        notice,
+                        ignore_delayed_message_publish,
+                        #{
+                            reason => "client is banned",
+                            clienid => Msg#message.from
+                        }
+                    ),
+                    ok
+            end
     end,
     end,
     do_publish(mnesia:dirty_next(?TAB, Key), Now, [Key | Acc]).
     do_publish(mnesia:dirty_next(?TAB, Key), Now, [Key | Acc]).
 
 

+ 35 - 1
apps/emqx_modules/test/emqx_delayed_SUITE.erl

@@ -26,6 +26,7 @@
 -include_lib("common_test/include/ct.hrl").
 -include_lib("common_test/include/ct.hrl").
 -include_lib("eunit/include/eunit.hrl").
 -include_lib("eunit/include/eunit.hrl").
 -include_lib("emqx/include/emqx.hrl").
 -include_lib("emqx/include/emqx.hrl").
+-include_lib("snabbkaffe/include/snabbkaffe.hrl").
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Setups
 %% Setups
@@ -36,7 +37,8 @@
 }).
 }).
 
 
 all() ->
 all() ->
-    emqx_common_test_helpers:all(?MODULE).
+    [t_banned_delayed].
+%%    emqx_common_test_helpers:all(?MODULE).
 
 
 init_per_suite(Config) ->
 init_per_suite(Config) ->
     ok = emqx_common_test_helpers:load_config(emqx_modules_schema, ?BASE_CONF, #{
     ok = emqx_common_test_helpers:load_config(emqx_modules_schema, ?BASE_CONF, #{
@@ -212,6 +214,38 @@ t_delayed_precision(_) ->
     _ = on_message_publish(DelayedMsg0),
     _ = on_message_publish(DelayedMsg0),
     ?assert(FutureDiff() =< MaxSpan).
     ?assert(FutureDiff() =< MaxSpan).
 
 
+t_banned_delayed(_) ->
+    emqx:update_config([delayed, max_delayed_messages], 10000),
+    ClientId1 = <<"bc1">>,
+    ClientId2 = <<"bc2">>,
+
+    Now = erlang:system_time(second),
+    Who = {clientid, ClientId2},
+    emqx_banned:create(#{
+        who => Who,
+        by => <<"test">>,
+        reason => <<"test">>,
+        at => Now,
+        until => Now + 120
+    }),
+
+    snabbkaffe:start_trace(),
+    lists:foreach(
+        fun(ClientId) ->
+            Msg = emqx_message:make(ClientId, <<"$delayed/1/bc">>, <<"payload">>),
+            emqx_delayed:on_message_publish(Msg)
+        end,
+        [ClientId1, ClientId1, ClientId1, ClientId2, ClientId2]
+    ),
+
+    timer:sleep(2000),
+    Trace = snabbkaffe:collect_trace(),
+    snabbkaffe:stop(),
+    emqx_banned:delete(Who),
+    mnesia:clear_table(emqx_delayed),
+
+    ?assertEqual(2, length(?of_kind(ignore_delayed_message_publish, Trace))).
+
 subscribe_proc() ->
 subscribe_proc() ->
     Self = self(),
     Self = self(),
     Ref = erlang:make_ref(),
     Ref = erlang:make_ref(),