Browse Source

fix: try throw proto_ver and proto_name when parsing CONNECT packet

JimMoen 1 year ago
parent
commit
c313aa89f0
2 changed files with 64 additions and 30 deletions
  1. 9 7
      apps/emqx/src/emqx_channel.erl
  2. 55 23
      apps/emqx/src/emqx_frame.erl

+ 9 - 7
apps/emqx/src/emqx_channel.erl

@@ -145,7 +145,9 @@
 -type replies() :: emqx_types:packet() | reply() | [reply()].
 
 -define(IS_MQTT_V5, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}).
-
+-define(IS_CONNECTED_OR_REAUTHENTICATING(ConnState),
+    ((ConnState == connected) orelse (ConnState == reauthenticating))
+).
 -define(IS_COMMON_SESSION_TIMER(N),
     ((N == retry_delivery) orelse (N == expire_awaiting_rel))
 ).
@@ -333,7 +335,7 @@ take_conn_info_fields(Fields, ClientInfo, ConnInfo) ->
     | {shutdown, Reason :: term(), channel()}
     | {shutdown, Reason :: term(), replies(), channel()}.
 handle_in(?CONNECT_PACKET(), Channel = #channel{conn_state = ConnState}) when
-    ConnState =:= connected orelse ConnState =:= reauthenticating
+    ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
 ->
     handle_out(disconnect, ?RC_PROTOCOL_ERROR, Channel);
 handle_in(?CONNECT_PACKET(), Channel = #channel{conn_state = connecting}) ->
@@ -1016,11 +1018,11 @@ handle_frame_error(Reason, Channel = #channel{conn_state = connecting}) ->
 handle_frame_error(
     #{cause := frame_too_large}, Channel = #channel{conn_state = ConnState}
 ) when
-    ConnState =:= connected orelse ConnState =:= reauthenticating
+    ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
 ->
     handle_out(disconnect, {?RC_PACKET_TOO_LARGE, frame_too_large}, Channel);
 handle_frame_error(Reason, Channel = #channel{conn_state = ConnState}) when
-    ConnState =:= connected orelse ConnState =:= reauthenticating
+    ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
 ->
     handle_out(disconnect, {?RC_MALFORMED_PACKET, Reason}, Channel);
 handle_frame_error(Reason, Channel = #channel{conn_state = disconnected}) ->
@@ -1295,7 +1297,7 @@ handle_info(
             session = Session
         }
 ) when
-    ConnState =:= connected orelse ConnState =:= reauthenticating
+    ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
 ->
     {Intent, Session1} = session_disconnect(ClientInfo, ConnInfo, Session),
     Channel1 = ensure_disconnected(Reason, maybe_publish_will_msg(sock_closed, Channel)),
@@ -2675,13 +2677,13 @@ disconnect_and_shutdown(
         ?IS_MQTT_V5 =
         #channel{conn_state = ConnState}
 ) when
-    ConnState =:= connected orelse ConnState =:= reauthenticating
+    ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
 ->
     NChannel = ensure_disconnected(Reason, Channel),
     shutdown(Reason, Reply, ?DISCONNECT_PACKET(reason_code(Reason)), NChannel);
 %% mqtt v3/v4 connected sessions
 disconnect_and_shutdown(Reason, Reply, Channel = #channel{conn_state = ConnState}) when
-    ConnState =:= connected orelse ConnState =:= reauthenticating
+    ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
 ->
     NChannel = ensure_disconnected(Reason, Channel),
     shutdown(Reason, Reply, NChannel);

+ 55 - 23
apps/emqx/src/emqx_frame.erl

@@ -267,27 +267,36 @@ packet(Header, Variable, Payload) ->
     #mqtt_packet{header = Header, variable = Variable, payload = Payload}.
 
 parse_connect(FrameBin, StrictMode) ->
-    {ProtoName, Rest} = parse_utf8_string_with_cause(FrameBin, StrictMode, invalid_proto_name),
-    case ProtoName of
-        <<"MQTT">> ->
-            ok;
-        <<"MQIsdp">> ->
-            ok;
-        _ ->
-            %% from spec: the server MAY send disconnect with reason code 0x84
-            %% we chose to close socket because the client is likely not talking MQTT anyway
-            ?PARSE_ERR(#{
-                cause => invalid_proto_name,
-                expected => <<"'MQTT' or 'MQIsdp'">>,
-                received => ProtoName
-            })
-    end,
-    parse_connect2(ProtoName, Rest, StrictMode).
+    {ProtoName, Rest0} = parse_utf8_string_with_cause(FrameBin, StrictMode, invalid_proto_name),
+    %% No need to parse and check proto_ver if proto_name is invalid, check it first
+    %% And the matching check of `proto_name` and `proto_ver` fields will be done in `emqx_packet:check_proto_ver/2`
+    _ = validate_proto_name(ProtoName),
+    {IsBridge, ProtoVer, Rest2} = parse_connect_proto_ver(Rest0),
+    Meta = #{proto_name => ProtoName, proto_ver => ProtoVer},
+    try
+        do_parse_connect(ProtoName, IsBridge, ProtoVer, Rest2, StrictMode)
+    catch
+        throw:{?FRAME_PARSE_ERROR, ReasonM} when is_map(ReasonM) ->
+            ?PARSE_ERR(maps:merge(ReasonM, Meta));
+        throw:{?FRAME_PARSE_ERROR, Reason} ->
+            ?PARSE_ERR(Meta#{cause => Reason})
+    end.
 
-parse_connect2(
+do_parse_connect(
     ProtoName,
-    <<BridgeTag:4, ProtoVer:4, UsernameFlagB:1, PasswordFlagB:1, WillRetainB:1, WillQoS:2,
-        WillFlagB:1, CleanStart:1, Reserved:1, KeepAlive:16/big, Rest2/binary>>,
+    IsBridge,
+    ProtoVer,
+    <<
+        UsernameFlagB:1,
+        PasswordFlagB:1,
+        WillRetainB:1,
+        WillQoS:2,
+        WillFlagB:1,
+        CleanStart:1,
+        Reserved:1,
+        KeepAlive:16/big,
+        Rest/binary
+    >>,
     StrictMode
 ) ->
     _ = validate_connect_reserved(Reserved),
@@ -302,14 +311,14 @@ parse_connect2(
         UsernameFlag = bool(UsernameFlagB),
         PasswordFlag = bool(PasswordFlagB)
     ),
-    {Properties, Rest3} = parse_properties(Rest2, ProtoVer, StrictMode),
+    {Properties, Rest3} = parse_properties(Rest, ProtoVer, StrictMode),
     {ClientId, Rest4} = parse_utf8_string_with_cause(Rest3, StrictMode, invalid_clientid),
     ConnPacket = #mqtt_packet_connect{
         proto_name = ProtoName,
         proto_ver = ProtoVer,
         %% For bridge mode, non-standard implementation
         %% Invented by mosquitto, named 'try_private': https://mosquitto.org/man/mosquitto-conf-5.html
-        is_bridge = (BridgeTag =:= 8),
+        is_bridge = IsBridge,
         clean_start = bool(CleanStart),
         will_flag = WillFlag,
         will_qos = WillQoS,
@@ -342,8 +351,8 @@ parse_connect2(
                 unexpected_trailing_bytes => size(Rest7)
             })
     end;
-parse_connect2(_ProtoName, Bin, _StrictMode) ->
-    %% sent less than 32 bytes
+do_parse_connect(_ProtoName, _IsBridge, _ProtoVer, Bin, _StrictMode) ->
+    %% sent less than 24 bytes
     ?PARSE_ERR(#{cause => malformed_connect, header_bytes => Bin}).
 
 parse_packet(
@@ -515,6 +524,12 @@ parse_packet_id(<<PacketId:16/big, Rest/binary>>) ->
 parse_packet_id(_) ->
     ?PARSE_ERR(invalid_packet_id).
 
+parse_connect_proto_ver(<<BridgeTag:4, ProtoVer:4, Rest/binary>>) ->
+    {_IsBridge = (BridgeTag =:= 8), ProtoVer, Rest};
+parse_connect_proto_ver(Bin) ->
+    %% sent less than 1 bytes or empty
+    ?PARSE_ERR(#{cause => malformed_connect, header_bytes => Bin}).
+
 parse_properties(Bin, Ver, _StrictMode) when Ver =/= ?MQTT_PROTO_V5 ->
     {#{}, Bin};
 %% TODO: version mess?
@@ -1129,10 +1144,25 @@ validate_subqos([3 | _]) -> ?PARSE_ERR(bad_subqos);
 validate_subqos([_ | T]) -> validate_subqos(T);
 validate_subqos([]) -> ok.
 
+%% from spec: the server MAY send disconnect with reason code 0x84
+%% we chose to close socket because the client is likely not talking MQTT anyway
+validate_proto_name(<<"MQTT">>) ->
+    ok;
+validate_proto_name(<<"MQIsdp">>) ->
+    ok;
+validate_proto_name(ProtoName) ->
+    ?PARSE_ERR(#{
+        cause => invalid_proto_name,
+        expected => <<"'MQTT' or 'MQIsdp'">>,
+        received => ProtoName
+    }).
+
 %% MQTT-v3.1.1-[MQTT-3.1.2-3], MQTT-v5.0-[MQTT-3.1.2-3]
+-compile({inline, [validate_connect_reserved/1]}).
 validate_connect_reserved(0) -> ok;
 validate_connect_reserved(1) -> ?PARSE_ERR(reserved_connect_flag).
 
+-compile({inline, [validate_connect_will/3]}).
 %% MQTT-v3.1.1-[MQTT-3.1.2-13], MQTT-v5.0-[MQTT-3.1.2-11]
 validate_connect_will(false, _, WillQoS) when WillQoS > 0 -> ?PARSE_ERR(invalid_will_qos);
 %% MQTT-v3.1.1-[MQTT-3.1.2-14], MQTT-v5.0-[MQTT-3.1.2-12]
@@ -1141,6 +1171,7 @@ validate_connect_will(true, _, WillQoS) when WillQoS > 2 -> ?PARSE_ERR(invalid_w
 validate_connect_will(false, WillRetain, _) when WillRetain -> ?PARSE_ERR(invalid_will_retain);
 validate_connect_will(_, _, _) -> ok.
 
+-compile({inline, [validate_connect_password_flag/4]}).
 %% MQTT-v3.1
 %% Username flag and password flag are not strongly related
 %% https://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html#connect
@@ -1155,6 +1186,7 @@ validate_connect_password_flag(true, ?MQTT_PROTO_V5, _, _) ->
 validate_connect_password_flag(_, _, _, _) ->
     ok.
 
+-compile({inline, [bool/1]}).
 bool(0) -> false;
 bool(1) -> true.