Jelajahi Sumber

fix(delayed): unify and optimize the enable/disable codes

firest 3 tahun lalu
induk
melakukan
fc391c7b9e

+ 34 - 28
apps/emqx_modules/src/emqx_delayed.erl

@@ -143,10 +143,10 @@ store(DelayedMsg) ->
     gen_server:call(?SERVER, {store, DelayedMsg}, infinity).
 
 enable() ->
-    gen_server:call(?SERVER, enable).
+    enable(true).
 
 disable() ->
-    gen_server:call(?SERVER, disable).
+    enable(false).
 
 set_max_delayed_messages(Max) ->
     gen_server:call(?SERVER, {set_max_delayed_messages, Max}).
@@ -238,21 +238,7 @@ update_config(Config) ->
     emqx_conf:update([delayed], Config, #{rawconf_with_defaults => true, override_to => cluster}).
 
 post_config_update(_KeyPath, Config, _NewConf, _OldConf, _AppEnvs) ->
-    case maps:get(<<"enable">>, Config, undefined) of
-        undefined ->
-            ignore;
-        true ->
-            emqx_delayed:enable();
-        false ->
-            emqx_delayed:disable()
-    end,
-    case maps:get(<<"max_delayed_messages">>, Config, undefined) of
-        undefined ->
-            ignore;
-        Max ->
-            ok = emqx_delayed:set_max_delayed_messages(Max)
-    end,
-    ok.
+    gen_server:call(?SERVER, {update_config, Config}).
 
 %%--------------------------------------------------------------------
 %% gen_server callback
@@ -262,7 +248,7 @@ init([Opts]) ->
     erlang:process_flag(trap_exit, true),
     emqx_conf:add_handler([delayed], ?MODULE),
     MaxDelayedMessages = maps:get(max_delayed_messages, Opts, 0),
-    {ok,
+    State =
         ensure_stats_event(
             ensure_publish_timer(#{
                 publish_timer => undefined,
@@ -271,7 +257,8 @@ init([Opts]) ->
                 stats_fun => undefined,
                 max_delayed_messages => MaxDelayedMessages
             })
-        )}.
+        ),
+    {ok, ensure_enable(emqx:get_config([delayed, enable]), State)}.
 
 handle_call({set_max_delayed_messages, Max}, _From, State) ->
     {reply, ok, State#{max_delayed_messages => Max}};
@@ -293,12 +280,11 @@ handle_call(
             emqx_metrics:inc('messages.delayed'),
             {reply, ok, ensure_publish_timer(Key, State)}
     end;
-handle_call(enable, _From, State) ->
-    emqx_hooks:put('message.publish', {?MODULE, on_message_publish, []}),
-    {reply, ok, State};
-handle_call(disable, _From, State) ->
-    emqx_hooks:del('message.publish', {?MODULE, on_message_publish}),
-    {reply, ok, State};
+handle_call({update_config, Config}, _From, #{max_delayed_messages := Max} = State) ->
+    Max2 = maps:get(<<"max_delayed_messages">>, Config, Max),
+    State2 = State#{max_delayed_messages := Max2},
+    State3 = ensure_enable(maps:get(<<"enable">>, Config, undefined), State2),
+    {reply, ok, State3};
 handle_call(Req, _From, State) ->
     ?tp(error, emqx_delayed_unexpected_call, #{call => Req}),
     {reply, ignored, State}.
@@ -320,10 +306,10 @@ handle_info(Info, State) ->
     ?tp(error, emqx_delayed_unexpected_info, #{info => Info}),
     {noreply, State}.
 
-terminate(_Reason, #{publish_timer := PublishTimer, stats_timer := StatsTimer}) ->
+terminate(_Reason, #{stats_timer := StatsTimer} = State) ->
     emqx_conf:remove_handler([delayed]),
-    emqx_misc:cancel_timer(PublishTimer),
-    emqx_misc:cancel_timer(StatsTimer).
+    emqx_misc:cancel_timer(StatsTimer),
+    ensure_enable(false, State).
 
 code_change(_Vsn, State, _Extra) ->
     {ok, State}.
@@ -378,3 +364,23 @@ do_publish(Key = {Ts, _Id}, Now, Acc) when Ts =< Now ->
 
 -spec delayed_count() -> non_neg_integer().
 delayed_count() -> mnesia:table_info(?TAB, size).
+
+enable(Enable) ->
+    case emqx:get_raw_config([delayed]) of
+        #{<<"enable">> := Enable} ->
+            ok;
+        Cfg ->
+            {ok, _} = update_config(Cfg#{<<"enable">> := Enable}),
+            ok
+    end.
+
+ensure_enable(true, State) ->
+    emqx_hooks:put('message.publish', {?MODULE, on_message_publish, []}),
+    State;
+ensure_enable(false, #{publish_timer := PubTimer} = State) ->
+    emqx_hooks:del('message.publish', {?MODULE, on_message_publish}),
+    emqx_misc:cancel_timer(PubTimer),
+    ets:delete_all_objects(?TAB),
+    State#{publish_timer := undefined, publish_at := 0};
+ensure_enable(_, State) ->
+    State.

+ 1 - 1
apps/emqx_modules/test/emqx_delayed_SUITE.erl

@@ -76,7 +76,7 @@ t_delayed_message(_) ->
 
     [#delayed_message{msg = #message{payload = Payload}}] = ets:tab2list(emqx_delayed),
     ?assertEqual(<<"delayed_m">>, Payload),
-    ct:sleep(2000),
+    ct:sleep(2500),
 
     EmptyKey = mnesia:dirty_all_keys(emqx_delayed),
     ?assertEqual([], EmptyKey).

+ 12 - 5
apps/emqx_modules/test/emqx_delayed_api_SUITE.erl

@@ -98,6 +98,7 @@ t_status(_Config) ->
 
 t_messages(_) ->
     clear_all_record(),
+    emqx_delayed:enable(),
 
     {ok, C1} = emqtt:start_link([{clean_start, true}]),
     {ok, _} = emqtt:connect(C1),
@@ -114,7 +115,7 @@ t_messages(_) ->
     end,
 
     lists:foreach(Each, lists:seq(1, 5)),
-    timer:sleep(500),
+    timer:sleep(1000),
 
     Msgs = get_messages(5),
     [First | _] = Msgs,
@@ -197,6 +198,7 @@ t_messages(_) ->
 
 t_large_payload(_) ->
     clear_all_record(),
+    emqx_delayed:enable(),
 
     {ok, C1} = emqtt:start_link([{clean_start, true}]),
     {ok, _} = emqtt:connect(C1),
@@ -209,7 +211,7 @@ t_large_payload(_) ->
         [{qos, 0}, {retain, true}]
     ),
 
-    timer:sleep(500),
+    timer:sleep(1000),
 
     [#{msgid := MsgId}] = get_messages(1),
 
@@ -241,8 +243,13 @@ get_messages(Len) ->
     {ok, 200, MsgsJson} = request(get, uri(["mqtt", "delayed", "messages"])),
     #{data := Msgs} = decode_json(MsgsJson),
     MsgLen = erlang:length(Msgs),
-    ?assert(
-        MsgLen =:= Len,
-        lists:flatten(io_lib:format("message length is:~p~n", [MsgLen]))
+    ?assertEqual(
+        Len,
+        MsgLen,
+        lists:flatten(
+            io_lib:format("message length is:~p~nWhere:~p~nHooks:~p~n", [
+                MsgLen, erlang:whereis(emqx_delayed), ets:tab2list(emqx_hooks)
+            ])
+        )
     ),
     Msgs.