Parcourir la source

feat: refactor flapping detect conf

zhongwencool il y a 2 ans
Parent
commit
48381d4c86

+ 3 - 4
apps/emqx/src/emqx_channel.erl

@@ -1632,10 +1632,9 @@ check_banned(_ConnPkt, #channel{clientinfo = ClientInfo}) ->
 %%--------------------------------------------------------------------
 %% Flapping
 
-count_flapping_event(_ConnPkt, Channel = #channel{clientinfo = ClientInfo = #{zone := Zone}}) ->
-    is_integer(emqx_config:get_zone_conf(Zone, [flapping_detect, window_time])) andalso
-        emqx_flapping:detect(ClientInfo),
-    {ok, Channel}.
+count_flapping_event(_ConnPkt, #channel{clientinfo = ClientInfo}) ->
+    _ = emqx_flapping:detect(ClientInfo),
+    ok.
 
 %%--------------------------------------------------------------------
 %% Authenticate

+ 16 - 8
apps/emqx/src/emqx_config.erl

@@ -707,7 +707,10 @@ do_put(Type, Putter, [], DeepValue) ->
 do_put(Type, Putter, [RootName | KeyPath], DeepValue) ->
     OldValue = do_get(Type, [RootName], #{}),
     NewValue = do_deep_put(Type, Putter, KeyPath, OldValue, DeepValue),
-    persistent_term:put(?PERSIS_KEY(Type, RootName), NewValue).
+    Key = ?PERSIS_KEY(Type, RootName),
+    persistent_term:put(Key, NewValue),
+    post_save_config_hook(Key, NewValue),
+    ok.
 
 do_deep_get(?CONF, AtomKeyPath, Map, Default) ->
     emqx_utils_maps:deep_get(AtomKeyPath, Map, Default);
@@ -829,15 +832,12 @@ merge_with_global_defaults(GlobalDefaults, ZoneVal) ->
 maybe_update_zone([zones | T], ZonesValue, Value) ->
     %% note, do not write to PT, return *New value* instead
     NewZonesValue = emqx_utils_maps:deep_put(T, ZonesValue, Value),
-    ExistingZoneNames = maps:keys(?MODULE:get([zones], #{})),
-    %% Update only new zones with global defaults
     GLD = zone_global_defaults(),
-    maps:fold(
-        fun(ZoneName, ZoneValue, Acc) ->
-            Acc#{ZoneName := merge_with_global_defaults(GLD, ZoneValue)}
+    maps:map(
+        fun(_ZoneName, ZoneValue) ->
+            merge_with_global_defaults(GLD, ZoneValue)
         end,
-        NewZonesValue,
-        maps:without(ExistingZoneNames, NewZonesValue)
+        NewZonesValue
     );
 maybe_update_zone([RootName | T], RootValue, Value) when is_atom(RootName) ->
     NewRootValue = emqx_utils_maps:deep_put(T, RootValue, Value),
@@ -911,3 +911,11 @@ rawconf_to_conf(SchemaModule, RawPath, RawValue) ->
         ),
     AtomPath = to_atom_conf_path(RawPath, {raise_error, maybe_update_zone_error}),
     emqx_utils_maps:deep_get(AtomPath, RawUserDefinedValues).
+
+%% When the global zone change, the zones is updated with the new global zone.
+%% The zones config has no config_handler callback, so we need to update via this hook
+post_save_config_hook(?PERSIS_KEY(?CONF, zones), _Zones) ->
+    emqx_flapping:update_config(),
+    ok;
+post_save_config_hook(_Key, _NewValue) ->
+    ok.

+ 41 - 36
apps/emqx/src/emqx_flapping.erl

@@ -22,13 +22,13 @@
 -include("types.hrl").
 -include("logger.hrl").
 
--export([start_link/0, stop/0]).
+-export([start_link/0, update_config/0, stop/0]).
 
 %% API
 -export([detect/1]).
 
 -ifdef(TEST).
--export([get_policy/2]).
+-export([get_policy/1]).
 -endif.
 
 %% gen_server callbacks
@@ -59,12 +59,17 @@
 start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
+update_config() ->
+    gen_server:cast(?MODULE, update_config).
+
 stop() -> gen_server:stop(?MODULE).
 
 %% @doc Detect flapping when a MQTT client disconnected.
 -spec detect(emqx_types:clientinfo()) -> boolean().
 detect(#{clientid := ClientId, peerhost := PeerHost, zone := Zone}) ->
-    Policy = #{max_count := Threshold} = get_policy([max_count, window_time, ban_time], Zone),
+    detect(ClientId, PeerHost, get_policy(Zone)).
+
+detect(ClientId, PeerHost, #{enable := true, max_count := Threshold} = Policy) ->
     %% The initial flapping record sets the detect_cnt to 0.
     InitVal = #flapping{
         clientid = ClientId,
@@ -82,24 +87,12 @@ detect(#{clientid := ClientId, peerhost := PeerHost, zone := Zone}) ->
                 [] ->
                     false
             end
-    end.
+    end;
+detect(_ClientId, _PeerHost, #{enable := false}) ->
+    false.
 
-get_policy(Keys, Zone) when is_list(Keys) ->
-    RootKey = flapping_detect,
-    Conf = emqx_config:get_zone_conf(Zone, [RootKey]),
-    lists:foldl(
-        fun(Key, Acc) ->
-            case maps:find(Key, Conf) of
-                {ok, V} -> Acc#{Key => V};
-                error -> Acc#{Key => emqx_config:get([RootKey, Key])}
-            end
-        end,
-        #{},
-        Keys
-    );
-get_policy(Key, Zone) ->
-    #{Key := Conf} = get_policy([Key], Zone),
-    Conf.
+get_policy(Zone) ->
+    emqx_config:get_zone_conf(Zone, [flapping_detect]).
 
 now_diff(TS) -> erlang:system_time(millisecond) - TS.
 
@@ -115,8 +108,8 @@ init([]) ->
         {read_concurrency, true},
         {write_concurrency, true}
     ]),
-    start_timers(),
-    {ok, #{}, hibernate}.
+    Timers = start_timers(),
+    {ok, Timers, hibernate}.
 
 handle_call(Req, _From, State) ->
     ?SLOG(error, #{msg => "unexpected_call", call => Req}),
@@ -169,17 +162,20 @@ handle_cast(
             )
     end,
     {noreply, State};
+handle_cast(update_config, State) ->
+    NState = update_timer(State),
+    {noreply, NState};
 handle_cast(Msg, State) ->
     ?SLOG(error, #{msg => "unexpected_cast", cast => Msg}),
     {noreply, State}.
 
 handle_info({timeout, _TRef, {garbage_collect, Zone}}, State) ->
-    Timestamp =
-        erlang:system_time(millisecond) - get_policy(window_time, Zone),
+    Policy = #{window_time := WindowTime} = get_policy(Zone),
+    Timestamp = erlang:system_time(millisecond) - WindowTime,
     MatchSpec = [{{'_', '_', '_', '$1', '_'}, [{'<', '$1', Timestamp}], [true]}],
     ets:select_delete(?FLAPPING_TAB, MatchSpec),
-    _ = start_timer(Zone),
-    {noreply, State, hibernate};
+    Timer = start_timer(Policy, Zone),
+    {noreply, State#{Zone => Timer}, hibernate};
 handle_info(Info, State) ->
     ?SLOG(error, #{msg => "unexpected_info", info => Info}),
     {noreply, State}.
@@ -190,18 +186,27 @@ terminate(_Reason, _State) ->
 code_change(_OldVsn, State, _Extra) ->
     {ok, State}.
 
-start_timer(Zone) ->
-    case get_policy(window_time, Zone) of
-        WindowTime when is_integer(WindowTime) ->
-            emqx_utils:start_timer(WindowTime, {garbage_collect, Zone});
-        disabled ->
-            ok
-    end.
+start_timer(#{enable := true, window_time := WindowTime}, Zone) ->
+    emqx_utils:start_timer(WindowTime, {garbage_collect, Zone});
+start_timer(_Policy, _Zone) ->
+    undefined.
 
 start_timers() ->
-    maps:foreach(
-        fun(Zone, _ZoneConf) ->
-            start_timer(Zone)
+    maps:map(
+        fun(ZoneName, #{flapping_detect := FlappingDetect}) ->
+            start_timer(FlappingDetect, ZoneName)
+        end,
+        emqx:get_config([zones], #{})
+    ).
+
+update_timer(Timers) ->
+    maps:map(
+        fun(ZoneName, #{flapping_detect := FlappingDetect}) ->
+            case maps:get(ZoneName, Timers, undefined) of
+                undefined -> start_timer(FlappingDetect, ZoneName);
+                %% Don't reset this timer, it will be updated after next timeout.
+                TRef -> TRef
+            end
         end,
         emqx:get_config([zones], #{})
     ).

+ 3 - 4
apps/emqx/src/emqx_schema.erl

@@ -275,7 +275,7 @@ roots(low) ->
         {"flapping_detect",
             sc(
                 ref("flapping_detect"),
-                #{importance => ?IMPORTANCE_HIDDEN}
+                #{importance => ?DEFAULT_IMPORTANCE}
             )},
         {"persistent_session_store",
             sc(
@@ -685,15 +685,14 @@ fields("flapping_detect") ->
                 boolean(),
                 #{
                     default => false,
-                    deprecated => {since, "5.0.23"},
                     desc => ?DESC(flapping_detect_enable)
                 }
             )},
         {"window_time",
             sc(
-                hoconsc:union([disabled, duration()]),
+                duration(),
                 #{
-                    default => disabled,
+                    default => "1m",
                     importance => ?IMPORTANCE_HIGH,
                     desc => ?DESC(flapping_detect_window_time)
                 }

+ 1 - 2
apps/emqx/src/emqx_zone_schema.erl

@@ -58,8 +58,7 @@ hidden() ->
     [
         "stats",
         "overload_protection",
-        "conn_congestion",
-        "flapping_detect"
+        "conn_congestion"
     ].
 
 %% zone schemas are clones from the same name from root level

+ 11 - 9
apps/emqx/test/emqx_flapping_SUITE.erl

@@ -30,6 +30,7 @@ init_per_suite(Config) ->
         default,
         [flapping_detect],
         #{
+            enable => true,
             max_count => 3,
             % 0.1s
             window_time => 100,
@@ -102,20 +103,21 @@ t_expired_detecting(_) ->
         )
     ).
 
-t_conf_without_window_time(_) ->
-    %% enable is deprecated, so we need to make sure it won't be used.
+t_conf_update(_) ->
     Global = emqx_config:get([flapping_detect]),
-    ?assertNot(maps:is_key(enable, Global)),
-    %% zones don't have default value, so we need to make sure fallback to global conf.
-    %% this new_zone will fallback to global conf.
+    #{
+        ban_time := _BanTime,
+        enable := _Enable,
+        max_count := _MaxCount,
+        window_time := _WindowTime
+    } = Global,
+
     emqx_config:put_zone_conf(new_zone, [flapping_detect], #{}),
     ?assertEqual(Global, get_policy(new_zone)),
 
     emqx_config:put_zone_conf(new_zone_1, [flapping_detect], #{window_time => 100}),
-    ?assertEqual(100, emqx_flapping:get_policy(window_time, new_zone_1)),
-    ?assertEqual(maps:get(ban_time, Global), emqx_flapping:get_policy(ban_time, new_zone_1)),
-    ?assertEqual(maps:get(max_count, Global), emqx_flapping:get_policy(max_count, new_zone_1)),
+    ?assertEqual(Global#{window_time := 100}, emqx_flapping:get_policy(new_zone_1)),
     ok.
 
 get_policy(Zone) ->
-    emqx_flapping:get_policy([window_time, ban_time, max_count], Zone).
+    emqx_flapping:get_policy(Zone).