Browse Source

refactor(gw): support session takeover

JianBo He 3 years ago
parent
commit
ddf3585b22

+ 226 - 60
apps/emqx_gateway/src/emqx_gateway_cm.erl

@@ -25,6 +25,7 @@
 
 
 -include("include/emqx_gateway.hrl").
 -include("include/emqx_gateway.hrl").
 -include_lib("emqx/include/logger.hrl").
 -include_lib("emqx/include/logger.hrl").
+-include_lib("snabbkaffe/include/snabbkaffe.hrl").
 
 
 %% APIs
 %% APIs
 -export([start_link/1]).
 -export([start_link/1]).
@@ -33,6 +34,7 @@
         , open_session/6
         , open_session/6
         , kick_session/2
         , kick_session/2
         , kick_session/3
         , kick_session/3
+        , takeover_session/2
         , register_channel/4
         , register_channel/4
         , unregister_channel/2
         , unregister_channel/2
         , insert_channel_info/4
         , insert_channel_info/4
@@ -48,6 +50,11 @@
         , connection_closed/2
         , connection_closed/2
         ]).
         ]).
 
 
+-export([ call/3
+        , call/4
+        , cast/3
+        ]).
+
 -export([ with_channel/3
 -export([ with_channel/3
         , lookup_channels/2
         , lookup_channels/2
         ]).
         ]).
@@ -70,9 +77,11 @@
         , do_set_chan_info/4
         , do_set_chan_info/4
         , do_get_chan_stats/3
         , do_get_chan_stats/3
         , do_set_chan_stats/4
         , do_set_chan_stats/4
-        , do_discard_session/3
-        , do_kick_session/3
+        , do_kick_session/4
         , do_get_chann_conn_mod/3
         , do_get_chann_conn_mod/3
