Просмотр исходного кода

Merge pull request #14268 from keynslug/fix/EMQX-13533/postpone-stats-wsconn

fix(wsconn): avoid arming stats timer before CONNECT is processed
Andrew Mayorov 1 год назад
Родитель
Сommit
dfe20f8980

+ 36 - 20
apps/emqx/src/emqx_ws_connection.erl

@@ -31,6 +31,7 @@
 %% API
 -export([
     info/1,
+    info/2,
     stats/1
 ]).
 
@@ -78,11 +79,10 @@
     %% GC State
     gc_state :: option(emqx_gc:gc_state()),
     %% Postponed Packets|Cmds|Events
+    %% Order is reversed: most recent entry is the first element.
     postponed :: list(emqx_types:packet() | ws_cmd() | tuple()),
     %% Stats Timer
-    stats_timer :: disabled | option(reference()),
-    %% Idle Timeout
-    idle_timeout :: timeout(),
+    stats_timer :: paused | disabled | option(reference()),
     %% Idle Timer
     idle_timer :: option(reference()),
     %% Zone name
@@ -117,7 +117,6 @@
 
 -type ws_cmd() :: {active, boolean()} | close.
 
--define(ACTIVE_N, 100).
 -define(INFO_KEYS, [socktype, peername, sockname, sockstate]).
 -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]).
 
@@ -125,15 +124,14 @@
 -define(LIMITER_BYTES_IN, bytes).
 -define(LIMITER_MESSAGE_IN, messages).
 
