Просмотр исходного кода

Add 'remove_header/2', 'get_headers/1' functions

- Adopt new 'export' style
- Add 'remove_header/2', 'get_headers/1' functions
- Remove 'remove_topic_alias/1' function
Feng Lee 7 лет назад
Родитель
Сommit
dba16aeea5
3 измененных файлов с 84 добавлено и 25 удалено
  1. 60 17
      src/emqx_message.erl
  2. 1 1
      src/emqx_protocol.erl
  3. 23 7
      test/emqx_message_SUITE.erl

+ 60 - 17
src/emqx_message.erl

@@ -17,13 +17,34 @@
 -include("emqx.hrl").
 -include("emqx_mqtt.hrl").
 
--export([make/2, make/3, make/4]).
+-export([ make/2
+        , make/3
+        , make/4 ]).
+
+-export([ get_flag/2
+        , get_flag/3
+        , set_flag/2
+        , set_flag/3
+        , unset_flag/2
+        ]).
 -export([set_flags/2]).
--export([get_flag/2, get_flag/3, set_flag/2, set_flag/3, unset_flag/2]).
+
+-export([ get_headers/1
+        , get_header/2
+        , get_header/3
+        , set_header/3
+        , remove_header/2
+        ]).
 -export([set_headers/2]).
--export([get_header/2, get_header/3, set_header/3]).
--export([is_expired/1, update_expiry/1]).
--export([remove_topic_alias/1]).
+
+-export([ is_expired/1
+        , update_expiry/1
+        ]).
+
+-export([ to_map/1
+        , to_list/1
+        ]).
+
 -export([format/1]).
 
 -type(flag() :: atom()).
@@ -40,13 +61,13 @@ make(From, Topic, Payload) ->
 -spec(make(atom() | emqx_types:client_id(), emqx_mqtt_types:qos(),
            emqx_topic:topic(), emqx_types:payload()) -> emqx_types:message()).
 make(From, QoS, Topic, Payload) ->