+        , do_call/4
+        , do_call/5
+        , do_cast/4
         ]).
         ]).
 
 
 -export_type([ gateway_name/0
 -export_type([ gateway_name/0
@@ -304,10 +313,39 @@ open_session(GwName, true = _CleanStart, ClientInfo, ConnInfo, CreateSessionFun,
           end,
           end,
     locker_trans(GwName, ClientId, Fun);
     locker_trans(GwName, ClientId, Fun);
 
 
-open_session(_Type, false = _CleanStart,
-             _ClientInfo, _ConnInfo, _CreateSessionFun, _SessionMod) ->
-    %% TODO: The session takeover logic will be implemented on 0.9?
-    {error, not_supported_now}.
+open_session(GwName, false = _CleanStart,
+             ClientInfo = #{clientid := ClientId},
+             ConnInfo, CreateSessionFun, SessionMod) ->
+    Self = self(),
+
+    ResumeStart =
+        fun(_) ->
+            CreateSess =
+                fun() ->
+                    Session = create_session(
+                                GwName, ClientInfo, ConnInfo,
+                                CreateSessionFun, SessionMod),
+                        register_channel(
+                          GwName, ClientId, Self, ConnInfo),
+                        {ok, #{session => Session, present => false}}
+                end,
+                case takeover_session(GwName, ClientId) of
+                    {ok, ConnMod, ChanPid, Session} ->
+                        ok = emqx_session:resume(ClientInfo, Session),
+                        case request_stepdown({takeover, 'end'}, ConnMod, ChanPid) of
+                            {ok, Pendings} ->
+                                register_channel(
+                                  GwName, ClientId, Self, ConnInfo),
+                                {ok, #{session  => Session,
+                                       present  => true,
+                                       pendings => Pendings}};
+                            {error, _} ->
+                                CreateSess()
+                        end;
+                    {error, _Reason} -> CreateSess()
+                end
+            end,
+    locker_trans(GwName, ClientId, ResumeStart).
 
 
 %% @private
 %% @private
 create_session(GwName, ClientInfo, ConnInfo, CreateSessionFun, SessionMod) ->
 create_session(GwName, ClientInfo, ConnInfo, CreateSessionFun, SessionMod) ->
@@ -341,77 +379,167 @@ create_session(GwName, ClientInfo, ConnInfo, CreateSessionFun, SessionMod) ->
         throw(Reason)
         throw(Reason)
     end.
     end.
 
 
+%% @doc Try to takeover a session.
+-spec(takeover_session(gateway_name(), emqx_types:clientid())
+      -> {error, term()}
+       | {ok, atom(), pid(), emqx_session:session()}).
+takeover_session(GwName, ClientId) ->
+    case lookup_channels(GwName, ClientId) of
+        [] -> {error, not_found};
+        [ChanPid] ->
+            do_takeover_session(GwName, ClientId, ChanPid);
+        ChanPids ->
+            [ChanPid|StalePids] = lists:reverse(ChanPids),
+            ?SLOG(warning, #{ msg => "more_than_one_channel_found"
+                            , chan_pids => ChanPids
+                            }),
+            lists:foreach(fun(StalePid) ->
+                                  catch discard_session(ClientId, StalePid)
+                          end, StalePids),
+            do_takeover_session(GwName, ClientId, ChanPid)
+    end.
+
+do_takeover_session(GwName, ClientId, ChanPid) when node(ChanPid) == node() ->
+    case get_chann_conn_mod(GwName, ClientId, ChanPid) of
+        undefined ->
+            {error, not_found};
+        ConnMod when is_atom(ConnMod) ->
+            case request_stepdown({takeover, 'begin'}, ConnMod, ChanPid) of
+                {ok, Session} ->
+                    {ok, ConnMod, ChanPid, Session};
+                {error, Reason} ->
+                    {error, Reason}
+            end
+    end;
+do_takeover_session(GwName, ClientId, ChanPid) ->
+    wrap_rpc(emqx_gateway_cm_proto_v1:takeover_session(GwName, ClientId, ChanPid)).
+
 %% @doc Discard all the sessions identified by the ClientId.
 %% @doc Discard all the sessions identified by the ClientId.
 -spec discard_session(GwName :: gateway_name(), binary()) -> ok.
 -spec discard_session(GwName :: gateway_name(), binary()) -> ok.
 discard_session(GwName, ClientId) when is_binary(ClientId) ->
 discard_session(GwName, ClientId) when is_binary(ClientId) ->
     case lookup_channels(GwName, ClientId) of
     case lookup_channels(GwName, ClientId) of
         [] -> ok;
         [] -> ok;
-        ChanPids -> lists:foreach(fun(Pid) -> safe_discard_session(GwName, ClientId, Pid) end, ChanPids)
+        ChanPids -> lists:foreach(fun(Pid) -> discard_session(GwName, ClientId, Pid) end, ChanPids)
     end.
     end.
 
 
-%% @private
-safe_discard_session(GwName, ClientId, Pid) ->
+discard_session(GwName, ClientId, ChanPid) ->
+    kick_session(GwName, discard, ClientId, ChanPid).
+
+-spec kick_session(gateway_name(), emqx_types:clientid()) -> ok.
+
+kick_session(GwName, ClientId) ->
+    case lookup_channels(GwName, ClientId) of
+        [] -> ok;
+        ChanPids ->
+            ChanPids > 1 andalso begin
+                ?SLOG(warning, #{ msg => "more_than_one_channel_found"
+                                , chan_pids => ChanPids
+                                },
+                      #{clientid => ClientId})
+            end,
+            lists:foreach(fun(Pid) ->
+                kick_session(GwName, ClientId, Pid)
+            end, ChanPids)
+    end.
+
+kick_session(GwName, ClientId, ChanPid) ->
+    kick_session(GwName, kick, ClientId, ChanPid).
+
+%% @private This function is shared for session 'kick' and 'discard' (as the first arg Action).
+kick_session(GwName, Action, ClientId, ChanPid) ->
     try
     try
-        discard_session(GwName, ClientId, Pid)
+        wrap_rpc(emqx_gateway_cm_proto_v1:kick_session(GwName, Action, ClientId, ChanPid))
     catch
     catch
-        _ : noproc -> % emqx_ws_connection: call
-            ok;
-        _ : {noproc, _} -> % emqx_connection: gen_server:call
-            ok;
-        _ : {{shutdown, _}, _} ->
-            ok;
-        _ : _Error : _St ->
-            ok
+        Error : Reason ->
+            %% This should mostly be RPC failures.
+            %% However, if the node is still running the old version
+            %% code (prior to emqx app 4.3.10) some of the RPC handler
+            %% exceptions may get propagated to a new version node
+            ?SLOG(error, #{ msg => "failed_to_kick_session_on_remote_node"
+                          , node => node(ChanPid)
+                          , action => Action
+                          , error => Error
+                          , reason => Reason
+                          },
+                #{clientid => ClientId})
     end.
     end.
 
 
--spec do_discard_session(gateway_name(), emqx_types:clientid(), pid()) ->
-          _.
-do_discard_session(GwName, ClientId, ChanPid) ->
+-spec do_kick_session(gateway_name(),
+                      kick | discard,
+                      emqx_types:clientid(),
+                      pid()) -> ok.
+do_kick_session(GwName, Action, ClientId, ChanPid) ->
     case get_chann_conn_mod(GwName, ClientId, ChanPid) of
     case get_chann_conn_mod(GwName, ClientId, ChanPid) of
         undefined -> ok;
         undefined -> ok;
         ConnMod when is_atom(ConnMod) ->
         ConnMod when is_atom(ConnMod) ->
-            ConnMod:call(ChanPid, discard, ?T_TAKEOVER)
+            ok = request_stepdown(Action, ConnMod, ChanPid)
     end.
     end.
 
 
-%% @private
--spec discard_session(gateway_name(), emqx_types:clientid(), pid()) ->
-          _.
-discard_session(GwName, ClientId, ChanPid) ->
-    wrap_rpc(emqx_gateway_cm_proto_v1:discard_session(GwName, ClientId, ChanPid)).
-
--spec kick_session(gateway_name(), emqx_types:clientid())
-    -> {error, any()}
-     | ok.
-kick_session(GwName, ClientId) ->
-    case lookup_channels(GwName, ClientId) of
-        [] -> {error, not_found};
-        [ChanPid] ->
-            kick_session(GwName, ClientId, ChanPid);
-        ChanPids ->
-            [ChanPid|StalePids] = lists:reverse(ChanPids),
-            ?SLOG(error, #{ msg => "more_than_one_channel_found"
-                          , chan_pids => ChanPids
-                          }),
-            lists:foreach(fun(StalePid) ->
-                              catch discard_session(GwName, ClientId, StalePid)
-                          end, StalePids),
-            kick_session(GwName, ClientId, ChanPid)
+%% @private call a local stale session to execute an Action.
+%% If failed to response (e.g. timeout) force a kill.
+%% Keeping the stale pid around, or returning error or raise an exception
+%% benefits nobody.
+-spec request_stepdown(Action, module(), pid())
+-> ok
+   | {ok, emqx_session:session() | list(emqx_type:deliver())}
+   | {error, term()}
+     when Action :: kick | discard | {takeover, 'begin'} | {takeover, 'end'}.
+request_stepdown(Action, ConnMod, Pid) ->
+    Timeout =
+    case Action == kick orelse Action == discard of
+        true -> ?T_KICK;
+        _ -> ?T_TAKEOVER
+    end,
+    Return =
+    %% this is essentailly a gen_server:call implemented in emqx_connection
+    %% and emqx_ws_connection.
+    %% the handle_call is implemented in emqx_channel
+    try apply(ConnMod, call, [Pid, Action, Timeout]) of
+        ok -> ok;
+        Reply -> {ok, Reply}
+    catch
+        _ : noproc -> % emqx_ws_connection: call
+            ok = ?tp(debug, "session_already_gone", #{pid => Pid, action => Action}),
+            {error, noproc};
+        _ : {noproc, _} -> % emqx_connection: gen_server:call
+            ok = ?tp(debug, "session_already_gone", #{pid => Pid, action => Action}),
+            {error, noproc};
+        _ : Reason = {shutdown, _} ->
+            ok = ?tp(debug, "session_already_shutdown", #{pid => Pid, action => Action}),
+            {error, Reason};
+        _ : Reason = {{shutdown, _}, _} ->
+            ok = ?tp(debug, "session_already_shutdown", #{pid => Pid, action => Action}),
+            {error, Reason};
+        _ : {timeout, {gen_server, call, _}} ->
+            ?tp(warning, "session_stepdown_request_timeout",
+                #{pid => Pid,
+                  action => Action,
+                  stale_channel => stale_channel_info(Pid)
+                 }),
+            ok = force_kill(Pid),
+            {error, timeout};
+        _ : Error : St ->
+            ?tp(error, "session_stepdown_request_exception",
+                #{pid => Pid,
+                  action => Action,
+                  reason => Error,
+                  stacktrace => St,
+                  stale_channel => stale_channel_info(Pid)
+                 }),
+            ok = force_kill(Pid),
+            {error, Error}
+    end,
+    case Action == kick orelse Action == discard of
+        true -> ok;
+        _ -> Return
     end.
     end.
 
 
--spec do_kick_session(gateway_name(), emqx_types:clientid(), pid()) ->
-          _.
-do_kick_session(GwName, ClientId, ChanPid) ->
-    case get_chan_info(GwName, ClientId, ChanPid) of
-        #{conninfo := #{conn_mod := ConnMod}} ->
-            ConnMod:call(ChanPid, kick, ?T_TAKEOVER);
-        undefined ->
-            {error, not_found}
-    end.
+force_kill(Pid) ->
+    exit(Pid, kill),
+    ok.
 
 
--spec kick_session(gateway_name(), emqx_types:clientid(), pid()) ->
-          _.
-kick_session(GwName, ClientId, ChanPid) ->
-    wrap_rpc(emqx_gateway_cm_proto_v1:kick_session(GwName, ClientId, ChanPid)).
+stale_channel_info(Pid) ->
+    process_info(Pid, [status, message_queue_len, current_stacktrace]).
 
 
 with_channel(GwName, ClientId, Fun) ->
 with_channel(GwName, ClientId, Fun) ->
     case lookup_channels(GwName, ClientId) of
     case lookup_channels(GwName, ClientId) of
@@ -437,9 +565,47 @@ do_get_chann_conn_mod(GwName, ClientId, ChanPid) ->
 get_chann_conn_mod(GwName, ClientId, ChanPid) ->
 get_chann_conn_mod(GwName, ClientId, ChanPid) ->
     wrap_rpc(emqx_gateway_cm_proto_v1:get_chann_conn_mod(GwName, ClientId, ChanPid)).
     wrap_rpc(emqx_gateway_cm_proto_v1:get_chann_conn_mod(GwName, ClientId, ChanPid)).
 
 
+-spec call(gateway_name(), emqx_types:clientid(), term()) -> term().
+call(GwName, ClientId, Req) ->
+    with_channel(GwName, ClientId, fun(ChanPid) ->
+                                           wrap_rpc(emqx_gateway_cm_proto_v1:call(GwName, ClientId, ChanPid, Req))
+                                   end).
+
+-spec call(gateway_name(), emqx_types:clientid(), term(), timeout()) -> term().
+call(GwName, ClientId, Req, Timeout) ->
+    with_channel(GwName, ClientId, fun(ChanPid) ->
+                                           wrap_rpc(
+                                             emqx_gateway_cm_proto_v1:call(
+                                               GwName, ClientId, ChanPid, Req, Timeout))
+                                   end).
+
+do_call(GwName, ClientId, ChanPid, Req) ->
+    case do_get_chann_conn_mod(GwName, ClientId, ChanPid) of
+        undefined -> error(noproc);
+        ConnMod -> ConnMod:call(ChanPid, Req)
+    end.
+
+do_call(GwName, ClientId, ChanPid, Req, Timeout) ->
+    case do_get_chann_conn_mod(GwName, ClientId, ChanPid) of
+        undefined -> error(noproc);
+        ConnMod -> ConnMod:call(ChanPid, Req, Timeout)
+    end.
+
+-spec cast(gateway_name(), emqx_types:clientid(), term()) -> term().
+cast(GwName, ClientId, Req) ->
+    with_channel(GwName, ClientId, fun(ChanPid) ->
+                                           wrap_rpc(emqx_gateway_cm_proto_v1:cast(GwName, ClientId, ChanPid, Req))
+                                   end).
+
+do_cast(GwName, ClientId, ChanPid, Req) ->
+    case do_get_chann_conn_mod(GwName, ClientId, ChanPid) of
+        undefined -> error(noproc);
+        ConnMod -> ConnMod:cast(ChanPid, Req)
+    end.
+
 %% Locker
 %% Locker
 
 
-locker_trans(_Type, undefined, Fun) ->
+locker_trans(_GwName, undefined, Fun) ->
     Fun([]);
     Fun([]);
 locker_trans(GwName, ClientId, Fun) ->
 locker_trans(GwName, ClientId, Fun) ->
     Locker = lockername(GwName),
     Locker = lockername(GwName),
@@ -530,7 +696,7 @@ code_change(_OldVsn, State, _Extra) ->
 do_unregister_channel_task(Items, GwName, CmTabs) ->
 do_unregister_channel_task(Items, GwName, CmTabs) ->
     lists:foreach(
     lists:foreach(
       fun({ChanPid, ClientId}) ->
       fun({ChanPid, ClientId}) ->
-          do_unregister_channel(GwName, {ClientId, ChanPid}, CmTabs)
+        do_unregister_channel(GwName, {ClientId, ChanPid}, CmTabs)
       end, Items).
       end, Items).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------

+ 175 - 38
apps/emqx_gateway/src/mqttsn/emqx_sn_channel.erl

@@ -20,6 +20,7 @@
 
 
 -include("src/mqttsn/include/emqx_sn.hrl").
 -include("src/mqttsn/include/emqx_sn.hrl").
 -include_lib("emqx/include/emqx.hrl").
 -include_lib("emqx/include/emqx.hrl").
+-include_lib("emqx/include/types.hrl").
 -include_lib("emqx/include/emqx_mqtt.hrl").
 -include_lib("emqx/include/emqx_mqtt.hrl").
 -include_lib("emqx/include/logger.hrl").
 -include_lib("emqx/include/logger.hrl").
 
 
@@ -66,6 +67,10 @@
           clientinfo_override :: map(),
           clientinfo_override :: map(),
           %% Connection State
           %% Connection State
           conn_state    :: conn_state(),
           conn_state    :: conn_state(),
+          %% Inflight register message queue
+          register_inflight :: maybe(term()),
+          %% Topics list for awaiting to register to client
+          register_awaiting_queue :: list(),
           %% Timer
           %% Timer
           timers :: #{atom() => disable | undefined | reference()},
           timers :: #{atom() => disable | undefined | reference()},
           %%% Takeover
           %%% Takeover
@@ -88,10 +93,12 @@
 -type(replies() :: reply() | [reply()]).
 -type(replies() :: reply() | [reply()]).
 
 
 -define(TIMER_TABLE, #{
 -define(TIMER_TABLE, #{
-          alive_timer  => keepalive,
-          retry_timer  => retry_delivery,
-          await_timer  => expire_awaiting_rel,
-          asleep_timer => expire_asleep
+          alive_timer    => keepalive,
+          retry_timer    => retry_delivery,
+          await_timer    => expire_awaiting_rel,
+          expire_timer   => expire_session,
+          asleep_timer   => expire_asleep,
+          register_timer => retry_register
          }).
          }).
 
 
 -define(DEFAULT_OVERRIDE,
 -define(DEFAULT_OVERRIDE,
@@ -104,6 +111,9 @@
 
 
 -define(NEG_QOS_CLIENT_ID, <<"NegQoS-Client">>).
 -define(NEG_QOS_CLIENT_ID, <<"NegQoS-Client">>).
 
 
+-define(REGISTER_TIMEOUT, 10000). % 10s
+-define(DEFAULT_SESSION_EXPIRY, 7200000). %% 2h
+
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Init the channel
 %% Init the channel
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
@@ -148,6 +158,7 @@ init(ConnInfo = #{peername := {PeerHost, _},
             , clientinfo = ClientInfo
             , clientinfo = ClientInfo
             , clientinfo_override = Override
             , clientinfo_override = Override
             , conn_state = idle
             , conn_state = idle
+            , register_awaiting_queue = []
             , timers = #{}
             , timers = #{}
             , takeover = false
             , takeover = false
             , resuming = false
             , resuming = false
@@ -195,17 +206,24 @@ stats(#channel{session = Session})->
 set_conn_state(ConnState, Channel) ->
 set_conn_state(ConnState, Channel) ->
     Channel#channel{conn_state = ConnState}.
     Channel#channel{conn_state = ConnState}.
 
 
-enrich_conninfo(?SN_CONNECT_MSG(_Flags, _ProtoId, Duration, ClientId),
+enrich_conninfo(?SN_CONNECT_MSG(Flags, _ProtoId, Duration, ClientId),
                 Channel = #channel{conninfo = ConnInfo}) ->
                 Channel = #channel{conninfo = ConnInfo}) ->
+    CleanStart = Flags#mqtt_sn_flags.clean_start,
     NConnInfo = ConnInfo#{ clientid => ClientId
     NConnInfo = ConnInfo#{ clientid => ClientId
                          , proto_name => <<"MQTT-SN">>
                          , proto_name => <<"MQTT-SN">>
                          , proto_ver => <<"1.2">>
                          , proto_ver => <<"1.2">>
-                         , clean_start => true
+                         , clean_start => CleanStart
                          , keepalive => Duration
                          , keepalive => Duration
-                         , expiry_interval => 0
+                         , expiry_interval => expiry_interval(Flags)
                          },
                          },
     {ok, Channel#channel{conninfo = NConnInfo}}.
     {ok, Channel#channel{conninfo = NConnInfo}}.
 
 
+expiry_interval(#mqtt_sn_flags{clean_start = false}) ->
+    %% TODO: make it configurable
+    ?DEFAULT_SESSION_EXPIRY;
+expiry_interval(#mqtt_sn_flags{clean_start = true}) ->
+    0.
+
 run_conn_hooks(Packet, Channel = #channel{ctx = Ctx,
 run_conn_hooks(Packet, Channel = #channel{ctx = Ctx,
                                           conninfo = ConnInfo}) ->
                                           conninfo = ConnInfo}) ->
     %% XXX: Assign headers of Packet to ConnProps
     %% XXX: Assign headers of Packet to ConnProps
@@ -308,13 +326,13 @@ ensure_connected(Channel = #channel{
 
 
 process_connect(Channel = #channel{
 process_connect(Channel = #channel{
                            ctx = Ctx,
                            ctx = Ctx,
-                           conninfo = ConnInfo,
+                           conninfo = ConnInfo = #{clean_start := CleanStart},
                            clientinfo = ClientInfo
                            clientinfo = ClientInfo
                           }) ->
                           }) ->
     SessFun = fun(_,_) -> emqx_session:init(#{max_inflight => 1}) end,
     SessFun = fun(_,_) -> emqx_session:init(#{max_inflight => 1}) end,
     case emqx_gateway_ctx:open_session(
     case emqx_gateway_ctx:open_session(
            Ctx,
            Ctx,
-           true,
+           CleanStart,
            ClientInfo,
            ClientInfo,
            ConnInfo,
            ConnInfo,
            SessFun
            SessFun
@@ -327,7 +345,7 @@ process_connect(Channel = #channel{
             ?SLOG(error, #{ msg => "failed_to_open_session"
             ?SLOG(error, #{ msg => "failed_to_open_session"
                           , reason => Reason
                           , reason => Reason
                           }),
                           }),
-            handle_out(connack, ?SN_RC_FAILED_SESSION, Channel)
+            handle_out(connack, ?SN_RC2_FAILED_SESSION, Channel)
     end.
     end.
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
@@ -501,6 +519,40 @@ handle_in(?SN_REGISTER_MSG(_TopicId, MsgId, TopicName),
             {ok, {outgoing, AckPacket}, Channel}
             {ok, {outgoing, AckPacket}, Channel}
     end;
     end;
 
 
+handle_in(?SN_REGACK_MSG(TopicId, MsgId, ?SN_RC_ACCEPTED),
+          Channel = #channel{register_inflight = Inflight}) ->
+    case Inflight of
+        {TopicId, _, TopicName} ->
+            ?SLOG(debug, #{ msg => "register_topic_name_to_client_succesfully"
+                          , topic_id => TopicId
+                          , topic_name => TopicName
+                          }),
+            NChannel = cancel_timer(
+                         register_timer,
+                         Channel#channel{register_inflight = undefined}),
+            send_next_register_or_replay_publish(TopicName, NChannel);
+        _ ->
+            ?SLOG(error, #{ msg => "unexpected_regack_msg"
+                          , msg_id => MsgId
+                          , topic_id => TopicId
+                          , current_inflight => Inflight
+                          }),
+            {ok, Channel}
+    end;
+
+handle_in(?SN_REGACK_MSG(_TopicId, _MsgId, Reason), Channel) ->
+    case Reason of
+        ?SN_RC_CONGESTION ->
+            %% TODO: a or b?
+            %% a. waiting for next register timer
+            %% b. re-new the re-transmit timer
+            {ok, Channel};
+        _ ->
+            %% disconnect this client, if the reason is
+            %% ?SN_RC_NOT_SUPPORTED, ?SN_RC_INVALID_TOPIC_ID, etc.
+            handle_out(disconnect, ?SN_RC_NOT_SUPPORTED, Channel)
+    end;
+
 handle_in(PubPkt = ?SN_PUBLISH_MSG(_Flags, TopicId0, MsgId, _Data), Channel) ->
 handle_in(PubPkt = ?SN_PUBLISH_MSG(_Flags, TopicId0, MsgId, _Data), Channel) ->
     TopicId = case is_integer(TopicId0) of
     TopicId = case is_integer(TopicId0) of
                   true -> TopicId0;
                   true -> TopicId0;
@@ -560,8 +612,7 @@ handle_in(?SN_PUBACK_MSG(TopicId, MsgId, ReturnCode),
                     %% involving the predefined topic name in register to
                     %% involving the predefined topic name in register to
                     %% enhance the gateway's robustness even inconsistent
                     %% enhance the gateway's robustness even inconsistent
                     %% with MQTT-SN channels
                     %% with MQTT-SN channels
-                    RegPkt = ?SN_REGISTER_MSG(TopicId, MsgId, TopicName),
-                    {ok, {outgoing, RegPkt}, Channel}
+                    handle_out(register, {TopicId, MsgId, TopicName}, Channel)
             end;
             end;
         _ ->
         _ ->
             ?SLOG(error, #{ msg => "cannt_handle_PUBACK"
             ?SLOG(error, #{ msg => "cannt_handle_PUBACK"
@@ -687,12 +738,17 @@ handle_in(?SN_PINGRESP_MSG(), Channel) ->
     {ok, Channel};
     {ok, Channel};
 
 
 handle_in(?SN_DISCONNECT_MSG(Duration), Channel) ->
 handle_in(?SN_DISCONNECT_MSG(Duration), Channel) ->
-    AckPkt = ?SN_DISCONNECT_MSG(undefined),
     case Duration of
     case Duration of
         undefined ->
         undefined ->
-            shutdown(normal, AckPkt, Channel);
+            handle_out(disconnect, normal, Channel);
         _ ->
         _ ->
-            %% TODO: asleep mechnisa
+            %% A DISCONNECT message with a Duration field is sent by a client
+            %% when it wants to go to the “asleep” state. The receipt of this
+            %% message is also acknowledged by the gateway by means of a
+            %% DISCONNECT message (without a duration field) [5.4.21]
+            %%
+            %% TODO: asleep mechanism
+            AckPkt = ?SN_DISCONNECT_MSG(undefined),
             {ok, {outgoing, AckPkt}, asleep(Duration, Channel)}
             {ok, {outgoing, AckPkt}, asleep(Duration, Channel)}
     end;
     end;
 
 
@@ -729,6 +785,31 @@ after_message_acked(ClientInfo, Msg, #channel{ctx = Ctx}) ->
 outgoing_and_update(Pkt) ->
 outgoing_and_update(Pkt) ->
     [{outgoing, Pkt}, {event, update}].
     [{outgoing, Pkt}, {event, update}].
 
 
+send_next_register_or_replay_publish(TopicName,
+                                     Channel = #channel{
+                                                  session = Session,
+                                                  register_awaiting_queue = []}) ->
+    case emqx_inflight:to_list(emqx_session:info(inflight, Session)) of
+        [] -> {ok, Channel};
+        [{PktId, {inflight_data, _, Msg, _}}] ->
+            case TopicName =:= emqx_message:topic(Msg) of
+                false ->
+                    ?SLOG(warning, #{ msg => "replay_inflight_message_failed"
+                                    , acked_topic_name => TopicName
+                                    , inflight_message => Msg
+                                    }),
+                    {ok, Channel};
+                true ->
+                    NMsg = emqx_message:set_flag(dup, true, Msg),
+                    handle_out(publish, {PktId, NMsg}, Channel)
+            end
+    end;
+send_next_register_or_replay_publish(_TopicName,
+                                     Channel = #channel{
+                                                  register_awaiting_queue = RAQueue}) ->
+    [RegisterReq | NRAQueue] = RAQueue,
+    handle_out(register, RegisterReq, Channel#channel{register_awaiting_queue = NRAQueue}).
+
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Handle Publish
 %% Handle Publish
 
 
@@ -786,7 +867,7 @@ check_pub_authz({TopicName, _Flags, _Data},
               #channel{ctx = Ctx, clientinfo = ClientInfo}) ->
               #channel{ctx = Ctx, clientinfo = ClientInfo}) ->
     case emqx_gateway_ctx:authorize(Ctx, ClientInfo, publish, TopicName) of
     case emqx_gateway_ctx:authorize(Ctx, ClientInfo, publish, TopicName) of
         allow -> ok;
         allow -> ok;
-        deny  -> {error, ?SN_RC_NOT_AUTHORIZE}
+        deny  -> {error, ?SN_RC2_NOT_AUTHORIZE}
     end.
     end.
 
 
 convert_pub_to_msg({TopicName, Flags, Data},
 convert_pub_to_msg({TopicName, Flags, Data},
@@ -857,7 +938,7 @@ preproc_subs_type(?SN_SUBSCRIBE_MSG_TYPE(?SN_NORMAL_TOPIC,
     %% and returns it within a SUBACK message
     %% and returns it within a SUBACK message
     case emqx_sn_registry:register_topic(Registry, ClientId, TopicName) of
     case emqx_sn_registry:register_topic(Registry, ClientId, TopicName) of
         {error, too_large} ->
         {error, too_large} ->
-            {error, ?SN_EXCEED_LIMITATION};
+            {error, ?SN_RC2_EXCEED_LIMITATION};
         {error, wildcard_topic} ->
         {error, wildcard_topic} ->
             %% If the client subscribes to a topic name which contains a
             %% If the client subscribes to a topic name which contains a
             %% wildcard character, the returning SUBACK message will contain
             %% wildcard character, the returning SUBACK message will contain
@@ -904,7 +985,7 @@ check_subscribe_authz({_TopicId, TopicName, _QoS},
         allow ->
         allow ->
             {ok, Channel};
             {ok, Channel};
         _ ->
         _ ->
-            {error, ?SN_RC_NOT_AUTHORIZE}
+            {error, ?SN_RC2_NOT_AUTHORIZE}
     end.
     end.
 
 
 run_client_subs_hook({TopicId, TopicName, QoS},
 run_client_subs_hook({TopicId, TopicName, QoS},
@@ -920,7 +1001,7 @@ run_client_subs_hook({TopicId, TopicName, QoS},
                             , topic_name => TopicName
                             , topic_name => TopicName
                             , reason => "'client.subscribe' filtered it"
                             , reason => "'client.subscribe' filtered it"
                             }),
                             }),
-            {error, ?SN_EXCEED_LIMITATION};
+            {error, ?SN_RC2_EXCEED_LIMITATION};
         [{NTopicName, NSubOpts}|_] ->
         [{NTopicName, NSubOpts}|_] ->
             {ok, {TopicId, NTopicName, NSubOpts}, Channel}
             {ok, {TopicId, NTopicName, NSubOpts}, Channel}
     end.
     end.
@@ -941,7 +1022,7 @@ do_subscribe({TopicId, TopicName, SubOpts},
                             , topic_name => TopicName
                             , topic_name => TopicName
                             , reason => emqx_reason_codes:text(?RC_QUOTA_EXCEEDED)
                             , reason => emqx_reason_codes:text(?RC_QUOTA_EXCEEDED)
                             }),
                             }),
-            {error, ?SN_EXCEED_LIMITATION}
+            {error, ?SN_RC2_EXCEED_LIMITATION}
     end.
     end.
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
@@ -1080,6 +1161,43 @@ handle_out(pubrel, MsgId, Channel) ->
 handle_out(pubcomp, MsgId, Channel) ->
 handle_out(pubcomp, MsgId, Channel) ->
     {ok, {outgoing, ?SN_PUBREC_MSG(?SN_PUBCOMP, MsgId)}, Channel};
     {ok, {outgoing, ?SN_PUBREC_MSG(?SN_PUBCOMP, MsgId)}, Channel};
 
 
+handle_out(register, {TopicId, MsgId, TopicName},
+           Channel = #channel{register_inflight = undefined}) ->
+    Outgoing = {outgoing, ?SN_REGISTER_MSG(TopicId, MsgId, TopicName)},
+    NChannel = Channel#channel{register_inflight = {TopicId, MsgId, TopicName}},
+    {ok, Outgoing, ensure_timer(register_timer, ?REGISTER_TIMEOUT, NChannel)};
+
+handle_out(register, {TopicId, MsgId, TopicName},
+           Channel = #channel{register_inflight = Inflight,
+                              register_awaiting_queue = RAQueue}) ->
+    case Inflight of
+        {_, _, TopicName} ->
+            ?SLOG(debug, #{ msg => "ingore_handle_out_register"
+                          , requested_register_msg =>
+                             #{ topic_id => TopicId
+                              , msg_id => MsgId
+                              , topic_name => TopicName
+                              }
+                          }),
+            {ok, Channel};
+        {InflightTopicId, InflightMsgId, InflightTopicName} ->
+            NRAQueue = RAQueue ++ [{TopicId, MsgId, TopicName}],
+            ?SLOG(debug, #{ msg => "put_register_msg_into_awaiting_queue"
+                          , inflight_register_msg =>
+                             #{ topic_id => InflightTopicId
+                              , msg_id => InflightMsgId
+                              , topic_name => InflightTopicName
+                              }
+                          , queued_register_msg =>
+                             #{ topic_id => TopicId
+                              , msg_id => MsgId
+                              , topic_name => TopicName
+                              }
+                          , register_awaiting_queue_size => length(NRAQueue)
+                          }),
+            {ok, Channel#channel{register_awaiting_queue = NRAQueue}}
+    end;
+
 handle_out(disconnect, RC, Channel) ->
 handle_out(disconnect, RC, Channel) ->
     DisPkt = ?SN_DISCONNECT_MSG(undefined),
     DisPkt = ?SN_DISCONNECT_MSG(undefined),
     {ok, [{outgoing, DisPkt}, {close, RC}], Channel}.
     {ok, [{outgoing, DisPkt}, {close, RC}], Channel}.
@@ -1196,7 +1314,7 @@ handle_call({subscribe, Topic, SubOpts}, _From, Channel) ->
                                        Topic, SubOpts}, Channel) of
                                        Topic, SubOpts}, Channel) of
                         {ok, {_, NTopicName, NSubOpts}, NChannel} ->
                         {ok, {_, NTopicName, NSubOpts}, NChannel} ->
                             reply({ok, {NTopicName, NSubOpts}}, NChannel);
                             reply({ok, {NTopicName, NSubOpts}}, NChannel);
-                        {error, ?SN_EXCEED_LIMITATION} ->
+                        {error, ?SN_RC2_EXCEED_LIMITATION} ->
                             reply({error, exceed_limitation}, Channel)
                             reply({error, exceed_limitation}, Channel)
                     end;
                     end;
                 _ ->
                 _ ->
@@ -1223,17 +1341,21 @@ handle_call(kick, _From, Channel) ->
 handle_call(discard, _From, Channel) ->
 handle_call(discard, _From, Channel) ->
     shutdown_and_reply(discarded, ok, Channel);
     shutdown_and_reply(discarded, ok, Channel);
 
 
-%% XXX: No Session Takeover
-%handle_call({takeover, 'begin'}, _From, Channel = #channel{session = Session}) ->
-%    reply(Session, Channel#channel{takeover = true});
-%
-%handle_call({takeover, 'end'}, _From, Channel = #channel{session  = Session,
-%                                                  pendings = Pendings}) ->
-%    ok = emqx_session:takeover(Session),
-%    %% TODO: Should not drain deliver here (side effect)
-%    Delivers = emqx_misc:drain_deliver(),
-%    AllPendings = lists:append(Delivers, Pendings),
-%    shutdown_and_reply(takenover, AllPendings, Channel);
+handle_call({takeover, 'begin'}, _From, Channel = #channel{session = Session}) ->
+    %% In MQTT-SN the meaning of a “clean session” is extended to the Will
+    %% feature, i.e. not only the subscriptions are persistent, but also the
+    %% Will topic and the Will message. [6.3]
+    %%
+    %% FIXME: We need to reply WillMsg and Session
+    reply(Session, Channel#channel{takeover = true});
+
+handle_call({takeover, 'end'}, _From, Channel = #channel{session  = Session,
+                                                  pendings = Pendings}) ->
+    ok = emqx_session:takeover(Session),
+    %% TODO: Should not drain deliver here (side effect)
+    Delivers = emqx_misc:drain_deliver(),
+    AllPendings = lists:append(Delivers, Pendings),
+    shutdown_and_reply(takenover, AllPendings, Channel);
 
 
 %handle_call(list_authz_cache, _From, Channel) ->
 %handle_call(list_authz_cache, _From, Channel) ->
 %    {reply, emqx_authz_cache:list_authz_cache(), Channel};
 %    {reply, emqx_authz_cache:list_authz_cache(), Channel};
@@ -1282,8 +1404,10 @@ handle_info({sock_closed, Reason},
     %emqx_zone:enable_flapping_detect(Zone)
     %emqx_zone:enable_flapping_detect(Zone)
     %    andalso emqx_flapping:detect(ClientInfo),
     %    andalso emqx_flapping:detect(ClientInfo),
     NChannel = ensure_disconnected(Reason, mabye_publish_will_msg(Channel)),
     NChannel = ensure_disconnected(Reason, mabye_publish_will_msg(Channel)),
-    %% XXX: Session keepper detect here
-    shutdown(Reason, NChannel);
+    case maybe_shutdown(Reason, NChannel) of
+        {ok, NChannel1} -> {ok, {event, disconnected}, NChannel1};
+        Shutdown -> Shutdown
+    end;
 
 
 handle_info({sock_closed, Reason},
 handle_info({sock_closed, Reason},
             Channel = #channel{conn_state = disconnected}) ->
             Channel = #channel{conn_state = disconnected}) ->
@@ -1305,6 +1429,14 @@ handle_info(Info, Channel) ->
                   }),
                   }),
     {ok, Channel}.
     {ok, Channel}.
 
 
+maybe_shutdown(Reason, Channel = #channel{conninfo = ConnInfo}) ->
+    case maps:get(expiry_interval, ConnInfo) of
+        ?UINT_MAX -> {ok, Channel};
+        I when I > 0 ->
+            {ok, ensure_timer(expire_timer, I, Channel)};
+        _ -> shutdown(Reason, Channel)
+    end.
+
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Ensure disconnected
 %% Ensure disconnected
 
 
@@ -1420,7 +1552,7 @@ handle_timeout(_TRef, {keepalive, StatVal},
             NChannel = Channel#channel{keepalive = NKeepalive},
             NChannel = Channel#channel{keepalive = NKeepalive},
             {ok, reset_timer(alive_timer, NChannel)};
             {ok, reset_timer(alive_timer, NChannel)};
         {error, timeout} ->
         {error, timeout} ->
-            handle_out(disconnect, ?RC_KEEP_ALIVE_TIMEOUT, Channel)
+            handle_out(disconnect, ?SN_RC2_KEEPALIVE_TIMEOUT, Channel)
     end;
     end;
 
 
 handle_timeout(_TRef, retry_delivery,
 handle_timeout(_TRef, retry_delivery,
@@ -1436,6 +1568,7 @@ handle_timeout(_TRef, retry_delivery,
             {ok, clean_timer(retry_timer, Channel#channel{session = NSession})};
             {ok, clean_timer(retry_timer, Channel#channel{session = NSession})};
         {ok, Publishes, Timeout, NSession} ->
         {ok, Publishes, Timeout, NSession} ->
             NChannel = Channel#channel{session = NSession},
             NChannel = Channel#channel{session = NSession},
+            %% XXX: These replay messages should awaiting register acked?
             handle_out(publish, Publishes, reset_timer(retry_timer, Timeout, NChannel))
             handle_out(publish, Publishes, reset_timer(retry_timer, Timeout, NChannel))
     end;
     end;
 
 
@@ -1454,6 +1587,9 @@ handle_timeout(_TRef, expire_awaiting_rel,
             {ok, reset_timer(await_timer, Timeout, Channel#channel{session = NSession})}
             {ok, reset_timer(await_timer, Timeout, Channel#channel{session = NSession})}
     end;
     end;
 
 
+handle_timeout(_TRef, expire_session, Channel) ->
+    shutdown(expired, Channel);
+
 handle_timeout(_TRef, expire_asleep, Channel) ->
 handle_timeout(_TRef, expire_asleep, Channel) ->
     shutdown(asleep_timeout, Channel);
     shutdown(asleep_timeout, Channel);
 
 
@@ -1563,7 +1699,8 @@ returncode_name(?SN_RC_ACCEPTED) -> accepted;
 returncode_name(?SN_RC_CONGESTION) -> rejected_congestion;
 returncode_name(?SN_RC_CONGESTION) -> rejected_congestion;
 returncode_name(?SN_RC_INVALID_TOPIC_ID) -> rejected_invaild_topic_id;
 returncode_name(?SN_RC_INVALID_TOPIC_ID) -> rejected_invaild_topic_id;
 returncode_name(?SN_RC_NOT_SUPPORTED) -> rejected_not_supported;
 returncode_name(?SN_RC_NOT_SUPPORTED) -> rejected_not_supported;
-returncode_name(?SN_RC_NOT_AUTHORIZE) -> rejected_not_authorize;
-returncode_name(?SN_RC_FAILED_SESSION) -> rejected_failed_open_session;
-returncode_name(?SN_EXCEED_LIMITATION) -> rejected_exceed_limitation;
+returncode_name(?SN_RC2_NOT_AUTHORIZE) -> rejected_not_authorize;
+returncode_name(?SN_RC2_FAILED_SESSION) -> rejected_failed_open_session;
+returncode_name(?SN_RC2_KEEPALIVE_TIMEOUT) -> rejected_keepalive_timeout;
+returncode_name(?SN_RC2_EXCEED_LIMITATION) -> rejected_exceed_limitation;
 returncode_name(_) -> accepted.
 returncode_name(_) -> accepted.

+ 65 - 23
apps/emqx_gateway/src/mqttsn/emqx_sn_frame.erl

@@ -291,42 +291,84 @@ message_type(16#1d) ->
 message_type(Type) ->
 message_type(Type) ->
     io_lib:format("Unknown Type ~p", [Type]).
     io_lib:format("Unknown Type ~p", [Type]).
 
 
+format(?SN_CONNECT_MSG(Flags, ProtocolId, Duration, ClientId)) ->
+    #mqtt_sn_flags{
+       will = Will,
+       clean_start = CleanStart} = Flags,
+    io_lib:format("SN_CONNECT(W~w, C~w, ProtocolId=~w, Duration=~w, "
+                  "ClientId=~ts)",
+                  [bool(Will), bool(CleanStart),
+                   ProtocolId, Duration, ClientId]);
+format(?SN_CONNACK_MSG(ReturnCode)) ->
+    io_lib:format("SN_CONNACK(ReturnCode=~w)", [ReturnCode]);
+format(?SN_WILLTOPICREQ_MSG()) ->
+    "SN_WILLTOPICREQ()";
+format(?SN_WILLTOPIC_MSG(Flags, Topic)) ->
+    #mqtt_sn_flags{
+       qos = QoS,
+       retain = Retain} = Flags,
+    io_lib:format("SN_WILLTOPIC(Q~w, R~w, Topic=~s)",
+                  [QoS, bool(Retain), Topic]);
+format(?SN_WILLTOPIC_EMPTY_MSG) ->
+    "SN_WILLTOPIC(_)";
+format(?SN_WILLMSGREQ_MSG()) ->
+    "SN_WILLMSGREQ()";
+format(?SN_WILLMSG_MSG(Msg)) ->
+    io_lib:format("SN_WILLMSG_MSG(Msg=~p)", [Msg]);
 format(?SN_PUBLISH_MSG(Flags, TopicId, MsgId, Data)) ->
 format(?SN_PUBLISH_MSG(Flags, TopicId, MsgId, Data)) ->
-    io_lib:format("mqtt_sn_message SN_PUBLISH, ~ts, TopicId=~w, MsgId=~w, Payload=~w",
-                  [format_flag(Flags), TopicId, MsgId, Data]);
-format(?SN_PUBACK_MSG(Flags, MsgId, ReturnCode)) ->
-    io_lib:format("mqtt_sn_message SN_PUBACK, ~ts, MsgId=~w, ReturnCode=~w",
-                  [format_flag(Flags), MsgId, ReturnCode]);
+    #mqtt_sn_flags{
+       dup = Dup,
+       qos = QoS,
+       retain = Retain,
+       topic_id_type = TopicIdType} = Flags,
+    io_lib:format("SN_PUBLISH(D~w, Q~w, R~w, TopicIdType=~w, TopicId=~w, "
+                  "MsgId=~w, Payload=~p)",
+                  [bool(Dup), QoS, bool(Retain),
+                   TopicIdType, TopicId, MsgId, Data]);
+format(?SN_PUBACK_MSG(TopicId, MsgId, ReturnCode)) ->
+    io_lib:format("SN_PUBACK(TopicId=~w, MsgId=~w, ReturnCode=~w)",
+                  [TopicId, MsgId, ReturnCode]);
 format(?SN_PUBREC_MSG(?SN_PUBCOMP, MsgId)) ->
 format(?SN_PUBREC_MSG(?SN_PUBCOMP, MsgId)) ->
-    io_lib:format("mqtt_sn_message SN_PUBCOMP, MsgId=~w", [MsgId]);
+    io_lib:format("SN_PUBCOMP(MsgId=~w)", [MsgId]);
 format(?SN_PUBREC_MSG(?SN_PUBREC, MsgId)) ->
 format(?SN_PUBREC_MSG(?SN_PUBREC, MsgId)) ->
-    io_lib:format("mqtt_sn_message SN_PUBREC, MsgId=~w", [MsgId]);
+    io_lib:format("SN_PUBREC(MsgId=~w)", [MsgId]);
 format(?SN_PUBREC_MSG(?SN_PUBREL, MsgId)) ->
 format(?SN_PUBREC_MSG(?SN_PUBREL, MsgId)) ->
-    io_lib:format("mqtt_sn_message SN_PUBREL, MsgId=~w", [MsgId]);
+    io_lib:format("SN_PUBREL(MsgId=~w)", [MsgId]);
 format(?SN_SUBSCRIBE_MSG(Flags, Msgid, Topic)) ->
 format(?SN_SUBSCRIBE_MSG(Flags, Msgid, Topic)) ->
-    io_lib:format("mqtt_sn_message SN_SUBSCRIBE, ~ts, MsgId=~w, TopicId=~w",
-                  [format_flag(Flags), Msgid, Topic]);
+    #mqtt_sn_flags{
+       dup = Dup,
+       qos = QoS,
+       topic_id_type = TopicIdType} = Flags,
+    io_lib:format("SN_SUBSCRIBE(D~w, Q~w, TopicIdType=~w, MsgId=~w, "
+                  "TopicId=~w)",
+                  [bool(Dup), QoS, TopicIdType, Msgid, Topic]);
 format(?SN_SUBACK_MSG(Flags, TopicId, MsgId, ReturnCode)) ->
 format(?SN_SUBACK_MSG(Flags, TopicId, MsgId, ReturnCode)) ->
-    io_lib:format("mqtt_sn_message SN_SUBACK, ~ts, MsgId=~w, TopicId=~w, ReturnCode=~w",
-                  [format_flag(Flags), MsgId, TopicId, ReturnCode]);
+    #mqtt_sn_flags{qos = QoS} = Flags,
+    io_lib:format("SN_SUBACK(GrantedQoS=~w, MsgId=~w, TopicId=~w, "
+                  "ReturnCode=~w)",
+                  [QoS, MsgId, TopicId, ReturnCode]);
 format(?SN_UNSUBSCRIBE_MSG(Flags, Msgid, Topic)) ->
 format(?SN_UNSUBSCRIBE_MSG(Flags, Msgid, Topic)) ->
-    io_lib:format("mqtt_sn_message SN_UNSUBSCRIBE, ~ts, MsgId=~w, TopicId=~w",
-                  [format_flag(Flags), Msgid, Topic]);
+    #mqtt_sn_flags{topic_id_type = TopicIdType} = Flags,
+    io_lib:format("SN_UNSUBSCRIBE(TopicIdType=~s, MsgId=~w, TopicId=~w)",
+                  [TopicIdType, Msgid, Topic]);
 format(?SN_UNSUBACK_MSG(MsgId)) ->
 format(?SN_UNSUBACK_MSG(MsgId)) ->
-    io_lib:format("mqtt_sn_message SN_UNSUBACK, MsgId=~w", [MsgId]);
+    io_lib:format("SN_UNSUBACK(MsgId=~w)", [MsgId]);
 format(?SN_REGISTER_MSG(TopicId, MsgId, TopicName)) ->
 format(?SN_REGISTER_MSG(TopicId, MsgId, TopicName)) ->
-    io_lib:format("mqtt_sn_message SN_REGISTER, TopicId=~w, MsgId=~w, TopicName=~w",
+    io_lib:format("SN_REGISTER(TopicId=~w, MsgId=~w, TopicName=~s)",
                   [TopicId, MsgId, TopicName]);
                   [TopicId, MsgId, TopicName]);
 format(?SN_REGACK_MSG(TopicId, MsgId, ReturnCode)) ->
 format(?SN_REGACK_MSG(TopicId, MsgId, ReturnCode)) ->
-    io_lib:format("mqtt_sn_message SN_REGACK, TopicId=~w, MsgId=~w, ReturnCode=~w",
+    io_lib:format("SN_REGACK(TopicId=~w, MsgId=~w, ReturnCode=~w)",
                   [TopicId, MsgId, ReturnCode]);
                   [TopicId, MsgId, ReturnCode]);
-format(#mqtt_sn_message{type = Type, variable = Var}) ->
-    io_lib:format("mqtt_sn_message type=~ts, Var=~w", [emqx_sn_frame:message_type(Type), Var]).
+format(?SN_PINGREQ_MSG(ClientId)) ->
+    io_lib:format("SN_PINGREQ(ClientId=~s)", [ClientId]);
+format(?SN_PINGRESP_MSG()) ->
+    "SN_PINGREQ()";
+format(?SN_DISCONNECT_MSG(Duration)) ->
+    io_lib:format("SN_DISCONNECT(Duration=~s)", [Duration]);
 
 
-format_flag(#mqtt_sn_flags{dup = Dup, qos = QoS, retain = Retain, will = Will, clean_start = CleanStart, topic_id_type = TopicType}) ->
-    io_lib:format("mqtt_sn_flags{dup=~p, qos=~p, retain=~p, will=~p, clean_session=~p, topic_id_type=~p}",
-                  [Dup, QoS, Retain, Will, CleanStart, TopicType]);
-format_flag(_Flag) -> "invalid flag".
+format(#mqtt_sn_message{type = Type, variable = Var}) ->
+    io_lib:format("mqtt_sn_message(type=~s, Var=~w)",
+                  [emqx_sn_frame:message_type(Type), Var]).
 
 
 is_message(#mqtt_sn_message{type = Type})
 is_message(#mqtt_sn_message{type = Type})
     when Type == ?SN_PUBLISH ->
     when Type == ?SN_PUBLISH ->

+ 5 - 4
apps/emqx_gateway/src/mqttsn/include/emqx_sn.hrl

@@ -54,13 +54,14 @@
 -define(SN_RC_INVALID_TOPIC_ID, 16#02).
 -define(SN_RC_INVALID_TOPIC_ID, 16#02).
 -define(SN_RC_NOT_SUPPORTED,    16#03).
 -define(SN_RC_NOT_SUPPORTED,    16#03).
 %% Custom Reason code by emqx
 %% Custom Reason code by emqx
--define(SN_RC_NOT_AUTHORIZE,    16#04).
--define(SN_RC_FAILED_SESSION,   16#05).
--define(SN_EXCEED_LIMITATION,   16#06).
+-define(SN_RC2_NOT_AUTHORIZE,     16#80).
+-define(SN_RC2_FAILED_SESSION,    16#81).
+-define(SN_RC2_KEEPALIVE_TIMEOUT, 16#82).
+-define(SN_RC2_EXCEED_LIMITATION, 16#83).
 
 
 -define(QOS_NEG1, 3).
 -define(QOS_NEG1, 3).
 
 
--type(mqtt_sn_return_code() :: ?SN_RC_ACCEPTED .. ?SN_EXCEED_LIMITATION).
+-type(mqtt_sn_return_code() :: ?SN_RC_ACCEPTED .. ?SN_RC2_EXCEED_LIMITATION).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% MQTT-SN Message
 %% MQTT-SN Message

+ 40 - 11
apps/emqx_gateway/src/proto/emqx_gateway_cm_proto_v1.erl

@@ -24,10 +24,12 @@
         , set_chan_info/4
         , set_chan_info/4
         , get_chan_stats/3
         , get_chan_stats/3
         , set_chan_stats/4
         , set_chan_stats/4
-        , discard_session/3
-        , kick_session/3
+        , kick_session/4
         , get_chann_conn_mod/3
         , get_chann_conn_mod/3
         , lookup_by_clientid/3
         , lookup_by_clientid/3
+        , call/4
+        , call/5
+        , cast/4
         ]).
         ]).
 
 
 -include_lib("emqx/include/bpapi.hrl").
 -include_lib("emqx/include/bpapi.hrl").
@@ -64,14 +66,41 @@ get_chan_stats(GwName, ClientId, ChanPid) ->
 set_chan_stats(GwName, ClientId, ChanPid, Stats) ->
 set_chan_stats(GwName, ClientId, ChanPid, Stats) ->
     rpc:call(node(ChanPid), emqx_gateway_cm, do_set_chan_stats, [GwName, ClientId, ChanPid, Stats]).
     rpc:call(node(ChanPid), emqx_gateway_cm, do_set_chan_stats, [GwName, ClientId, ChanPid, Stats]).
 
 
--spec discard_session(emqx_gateway_cm:gateway_name(), emqx_types:clientid(), pid()) -> _.
-discard_session(GwName, ClientId, ChanPid) ->
-    rpc:call(node(ChanPid), emqx_gateway_cm, do_discard_session, [GwName, ClientId, ChanPid]).
+-spec kick_session(emqx_gateway_cm:gateway_name(),
+                   kick | discard,
+                   emqx_types:clientid(), pid()) -> _.
+kick_session(GwName, Action, ClientId, ChanPid) ->
+    rpc:call(node(ChanPid),
+             emqx_gateway_cm, do_kick_session,
+             [GwName, Action, ClientId, ChanPid]).
 
 
--spec kick_session(emqx_gateway_cm:gateway_name(), emqx_types:clientid(), pid()) -> _.
-kick_session(GwName, ClientId, ChanPid) ->
-    rpc:call(node(ChanPid), emqx_gateway_cm, do_kick_session, [GwName, ClientId, ChanPid]).
-
--spec get_chann_conn_mod(emqx_gateway_cm:gateway_name(), emqx_types:clientid(), pid()) -> atom() | {badrpc, _}.
+-spec get_chann_conn_mod(emqx_gateway_cm:gateway_name(),
+                         emqx_types:clientid(),
+                         pid()) -> atom() | {badrpc, _}.
 get_chann_conn_mod(GwName, ClientId, ChanPid) ->
 get_chann_conn_mod(GwName, ClientId, ChanPid) ->
-    rpc:call(node(ChanPid), emqx_gateway_cm, do_get_chann_conn_mod, [GwName, ClientId, ChanPid]).
+    rpc:call(node(ChanPid), emqx_gateway_cm, do_get_chann_conn_mod,
+             [GwName, ClientId, ChanPid]).
+
+-spec call(emqx_gateway_cm:gateway_name(),
+           emqx_types:clientid(),
+           pid(),
+           term(),
+           timeout()) -> term() | {badrpc, _}.
+call(GwName, ClientId, ChanPid, Req, Timeout) ->
+    rpc:call(node(ChanPid), emqx_gateway_cm, do_call,
+             [GwName, ClientId, ChanPid, Req, Timeout]).
+
+-spec call(emqx_gateway_cm:gateway_name(),
+           emqx_types:clientid(),
+           pid(),
+           term()) -> term() | {badrpc, _}.
+call(GwName, ClientId, ChanPid, Req) ->
+    rpc:call(node(ChanPid), emqx_gateway_cm, do_call,
+             [GwName, ClientId, ChanPid, Req]).
+
+-spec cast(emqx_gateway_cm:gateway_name(),
+           emqx_types:clientid(),
+           pid(),
+           term()) -> term() | {badrpc, _}.
+cast(GwName, ClientId, ChanPid, Req) ->
+    rpc:call(node(ChanPid), emqx_gateway_cm, do_cast, [GwName, ClientId, ChanPid, Req]).

+ 159 - 2
apps/emqx_gateway/test/emqx_sn_protocol_SUITE.erl

@@ -964,6 +964,157 @@ t_publish_qos2_case03(_) ->
     ?assertEqual(<<2, ?SN_DISCONNECT>>, receive_response(Socket)),
     ?assertEqual(<<2, ?SN_DISCONNECT>>, receive_response(Socket)),
     gen_udp:close(Socket).
     gen_udp:close(Socket).
 
 
+t_delivery_qos1_register_invalid_topic_id(_) ->
+    Dup = 0,
+    QoS = 1,
+    Retain = 0,
+    Will = 0,
+    CleanSession = 0,
+    MsgId = 1,
+    TopicId = ?MAX_PRED_TOPIC_ID + 1,
+    {ok, Socket} = gen_udp:open(0, [binary]),
+    send_connect_msg(Socket, <<"test">>),
+    ?assertEqual(<<3, ?SN_CONNACK, 0>>, receive_response(Socket)),
+
+    send_subscribe_msg_normal_topic(Socket, QoS, <<"ab">>, MsgId),
+    ?assertEqual(<<8, ?SN_SUBACK, Dup:1, QoS:2, Retain:1, Will:1, CleanSession:1,
+                   ?SN_NORMAL_TOPIC:2, TopicId:16, MsgId:16, ?SN_RC_ACCEPTED>>,
+                 receive_response(Socket)),
+
+    Payload = <<"test-registration-inconsistent">>,
+    _ = emqx:publish(emqx_message:make(test, ?QOS_1, <<"ab">>, Payload)),
+
+    ?assertEqual(
+       <<(7 + byte_size(Payload)), ?SN_PUBLISH,
+         Dup:1, QoS:2, Retain:1,
+         Will:1, CleanSession:1, ?SN_NORMAL_TOPIC:2,
+         TopicId:16, MsgId:16, Payload/binary>>, receive_response(Socket)),
+    %% acked with ?SN_RC_INVALID_TOPIC_ID to
+    send_puback_msg(Socket, TopicId, MsgId, ?SN_RC_INVALID_TOPIC_ID),
+
+    ?assertEqual(
+       {TopicId, MsgId},
+       check_register_msg_on_udp(<<"ab">>, receive_response(Socket))),
+    send_regack_msg(Socket, TopicId, MsgId + 1),
+
+    %% receive the replay message
+    ?assertEqual(
+       <<(7 + byte_size(Payload)), ?SN_PUBLISH,
+         Dup:1, QoS:2, Retain:1,
+         Will:1, CleanSession:1, ?SN_NORMAL_TOPIC:2,
+         TopicId:16, (MsgId):16, Payload/binary>>, receive_response(Socket)),
+
+    send_disconnect_msg(Socket, undefined),
+    ?assertEqual(<<2, ?SN_DISCONNECT>>, receive_response(Socket)),
+    gen_udp:close(Socket).
+
+t_delivery_takeover_and_re_register(_) ->
+    MsgId = 1,
+    {ok, Socket} = gen_udp:open(0, [binary]),
+    send_connect_msg(Socket, <<"test">>, 0),
+    ?assertMatch(<<_, ?SN_CONNACK, ?SN_RC_ACCEPTED>>,
+                 receive_response(Socket)),
+
+    send_subscribe_msg_normal_topic(Socket, ?QOS_1, <<"topic-a">>, MsgId+1),
+    <<_, ?SN_SUBACK, 2#00100000,
+      TopicIdA:16, _:16, ?SN_RC_ACCEPTED>> = receive_response(Socket),
+
+    send_subscribe_msg_normal_topic(Socket, ?QOS_2, <<"topic-b">>, MsgId+2),
+    <<_, ?SN_SUBACK, 2#01000000,
+      TopicIdB:16, _:16, ?SN_RC_ACCEPTED>> = receive_response(Socket),
+
+    _ = emqx:publish(
+          emqx_message:make(test, ?QOS_1, <<"topic-a">>, <<"test-a">>)),
+    _ = emqx:publish(
+          emqx_message:make(test, ?QOS_2, <<"topic-b">>, <<"test-b">>)),
+
+    <<_, ?SN_PUBLISH, 2#00100000,
+      TopicIdA:16, MsgId1:16, "test-a">> = receive_response(Socket),
+    send_puback_msg(Socket, TopicIdA, MsgId1, ?SN_RC_ACCEPTED),
+
+    <<_, ?SN_PUBLISH, 2#01000000,
+      TopicIdB:16, MsgId2:16, "test-b">> = receive_response(Socket),
+    send_puback_msg(Socket, TopicIdB, MsgId2, ?SN_RC_ACCEPTED),
+
+    send_disconnect_msg(Socket, undefined),
+    ?assertMatch(<<2, ?SN_DISCONNECT>>, receive_response(Socket)),
+    gen_udp:close(Socket),
+
+    %% offline messages will be queued into the MQTT-SN session
+    _ = emqx:publish(emqx_message:make(test, ?QOS_1, <<"topic-a">>, <<"m1">>)),
+    _ = emqx:publish(emqx_message:make(test, ?QOS_1, <<"topic-a">>, <<"m2">>)),
+    _ = emqx:publish(emqx_message:make(test, ?QOS_1, <<"topic-a">>, <<"m3">>)),
+    _ = emqx:publish(emqx_message:make(test, ?QOS_2, <<"topic-b">>, <<"m1">>)),
+    _ = emqx:publish(emqx_message:make(test, ?QOS_2, <<"topic-b">>, <<"m2">>)),
+    _ = emqx:publish(emqx_message:make(test, ?QOS_2, <<"topic-b">>, <<"m3">>)),
+
+    emqx_logger:set_log_level(debug),
+    dbg:tracer(),dbg:p(all,call),
+    dbg:tp(emqx_gateway_cm,x),
+    %dbg:tpl(emqx_gateway_cm, request_stepdown,x),
+
+    {ok, NSocket} = gen_udp:open(0, [binary]),
+    send_connect_msg(NSocket, <<"test">>, 0),
+    ?assertMatch(<<_, ?SN_CONNACK, ?SN_RC_ACCEPTED>>,
+                 receive_response(NSocket)),
+
+    %% qos1
+
+    %% received the resume messages
+    <<_, ?SN_PUBLISH, 2#00100000,
+      TopicIdA:16, MsgIdA0:16, "m1">> = receive_response(NSocket),
+    %% only one qos1/qos2 inflight
+    ?assertEqual(udp_receive_timeout, receive_response(NSocket)),
+    send_puback_msg(NSocket, TopicIdA, MsgIdA0, ?SN_RC_INVALID_TOPIC_ID),
+    %% recv register
+    <<_, ?SN_REGISTER,
+      TopicIdA:16, RegMsgIdA:16, "topic-a">> = receive_response(NSocket),
+    send_regack_msg(NSocket, TopicIdA, RegMsgIdA),
+    %% received the replay messages
+    <<_, ?SN_PUBLISH, 2#00100000,
+      TopicIdA:16, MsgIdA1:16, "m1">> = receive_response(NSocket),
+    send_puback_msg(NSocket, TopicIdA, MsgIdA1, ?SN_RC_ACCEPTED),
+
+    <<_, ?SN_PUBLISH, 2#00100000,
+      TopicIdA:16, MsgIdA2:16, "m2">> = receive_response(NSocket),
+    send_puback_msg(NSocket, TopicIdA, MsgIdA2, ?SN_RC_ACCEPTED),
+
+    <<_, ?SN_PUBLISH, 2#00100000,
+      TopicIdA:16, MsgIdA3:16, "m3">> = receive_response(NSocket),
+    send_puback_msg(NSocket, TopicIdA, MsgIdA3, ?SN_RC_ACCEPTED),
+
+    %% qos2
+    <<_, ?SN_PUBLISH, 2#01000000,
+      TopicIdB:16, MsgIdB0:16, "m1">> = receive_response(NSocket),
+    %% only one qos1/qos2 inflight
+    ?assertEqual(udp_receive_timeout, receive_response(NSocket)),
+    send_puback_msg(NSocket, TopicIdB, MsgIdB0, ?SN_RC_INVALID_TOPIC_ID),
+    %% recv register
+    <<_, ?SN_REGISTER,
+      TopicIdB:16, RegMsgIdB:16, "topic-b">> = receive_response(NSocket),
+    send_regack_msg(NSocket, TopicIdB, RegMsgIdB),
+    %% received the replay messages
+    <<_, ?SN_PUBLISH, 2#01000000,
+      TopicIdB:16, MsgIdB1:16, "m1">> = receive_response(NSocket),
+    send_pubrec_msg(NSocket, MsgIdB1),
+    <<_, ?SN_PUBREL, MsgIdB1:16>> = receive_response(NSocket),
+    send_pubcomp_msg(NSocket, MsgIdB1),
+
+    <<_, ?SN_PUBLISH, 2#01000000,
+      TopicIdB:16, MsgIdB2:16, "m2">> = receive_response(NSocket),
+    send_puback_msg(NSocket, TopicIdB, MsgIdB2, ?SN_RC_ACCEPTED),
+
+    <<_, ?SN_PUBLISH, 2#01000000,
+      TopicIdB:16, MsgIdB3:16, "m3">> = receive_response(NSocket),
+    send_puback_msg(NSocket, TopicIdB, MsgIdB3, ?SN_RC_ACCEPTED),
+
+    %% no more messages
+    ?assertEqual(udp_receive_timeout, receive_response(NSocket)),
+
+    send_disconnect_msg(NSocket, undefined),
+    ?assertMatch(<<2, ?SN_DISCONNECT>>, receive_response(NSocket)),
+    gen_udp:close(NSocket).
+
 t_will_case01(_) ->
 t_will_case01(_) ->
     QoS = 1,
     QoS = 1,
     Duration = 1,
     Duration = 1,
@@ -1843,13 +1994,16 @@ send_searchgw_msg(Socket) ->
     ok = gen_udp:send(Socket, ?HOST, ?PORT, <<Length:8, MsgType:8, Radius:8>>).
     ok = gen_udp:send(Socket, ?HOST, ?PORT, <<Length:8, MsgType:8, Radius:8>>).
 
 
 send_connect_msg(Socket, ClientId) ->
 send_connect_msg(Socket, ClientId) ->
+    send_connect_msg(Socket, ClientId, 1).
+
+send_connect_msg(Socket, ClientId, CleanSession) when CleanSession == 0;
+                                                      CleanSession == 1 ->
     Length = 6 + byte_size(ClientId),
     Length = 6 + byte_size(ClientId),
     MsgType = ?SN_CONNECT,
     MsgType = ?SN_CONNECT,
     Dup = 0,
     Dup = 0,
     QoS = 0,
     QoS = 0,
     Retain = 0,
     Retain = 0,
     Will = 0,
     Will = 0,
-    CleanSession = 1,
     TopicIdType = 0,
     TopicIdType = 0,
     ProtocolId = 1,
     ProtocolId = 1,
     Duration = 10,
     Duration = 10,
@@ -1965,9 +2119,12 @@ send_publish_msg_short_topic(Socket, QoS, MsgId, TopicName, Data) ->
     ok = gen_udp:send(Socket, ?HOST, ?PORT, PublishPacket).
     ok = gen_udp:send(Socket, ?HOST, ?PORT, PublishPacket).
 
 
 send_puback_msg(Socket, TopicId, MsgId) ->
 send_puback_msg(Socket, TopicId, MsgId) ->
+    send_puback_msg(Socket, TopicId, MsgId, ?SN_RC_ACCEPTED).
+
+send_puback_msg(Socket, TopicId, MsgId, Rc) ->
     Length = 7,
     Length = 7,
     MsgType = ?SN_PUBACK,
     MsgType = ?SN_PUBACK,
-    PubAckPacket = <<Length:8, MsgType:8, TopicId:16, MsgId:16, ?SN_RC_ACCEPTED:8>>,
+    PubAckPacket = <<Length:8, MsgType:8, TopicId:16, MsgId:16, Rc:8>>,
     ?LOG("send_puback_msg TopicId=~p, MsgId=~p", [TopicId, MsgId]),
     ?LOG("send_puback_msg TopicId=~p, MsgId=~p", [TopicId, MsgId]),
     ok = gen_udp:send(Socket, ?HOST, ?PORT, PubAckPacket).
     ok = gen_udp:send(Socket, ?HOST, ?PORT, PubAckPacket).