--dialyzer({no_match, [info/2]}).
--dialyzer({nowarn_function, [websocket_init/1]}).
-
 -define(LOG(Level, Data), ?SLOG(Level, (Data)#{tag => "MQTT"})).
 
 %%--------------------------------------------------------------------
 %% Info, Stats
 %%--------------------------------------------------------------------
 
+-type info() :: atom() | {channel, _Info}.
+
 -spec info(pid() | state()) -> emqx_types:infos().
 info(WsPid) when is_pid(WsPid) ->
     call(WsPid, info);
@@ -144,6 +142,9 @@ info(State = #state{channel = Channel}) ->
     ),
     ChanInfo#{sockinfo => SockInfo}.
 
+-spec info
+    (info(), state()) -> _Value;
+    (info(), state()) -> [{atom(), _Value}].
 info(Keys, State) when is_list(Keys) ->
     [{Key, info(Key, State)} || Key <- Keys];
 info(socktype, _State) ->
@@ -164,10 +165,10 @@ info(postponed, #state{postponed = Postponed}) ->
     Postponed;
 info(stats_timer, #state{stats_timer = TRef}) ->
     TRef;
-info(idle_timeout, #state{idle_timeout = Timeout}) ->
-    Timeout;
 info(idle_timer, #state{idle_timer = TRef}) ->
-    TRef.
+    TRef;
+info({channel, Info}, #state{channel = Channel}) ->
+    emqx_channel:info(Info, Channel).
 
 -spec stats(pid() | state()) -> emqx_types:stats().
 stats(WsPid) when is_pid(WsPid) ->
@@ -310,7 +311,7 @@ websocket_init([Req, Opts]) ->
             %% MQTT Idle Timeout
             IdleTimeout = emqx_channel:get_mqtt_conf(Zone, idle_timeout),
             IdleTimer = start_timer(IdleTimeout, idle_timeout),
-            tune_heap_size(Channel),
+            _ = tune_heap_size(Channel),
             emqx_logger:set_metadata_peername(esockd:format(Peername)),
             {ok,
                 #state{
@@ -325,7 +326,6 @@ websocket_init([Req, Opts]) ->
                     gc_state = GcState,
                     postponed = [],
                     stats_timer = StatsTimer,
-                    idle_timeout = IdleTimeout,
                     idle_timer = IdleTimer,
                     zone = Zone,
                     listener = {Type, Listener},
@@ -350,7 +350,7 @@ tune_heap_size(Channel) ->
 
 get_stats_enable(Zone) ->
     case emqx_config:get_zone_conf(Zone, [stats, enable]) of
-        true -> undefined;
+        true -> paused;
         false -> disabled
     end.
 
@@ -533,7 +533,8 @@ handle_info({close, Reason}, State) ->
 handle_info({event, connected}, State = #state{channel = Channel}) ->
     ClientId = emqx_channel:info(clientid, Channel),
     emqx_cm:insert_channel_info(ClientId, info(State), stats(State)),
-    return(State);
+    NState = resume_stats_timer(State),
+    return(NState);
 handle_info({event, disconnected}, State = #state{channel = Channel}) ->
     ClientId = emqx_channel:info(clientid, Channel),
     emqx_cm:set_chan_info(ClientId, info(State)),
@@ -951,12 +952,19 @@ cancel_idle_timer(State = #state{idle_timer = IdleTimer}) ->
 
 ensure_stats_timer(
     State = #state{
-        idle_timeout = Timeout,
+        zone = Zone,
         stats_timer = undefined
     }
 ) ->
+    Timeout = emqx_channel:get_mqtt_conf(Zone, idle_timeout),
     State#state{stats_timer = start_timer(Timeout, emit_stats)};
 ensure_stats_timer(State) ->
+    %% Either already active, disabled or paused.
+    State.
+
+resume_stats_timer(State = #state{stats_timer = paused}) ->
+    State#state{stats_timer = undefined};
+resume_stats_timer(State = #state{stats_timer = disabled}) ->
     State.
 
 -compile({inline, [postpone/2, enqueue/2, return/1, shutdown/2]}).
@@ -964,8 +972,6 @@ ensure_stats_timer(State) ->
 %%--------------------------------------------------------------------
 %% Postpone the packet, cmd or event
 
-postpone(Packet, State) when is_record(Packet, mqtt_packet) ->
-    enqueue(Packet, State);
 postpone(Event, State) when is_tuple(Event) ->
     enqueue(Event, State);
 postpone(More, State) when is_list(More) ->
@@ -1001,9 +1007,19 @@ return(State = #state{postponed = Postponed}) ->
 
 classify([], Packets, Cmds, Events) ->
     {Packets, Cmds, Events};
-classify([Packet | More], Packets, Cmds, Events) when
-    is_record(Packet, mqtt_packet)
-->
+classify([{outgoing, Outgoing} | More], Packets, Cmds, Events) ->
+    case is_list(Outgoing) of
+        true ->
+            %% Outgoing is a list in least-to-most recent order (i.e. not reversed).
+            %% Prepending will keep the overall order correct.
+            NPackets = Outgoing ++ Packets;
+        false ->
+            NPackets = [Outgoing | Packets]
+    end,
+    classify(More, NPackets, Cmds, Events);
+classify([{connack, Packet} | More], Packets, Cmds, Events) ->
+    classify(More, [Packet | Packets], Cmds, Events);
+classify([Packet = #mqtt_packet{} | More], Packets, Cmds, Events) ->
     classify(More, [Packet | Packets], Cmds, Events);
 classify([Cmd = {active, _} | More], Packets, Cmds, Events) ->
     classify(More, Packets, [Cmd | Cmds], Events);

+ 25 - 7
apps/emqx/test/emqx_mqtt_protocol_v5_SUITE.erl

@@ -48,6 +48,7 @@
 all() ->
     [
         {group, tcp},
+        {group, ws},
         {group, quic}
     ].
 
@@ -55,18 +56,25 @@ groups() ->
     TCs = emqx_common_test_helpers:all(?MODULE),
     [
         {tcp, [], TCs},
+        {ws, [], TCs},
         {quic, [], TCs}
     ].
 
 init_per_group(tcp, Config) ->
     Apps = emqx_cth_suite:start([emqx], #{work_dir => emqx_cth_suite:work_dir(Config)}),
-    [{port, 1883}, {conn_fun, connect}, {group_apps, Apps} | Config];
+    [{conn_type, tcp}, {port, 1883}, {conn_fun, connect}, {group_apps, Apps} | Config];
 init_per_group(quic, Config) ->
     Apps = emqx_cth_suite:start(
         [{emqx, "listeners.quic.test { enable = true, bind = 1884 }"}],
         #{work_dir => emqx_cth_suite:work_dir(Config)}
     ),
-    [{port, 1884}, {conn_fun, quic_connect}, {group_apps, Apps} | Config].
+    [{conn_type, quic}, {port, 1884}, {conn_fun, quic_connect}, {group_apps, Apps} | Config];
+init_per_group(ws, Config) ->
+    Apps = emqx_cth_suite:start(
+        [{emqx, "listeners.ws.test { enable = true, bind = 8888 }"}],
+        #{work_dir => emqx_cth_suite:work_dir(Config)}
+    ),
+    [{conn_type, ws}, {port, 8888}, {conn_fun, ws_connect}, {group_apps, Apps} | Config].
 
 end_per_group(_Group, Config) ->
     emqx_cth_suite:stop(?config(group_apps, Config)).
@@ -89,7 +97,17 @@ end_per_testcase(TestCase, Config) ->
 %%--------------------------------------------------------------------
 
 client_info(Key, Client) ->
-    maps:get(Key, maps:from_list(emqtt:info(Client)), undefined).
+    proplists:get_value(Key, emqtt:info(Client), undefined).
+
+connection_info(Info, ClientPid, Config) when is_list(Config) ->
+    connection_info(Info, ClientPid, ?config(conn_type, Config));
+connection_info(Info, ClientPid, tcp) ->
+    emqx_connection:info(Info, sys:get_state(ClientPid));
+connection_info(Info, ClientPid, quic) ->
+    emqx_connection:info(Info, sys:get_state(ClientPid));
+connection_info(Info, ClientPid, ws) ->
+    {_WSState, ConnState, _} = sys:get_state(ClientPid),
+    emqx_ws_connection:info(Info, ConnState).
 
 receive_messages(Count) ->
     receive_messages(Count, []).
@@ -206,9 +224,9 @@ t_connect_will_message(Config) ->
     ]),
     {ok, _} = emqtt:ConnFun(Client1),
     [ClientPid] = emqx_cm:lookup_channels(client_info(clientid, Client1)),
-    Info = emqx_connection:info(sys:get_state(ClientPid)),
+    WillMsg = connection_info({channel, will_msg}, ClientPid, Config),
     %% [MQTT-3.1.2-7]
-    ?assertNotEqual(undefined, maps:find(will_msg, Info)),
+    ?assertNotEqual(undefined, WillMsg),
 
     {ok, Client2} = emqtt:start_link([{proto_ver, v5} | Config]),
     {ok, _} = emqtt:ConnFun(Client2),
@@ -392,10 +410,10 @@ t_connect_emit_stats_timeout(Config) ->
     [ClientPid] = emqx_cm:lookup_channels(client_info(clientid, Client)),
     ?assertMatch(
         TRef when is_reference(TRef),
-        emqx_connection:info(stats_timer, sys:get_state(ClientPid))
+        connection_info(stats_timer, ClientPid, Config)
     ),
     ?block_until(#{?snk_kind := cancel_stats_timer}, IdleTimeout * 2, _BackInTime = 0),
-    ?assertEqual(undefined, emqx_connection:info(stats_timer, sys:get_state(ClientPid))),
+    ?assertEqual(undefined, connection_info(stats_timer, ClientPid, Config)),
     ok = emqtt:disconnect(Client).
 
 %% [MQTT-3.1.2-22]

+ 33 - 66
apps/emqx/test/emqx_ws_connection_SUITE.erl

@@ -41,6 +41,16 @@ all() -> emqx_common_test_helpers:all(?MODULE).
 %% CT callbacks
 %%--------------------------------------------------------------------
 
+init_per_suite(Config) ->
+    Apps = emqx_cth_suite:start(
+        [emqx],
+        #{work_dir => emqx_cth_suite:work_dir(Config)}
+    ),
+    [{apps, Apps} | Config].
+
+end_per_suite(Config) ->
+    ok = emqx_cth_suite:stop(?config(apps, Config)).
+
 init_per_testcase(TestCase, Config) when
     TestCase =/= t_ws_sub_protocols_mqtt_equivalents,
     TestCase =/= t_ws_sub_protocols_mqtt,
@@ -49,10 +59,6 @@ init_per_testcase(TestCase, Config) when
     TestCase =/= t_ws_non_check_origin
 ->
     add_bucket(),
-    Apps = emqx_cth_suite:start(
-        [emqx],
-        #{work_dir => emqx_cth_suite:work_dir(TestCase, Config)}
-    ),
     %% Meck Cm
     ok = meck:new(emqx_cm, [passthrough, no_history, no_link]),
     ok = meck:expect(emqx_cm, mark_channel_connected, fun(_) -> ok end),
@@ -64,75 +70,31 @@ init_per_testcase(TestCase, Config) when
     ok = meck:expect(cowboy_req, sock, fun(_) -> {{127, 0, 0, 1}, 18083} end),
     ok = meck:expect(cowboy_req, cert, fun(_) -> undefined end),
     ok = meck:expect(cowboy_req, parse_cookies, fun(_) -> error(badarg) end),
-    %% Mock emqx_access_control
-    ok = meck:new(emqx_access_control, [passthrough, no_history, no_link]),
-    ok = meck:expect(emqx_access_control, authorize, fun(_, _, _) -> allow end),
-    %% Mock emqx_hooks
-    ok = meck:new(emqx_hooks, [passthrough, no_history, no_link]),
-    ok = meck:expect(emqx_hooks, run, fun(_Hook, _Args) -> ok end),
-    ok = meck:expect(emqx_hooks, run_fold, fun(_Hook, _Args, Acc) -> Acc end),
-    %% Mock emqx_broker
-    ok = meck:new(emqx_broker, [passthrough, no_history, no_link]),
-    ok = meck:expect(emqx_broker, subscribe, fun(_, _, _) -> ok end),
-    ok = meck:expect(emqx_broker, publish, fun(#message{topic = Topic}) ->
-        [{node(), Topic, 1}]
-    end),
-    ok = meck:expect(emqx_broker, unsubscribe, fun(_) -> ok end),
-    %% Mock emqx_metrics
-    ok = meck:new(emqx_metrics, [passthrough, no_history, no_link]),
-    ok = meck:expect(emqx_metrics, inc, fun(_) -> ok end),
-    ok = meck:expect(emqx_metrics, inc, fun(_, _) -> ok end),
-    ok = meck:expect(emqx_metrics, inc_recv, fun(_) -> ok end),
-    ok = meck:expect(emqx_metrics, inc_sent, fun(_) -> ok end),
-    [{apps, Apps} | Config];
-init_per_testcase(t_ws_non_check_origin = TestCase, Config) ->
+    Config;
+init_per_testcase(t_ws_non_check_origin, Config) ->
     add_bucket(),
-    Apps = emqx_cth_suite:start(
-        [emqx],
-        #{work_dir => emqx_cth_suite:work_dir(TestCase, Config)}
-    ),
     emqx_config:put_listener_conf(ws, default, [websocket, check_origin_enable], false),
     emqx_config:put_listener_conf(ws, default, [websocket, check_origins], []),
-    [{apps, Apps} | Config];
-init_per_testcase(TestCase, Config) ->
+    Config;
+init_per_testcase(_TestCase, Config) ->
     add_bucket(),
-    Apps = emqx_cth_suite:start(
-        [emqx],
-        #{work_dir => emqx_cth_suite:work_dir(TestCase, Config)}
-    ),
-    [{apps, Apps} | Config].
+    Config.
 
-end_per_testcase(TestCase, Config) when
+end_per_testcase(TestCase, _Config) when
     TestCase =/= t_ws_sub_protocols_mqtt_equivalents,
     TestCase =/= t_ws_sub_protocols_mqtt,
     TestCase =/= t_ws_check_origin,
     TestCase =/= t_ws_non_check_origin,
     TestCase =/= t_ws_pingreq_before_connected
 ->
-    Apps = ?config(apps, Config),
     del_bucket(),
-    lists:foreach(
-        fun meck:unload/1,
-        [
-            emqx_cm,
-            cowboy_req,
-            emqx_access_control,
-            emqx_broker,
-            emqx_hooks,
-            emqx_metrics
-        ]
-    ),
-    ok = emqx_cth_suite:stop(Apps),
-    ok;
-end_per_testcase(t_ws_non_check_origin, Config) ->
-    Apps = ?config(apps, Config),
-    del_bucket(),
-    ok = emqx_cth_suite:stop(Apps),
+    meck:unload([
+        emqx_cm,
+        cowboy_req
+    ]),
     ok;
 end_per_testcase(_, Config) ->
-    Apps = ?config(apps, Config),
     del_bucket(),
-    ok = emqx_cth_suite:stop(Apps),
     Config.
 
 %%--------------------------------------------------------------------
@@ -388,6 +350,7 @@ t_websocket_info_cast(_) ->
     {ok, _St} = websocket_info({cast, msg}, st()).
 
 t_websocket_info_incoming(_) ->
+    ok = emqx_broker:subscribe(<<"#">>, <<?MODULE_STRING>>),
     ConnPkt = #mqtt_packet_connect{
         proto_name = <<"MQTT">>,
         proto_ver = ?MQTT_PROTO_V5,
@@ -399,8 +362,9 @@ t_websocket_info_incoming(_) ->
         username = <<"username">>,
         password = <<"passwd">>
     },
-    {[{close, protocol_error}], St1} = websocket_info({incoming, ?CONNECT_PACKET(ConnPkt)}, st()),
-    % ?assertEqual(<<224,2,130,0>>, iolist_to_binary(IoData1)),
+    {[{binary, IoData1}, {close, protocol_error}], St1} =
+        websocket_info({incoming, ?CONNECT_PACKET(ConnPkt)}, st()),
+    ?assertEqual(<<224, 2, ?RC_PROTOCOL_ERROR, 0>>, iolist_to_binary(IoData1)),
     %% PINGREQ
     {[{binary, IoData2}], St2} =
         websocket_info({incoming, ?PACKET(?PINGREQ)}, St1),
@@ -416,7 +380,9 @@ t_websocket_info_incoming(_) ->
     FrameError = {frame_error, #{cause => Cause, property_code => 16#2B}},
     %% cowboy_websocket's close reason must be an atom to avoid crashing the sender process.
     %% ensure the cause is atom
-    {[{close, CauseReq}], _St4} = websocket_info({incoming, FrameError}, St3),
+    {[{binary, IoData4}, {close, CauseReq}], _St4} =
+        websocket_info({incoming, FrameError}, St3),
+    ?assertEqual(<<224, 2, ?RC_MALFORMED_PACKET, 0>>, iolist_to_binary(IoData4)),
     ?assertEqual(Cause, CauseReq).
 
 t_websocket_info_check_gc(_) ->
@@ -427,8 +393,8 @@ t_websocket_info_deliver(_) ->
     Msg0 = emqx_message:make(clientid, ?QOS_0, <<"t">>, <<"">>),
     Msg1 = emqx_message:make(clientid, ?QOS_1, <<"t">>, <<"">>),
     self() ! {deliver, <<"#">>, Msg1},
-    {ok, _St} = websocket_info({deliver, <<"#">>, Msg0}, st()).
-% ?assertEqual(<<48,3,0,1,116,50,5,0,1,116,0,1>>, iolist_to_binary(IoData)).
+    {[{binary, _Pub1}, {binary, _Pub2}], _St} =
+        websocket_info({deliver, <<"#">>, Msg0}, st()).
 
 t_websocket_info_timeout_limiter(_) ->
     Ref = make_ref(),
@@ -550,9 +516,10 @@ t_parse_incoming_frame_error(_) ->
 
 t_handle_incomming_frame_error(_) ->
     FrameError = {frame_error, bad_qos},
-    Serialize = emqx_frame:serialize_fun(#{version => 5, max_size => 16#FFFF, strict_mode => false}),
-    {[{close, bad_qos}], _St} = ?ws_conn:handle_incoming(FrameError, st(#{serialize => Serialize})).
-% ?assertEqual(<<224,2,129,0>>, iolist_to_binary(IoData)).
+    Serialize = #{version => 5, max_size => 16#FFFF, strict_mode => false},
+    {[{binary, IoData}, {close, bad_qos}], _St} =
+        ?ws_conn:handle_incoming(FrameError, st(#{serialize => Serialize})),
+    ?assertEqual(<<224, 2, 129, 0>>, iolist_to_binary(IoData)).
 
 t_handle_outgoing(_) ->
     Packets = [

+ 1 - 0
changes/ce/fix-14268.en.md

@@ -0,0 +1 @@
+Fix rare race condition that might have caused WebSocket connection process to crash when CONNECT packet is not fully received until idle timeout passes.