Преглед изворни кода

Validate packet id if strict mode.

Feng Lee пре 6 година
родитељ
комит
d3107facf9
2 измењених фајлова са 61 додато и 27 уклоњено
  1. 26 22
      src/emqx_frame.erl
  2. 35 5
      test/emqx_frame_SUITE.erl

+ 26 - 22
src/emqx_frame.erl

@@ -89,12 +89,12 @@ parse(<<>>, {none, Options}) ->
 parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>,
       {none, Options = #{strict_mode := StrictMode}}) ->
     %% Validate header if strict mode.
+    StrictMode andalso validate_header(Type, Dup, QoS, Retain),
     Header = #mqtt_packet_header{type   = Type,
                                  dup    = bool(Dup),
                                  qos    = QoS,
                                  retain = bool(Retain)
                                 },
-    StrictMode andalso validate_header(Type, Dup, QoS, Retain),
     Header1 = case fixqos(Type, QoS) of
                   QoS      -> Header;
                   FixedQoS -> Header#mqtt_packet_header{qos = FixedQoS}
@@ -164,7 +164,8 @@ packet(Header, Variable, Payload) ->
 parse_packet(#mqtt_packet_header{type = ?CONNECT}, FrameBin, _Options) ->
     {ProtoName, Rest} = parse_utf8_string(FrameBin),
     <<BridgeTag:4, ProtoVer:4, Rest1/binary>> = Rest,
-    % Note: Crash when reserved flag doesn't equal to 0, there is no strict compliance with the MQTT5.0.
+    % Note: Crash when reserved flag doesn't equal to 0, there is no strict
+    % compliance with the MQTT5.0.
     <<UsernameFlag : 1,
       PasswordFlag : 1,
       WillRetain   : 1,
@@ -201,13 +202,15 @@ parse_packet(#mqtt_packet_header{type = ?CONNACK},
                          properties  = Properties
                         };
 
-parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, #{version := Ver}) ->
+parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin,
+             #{strict_mode := StrictMode, version := Ver}) ->
     {TopicName, Rest} = parse_utf8_string(Bin),
     {PacketId, Rest1} = case QoS of
                             ?QOS_0 -> {undefined, Rest};
                             _ -> parse_packet_id(Rest)
                         end,
-    (PacketId =/= undefined) andalso validate_packet_id(PacketId),
+    (PacketId =/= undefined) andalso
+      StrictMode andalso validate_packet_id(PacketId),
     {Properties, Payload} = parse_properties(Rest1, Ver),
     Publish = #mqtt_packet_publish{topic_name = TopicName,
                                    packet_id  = PacketId,
@@ -215,15 +218,15 @@ parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, #{version :=
                                   },
     {Publish, Payload};
 
-parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big>>, _Options)
-    when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
-    ok = validate_packet_id(PacketId),
+parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big>>, #{strict_mode := StrictMode})
+  when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
+    StrictMode andalso validate_packet_id(PacketId),
     #mqtt_packet_puback{packet_id = PacketId, reason_code = 0};
 
 parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big, ReasonCode, Rest/binary>>,
-             #{version := Ver = ?MQTT_PROTO_V5})
-    when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
-    ok = validate_packet_id(PacketId),
+             #{strict_mode := StrictMode, version := Ver = ?MQTT_PROTO_V5})
+  when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
+    StrictMode andalso validate_packet_id(PacketId),
     {Properties, <<>>} = parse_properties(Rest, Ver),
     #mqtt_packet_puback{packet_id   = PacketId,
                         reason_code = ReasonCode,
@@ -231,8 +234,8 @@ parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big, ReasonCode,
                        };
 
 parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <<PacketId:16/big, Rest/binary>>,
