Kaynağa Gözat

Improve the emqx_message module and add more test cases

- Add 'emqx_message:clean_dup/1' function
- Clean dup flag before publishing a message
- Add more test cases for emqx_message module
Feng Lee 6 yıl önce
ebeveyn
işleme
0f9f1258b6

+ 8 - 7
src/emqx_broker.erl

@@ -191,20 +191,21 @@ do_unsubscribe(undefined, Topic, SubPid, SubOpts) ->
 do_unsubscribe(Group, Topic, SubPid, _SubOpts) ->
     emqx_shared_sub:unsubscribe(Group, Topic, SubPid).
 
-%%------------------------------------------------------------------------------
+%%--------------------------------------------------------------------
 %% Publish
-%%------------------------------------------------------------------------------
+%%--------------------------------------------------------------------
 
 -spec(publish(emqx_types:message()) -> emqx_types:publish_result()).
 publish(Msg) when is_record(Msg, message) ->
     _ = emqx_tracer:trace(publish, Msg),
-    Headers = Msg#message.headers,
-    case emqx_hooks:run_fold('message.publish', [], Msg#message{headers = Headers#{allow_publish => true}}) of
+    Msg1 = emqx_message:set_header(allow_publish, true,
+                                   emqx_message:clean_dup(Msg)),
+    case emqx_hooks:run_fold('message.publish', [], Msg1) of
         #message{headers = #{allow_publish := false}} ->
-            ?LOG(notice, "Publishing interrupted: ~s", [emqx_message:format(Msg)]),
+            ?LOG(notice, "Stop publishing: ~s", [emqx_message:format(Msg1)]),
             [];
-        #message{topic = Topic} = Msg1 ->
-            route(aggre(emqx_router:match_routes(Topic)), delivery(Msg1))
+        #message{topic = Topic} = Msg2 ->
+            route(aggre(emqx_router:match_routes(Topic)), delivery(Msg2))
     end.
 
 %% Called internally

+ 2 - 2
src/emqx_channel.erl

@@ -417,8 +417,8 @@ do_publish(PacketId, Msg = #message{qos = ?QOS_2},
     end.
 
 -compile({inline, [puback_reason_code/1]}).
-puback_reason_code([]) -> ?RC_NO_MATCHING_SUBSCRIBERS;
-puback_reason_code(_)  -> ?RC_SUCCESS.
+puback_reason_code([])    -> ?RC_NO_MATCHING_SUBSCRIBERS;
+puback_reason_code([_|_]) -> ?RC_SUCCESS.
 
 %%--------------------------------------------------------------------
 %% Process Subscribe

+ 56 - 42
src/emqx_message.erl

@@ -16,6 +16,8 @@
 
 -module(emqx_message).
 
+-compile(inline).
+
 -include("emqx.hrl").
 -include("emqx_mqtt.hrl").
 -include("types.hrl").
@@ -36,7 +38,8 @@
         ]).
 
 %% Flags