-    #message{id         = emqx_guid:gen(),
-             qos        = QoS,
-             from       = From,
-             flags      = #{dup => false},
-             topic      = Topic,
-             payload    = Payload,
-             timestamp  = os:timestamp()}.
+    #message{id = emqx_guid:gen(),
+             qos = QoS,
+             from = From,
+             flags = #{dup => false},
+             topic = Topic,
+             payload = Payload,
+             timestamp = os:timestamp()}.
 
 -spec(set_flags(map(), emqx_types:message()) -> emqx_types:message()).
 set_flags(Flags, Msg = #message{flags = undefined}) when is_map(Flags) ->
@@ -88,6 +109,10 @@ set_headers(New, Msg = #message{headers = Old}) when is_map(New) ->
     Msg#message{headers = maps:merge(Old, New)};
 set_headers(undefined, Msg) -> Msg.
 
+-spec(get_headers(emqx_types:message()) -> map()).
+get_headers(Msg) ->
+    Msg#message.headers.
+
 -spec(get_header(term(), emqx_types:message()) -> term()).
 get_header(Hdr, Msg) ->
     get_header(Hdr, Msg, undefined).
@@ -101,14 +126,24 @@ set_header(Hdr, Val, Msg = #message{headers = undefined}) ->
 set_header(Hdr, Val, Msg = #message{headers = Headers}) ->
     Msg#message{headers = maps:put(Hdr, Val, Headers)}.
 
+-spec(remove_header(term(), emqx_types:message()) -> emqx_types:message()).
+remove_header(Hdr, Msg = #message{headers = Headers}) ->
+    case maps:is_key(Hdr, Headers) of
+        true ->
+            Msg#message{headers = maps:remove(Hdr, Headers)};
+        false -> Msg
+    end.
+
 -spec(is_expired(emqx_types:message()) -> boolean()).
-is_expired(#message{headers = #{'Message-Expiry-Interval' := Interval}, timestamp = CreatedAt}) ->
+is_expired(#message{headers = #{'Message-Expiry-Interval' := Interval},
+                    timestamp = CreatedAt}) ->
     elapsed(CreatedAt) > timer:seconds(Interval);
 is_expired(_Msg) ->
     false.
 
 -spec(update_expiry(emqx_types:message()) -> emqx_types:message()).
-update_expiry(Msg = #message{headers = #{'Message-Expiry-Interval' := Interval}, timestamp = CreatedAt}) ->
+update_expiry(Msg = #message{headers = #{'Message-Expiry-Interval' := Interval},
+                             timestamp = CreatedAt}) ->
     case elapsed(CreatedAt) of
         Elapsed when Elapsed > 0 ->
             set_header('Message-Expiry-Interval', max(1, Interval - (Elapsed div 1000)), Msg);
@@ -116,14 +151,21 @@ update_expiry(Msg = #message{headers = #{'Message-Expiry-Interval' := Interval},
     end;
 update_expiry(Msg) -> Msg.
 
-remove_topic_alias(Msg = #message{headers = Headers}) ->
-    Msg#message{headers = maps:remove('Topic-Alias', Headers)}.
+%% @doc Message to map
+-spec(to_map(emqx_types:message()) -> map()).
+to_map(Msg) ->
+    maps:from_list(to_list(Msg)).
+
+%% @doc Message to tuple list
+-spec(to_list(emqx_types:message()) -> map()).
+to_list(Msg) ->
+    lists:zip(record_info(fields, message), tl(tuple_to_list(Msg))).
 
 %% MilliSeconds
 elapsed(Since) ->
     max(0, timer:now_diff(os:timestamp(), Since) div 1000).
 
-format(#message{id = Id, qos = QoS, topic = Topic, from = From, flags = Flags, headers = Headers}) ->
+format(#message{id = Id,qos = QoS, topic = Topic, from = From, flags = Flags, headers = Headers}) ->
     io_lib:format("Message(Id=~s, QoS=~w, Topic=~s, From=~p, Flags=~s, Headers=~s)",
                   [Id, QoS, Topic, From, format(flags, Flags), format(headers, Headers)]).
 
@@ -133,3 +175,4 @@ format(flags, Flags) ->
     io_lib:format("~p", [[Flag || {Flag, true} <- maps:to_list(Flags)]]);
 format(headers, Headers) ->
     io_lib:format("~p", [Headers]).
+

+ 1 - 1
src/emqx_protocol.erl

@@ -654,7 +654,7 @@ deliver({publish, PacketId, Msg}, PState = #pstate{mountpoint = MountPoint}) ->
     Msg0 = emqx_hooks:run_fold('message.deliver', [credentials(PState)], Msg),
     Msg1 = emqx_message:update_expiry(Msg0),
     Msg2 = emqx_mountpoint:unmount(MountPoint, Msg1),
-    send(emqx_packet:from_message(PacketId, emqx_message:remove_topic_alias(Msg2)), PState);
+    send(emqx_packet:from_message(PacketId, Msg2), PState);
 
 deliver({puback, PacketId, ReasonCode}, PState) ->
     send(?PUBACK_PACKET(PacketId, ReasonCode), PState);

+ 23 - 7
test/emqx_message_SUITE.erl

@@ -24,12 +24,12 @@
 -include_lib("eunit/include/eunit.hrl").
 
 all() ->
-    [
-        message_make,
-        message_flag,
-        message_header,
-        message_format,
-        message_expired
+    [ message_make
+    , message_flag
+    , message_header
+    , message_format
+    , message_expired
+    , message_to_map
     ].
 
 message_make(_) ->
@@ -60,7 +60,9 @@ message_header(_) ->
     Msg1 = emqx_message:set_headers(#{a => 1, b => 2}, Msg),
     Msg2 = emqx_message:set_header(c, 3, Msg1),
     ?assertEqual(1, emqx_message:get_header(a, Msg2)),
-    ?assertEqual(4, emqx_message:get_header(d, Msg2, 4)).
+    ?assertEqual(4, emqx_message:get_header(d, Msg2, 4)),
+    Msg3 = emqx_message:remove_header(a, Msg2),
+    ?assertEqual(#{b => 2, c => 3}, emqx_message:get_headers(Msg3)).
 
 message_format(_) ->
     io:format("~s", [emqx_message:format(emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>))]).
@@ -75,3 +77,17 @@ message_expired(_) ->
     timer:sleep(1000),
     Msg2 = emqx_message:update_expiry(Msg1),
     ?assertEqual(1, emqx_message:get_header('Message-Expiry-Interval', Msg2)).
+
+message_to_map(_) ->
+    Msg = emqx_message:make(<<"clientid">>, ?QOS_1, <<"topic">>, <<"payload">>),
+    List = [{id, Msg#message.id},
+            {qos, ?QOS_1},
+            {from, <<"clientid">>},
+            {flags, #{dup => false}},
+            {headers, #{}},
+            {topic, <<"topic">>},
+            {payload, <<"payload">>},
+            {timestamp, Msg#message.timestamp}],
+    ?assertEqual(List, emqx_message:to_list(Msg)),
+    ?assertEqual(maps:from_list(List), emqx_message:to_map(Msg)).
+