-             #{version := Ver}) ->
-    ok = validate_packet_id(PacketId),
+             #{strict_mode := StrictMode, version := Ver}) ->
+    StrictMode andalso validate_packet_id(PacketId),
     {Properties, Rest1} = parse_properties(Rest, Ver),
     TopicFilters = parse_topic_filters(subscribe, Rest1),
     ok = validate_subqos([QoS || {_, #{qos := QoS}} <- TopicFilters]),
@@ -242,8 +245,8 @@ parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <<PacketId:16/big, Rest/bin
                           };
 
 parse_packet(#mqtt_packet_header{type = ?SUBACK}, <<PacketId:16/big, Rest/binary>>,
-             #{version := Ver}) ->
-    ok = validate_packet_id(PacketId),
+             #{strict_mode := StrictMode, version := Ver}) ->
+    StrictMode andalso validate_packet_id(PacketId),
     {Properties, Rest1} = parse_properties(Rest, Ver),
     ReasonCodes = parse_reason_codes(Rest1),
     #mqtt_packet_suback{packet_id    = PacketId,
@@ -252,8 +255,8 @@ parse_packet(#mqtt_packet_header{type = ?SUBACK}, <<PacketId:16/big, Rest/binary
                        };
 
 parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <<PacketId:16/big, Rest/binary>>,
-             #{version := Ver}) ->
-    ok = validate_packet_id(PacketId),
+             #{strict_mode := StrictMode, version := Ver}) ->
+    StrictMode andalso validate_packet_id(PacketId),
     {Properties, Rest1} = parse_properties(Rest, Ver),
     TopicFilters = parse_topic_filters(unsubscribe, Rest1),
     #mqtt_packet_unsubscribe{packet_id     = PacketId,
@@ -261,13 +264,14 @@ parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <<PacketId:16/big, Rest/b
                              topic_filters = TopicFilters
                             };
 
-parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big>>, _Options) ->
-    ok = validate_packet_id(PacketId),
+parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big>>,
+             #{strict_mode := StrictMode}) ->
+    StrictMode andalso validate_packet_id(PacketId),
     #mqtt_packet_unsuback{packet_id = PacketId};
 
 parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big, Rest/binary>>,
-             #{version := Ver}) ->
-    ok = validate_packet_id(PacketId),
+             #{strict_mode := StrictMode, version := Ver}) ->
+    StrictMode andalso validate_packet_id(PacketId),
     {Properties, Rest1} = parse_properties(Rest, Ver),
     ReasonCodes = parse_reason_codes(Rest1),
     #mqtt_packet_unsuback{packet_id    = PacketId,
@@ -296,8 +300,7 @@ parse_will_message(Packet = #mqtt_packet_connect{will_flag = true,
                                 will_topic   = Topic,
                                 will_payload = Payload
                                }, Rest2};
-parse_will_message(Packet, Bin) ->
-    {Packet, Bin}.
+parse_will_message(Packet, Bin) -> {Packet, Bin}.
 
 -compile({inline, [parse_packet_id/1]}).
 parse_packet_id(<<PacketId:16/big, Rest/binary>>) ->
@@ -720,6 +723,7 @@ validate_header(?DISCONNECT, 0, 0, 0)   -> ok;
 validate_header(?AUTH, 0, 0, 0)         -> ok;
 validate_header(_Type, _Dup, _QoS, _Rt) -> error(bad_frame_header).
 
+-compile({inline, [validate_packet_id/1]}).
 validate_packet_id(0) -> error(bad_packet_id);
 validate_packet_id(_) -> ok.
 

+ 35 - 5
test/emqx_frame_SUITE.erl

@@ -40,7 +40,8 @@ all() ->
      {group, unsuback},
      {group, ping},
      {group, disconnect},
-     {group, auth}].
+     {group, auth}
+    ].
 
 groups() ->
     [{parse, [parallel],
@@ -333,7 +334,10 @@ t_serialize_parse_qos1_publish(_) ->
                           payload  = <<"haha">>},
     ?assertEqual(Bin, serialize_to_binary(Packet)),
     ?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})),
-    ?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>))).
+    %% strict_mode = true
+    ?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>))),
+    %% strict_mode = false
+    _ = parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>), #{strict_mode => false}).
 
 t_serialize_parse_qos2_publish(_) ->
     Packet = ?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 1, <<>>),
@@ -341,7 +345,10 @@ t_serialize_parse_qos2_publish(_) ->
     ?assertEqual(Packet, parse_serialize(Packet)),
     ?assertEqual(Bin, serialize_to_binary(Packet)),
     ?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})),
-    ?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>))).
+    %% strict_mode = true
+    ?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>))),
+    %% strict_mode = false
+    _ = parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>), #{strict_mode => false}).
 
 t_serialize_parse_publish_v5(_) ->
     Props = #{'Payload-Format-Indicator' => 1,
@@ -358,7 +365,10 @@ t_serialize_parse_puback(_) ->
     Packet = ?PUBACK_PACKET(1),
     ?assertEqual(<<64,2,0,1>>, serialize_to_binary(Packet)),
     ?assertEqual(Packet, parse_serialize(Packet)),
