Kaynağa Gözat

fix: enrich parse_state and connection serialize opts

JimMoen 1 yıl önce
ebeveyn
işleme
37a89d0094

+ 1 - 0
apps/emqx/include/emqx_mqtt.hrl

@@ -683,6 +683,7 @@ end).
 
 -define(FRAME_PARSE_ERROR, frame_parse_error).
 -define(FRAME_SERIALIZE_ERROR, frame_serialize_error).
+
 -define(THROW_FRAME_ERROR(Reason), erlang:throw({?FRAME_PARSE_ERROR, Reason})).
 -define(THROW_SERIALIZE_ERROR(Reason), erlang:throw({?FRAME_SERIALIZE_ERROR, Reason})).
 

+ 56 - 14
apps/emqx/src/emqx_channel.erl

@@ -37,6 +37,7 @@
     get_mqtt_conf/2,
     get_mqtt_conf/3,
     set_conn_state/2,
+    set_conninfo_proto_ver/2,
     stats/1,
     caps/1
 ]).
@@ -219,6 +220,9 @@ info(impl, #channel{session = Session}) ->
 set_conn_state(ConnState, Channel) ->
     Channel#channel{conn_state = ConnState}.
 
+set_conninfo_proto_ver({none, #{version := ProtoVer}}, Channel = #channel{conninfo = ConnInfo}) ->
+    Channel#channel{conninfo = ConnInfo#{proto_ver => ProtoVer}}.
+
 -spec stats(channel()) -> emqx_types:stats().
 stats(#channel{session = undefined}) ->
     emqx_pd:get_counters(?CHANNEL_METRICS);
@@ -1002,30 +1006,61 @@ not_nacked({deliver, _Topic, Msg}) ->
 %% Handle Frame Error
 %%--------------------------------------------------------------------
 
+handle_frame_error(
+    Reason = #{cause := frame_too_large},
+    Channel = #channel{conn_state = ConnState, conninfo = ConnInfo}
+) when
+    ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
+->
+    ShutdownCount = shutdown_count(frame_error, Reason),
+    case proto_ver(Reason, ConnInfo) of
+        ?MQTT_PROTO_V5 ->
+            handle_out(disconnect, {?RC_PACKET_TOO_LARGE, frame_too_large}, Channel);
+        _ ->
+            shutdown(ShutdownCount, Channel)
+    end;
+%% Only send CONNACK with reason code `frame_too_large` for MQTT-v5.0 when connecting,
+%% otherwise DONOT send any CONNACK or DISCONNECT packet.
 handle_frame_error(
     Reason,
-    Channel = #channel{conn_state = idle}
-) ->
-    shutdown(shutdown_count(frame_error, Reason), Channel);
+    Channel = #channel{conn_state = ConnState, conninfo = ConnInfo}
+) when
+    is_map(Reason) andalso
+        (ConnState == idle orelse ConnState == connecting)
+->
+    ShutdownCount = shutdown_count(frame_error, Reason),
+    ProtoVer = proto_ver(Reason, ConnInfo),
+    NChannel = Channel#channel{conninfo = ConnInfo#{proto_ver => ProtoVer}},
+    case ProtoVer of
+        ?MQTT_PROTO_V5 ->
+            shutdown(ShutdownCount, ?CONNACK_PACKET(?RC_PACKET_TOO_LARGE), NChannel);
+        _ ->
+            shutdown(ShutdownCount, NChannel)
+    end;
 handle_frame_error(
-    #{cause := frame_too_large} = R, Channel = #channel{conn_state = connecting}
+    Reason,
+    Channel = #channel{conn_state = connecting}
 ) ->
     shutdown(
-        shutdown_count(frame_error, R), ?CONNACK_PACKET(?RC_PACKET_TOO_LARGE), Channel
+        shutdown_count(frame_error, Reason),
+        ?CONNACK_PACKET(?RC_MALFORMED_PACKET),
+        Channel
     );
-handle_frame_error(Reason, Channel = #channel{conn_state = connecting}) ->
-    shutdown(shutdown_count(frame_error, Reason), ?CONNACK_PACKET(?RC_MALFORMED_PACKET), Channel);
 handle_frame_error(
-    #{cause := frame_too_large}, Channel = #channel{conn_state = ConnState}
+    Reason,
+    Channel = #channel{conn_state = ConnState}
 ) when
     ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
 ->
-    handle_out(disconnect, {?RC_PACKET_TOO_LARGE, frame_too_large}, Channel);
-handle_frame_error(Reason, Channel = #channel{conn_state = ConnState}) when
-    ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
-->
-    handle_out(disconnect, {?RC_MALFORMED_PACKET, Reason}, Channel);
-handle_frame_error(Reason, Channel = #channel{conn_state = disconnected}) ->
+    handle_out(
+        disconnect,
+        {?RC_MALFORMED_PACKET, Reason},
+        Channel
+    );
+handle_frame_error(
+    Reason,
+    Channel = #channel{conn_state = disconnected}
+) ->
     ?SLOG(error, #{msg => "malformed_mqtt_message", reason => Reason}),
     {ok, Channel}.
 
@@ -2726,6 +2761,13 @@ is_durable_session(#channel{session = Session}) ->
             false
     end.
 
+proto_ver(#{proto_ver := ProtoVer}, _ConnInfo) ->
+    ProtoVer;
+proto_ver(_Reason, #{proto_ver := ProtoVer}) ->
+    ProtoVer;
+proto_ver(_, _) ->
+    ?MQTT_PROTO_V4.
+
 %%--------------------------------------------------------------------
 %% For CT tests
 %%--------------------------------------------------------------------

+ 8 - 1
apps/emqx/src/emqx_connection.erl

@@ -782,7 +782,8 @@ parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) ->
                 input_bytes => Data,
                 parsed_packets => Packets
             }),
-            {[{frame_error, Reason} | Packets], State};
+            NState = enrich_state(Reason, State),
+            {[{frame_error, Reason} | Packets], NState};
         error:Reason:Stacktrace ->
             ?LOG(error, #{
                 at_state => emqx_frame:describe_state(ParseState),
@@ -1204,6 +1205,12 @@ inc_counter(Key, Inc) ->
     _ = emqx_pd:inc_counter(Key, Inc),
     ok.
 
+enrich_state(#{parse_state := NParseState}, State) ->
+    Serialize = emqx_frame:serialize_opts(NParseState),
+    State#state{parse_state = NParseState, serialize = Serialize};
+enrich_state(_, State) ->
+    State.
+
 set_tcp_keepalive({quic, _Listener}) ->
     ok;
 set_tcp_keepalive({Type, Id}) ->

+ 21 - 6
apps/emqx/src/emqx_frame.erl

@@ -266,20 +266,33 @@ packet(Header, Variable) ->
 packet(Header, Variable, Payload) ->
     #mqtt_packet{header = Header, variable = Variable, payload = Payload}.
 
-parse_connect(FrameBin, StrictMode) ->
+parse_connect(FrameBin, Options = #{strict_mode := StrictMode}) ->
     {ProtoName, Rest0} = parse_utf8_string_with_cause(FrameBin, StrictMode, invalid_proto_name),
     %% No need to parse and check proto_ver if proto_name is invalid, check it first
     %% And the matching check of `proto_name` and `proto_ver` fields will be done in `emqx_packet:check_proto_ver/2`
     _ = validate_proto_name(ProtoName),
     {IsBridge, ProtoVer, Rest2} = parse_connect_proto_ver(Rest0),
-    Meta = #{proto_name => ProtoName, proto_ver => ProtoVer},
+    NOptions = Options#{version => ProtoVer},
     try
         do_parse_connect(ProtoName, IsBridge, ProtoVer, Rest2, StrictMode)
     catch
         throw:{?FRAME_PARSE_ERROR, ReasonM} when is_map(ReasonM) ->
-            ?PARSE_ERR(maps:merge(ReasonM, Meta));
+            ?PARSE_ERR(
+                ReasonM#{
+                    proto_ver => ProtoVer,
+                    proto_name => ProtoName,
+                    parse_state => ?NONE(NOptions)
+                }
+            );
         throw:{?FRAME_PARSE_ERROR, Reason} ->
-            ?PARSE_ERR(Meta#{cause => Reason})
+            ?PARSE_ERR(
+                #{
+                    cause => Reason,
+                    proto_ver => ProtoVer,
+                    proto_name => ProtoName,
+                    parse_state => ?NONE(NOptions)
+                }
+            )
     end.
 
 do_parse_connect(
@@ -358,9 +371,9 @@ do_parse_connect(_ProtoName, _IsBridge, _ProtoVer, Bin, _StrictMode) ->
 parse_packet(
     #mqtt_packet_header{type = ?CONNECT},
     FrameBin,
-    #{strict_mode := StrictMode}
+    Options
 ) ->
-    parse_connect(FrameBin, StrictMode);
+    parse_connect(FrameBin, Options);
 parse_packet(
     #mqtt_packet_header{type = ?CONNACK},
     <<AckFlags:8, ReasonCode:8, Rest/binary>>,
@@ -753,6 +766,8 @@ serialize_fun(#{version := Ver, max_size := MaxSize}) ->
 serialize_opts() ->
     ?DEFAULT_OPTIONS.
 
+serialize_opts(?NONE(Options)) ->
+    maps:merge(?DEFAULT_OPTIONS, Options);
 serialize_opts(#mqtt_packet_connect{proto_ver = ProtoVer, properties = ConnProps}) ->
     MaxSize = get_property('Maximum-Packet-Size', ConnProps, ?MAX_PACKET_SIZE),
     #{version => ProtoVer, max_size => MaxSize}.

+ 10 - 1
apps/emqx/src/emqx_ws_connection.erl

@@ -436,6 +436,7 @@ websocket_handle({Frame, _}, State) ->
     %% TODO: should not close the ws connection
     ?LOG(error, #{msg => "unexpected_frame", frame => Frame}),
     shutdown(unexpected_ws_frame, State).
+
 websocket_info({call, From, Req}, State) ->
     handle_call(From, Req, State);
 websocket_info({cast, rate_limit}, State) ->
@@ -725,7 +726,8 @@ parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) ->
                 input_bytes => Data
             }),
             FrameError = {frame_error, Reason},
-            {[{incoming, FrameError} | Packets], State};
+            NState = enrich_state(Reason, State),
+            {[{incoming, FrameError} | Packets], NState};
         error:Reason:Stacktrace ->
             ?LOG(error, #{
                 at_state => emqx_frame:describe_state(ParseState),
@@ -1059,6 +1061,13 @@ check_max_connection(Type, Listener) ->
                     {denny, Reason}
             end
     end.
+
+enrich_state(#{parse_state := NParseState}, State) ->
+    Serialize = emqx_frame:serialize_opts(NParseState),
+    State#state{parse_state = NParseState, serialize = Serialize};
+enrich_state(_, State) ->
+    State.
+
 %%--------------------------------------------------------------------
 %% For CT tests
 %%--------------------------------------------------------------------