--export([ get_flag/2
+-export([ clean_dup/1
+        , get_flag/2
         , get_flag/3
         , get_flags/1
         , set_flag/2
@@ -71,25 +74,24 @@
 make(Topic, Payload) ->
     make(undefined, Topic, Payload).
 
--spec(make(atom() | emqx_types:clientid(),
+-spec(make(emqx_types:clientid(),
            emqx_topic:topic(),
            emqx_types:payload()) -> emqx_types:message()).
 make(From, Topic, Payload) ->
     make(From, ?QOS_0, Topic, Payload).
 
--spec(make(atom() | emqx_types:clientid(),
+-spec(make(emqx_types:clientid(),
            emqx_types:qos(),
            emqx_topic:topic(),
            emqx_types:payload()) -> emqx_types:message()).
 make(From, QoS, Topic, Payload) when ?QOS_0 =< QoS, QoS =< ?QOS_2 ->
+    Now = erlang:system_time(millisecond),
     #message{id = emqx_guid:gen(),
              qos = QoS,
              from = From,
-             flags = #{dup => false},
-             headers = #{},
              topic = Topic,
              payload = Payload,
-             timestamp = erlang:system_time(millisecond)
+             timestamp = Now
             }.
 
 -spec(id(emqx_types:message()) -> maybe(binary())).
@@ -110,6 +112,11 @@ payload(#message{payload = Payload}) -> Payload.
 -spec(timestamp(emqx_types:message()) -> integer()).
 timestamp(#message{timestamp = TS}) -> TS.
 
+-spec(clean_dup(emqx_types:message()) -> emqx_types:message()).
+clean_dup(Msg = #message{flags = Flags = #{dup := true}}) ->
+    Msg#message{flags = Flags#{dup => false}};
+clean_dup(Msg) -> Msg.
+
 -spec(set_flags(map(), emqx_types:message()) -> emqx_types:message()).
 set_flags(Flags, Msg = #message{flags = undefined}) when is_map(Flags) ->
     Msg#message{flags = Flags};
@@ -117,8 +124,13 @@ set_flags(New, Msg = #message{flags = Old}) when is_map(New) ->
     Msg#message{flags = maps:merge(Old, New)}.
 
 -spec(get_flag(flag(), emqx_types:message()) -> boolean()).
+get_flag(_Flag, #message{flags = undefined}) ->
+    false;
 get_flag(Flag, Msg) ->
     get_flag(Flag, Msg, false).
+
+get_flag(_Flag, #message{flags = undefined}, Default) ->
+    Default;
 get_flag(Flag, #message{flags = Flags}, Default) ->
     maps:get(Flag, Flags, Default).
 
@@ -141,25 +153,27 @@ set_flag(Flag, Val, Msg = #message{flags = Flags}) when is_atom(Flag) ->
 -spec(unset_flag(flag(), emqx_types:message()) -> emqx_types:message()).
 unset_flag(Flag, Msg = #message{flags = Flags}) ->
     case maps:is_key(Flag, Flags) of
-        true ->
-            Msg#message{flags = maps:remove(Flag, Flags)};
+        true  -> Msg#message{flags = maps:remove(Flag, Flags)};
         false -> Msg
     end.
 
--spec(set_headers(undefined | map(), emqx_types:message()) -> emqx_types:message()).
+-spec(set_headers(map(), emqx_types:message()) -> emqx_types:message()).
 set_headers(Headers, Msg = #message{headers = undefined}) when is_map(Headers) ->
     Msg#message{headers = Headers};
 set_headers(New, Msg = #message{headers = Old}) when is_map(New) ->
     Msg#message{headers = maps:merge(Old, New)}.
 
--spec(get_headers(emqx_types:message()) -> map()).
-get_headers(Msg) ->
-    Msg#message.headers.
+-spec(get_headers(emqx_types:message()) -> maybe(map())).
+get_headers(Msg) -> Msg#message.headers.
 
 -spec(get_header(term(), emqx_types:message()) -> term()).
+get_header(_Hdr, #message{headers = undefined}) ->
+    undefined;
 get_header(Hdr, Msg) ->
     get_header(Hdr, Msg, undefined).
--spec(get_header(term(), emqx_types:message(), Default :: term()) -> term()).
+-spec(get_header(term(), emqx_types:message(), term()) -> term()).
+get_header(_Hdr, #message{headers = undefined}, Default) ->
+    Default;
 get_header(Hdr, #message{headers = Headers}, Default) ->
     maps:get(Hdr, Headers, Default).
 
@@ -170,10 +184,11 @@ set_header(Hdr, Val, Msg = #message{headers = Headers}) ->
     Msg#message{headers = maps:put(Hdr, Val, Headers)}.
 
 -spec(remove_header(term(), emqx_types:message()) -> emqx_types:message()).
+remove_header(_Hdr, Msg = #message{headers = undefined}) ->
+    Msg;
 remove_header(Hdr, Msg = #message{headers = Headers}) ->
     case maps:is_key(Hdr, Headers) of
-        true ->
-            Msg#message{headers = maps:remove(Hdr, Headers)};
+        true  -> Msg#message{headers = maps:remove(Hdr, Headers)};
         false -> Msg
     end.
 
@@ -181,15 +196,15 @@ remove_header(Hdr, Msg = #message{headers = Headers}) ->
 is_expired(#message{headers = #{'Message-Expiry-Interval' := Interval},
                     timestamp = CreatedAt}) ->
     elapsed(CreatedAt) > timer:seconds(Interval);
-is_expired(_Msg) ->
-    false.
+is_expired(_Msg) -> false.
 
 -spec(update_expiry(emqx_types:message()) -> emqx_types:message()).
 update_expiry(Msg = #message{headers = #{'Message-Expiry-Interval' := Interval},
                              timestamp = CreatedAt}) ->
     case elapsed(CreatedAt) of
         Elapsed when Elapsed > 0 ->
-            set_header('Message-Expiry-Interval', max(1, Interval - (Elapsed div 1000)), Msg);
+            Interval1 = max(1, Interval - (Elapsed div 1000)),
+            set_header('Message-Expiry-Interval', Interval1, Msg);
         _ -> Msg
     end;
 update_expiry(Msg) -> Msg.
@@ -197,30 +212,29 @@ update_expiry(Msg) -> Msg.
 %% @doc Message to PUBLISH Packet.
 -spec(to_packet(emqx_types:packet_id(), emqx_types:message())
       -> emqx_types:packet()).
-to_packet(PacketId, #message{qos = QoS, flags = Flags, headers = Headers,
-                                topic = Topic, payload = Payload}) ->
-    Flags1 = if Flags =:= undefined -> #{};
-                true -> Flags
-             end,
-    Dup = maps:get(dup, Flags1, false),
-    Retain = maps:get(retain, Flags1, false),
-    Publish = #mqtt_packet_publish{topic_name = Topic,
-                                   packet_id  = PacketId,
-                                   properties = publish_props(Headers)},
-    #mqtt_packet{header = #mqtt_packet_header{type   = ?PUBLISH,
-                                              dup    = Dup,
-                                              qos    = QoS,
-                                              retain = Retain},
-                 variable = Publish, payload = Payload}.
-
-publish_props(Headers) ->
-    maps:with(['Payload-Format-Indicator',
-               'Response-Topic',
-               'Correlation-Data',
-               'User-Property',
-               'Subscription-Identifier',
-               'Content-Type',
-               'Message-Expiry-Interval'], Headers).
+to_packet(PacketId, Msg = #message{qos = QoS, headers = Headers,
+                                   topic = Topic, payload = Payload}) ->
+    #mqtt_packet{header   = #mqtt_packet_header{type   = ?PUBLISH,
+                                                dup    = get_flag(dup, Msg),
+                                                qos    = QoS,
+                                                retain = get_flag(retain, Msg)
+                                               },
+                 variable = #mqtt_packet_publish{topic_name = Topic,
+                                                 packet_id  = PacketId,
+                                                 properties = props(Headers)
+                                                },
+                 payload  = Payload
+                }.
+
+props(undefined) -> undefined;
+props(Headers)   -> maps:with(['Payload-Format-Indicator',
+                               'Response-Topic',
+                               'Correlation-Data',
+                               'User-Property',
+                               'Subscription-Identifier',
+                               'Content-Type',
+                               'Message-Expiry-Interval'
+                              ], Headers).
 
 %% @doc Message to map
 -spec(to_map(emqx_types:message()) -> map()).

+ 2 - 2
src/emqx_types.erl

@@ -176,8 +176,8 @@
 -type(deliver() :: {deliver, topic(), message()}).
 -type(delivery() :: #delivery{}).
 -type(deliver_result() :: ok | {error, term()}).
--type(publish_result() :: [ {node(), topic(), deliver_result()}
-                          | {share, topic(), deliver_result()}]).
+-type(publish_result() :: [{node(), topic(), deliver_result()} |
+                           {share, topic(), deliver_result()}]).
 -type(route() :: #route{}).
 -type(sub_group() :: tuple() | binary()).
 -type(route_entry() :: {topic(), node()} | {topic, sub_group()}).

+ 3 - 4
test/emqx_channel_SUITE.erl

@@ -93,8 +93,7 @@ t_chan_info(_) ->
 
 t_chan_caps(_) ->
     Caps = emqx_mqtt_caps:default(),
-    ?assertEqual(Caps#{max_packet_size => 1048576},
-                 emqx_channel:caps(channel())).
+    ?assertEqual(Caps, emqx_channel:caps(channel())).
 
 %%--------------------------------------------------------------------
 %% Test cases for channel init
@@ -129,14 +128,14 @@ t_handle_in_unexpected_connect_packet(_) ->
       = emqx_channel:handle_in(?CONNECT_PACKET(connpkt()), Channel).
 
 t_handle_in_qos0_publish(_) ->
-    ok = meck:expect(emqx_broker, publish, fun(_) -> ok end),
+    ok = meck:expect(emqx_broker, publish, fun(_) -> [] end),
     Channel = channel(#{conn_state => connected}),
     Publish = ?PUBLISH_PACKET(?QOS_0, <<"topic">>, undefined, <<"payload">>),
     {ok, _NChannel} = emqx_channel:handle_in(Publish, Channel).
     % ?assertEqual(#{publish_in => 1}, emqx_channel:info(pub_stats, NChannel)).
 
 t_handle_in_qos1_publish(_) ->
-    ok = meck:expect(emqx_broker, publish, fun(_) -> ok end),
+    ok = meck:expect(emqx_broker, publish, fun(_) -> [] end),
     Channel = channel(#{conn_state => connected}),
     Publish = ?PUBLISH_PACKET(?QOS_1, <<"topic">>, 1, <<"payload">>),
     {ok, ?PUBACK_PACKET(1, RC), _NChannel} = emqx_channel:handle_in(Publish, Channel),

+ 76 - 17
test/emqx_message_SUITE.erl

@@ -43,17 +43,56 @@ t_make(_) ->
     ?assertEqual(<<"topic">>, emqx_message:topic(Msg2)),
     ?assertEqual(<<"payload">>, emqx_message:payload(Msg2)).
 
+t_id(_) ->
+    Msg = emqx_message:make(<<"topic">>, <<"payload">>),
+    ?assert(is_binary(emqx_message:id(Msg))).
+
+t_qos(_) ->
+    Msg = emqx_message:make(<<"topic">>, <<"payload">>),
+    ?assertEqual(?QOS_0, emqx_message:qos(Msg)),
+    Msg1 = emqx_message:make(id, ?QOS_1, <<"t">>, <<"payload">>),
+    ?assertEqual(?QOS_1, emqx_message:qos(Msg1)),
+    Msg2 = emqx_message:make(id, ?QOS_2, <<"t">>, <<"payload">>),
+    ?assertEqual(?QOS_2, emqx_message:qos(Msg2)).
+
+t_topic(_) ->
+    Msg = emqx_message:make(<<"t">>, <<"payload">>),
+    ?assertEqual(<<"t">>, emqx_message:topic(Msg)).
+
+t_payload(_) ->
+    Msg = emqx_message:make(<<"t">>, <<"payload">>),
+    ?assertEqual(<<"payload">>, emqx_message:payload(Msg)).
+
+t_timestamp(_) ->
+    Msg = emqx_message:make(<<"t">>, <<"payload">>),
+    timer:sleep(1),
+    ?assert(erlang:system_time(millisecond) > emqx_message:timestamp(Msg)).
+
+t_clean_dup(_) ->
+    Msg = emqx_message:make(<<"topic">>, <<"payload">>),
+    ?assertNot(emqx_message:get_flag(dup, Msg)),
+    Msg = emqx_message:clean_dup(Msg),
+    Msg1 = emqx_message:set_flag(dup, Msg),
+    ?assert(emqx_message:get_flag(dup, Msg1)),
+    Msg2 = emqx_message:clean_dup(Msg1),
+    ?assertNot(emqx_message:get_flag(dup, Msg2)).
+
 t_get_set_flags(_) ->
     Msg = #message{id = <<"id">>, qos = ?QOS_1, flags = undefined},
     Msg1 = emqx_message:set_flags(#{retain => true}, Msg),
-    ?assertEqual(#{retain => true}, emqx_message:get_flags(Msg1)).
+    ?assertEqual(#{retain => true}, emqx_message:get_flags(Msg1)),
+    Msg2 = emqx_message:set_flags(#{dup => true}, Msg1),
+    ?assertEqual(#{retain => true, dup => true}, emqx_message:get_flags(Msg2)).
 
 t_get_set_flag(_) ->
     Msg = emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>),
-    Msg2 = emqx_message:set_flag(retain, false, Msg),
+    ?assertNot(emqx_message:get_flag(dup, Msg)),
+    ?assertNot(emqx_message:get_flag(retain, Msg)),
+    Msg1 = emqx_message:set_flag(dup, true, Msg),
+    Msg2 = emqx_message:set_flag(retain, true, Msg1),
     Msg3 = emqx_message:set_flag(dup, Msg2),
     ?assert(emqx_message:get_flag(dup, Msg3)),
-    ?assertNot(emqx_message:get_flag(retain, Msg3)),
+    ?assert(emqx_message:get_flag(retain, Msg3)),
     Msg4 = emqx_message:unset_flag(dup, Msg3),
     Msg5 = emqx_message:unset_flag(retain, Msg4),
     Msg5 = emqx_message:unset_flag(badflag, Msg5),
@@ -76,6 +115,8 @@ t_get_set_headers(_) ->
 
 t_get_set_header(_) ->
     Msg = emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>),
+    Msg = emqx_message:remove_header(x, Msg),
+    ?assertEqual(undefined, emqx_message:get_header(a, Msg)),
     Msg1 = emqx_message:set_header(a, 1, Msg),
     Msg2 = emqx_message:set_header(b, 2, Msg1),
     Msg3 = emqx_message:set_header(c, 3, Msg2),
@@ -95,11 +136,8 @@ t_undefined_headers(_) ->
 t_format(_) ->
     Msg = emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>),
     io:format("~s~n", [emqx_message:format(Msg)]),
-    Msg1 = #message{id = <<"id">>,
-                    qos = ?QOS_0,
-                    flags = undefined,
-                    headers = undefined
-                   },
+    Msg1 = emqx_message:set_header('Subscription-Identifier', 1,
+                                   emqx_message:set_flag(dup, Msg)),
     io:format("~s~n", [emqx_message:format(Msg1)]).
 
 t_is_expired(_) ->
@@ -117,28 +155,49 @@ t_is_expired(_) ->
 
 % t_to_list(_) ->
 %     error('TODO').
-    
+
 t_to_packet(_) ->
-    Pkt = #mqtt_packet{header = #mqtt_packet_header{type   = ?PUBLISH,
-                                                    qos    = ?QOS_0,
-                                                    retain = false,
-                                                    dup    = false},
+    Pkt = #mqtt_packet{header   = #mqtt_packet_header{type   = ?PUBLISH,
+                                                      qos    = ?QOS_0,
+                                                      retain = false,
+                                                      dup    = false
+                                                     },
                        variable = #mqtt_packet_publish{topic_name = <<"topic">>,
                                                        packet_id  = 10,
-                                                       properties = #{}},
-                       payload = <<"payload">>},
+                                                       properties = undefined
+                                                      },
+                       payload  = <<"payload">>
+                      },
     Msg = emqx_message:make(<<"clientid">>, ?QOS_0, <<"topic">>, <<"payload">>),
     ?assertEqual(Pkt, emqx_message:to_packet(10, Msg)).
 
+t_to_packet_with_props(_) ->
+    Props = #{'Subscription-Identifier' => 1},
+    Pkt = #mqtt_packet{header   = #mqtt_packet_header{type   = ?PUBLISH,
+                                                      qos    = ?QOS_0,
+                                                      retain = false,
+                                                      dup    = false
+                                                     },
+                       variable = #mqtt_packet_publish{topic_name = <<"topic">>,
+                                                       packet_id  = 10,
+                                                       properties = Props
+                                                      },
+                       payload  = <<"payload">>
+                      },
+    Msg = emqx_message:make(<<"clientid">>, ?QOS_0, <<"topic">>, <<"payload">>),
+    Msg1 = emqx_message:set_header('Subscription-Identifier', 1, Msg),
+    ?assertEqual(Pkt, emqx_message:to_packet(10, Msg1)).
+
 t_to_map(_) ->
     Msg = emqx_message:make(<<"clientid">>, ?QOS_1, <<"topic">>, <<"payload">>),
     List = [{id, emqx_message:id(Msg)},
             {qos, ?QOS_1},
             {from, <<"clientid">>},
-            {flags, #{dup => false}},
-            {headers, #{}},
+            {flags, undefined},
+            {headers, undefined},
             {topic, <<"topic">>},
             {payload, <<"payload">>},
             {timestamp, emqx_message:timestamp(Msg)}],
     ?assertEqual(List, emqx_message:to_list(Msg)),
     ?assertEqual(maps:from_list(List), emqx_message:to_map(Msg)).
+

+ 0 - 6
test/emqx_mqtt_caps_SUITE.erl

@@ -24,12 +24,6 @@
 
 all() -> emqx_ct:all(?MODULE).
 
-% t_get_caps(_) ->
-%     error('TODO').
-
-% t_default(_) ->
-%     error('TODO').
-
 t_check_pub(_) ->
     PubCaps = #{max_qos_allowed => ?QOS_1,
                 retain_available => false

+ 1 - 1
test/emqx_packet_SUITE.erl

@@ -153,7 +153,7 @@ t_check_connect(_) ->
 
 t_from_to_message(_) ->
     ExpectedMsg = emqx_message:make(<<"clientid">>, ?QOS_0, <<"topic">>, <<"payload">>),
-    ExpectedMsg1 = emqx_message:set_flag(retain, false, ExpectedMsg),
+    ExpectedMsg1 = emqx_message:set_flags(#{dup => false, retain => false}, ExpectedMsg),
     ExpectedMsg2 = emqx_message:set_headers(#{peerhost => {127,0,0,1},
                                               protocol => mqtt,
                                               username => <<"test">>

+ 7 - 7
test/emqx_session_SUITE.erl

@@ -114,21 +114,21 @@ t_unsubscribe(_) ->
     Error = emqx_session:unsubscribe(clientinfo(), <<"#">>, NSession),
     ?assertEqual({error, ?RC_NO_SUBSCRIPTION_EXISTED}, Error).
 
-t_publish_qos2(_) ->
+t_publish_qos0(_) ->
     ok = meck:expect(emqx_broker, publish, fun(_) -> [] end),
-    Msg = emqx_message:make(test, ?QOS_2, <<"t">>, <<"payload">>),
-    {ok, [], Session} = emqx_session:publish(1, Msg, session()),
-    ?assertEqual(1, emqx_session:info(awaiting_rel_cnt, Session)).
+    Msg = emqx_message:make(test, ?QOS_0, <<"t">>, <<"payload">>),
+    {ok, [], Session} = emqx_session:publish(0, Msg, Session = session()).
 
 t_publish_qos1(_) ->
     ok = meck:expect(emqx_broker, publish, fun(_) -> [] end),
     Msg = emqx_message:make(test, ?QOS_1, <<"t">>, <<"payload">>),
     {ok, [], _Session} = emqx_session:publish(1, Msg, session()).
 
-t_publish_qos0(_) ->
+t_publish_qos2(_) ->
     ok = meck:expect(emqx_broker, publish, fun(_) -> [] end),
-    Msg = emqx_message:make(test, ?QOS_1, <<"t">>, <<"payload">>),
-    {ok, [], _Session} = emqx_session:publish(0, Msg, session()).
+    Msg = emqx_message:make(test, ?QOS_2, <<"t">>, <<"payload">>),
+    {ok, [], Session} = emqx_session:publish(1, Msg, session()),
+    ?assertEqual(1, emqx_session:info(awaiting_rel_cnt, Session)).
 
 t_is_awaiting_full_false(_) ->
     ?assertNot(emqx_session:is_awaiting_full(session(#{max_awaiting_rel => 0}))).

+ 1 - 1
test/emqx_shared_sub_SUITE.erl

@@ -41,7 +41,7 @@ init_per_suite(Config) ->
 
 end_per_suite(_Config) ->
     emqx_ct_helpers:stop_apps([]).
-    
+
 t_is_ack_required(_) ->
     ?assertEqual(false, emqx_shared_sub:is_ack_required(#message{headers = #{}})).