-    ?catch_error(bad_packet_id, parse_serialize(?PUBACK_PACKET(0))).
+    %% strict_mode = true
+    ?catch_error(bad_packet_id, parse_serialize(?PUBACK_PACKET(0))),
+    %% strict_mode = false
+    ?PUBACK_PACKET(0) = parse_serialize(?PUBACK_PACKET(0), #{strict_mode => false}).
 
 t_serialize_parse_puback_v3_4(_) ->
     Bin = <<64,2,0,1>>,
@@ -376,7 +386,10 @@ t_serialize_parse_pubrec(_) ->
     Packet = ?PUBREC_PACKET(1),
     ?assertEqual(<<5:4,0:4,2,0,1>>, serialize_to_binary(Packet)),
     ?assertEqual(Packet, parse_serialize(Packet)),
-    ?catch_error(bad_packet_id, parse_serialize(?PUBREC_PACKET(0))).
+    %% strict_mode = true
+    ?catch_error(bad_packet_id, parse_serialize(?PUBREC_PACKET(0))),
+    %% strict_mode = false
+    ?PUBREC_PACKET(0) = parse_serialize(?PUBREC_PACKET(0), #{strict_mode => false}).
 
 t_serialize_parse_pubrec_v5(_) ->
     Packet = ?PUBREC_PACKET(16, ?RC_SUCCESS, #{'Reason-String' => <<"success">>}),
@@ -391,6 +404,9 @@ t_serialize_parse_pubrel(_) ->
     Bin0 = <<6:4,0:4,2,0,1>>,
     ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
     ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
+    %% strict_mode = false
+    ?PUBREL_PACKET(0) = parse_serialize(?PUBREL_PACKET(0), #{strict_mode => false}),
+    %% strict_mode = true
     ?catch_error(bad_packet_id, parse_serialize(?PUBREL_PACKET(0))).
 
 t_serialize_parse_pubrel_v5(_) ->
@@ -402,6 +418,9 @@ t_serialize_parse_pubcomp(_) ->
     Bin = serialize_to_binary(Packet),
     ?assertEqual(<<7:4,0:4,2,0,1>>, Bin),
     ?assertEqual(Packet, parse_serialize(Packet)),
+    %% strict_mode = false
+    ?PUBCOMP_PACKET(0) = parse_serialize(?PUBCOMP_PACKET(0), #{strict_mode => false}),
+    %% strict_mode = true
     ?catch_error(bad_packet_id, parse_serialize(?PUBCOMP_PACKET(0))).
 
 t_serialize_parse_pubcomp_v5(_) ->
@@ -419,7 +438,12 @@ t_serialize_parse_subscribe(_) ->
     %% SUBSCRIBE with bad qos 0
     Bin0 = <<?SUBSCRIBE:4,0:4,11,0,2,0,6,84,111,112,105,99,65,2>>,
     ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
+    %% strict_mode = false
+    _ = parse_to_packet(Bin0, #{strict_mode => false}),
     ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
+    %% strict_mode = false
+    _ = parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters), #{strict_mode => false}),
+    %% strict_mode = true
     ?catch_error(bad_packet_id, parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters))),
     ?catch_error(bad_subqos, parse_serialize(?SUBSCRIBE_PACKET(1, [{<<"t">>, #{qos => 3}}]))).
 
@@ -432,6 +456,9 @@ t_serialize_parse_subscribe_v5(_) ->
 t_serialize_parse_suback(_) ->
     Packet = ?SUBACK_PACKET(10, [?QOS_0, ?QOS_1, 128]),
     ?assertEqual(Packet, parse_serialize(Packet)),
+    %% strict_mode = false
+    _ = parse_serialize(?SUBACK_PACKET(0, [?QOS_0]), #{strict_mode => false}),
+    %% strict_mode = true
     ?catch_error(bad_packet_id, parse_serialize(?SUBACK_PACKET(0, [?QOS_0]))).
 
 t_serialize_parse_suback_v5(_) ->
@@ -451,6 +478,9 @@ t_serialize_parse_unsubscribe(_) ->
     Bin0 = <<?UNSUBSCRIBE:4,0:4,10,0,2,0,6,84,111,112,105,99,65>>,
     ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
     ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
+    %% strict_mode = false
+    _ = parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]), #{strict_mode => false}),
+    %% strict_mode = true
     ?catch_error(bad_packet_id, parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]))).
 
 t_serialize_parse_unsubscribe_v5(_) ->