Forráskód Böngészése

Merge branch 'develop' into introduce-new-bridge-impl

Gilbert 7 éve
szülő
commit
771f8c052a

+ 2 - 0
include/logger.hrl

@@ -35,6 +35,8 @@
 -define(ALERT(Format), ?LOG(alert, Format, [])).
 -define(ALERT(Format, Args), ?LOG(alert, Format, Args)).
 
+-define(LOG(Level, Format), ?LOG(Level, Format, [])).
+
 -define(LOG(Level, Format, Args),
         begin
           (logger:log(Level,#{},#{report_cb => fun(_) -> {(Format), (Args)} end}))

+ 281 - 239
src/emqx_connection.erl

@@ -14,20 +14,22 @@
 
 -module(emqx_connection).
 
--behaviour(gen_server).
+-behaviour(gen_statem).
 
 -include("emqx.hrl").
 -include("emqx_mqtt.hrl").
 -include("logger.hrl").
 
 -export([start_link/3]).
--export([info/1, attrs/1, stats/1]).
+-export([info/1]).
+-export([attrs/1]).
+-export([stats/1]).
 -export([kick/1]).
 -export([session/1]).
 
-%% gen_server callbacks
--export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2,
-         code_change/3]).
+%% gen_statem callbacks
+-export([idle/3, connected/3]).
+-export([init/1, callback_mode/0, code_change/4, terminate/3]).
 
 -record(state, {
           transport,
@@ -37,7 +39,7 @@
           conn_state,
           active_n,
           proto_state,
-          parser_state,
+          parse_state,
           gc_state,
           keepalive,
           enable_stats,
@@ -48,28 +50,29 @@
           idle_timeout
          }).
 
--define(DEFAULT_ACTIVE_N, 100).
+-define(ACTIVE_N, 100).
+-define(HANDLE(T, C, D), handle((T), (C), (D))).
 -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]).
 
 start_link(Transport, Socket, Options) ->
-    {ok, proc_lib:spawn_link(?MODULE, init, [[Transport, Socket, Options]])}.
+    {ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}.
 
 %%------------------------------------------------------------------------------
 %% API
 %%------------------------------------------------------------------------------
 
-%% for debug
+%% For debug
 info(CPid) when is_pid(CPid) ->
     call(CPid, info);
 
-info(#state{transport   = Transport,
-            socket      = Socket,
-            peername    = Peername,
-            sockname    = Sockname,
-            conn_state  = ConnState,
-            active_n    = ActiveN,
-            rate_limit  = RateLimit,
-            pub_limit   = PubLimit,
+info(#state{transport = Transport,
+            socket = Socket,
+            peername = Peername,
+            sockname = Sockname,
+            conn_state = ConnState,
+            active_n = ActiveN,
+            rate_limit = RateLimit,
+            pub_limit = PubLimit,
             proto_state = ProtoState}) ->
     ConnInfo = [{socktype, Transport:type(Socket)},
                 {peername, Peername},
@@ -81,10 +84,12 @@ info(#state{transport   = Transport,
     ProtoInfo = emqx_protocol:info(ProtoState),
     lists:usort(lists:append(ConnInfo, ProtoInfo)).
 
-rate_limit_info(undefined) -> #{};
-rate_limit_info(Limit) -> esockd_rate_limit:info(Limit).
+rate_limit_info(undefined) ->
+    #{};
+rate_limit_info(Limit) ->
+    esockd_rate_limit:info(Limit).
 
-%% for dashboard
+%% For dashboard
 attrs(CPid) when is_pid(CPid) ->
     call(CPid, attrs);
 
@@ -100,277 +105,305 @@ attrs(#state{peername = Peername,
 stats(CPid) when is_pid(CPid) ->
     call(CPid, stats);
 
-stats(#state{transport   = Transport,
-             socket      = Socket,
+stats(#state{transport = Transport,
+             socket = Socket,
              proto_state = ProtoState}) ->
-    lists:append([emqx_misc:proc_stats(),
-                  emqx_protocol:stats(ProtoState),
-                  case Transport:getstat(Socket, ?SOCK_STATS) of
-                      {ok, Ss}   -> Ss;
-                      {error, _} -> []
-                  end]).
+    SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of
+                    {ok, Ss}   -> Ss;
+                    {error, _} -> []
+                end,
+    lists:append([SockStats,
+                  emqx_misc:proc_stats(),
+                  emqx_protocol:stats(ProtoState)]).
 
-kick(CPid) -> call(CPid, kick).
+kick(CPid) ->
+    call(CPid, kick).
 
-session(CPid) -> call(CPid, session).
+session(CPid) ->
+    call(CPid, session).
 
 call(CPid, Req) ->
-    gen_server:call(CPid, Req, infinity).
+    gen_statem:call(CPid, Req, infinity).
 
 %%------------------------------------------------------------------------------
-%% gen_server callbacks
+%% gen_statem callbacks
 %%------------------------------------------------------------------------------
 
-init([Transport, RawSocket, Options]) ->
-    case Transport:wait(RawSocket) of
-        {ok, Socket} ->
-            Zone = proplists:get_value(zone, Options),
-            {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
-            {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
-            Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]),
-            RateLimit = init_limiter(proplists:get_value(rate_limit, Options)),
-            PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)),
-            ActiveN = proplists:get_value(active_n, Options, ?DEFAULT_ACTIVE_N),
-            EnableStats = emqx_zone:get_env(Zone, enable_stats, true),
-            IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000),
-            SendFun = send_fun(Transport, Socket),
-            ProtoState = emqx_protocol:init(#{peername => Peername,
-                                              sockname => Sockname,
-                                              peercert => Peercert,
-                                              sendfun  => SendFun}, Options),
-            ParserState = emqx_protocol:parser(ProtoState),
-            GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false),
-            GcState = emqx_gc:init(GcPolicy),
-            State = run_socket(#state{transport    = Transport,
-                                      socket       = Socket,
-                                      peername     = Peername,
-                                      conn_state   = running,
-                                      active_n     = ActiveN,
-                                      rate_limit   = RateLimit,
-                                      pub_limit    = PubLimit,
-                                      proto_state  = ProtoState,
-                                      parser_state = ParserState,
-                                      gc_state     = GcState,
-                                      enable_stats = EnableStats,
-                                      idle_timeout = IdleTimout
-                                     }),
-            ok = emqx_misc:init_proc_mng_policy(Zone),
-            emqx_logger:set_metadata_peername(esockd_net:format(Peername)),
-            gen_server:enter_loop(?MODULE, [{hibernate_after, IdleTimout}],
-                                  State, self(), IdleTimout);
-        {error, Reason} ->
-            {stop, Reason}
-    end.
+init({Transport, RawSocket, Options}) ->
+    {ok, Socket} = Transport:wait(RawSocket),
+    {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
+    {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
+    Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]),
+    emqx_logger:set_metadata_peername(esockd_net:format(Peername)),
+    Zone = proplists:get_value(zone, Options),
+    RateLimit = init_limiter(proplists:get_value(rate_limit, Options)),
+    PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)),
+    ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N),
+    EnableStats = emqx_zone:get_env(Zone, enable_stats, true),
+    IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000),
+    SendFun = fun(Data) -> Transport:async_send(Socket, Data) end,
+    ProtoState = emqx_protocol:init(#{peername => Peername,
+                                      sockname => Sockname,
+                                      peercert => Peercert,
+                                      sendfun  => SendFun}, Options),
+    ParseState = emqx_protocol:parser(ProtoState),
+    GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false),
+    GcState = emqx_gc:init(GcPolicy),
+    State = #state{transport    = Transport,
+                   socket       = Socket,
+                   peername     = Peername,
+                   conn_state   = running,
+                   active_n     = ActiveN,
+                   rate_limit   = RateLimit,
+                   pub_limit    = PubLimit,
+                   proto_state  = ProtoState,
+                   parse_state  = ParseState,
+                   gc_state     = GcState,
+                   enable_stats = EnableStats,
+                   idle_timeout = IdleTimout},
+    ok = emqx_misc:init_proc_mng_policy(Zone),
+    gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}],
+                          idle, State, self(), [IdleTimout]).
 
 init_limiter(undefined) ->
     undefined;
 init_limiter({Rate, Burst}) ->
     esockd_rate_limit:new(Rate, Burst).
 
-send_fun(Transport, Socket) ->
-    fun(Packet, Options) ->
-        Data = emqx_frame:serialize(Packet, Options),
-        try Transport:async_send(Socket, Data) of
-            ok ->
-                emqx_metrics:trans(inc, 'bytes/sent', iolist_size(Data)),
-                ok;
-            Error -> Error
-        catch
-            error:Error ->
-                {error, Error}
-        end
-    end.
-
-handle_call(info, _From, State) ->
-    {reply, info(State), State};
+callback_mode() ->
+    [state_functions, state_enter].
 
-handle_call(attrs, _From, State) ->
-    {reply, attrs(State), State};
+%%------------------------------------------------------------------------------
+%% Idle state
 
-handle_call(stats, _From, State) ->
-    {reply, stats(State), State};
+idle(enter, _, State) ->
+    ok = activate_socket(State),
+    keep_state_and_data;
 
-handle_call(kick, _From, State) ->
-    {stop, {shutdown, kicked}, ok, State};
+idle(timeout, _Timeout, State) ->
+    {stop, idle_timeout, State};
 
-handle_call(session, _From, State = #state{proto_state = ProtoState}) ->
-    {reply, emqx_protocol:session(ProtoState), State};
+idle(cast, {incoming, Packet}, State) ->
+    handle_packet(Packet, fun(NState) ->
+                              {next_state, connected, NState}
+                          end, State);
 
-handle_call(Req, _From, State) ->
-    ?LOG(error, "unexpected call: ~p", [Req]),
-    {reply, ignored, State}.
+idle(EventType, Content, State) ->
+    ?HANDLE(EventType, Content, State).
 
-handle_cast(Msg, State) ->
-    ?LOG(error, "unexpected cast: ~p", [Msg]),
-    {noreply, State}.
-
-handle_info({deliver, PubOrAck}, State = #state{proto_state = ProtoState}) ->
+%%------------------------------------------------------------------------------
+%% Connected state
+
+connected(enter, _, _State) ->
+    %% What to do?
+    keep_state_and_data;
+
+%% Handle Input
+connected(cast, {incoming, Packet = ?PACKET(Type)}, State) ->
+    _ = emqx_metrics:received(Packet),
+    (Type == ?PUBLISH) andalso emqx_pd:update_counter(incoming_pubs, 1),
+    handle_packet(Packet, fun(NState) ->
+                              {keep_state, NState}
+                          end, State);
+
+%% Handle Output
+connected(info, {deliver, PubOrAck}, State = #state{proto_state = ProtoState}) ->
     case emqx_protocol:deliver(PubOrAck, ProtoState) of
-        {ok, ProtoState1} ->
-            State1 = State#state{proto_state = ProtoState1},
-            {noreply, maybe_gc(PubOrAck, ensure_stats_timer(State1))};
+        {ok, NProtoState} ->
+            NState = State#state{proto_state = NProtoState},
+            {keep_state, maybe_gc(PubOrAck, NState)};
         {error, Reason} ->
             shutdown(Reason, State)
     end;
 
-handle_info({timeout, Timer, emit_stats},
-            State = #state{stats_timer = Timer,
-                           proto_state = ProtoState,
-                           gc_state = GcState}) ->
+%% Start Keepalive
+connected(info, {keepalive, start, Interval},
+          State = #state{transport = Transport, socket = Socket}) ->
+    StatFun = fun() ->
+                case Transport:getstat(Socket, [recv_oct]) of
+                    {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct};
+                    Error -> Error
+                end
+              end,
+    case emqx_keepalive:start(StatFun, Interval, {keepalive, check}) of
+        {ok, KeepAlive} ->
+            {keep_state, State#state{keepalive = KeepAlive}};
+        {error, Error} ->
+            shutdown(Error, State)
+    end;
+
+%% Keepalive timer
+connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) ->
+    case emqx_keepalive:check(KeepAlive) of
+        {ok, KeepAlive1} ->
+            {keep_state, State#state{keepalive = KeepAlive1}};
+        {error, timeout} ->
+            shutdown(keepalive_timeout, State);
+        {error, Error} ->
+            shutdown(Error, State)
+    end;
+
+connected(EventType, Content, State) ->
+    ?HANDLE(EventType, Content, State).
+
+%% Handle call
+handle({call, From}, info, State) ->
+    reply(From, info(State), State);
+
+handle({call, From}, attrs, State) ->
+    reply(From, attrs(State), State);
+
+handle({call, From}, stats, State) ->
+    reply(From, stats(State), State);
+
+handle({call, From}, kick, State) ->
+    ok = gen_statem:reply(From, ok),
+    shutdown(kicked, State);
+
+handle({call, From}, session, State = #state{proto_state = ProtoState}) ->
+    reply(From, emqx_protocol:session(ProtoState), State);
+
+handle({call, From}, Req, State) ->
+    ?LOG(error, "unexpected call: ~p", [Req]),
+    reply(From, ignored, State);
+
+%% Handle cast
+handle(cast, Msg, State) ->
+    ?LOG(error, "unexpected cast: ~p", [Msg]),
+    {keep_state, State};
+
+%% Handle Incoming
+handle(info, {Inet, _Sock, Data}, State) when Inet == tcp; Inet == ssl ->
+    Oct = iolist_size(Data),
+    ?LOG(debug, "RECV ~p", [Data]),
+    emqx_pd:update_counter(incoming_bytes, Oct),
+    emqx_metrics:trans(inc, 'bytes/received', Oct),
+    NState = ensure_stats_timer(maybe_gc({1, Oct}, State)),
+    process_incoming(Data, [], NState);
+
+handle(info, {Error, _Sock, Reason}, State)
+  when Error == tcp_error; Error == ssl_error ->
+    shutdown(Reason, State);
+
+handle(info, {Closed, _Sock}, State)
+  when Closed == tcp_closed; Closed == ssl_closed ->
+    shutdown(closed, State);
+
+handle(info, {tcp_passive, _Sock}, State) ->
+    %% Rate limit here:)
+    NState = ensure_rate_limit(State),
+    ok = activate_socket(NState),
+    {keep_state, NState};
+
+handle(info, activate_socket, State) ->
+    %% Rate limit timer expired.
+    ok = activate_socket(State),
+    {keep_state, State#state{conn_state = running, limit_timer = undefined}};
+
+handle(info, {inet_reply, _Sock, ok}, State) ->
+    %% something sent
+    {keep_state, ensure_stats_timer(State)};
+
+handle(info, {inet_reply, _Sock, {error, Reason}}, State) ->
+    shutdown(Reason, State);
+
+handle(info, {timeout, Timer, emit_stats},
+       State = #state{stats_timer = Timer,
+                      proto_state = ProtoState,
+                      gc_state = GcState}) ->
     emqx_metrics:commit(),
     emqx_cm:set_conn_stats(emqx_protocol:client_id(ProtoState), stats(State)),
-    NewState = State#state{stats_timer = undefined},
+    NState = State#state{stats_timer = undefined},
     Limits = erlang:get(force_shutdown_policy),
     case emqx_misc:conn_proc_mng_policy(Limits) of
         continue ->
-            {noreply, NewState};
+            {keep_state, NState};
         hibernate ->
             %% going to hibernate, reset gc stats
             GcState1 = emqx_gc:reset(GcState),
-            {noreply, NewState#state{gc_state = GcState1}, hibernate};
+            {keep_state, NState#state{gc_state = GcState1}, hibernate};
         {shutdown, Reason} ->
             ?LOG(warning, "shutdown due to ~p", [Reason]),
-            shutdown(Reason, NewState)
+            shutdown(Reason, NState)
     end;
 
-handle_info(timeout, State) ->
-    shutdown(idle_timeout, State);
-
-handle_info({shutdown, Reason}, State) ->
-    shutdown(Reason, State);
-
-handle_info({shutdown, discard, {ClientId, ByPid}}, State) ->
+handle(info, {shutdown, discard, {ClientId, ByPid}}, State) ->
     ?LOG(warning, "discarded by ~s:~p", [ClientId, ByPid]),
     shutdown(discard, State);
 
-handle_info({shutdown, conflict, {ClientId, NewPid}}, State) ->
+handle(info, {shutdown, conflict, {ClientId, NewPid}}, State) ->
     ?LOG(warning, "clientid '~s' conflict with ~p", [ClientId, NewPid]),
     shutdown(conflict, State);
 
-handle_info({TcpOrSsL, _Sock, Data}, State) when TcpOrSsL =:= tcp; TcpOrSsL =:= ssl ->
-    process_incoming(Data, State);
-
-%% Rate limit here, cool:)
-handle_info({tcp_passive, _Sock}, State) ->
-    {noreply, run_socket(ensure_rate_limit(State))};
-%% FIXME Later
-handle_info({ssl_passive, _Sock}, State) ->
-    {noreply, run_socket(ensure_rate_limit(State))};
-
-handle_info({Err, _Sock, Reason}, State) when Err =:= tcp_error; Err =:= ssl_error ->
+handle(info, {shutdown, Reason}, State) ->
     shutdown(Reason, State);
 
-handle_info({Closed, _Sock}, State) when Closed =:= tcp_closed; Closed =:= ssl_closed ->
-    shutdown(closed, State);
-
-%% Rate limit timer
-handle_info(activate_sock, State) ->
-    {noreply, run_socket(State#state{conn_state = running, limit_timer = undefined})};
-
-handle_info({inet_reply, _Sock, ok}, State) ->
-    {noreply, State};
-
-handle_info({inet_reply, _Sock, {error, Reason}}, State) ->
-    shutdown(Reason, State);
-
-handle_info({keepalive, start, Interval}, State = #state{transport = Transport, socket = Socket}) ->
-    ?LOG(debug, "Keepalive at the interval of ~p", [Interval]),
-    StatFun = fun() ->
-                case Transport:getstat(Socket, [recv_oct]) of
-                    {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct};
-                    Error                       -> Error
-                end
-             end,
-    case emqx_keepalive:start(StatFun, Interval, {keepalive, check}) of
-        {ok, KeepAlive} ->
-            {noreply, State#state{keepalive = KeepAlive}};
-        {error, Error} ->
-            shutdown(Error, State)
-    end;
-
-handle_info({keepalive, check}, State = #state{keepalive = KeepAlive}) ->
-    case emqx_keepalive:check(KeepAlive) of
-        {ok, KeepAlive1} ->
-            {noreply, State#state{keepalive = KeepAlive1}};
-        {error, timeout} ->
-            shutdown(keepalive_timeout, State);
-        {error, Error} ->
-            shutdown(Error, State)
-    end;
-
-handle_info(Info, State) ->
+handle(info, Info, State) ->
     ?LOG(error, "unexpected info: ~p", [Info]),
-    {noreply, State}.
+    {keep_state, State}.
+
+code_change(_Vsn, State, Data, _Extra) ->
+    {ok, State, Data}.
 
-terminate(Reason, #state{transport   = Transport,
-                         socket      = Socket,
-                         keepalive   = KeepAlive,
-                         proto_state = ProtoState}) ->
+terminate(Reason, _StateName, #state{transport = Transport,
+                                     socket = Socket,
+                                     keepalive = KeepAlive,
+                                     proto_state = ProtoState}) ->
     ?LOG(debug, "Terminated for ~p", [Reason]),
     Transport:fast_close(Socket),
     emqx_keepalive:cancel(KeepAlive),
     case {ProtoState, Reason} of
         {undefined, _} -> ok;
         {_, {shutdown, Error}} ->
-            emqx_protocol:shutdown(Error, ProtoState);
+            emqx_protocol:terminate(Error, ProtoState);
         {_, Reason} ->
-            emqx_protocol:shutdown(Reason, ProtoState)
+            emqx_protocol:terminate(Reason, ProtoState)
     end.
 
-code_change(_OldVsn, State, _Extra) ->
-    {ok, State}.
-
-%%------------------------------------------------------------------------------
-%% Internals: process incoming, parse and handle packets
 %%------------------------------------------------------------------------------
+%% Process incoming data
 
-process_incoming(Data, State) ->
-    Oct = iolist_size(Data),
-    ?LOG(debug, "RECV ~p", [Data]),
-    emqx_pd:update_counter(incoming_bytes, Oct),
-    emqx_metrics:trans(inc, 'bytes/received', Oct),
-    case handle_packet(Data, State) of
-        {noreply, State1} ->
-            State2 = maybe_gc({1, Oct}, State1),
-            {noreply, ensure_stats_timer(State2)};
-        Shutdown -> Shutdown
-    end.
+process_incoming(<<>>, Packets, State) ->
+    {keep_state, State, next_events(Packets)};
 
-%% Parse and handle packets
-handle_packet(<<>>, State) ->
-    {noreply, State};
-
-handle_packet(Data, State = #state{proto_state  = ProtoState,
-                                   parser_state = ParserState,
-                                   idle_timeout = IdleTimeout}) ->
-    try emqx_frame:parse(Data, ParserState) of
-        {more, ParserState1} ->
-            {noreply, State#state{parser_state = ParserState1}, IdleTimeout};
-        {ok, Packet = ?PACKET(Type), Rest} ->
-            emqx_metrics:received(Packet),
-            (Type == ?PUBLISH) andalso emqx_pd:update_counter(incoming_pubs, 1),
-            case emqx_protocol:received(Packet, ProtoState) of
-                {ok, ProtoState1} ->
-                    handle_packet(Rest, reset_parser(State#state{proto_state = ProtoState1}));
-                {error, Reason} ->
-                    ?LOG(error, "Process packet error - ~p", [Reason]),
-                    shutdown(Reason, State);
-                {error, Reason, ProtoState1} ->
-                    shutdown(Reason, State#state{proto_state = ProtoState1});
-                {stop, Error, ProtoState1} ->
-                    stop(Error, State#state{proto_state = ProtoState1})
-            end;
+process_incoming(Data, Packets, State = #state{parse_state = ParseState}) ->
+    try emqx_frame:parse(Data, ParseState) of
+        {ok, Packet, Rest} ->
+            process_incoming(Rest, [Packet|Packets], reset_parser(State));
+        {more, NewParseState} ->
+            {keep_state, State#state{parse_state = NewParseState}, next_events(Packets)};
         {error, Reason} ->
-            ?LOG(error, "Parse frame error - ~p", [Reason]),
             shutdown(Reason, State)
     catch
-        _:Error ->
-            ?LOG(error, "Parse failed for ~p~nError data:~p", [Error, Data]),
-            shutdown(parse_error, State)
+        _:Error:Stk->
+            ?LOG(error, "Parse failed for ~p~nStacktrace:~p~nError data:~p", [Error, Stk, Data]),
+            shutdown(Error, State)
     end.
 
 reset_parser(State = #state{proto_state = ProtoState}) ->
-    State#state{parser_state = emqx_protocol:parser(ProtoState)}.
+    State#state{parse_state = emqx_protocol:parser(ProtoState)}.
+
+next_events([]) ->
+    [];
+next_events([Packet]) ->
+    {next_event, cast, {incoming, Packet}};
+next_events(Packets) ->
+    [next_events([Packet]) || Packet <- lists:reverse(Packets)].
+
+%%------------------------------------------------------------------------------
+%% Handle incoming packet
+
+handle_packet(Packet, SuccFun, State = #state{proto_state = ProtoState}) ->
+    case emqx_protocol:received(Packet, ProtoState) of
+        {ok, NProtoState} ->
+            SuccFun(State#state{proto_state = NProtoState});
+        {error, Reason} ->
+            shutdown(Reason, State);
+        {error, Reason, NProtoState} ->
+            shutdown(Reason, State#state{proto_state = NProtoState});
+        {stop, Error, NProtoState} ->
+            stop(Error, State#state{proto_state = NProtoState})
+    end.
 
 %%------------------------------------------------------------------------------
 %% Ensure rate limit
@@ -389,27 +422,27 @@ ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) ->
        {0, Rl1} ->
            ensure_rate_limit(Limiters, setelement(Pos, State, Rl1));
        {Pause, Rl1} ->
-           TRef = erlang:send_after(Pause, self(), activate_sock),
+           TRef = erlang:send_after(Pause, self(), activate_socket),
            setelement(Pos, State#state{conn_state = blocked, limit_timer = TRef}, Rl1)
    end.
 
 %%------------------------------------------------------------------------------
 %% Activate socket
 
-run_socket(State = #state{conn_state = blocked}) ->
-    State;
+activate_socket(#state{conn_state = blocked}) ->
+    ok;
 
-run_socket(State = #state{transport = Transport, socket = Socket, active_n = N}) ->
+activate_socket(#state{transport = Transport, socket = Socket, active_n = N}) ->
     TrueOrN = case Transport:is_ssl(Socket) of
                   true  -> true; %% Cannot set '{active, N}' for SSL:(
                   false -> N
               end,
-    ensure_ok_or_exit(Transport:setopts(Socket, [{active, TrueOrN}])),
-    State.
-
-ensure_ok_or_exit(ok) -> ok;
-ensure_ok_or_exit({error, Reason}) ->
-    self() ! {shutdown, Reason}.
+    case Transport:setopts(Socket, [{active, TrueOrN}]) of
+        ok -> ok;
+        {error, Reason} ->
+            self() ! {shutdown, Reason},
+            ok
+    end.
 
 %%------------------------------------------------------------------------------
 %% Ensure stats timer
@@ -418,6 +451,7 @@ ensure_stats_timer(State = #state{enable_stats = true,
                                   stats_timer = undefined,
                                   idle_timeout = IdleTimeout}) ->
     State#state{stats_timer = emqx_misc:start_timer(IdleTimeout, emit_stats)};
+
 ensure_stats_timer(State) -> State.
 
 %%------------------------------------------------------------------------------
@@ -425,20 +459,28 @@ ensure_stats_timer(State) -> State.
 
 maybe_gc(_, State = #state{gc_state = undefined}) ->
     State;
-maybe_gc({publish, _PacketId, #message{payload = Payload}}, State) ->
+maybe_gc({publish, _, #message{payload = Payload}}, State) ->
     Oct = iolist_size(Payload),
     maybe_gc({1, Oct}, State);
+maybe_gc(Packets, State) when is_list(Packets) ->
+    {Cnt, Oct} =
+    lists:unzip([{1, iolist_size(Payload)}
+                 || {publish, _, #message{payload = Payload}} <- Packets]),
+    maybe_gc({lists:sum(Cnt), lists:sum(Oct)}, State);
 maybe_gc({Cnt, Oct}, State = #state{gc_state = GCSt}) ->
     {_, GCSt1} = emqx_gc:run(Cnt, Oct, GCSt),
     State#state{gc_state = GCSt1};
-maybe_gc(_, State) ->
-    State.
+maybe_gc(_, State) -> State.
 
 %%------------------------------------------------------------------------------
-%% Shutdown or stop
+%% Helper functions
+
+reply(From, Reply, State) ->
+    {keep_state, State, [{reply, From, Reply}]}.
 
 shutdown(Reason, State) ->
     stop({shutdown, Reason}, State).
 
 stop(Reason, State) ->
     {stop, Reason, State}.
+

+ 1 - 1
src/emqx_listeners.erl

@@ -56,7 +56,7 @@ start_listener(Proto, ListenOn, Options) when Proto == http; Proto == ws ->
 
 %% Start MQTT/WSS listener
 start_listener(Proto, ListenOn, Options) when Proto == https; Proto == wss ->
-    Dispatch = cowboy_router:compile([{'_', [{mqtt_path(Options),  emqx_ws_connection, Options}]}]),
+    Dispatch = cowboy_router:compile([{'_', [{mqtt_path(Options), emqx_ws_connection, Options}]}]),
     start_http_listener(fun cowboy:start_tls/3, 'mqtt:wss', ListenOn, ranch_opts(Options), Dispatch).
 
 start_mqtt_listener(Name, ListenOn, Options) ->

+ 103 - 85
src/emqx_protocol.erl

@@ -29,10 +29,10 @@
 -export([parser/1]).
 -export([session/1]).
 -export([received/2]).
--export([process_packet/2]).
+-export([process/2]).
 -export([deliver/2]).
 -export([send/2]).
--export([shutdown/2]).
+-export([terminate/2]).
 
 -export_type([state/0]).
 
@@ -53,6 +53,8 @@
           clean_start,
           topic_aliases,
           packet_size,
+          will_topic,
+          will_msg,
           keepalive,
           mountpoint,
           is_super,
@@ -130,11 +132,13 @@ info(PState = #pstate{conn_props    = ConnProps,
                       ack_props     = AckProps,
                       session       = Session,
                       topic_aliases = Aliases,
+                      will_msg      = WillMsg,
                       enable_acl    = EnableAcl}) ->
     attrs(PState) ++ [{conn_props, ConnProps},
                       {ack_props, AckProps},
                       {session, Session},
                       {topic_aliases, Aliases},
+                      {will_msg, WillMsg},
                       {enable_acl, EnableAcl}].
 
 attrs(#pstate{zone         = Zone,
@@ -218,15 +222,16 @@ parser(#pstate{packet_size = Size, proto_ver = Ver}) ->
 %% Packet Received
 %%------------------------------------------------------------------------------
 
-set_protover(?CONNECT_PACKET(#mqtt_packet_connect{
-                                proto_ver = ProtoVer}),
-             PState) ->
-    PState#pstate{ proto_ver = ProtoVer };
+set_protover(?CONNECT_PACKET(#mqtt_packet_connect{proto_ver = ProtoVer}), PState) ->
+    PState#pstate{proto_ver = ProtoVer};
 set_protover(_Packet, PState) ->
     PState.
 
--spec(received(emqx_mqtt_types:packet(), state()) ->
-    {ok, state()} | {error, term()} | {error, term(), state()} | {stop, term(), state()}).
+-spec(received(emqx_mqtt_types:packet(), state())
+      -> {ok, state()}
+       | {error, term()}
+       | {error, term(), state()}
+       | {stop, term(), state()}).
 received(?PACKET(Type), PState = #pstate{connected = false}) when Type =/= ?CONNECT ->
     {error, proto_not_connected, PState};
 
@@ -234,15 +239,15 @@ received(?PACKET(?CONNECT), PState = #pstate{connected = true}) ->
     {error, proto_unexpected_connect, PState};
 
 received(Packet = ?PACKET(Type), PState) ->
-    PState1 = set_protover(Packet, PState),
     trace(recv, Packet),
+    PState1 = set_protover(Packet, PState),
     try emqx_packet:validate(Packet) of
         true ->
             case preprocess_properties(Packet, PState1) of
+                {ok, Packet1, PState2} ->
+                    process(Packet1, inc_stats(recv, Type, PState2));
                 {error, ReasonCode} ->
-                    {error, ReasonCode, PState1};
-                {Packet1, PState2} ->
-                    process_packet(Packet1, inc_stats(recv, Type, PState2))
+                    {error, ReasonCode, PState1}
             end
     catch
         error:protocol_error ->
@@ -268,13 +273,14 @@ received(Packet = ?PACKET(Type), PState) ->
 %%------------------------------------------------------------------------------
 %% Preprocess MQTT Properties
 %%------------------------------------------------------------------------------
+
 preprocess_properties(Packet = #mqtt_packet{
                                    variable = #mqtt_packet_connect{
                                                   properties = #{'Topic-Alias-Maximum' := ToClient}
                                               }
                                },
                       PState = #pstate{topic_alias_maximum = TopicAliasMaximum}) ->
-    {Packet, PState#pstate{topic_alias_maximum = TopicAliasMaximum#{to_client => ToClient}}};
+    {ok, Packet, PState#pstate{topic_alias_maximum = TopicAliasMaximum#{to_client => ToClient}}};
 
 %% Subscription Identifier
 preprocess_properties(Packet = #mqtt_packet{
@@ -285,7 +291,7 @@ preprocess_properties(Packet = #mqtt_packet{
                                  },
                       PState = #pstate{proto_ver = ?MQTT_PROTO_V5}) ->
     TopicFilters1 = [{Topic, SubOpts#{subid => SubId}} || {Topic, SubOpts} <- TopicFilters],
-    {Packet#mqtt_packet{variable = Subscribe#mqtt_packet_subscribe{topic_filters = TopicFilters1}}, PState};
+    {ok, Packet#mqtt_packet{variable = Subscribe#mqtt_packet_subscribe{topic_filters = TopicFilters1}}, PState};
 
 %% Topic Alias Mapping
 preprocess_properties(#mqtt_packet{
@@ -306,8 +312,8 @@ preprocess_properties(Packet = #mqtt_packet{
                                        topic_alias_maximum = #{from_client := TopicAliasMaximum}}) ->
     case AliasId =< TopicAliasMaximum of
         true ->
-            {Packet#mqtt_packet{variable = Publish#mqtt_packet_publish{
-                                               topic_name = maps:get(AliasId, Aliases, <<>>)}}, PState};
+            {ok, Packet#mqtt_packet{variable = Publish#mqtt_packet_publish{
+                                                 topic_name = maps:get(AliasId, Aliases, <<>>)}}, PState};
         false ->
             deliver({disconnect, ?RC_TOPIC_ALIAS_INVALID}, PState),
             {error, ?RC_TOPIC_ALIAS_INVALID}
@@ -323,28 +329,28 @@ preprocess_properties(Packet = #mqtt_packet{
                                        topic_alias_maximum = #{from_client := TopicAliasMaximum}}) ->
     case AliasId =< TopicAliasMaximum of
         true ->
-            {Packet, PState#pstate{topic_aliases = maps:put(AliasId, Topic, Aliases)}};
+            {ok, Packet, PState#pstate{topic_aliases = maps:put(AliasId, Topic, Aliases)}};
         false ->
             deliver({disconnect, ?RC_TOPIC_ALIAS_INVALID}, PState),
             {error, ?RC_TOPIC_ALIAS_INVALID}
     end;
 
 preprocess_properties(Packet, PState) ->
-    {Packet, PState}.
+    {ok, Packet, PState}.
 
 %%------------------------------------------------------------------------------
 %% Process MQTT Packet
 %%------------------------------------------------------------------------------
-process_packet(?CONNECT_PACKET(
-                  #mqtt_packet_connect{proto_name  = ProtoName,
-                                       proto_ver   = ProtoVer,
-                                       is_bridge   = IsBridge,
-                                       clean_start = CleanStart,
-                                       keepalive   = Keepalive,
-                                       properties  = ConnProps,
-                                       client_id   = ClientId,
-                                       username    = Username,
-                                       password    = Password} = ConnPkt), PState) ->
+process(?CONNECT_PACKET(
+           #mqtt_packet_connect{proto_name  = ProtoName,
+                                proto_ver   = ProtoVer,
+                                is_bridge   = IsBridge,
+                                clean_start = CleanStart,
+                                keepalive   = Keepalive,
+                                properties  = ConnProps,
+                                client_id   = ClientId,
+                                username    = Username,
+                                password    = Password} = ConnPkt), PState) ->
 
     NewClientId = maybe_use_username_as_clientid(ClientId, Username, PState),
 
@@ -394,17 +400,17 @@ process_packet(?CONNECT_PACKET(
               {ReasonCode, PState1}
       end);
 
-process_packet(Packet = ?PUBLISH_PACKET(?QOS_0, Topic, _PacketId, _Payload), PState) ->
+process(Packet = ?PUBLISH_PACKET(?QOS_0, Topic, _PacketId, _Payload), PState) ->
     case check_publish(Packet, PState) of
         {ok, PState1} ->
             do_publish(Packet, PState1);
         {error, ReasonCode} ->
             ?LOG(warning, "Cannot publish qos0 message to ~s for ~s",
-                [Topic, emqx_reason_codes:text(ReasonCode)]),
+                 [Topic, emqx_reason_codes:text(ReasonCode)]),
             do_acl_deny_action(Packet, ReasonCode, PState)
     end;
 
-process_packet(Packet = ?PUBLISH_PACKET(?QOS_1, Topic, PacketId, _Payload), PState) ->
+process(Packet = ?PUBLISH_PACKET(?QOS_1, Topic, PacketId, _Payload), PState) ->
     case check_publish(Packet, PState) of
         {ok, PState1} ->
             do_publish(Packet, PState1);
@@ -414,30 +420,28 @@ process_packet(Packet = ?PUBLISH_PACKET(?QOS_1, Topic, PacketId, _Payload), PSta
             case deliver({puback, PacketId, ReasonCode}, PState) of
                 {ok, PState1} ->
                     do_acl_deny_action(Packet, ReasonCode, PState1);
-                Error ->
-                    Error
+                Error -> Error
             end
     end;
 
-process_packet(Packet = ?PUBLISH_PACKET(?QOS_2, Topic, PacketId, _Payload), PState) ->
+process(Packet = ?PUBLISH_PACKET(?QOS_2, Topic, PacketId, _Payload), PState) ->
     case check_publish(Packet, PState) of
         {ok, PState1} ->
             do_publish(Packet, PState1);
         {error, ReasonCode} ->
             ?LOG(warning, "Cannot publish qos2 message to ~s for ~s",
-                [Topic, emqx_reason_codes:text(ReasonCode)]),
+                 [Topic, emqx_reason_codes:text(ReasonCode)]),
             case deliver({pubrec, PacketId, ReasonCode}, PState) of
                 {ok, PState1} ->
                     do_acl_deny_action(Packet, ReasonCode, PState1);
-                Error ->
-                    Error
+                Error -> Error
             end
     end;
 
-process_packet(?PUBACK_PACKET(PacketId, ReasonCode), PState = #pstate{session = SPid}) ->
+process(?PUBACK_PACKET(PacketId, ReasonCode), PState = #pstate{session = SPid}) ->
     {ok = emqx_session:puback(SPid, PacketId, ReasonCode), PState};
 
-process_packet(?PUBREC_PACKET(PacketId, ReasonCode), PState = #pstate{session = SPid}) ->
+process(?PUBREC_PACKET(PacketId, ReasonCode), PState = #pstate{session = SPid}) ->
     case emqx_session:pubrec(SPid, PacketId, ReasonCode) of
         ok ->
             send(?PUBREL_PACKET(PacketId), PState);
@@ -445,7 +449,7 @@ process_packet(?PUBREC_PACKET(PacketId, ReasonCode), PState = #pstate{session =
             send(?PUBREL_PACKET(PacketId, NotFound), PState)
     end;
 
-process_packet(?PUBREL_PACKET(PacketId, ReasonCode), PState = #pstate{session = SPid}) ->
+process(?PUBREL_PACKET(PacketId, ReasonCode), PState = #pstate{session = SPid}) ->
     case emqx_session:pubrel(SPid, PacketId, ReasonCode) of
         ok ->
             send(?PUBCOMP_PACKET(PacketId), PState);
@@ -453,22 +457,22 @@ process_packet(?PUBREL_PACKET(PacketId, ReasonCode), PState = #pstate{session =
             send(?PUBCOMP_PACKET(PacketId, NotFound), PState)
     end;
 
-process_packet(?PUBCOMP_PACKET(PacketId, ReasonCode), PState = #pstate{session = SPid}) ->
+process(?PUBCOMP_PACKET(PacketId, ReasonCode), PState = #pstate{session = SPid}) ->
     {ok = emqx_session:pubcomp(SPid, PacketId, ReasonCode), PState};
 
-process_packet(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters),
-               PState = #pstate{session = SPid, mountpoint = Mountpoint,
-                                proto_ver = ProtoVer, is_bridge = IsBridge,
-                                ignore_loop = IgnoreLoop}) ->
-    RawTopicFilters1 =  if ProtoVer < ?MQTT_PROTO_V5 ->
-                            IfIgnoreLoop = case IgnoreLoop of true -> 1; false -> 0 end,
-                            case IsBridge of
-                                true -> [{RawTopic, SubOpts#{rap => 1, nl => IfIgnoreLoop}} || {RawTopic, SubOpts} <- RawTopicFilters];
-                                false -> [{RawTopic, SubOpts#{rap => 0, nl => IfIgnoreLoop}} || {RawTopic, SubOpts} <- RawTopicFilters]
-                            end;
-                           true ->
-                               RawTopicFilters
-                        end,
+process(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters),
+        PState = #pstate{session = SPid, mountpoint = Mountpoint,
+                         proto_ver = ProtoVer, is_bridge = IsBridge,
+                         ignore_loop = IgnoreLoop}) ->
+    RawTopicFilters1 = if ProtoVer < ?MQTT_PROTO_V5 ->
+                           IfIgnoreLoop = case IgnoreLoop of true -> 1; false -> 0 end,
+                           case IsBridge of
+                               true -> [{RawTopic, SubOpts#{rap => 1, nl => IfIgnoreLoop}} || {RawTopic, SubOpts} <- RawTopicFilters];
+                               false -> [{RawTopic, SubOpts#{rap => 0, nl => IfIgnoreLoop}} || {RawTopic, SubOpts} <- RawTopicFilters]
+                           end;
+                          true ->
+                              RawTopicFilters
+                       end,
     case check_subscribe(
            parse_topic_filters(?SUBSCRIBE, RawTopicFilters1), PState) of
         {ok, TopicFilters} ->
@@ -483,15 +487,14 @@ process_packet(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters)
                     deliver({suback, PacketId, ReasonCodes}, PState)
             end;
         {error, TopicFilters} ->
-            {ReverseSubTopics, ReverseReasonCodes} =
-                lists:foldl(fun({Topic, #{rc := ?RC_SUCCESS}}, {Topics, Codes}) ->
+            {SubTopics, ReasonCodes} =
+                lists:foldr(fun({Topic, #{rc := ?RC_SUCCESS}}, {Topics, Codes}) ->
                                     {[Topic|Topics], [?RC_IMPLEMENTATION_SPECIFIC_ERROR | Codes]};
                                 ({Topic, #{rc := Code}}, {Topics, Codes}) ->
                                     {[Topic|Topics], [Code|Codes]}
                             end, {[], []}, TopicFilters),
-            {SubTopics, ReasonCodes} = {lists:reverse(ReverseSubTopics), lists:reverse(ReverseReasonCodes)},
             ?LOG(warning, "Cannot subscribe ~p for ~p",
-                [SubTopics, [emqx_reason_codes:text(R) || R <- ReasonCodes]]),
+                 [SubTopics, [emqx_reason_codes:text(R) || R <- ReasonCodes]]),
             case deliver({suback, PacketId, ReasonCodes}, PState) of
                 {ok, PState1} ->
                     do_acl_deny_action(Packet, ReasonCodes, PState1);
@@ -500,8 +503,8 @@ process_packet(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters)
             end
     end;
 
-process_packet(?UNSUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters),
-               PState = #pstate{session = SPid, mountpoint = MountPoint}) ->
+process(?UNSUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters),
+        PState = #pstate{session = SPid, mountpoint = MountPoint}) ->
     case emqx_hooks:run('client.unsubscribe', [credentials(PState)],
                         parse_topic_filters(?UNSUBSCRIBE, RawTopicFilters)) of
         {ok, TopicFilters} ->
@@ -514,22 +517,25 @@ process_packet(?UNSUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters),
             deliver({unsuback, PacketId, ReasonCodes}, PState)
     end;
 
-process_packet(?PACKET(?PINGREQ), PState) ->
+process(?PACKET(?PINGREQ), PState) ->
     send(?PACKET(?PINGRESP), PState);
 
-process_packet(?DISCONNECT_PACKET(?RC_SUCCESS, #{'Session-Expiry-Interval' := Interval}),
-                PState = #pstate{session = SPid, conn_props = #{'Session-Expiry-Interval' := OldInterval}}) ->
+process(?DISCONNECT_PACKET(?RC_SUCCESS, #{'Session-Expiry-Interval' := Interval}),
+        PState = #pstate{session = SPid, conn_props = #{'Session-Expiry-Interval' := OldInterval}}) ->
     case Interval =/= 0 andalso OldInterval =:= 0 of
         true ->
             deliver({disconnect, ?RC_PROTOCOL_ERROR}, PState),
-            {error, protocol_error, PState};
+            {error, protocol_error, PState#pstate{will_msg = undefined}};
         false ->
             emqx_session:update_expiry_interval(SPid, Interval),
-            {stop, normal, PState}
+            %% Clean willmsg
+            {stop, normal, PState#pstate{will_msg = undefined}}
     end;
-process_packet(?DISCONNECT_PACKET(?RC_SUCCESS), PState) ->
-    {stop, normal, PState};
-process_packet(?DISCONNECT_PACKET(_), PState) ->
+
+process(?DISCONNECT_PACKET(?RC_SUCCESS), PState) ->
+    {stop, normal, PState#pstate{will_msg = undefined}};
+
+process(?DISCONNECT_PACKET(_), PState) ->
     {stop, {shutdown, abnormal_disconnet}, PState}.
 
 %%------------------------------------------------------------------------------
@@ -562,15 +568,16 @@ do_publish(Packet = ?PUBLISH_PACKET(QoS, PacketId),
 
 puback(?QOS_0, _PacketId, _Result, PState) ->
     {ok, PState};
-puback(?QOS_1, PacketId, [], PState) ->
+puback(?QOS_1, PacketId, {ok, []}, PState) ->
     deliver({puback, PacketId, ?RC_NO_MATCHING_SUBSCRIBERS}, PState);
-puback(?QOS_1, PacketId, [_|_], PState) -> %%TODO: check the dispatch?
+%%TODO: calc the deliver count?
+puback(?QOS_1, PacketId, {ok, _Result}, PState) ->
     deliver({puback, PacketId, ?RC_SUCCESS}, PState);
 puback(?QOS_1, PacketId, {error, ReasonCode}, PState) ->
     deliver({puback, PacketId, ReasonCode}, PState);
-puback(?QOS_2, PacketId, [], PState) ->
+puback(?QOS_2, PacketId, {ok, []}, PState) ->
     deliver({pubrec, PacketId, ?RC_NO_MATCHING_SUBSCRIBERS}, PState);
-puback(?QOS_2, PacketId, [_|_], PState) -> %%TODO: check the dispatch?
+puback(?QOS_2, PacketId, {ok, _Result}, PState) ->
     deliver({pubrec, PacketId, ?RC_SUCCESS}, PState);
 puback(?QOS_2, PacketId, {error, ReasonCode}, PState) ->
     deliver({pubrec, PacketId, ReasonCode}, PState).
@@ -579,7 +586,17 @@ puback(?QOS_2, PacketId, {error, ReasonCode}, PState) ->
 %% Deliver Packet -> Client
 %%------------------------------------------------------------------------------
 
--spec(deliver(tuple(), state()) -> {ok, state()} | {error, term()}).
+-spec(deliver(list(tuple()) | tuple(), state()) -> {ok, state()} | {error, term()}).
+deliver([], PState) ->
+    {ok, PState};
+deliver([Pub|More], PState) ->
+    case deliver(Pub, PState) of
+        {ok, PState1} ->
+            deliver(More, PState1);
+        {error, _} = Error ->
+            Error
+    end;
+
 deliver({connack, ReasonCode}, PState) ->
     send(?CONNACK_PACKET(ReasonCode), PState);
 
@@ -666,11 +683,13 @@ deliver({disconnect, _ReasonCode}, PState) ->
 %% Send Packet to Client
 
 -spec(send(emqx_mqtt_types:packet(), state()) -> {ok, state()} | {error, term()}).
-send(Packet = ?PACKET(Type), PState = #pstate{proto_ver = Ver, sendfun = SendFun}) ->
-    trace(send, Packet),
-    case SendFun(Packet, #{version => Ver}) of
+send(Packet = ?PACKET(Type), PState = #pstate{proto_ver = Ver, sendfun = Send}) ->
+    Data = emqx_frame:serialize(Packet, #{version => Ver}),
+    case Send(Data) of
         ok ->
+            trace(send, Packet),
             emqx_metrics:sent(Packet),
+            emqx_metrics:trans(inc, 'bytes/sent', iolist_size(Data)),
             {ok, inc_stats(send, Type, PState)};
         {error, Reason} ->
             {error, Reason}
@@ -809,14 +828,13 @@ check_will_topic(#mqtt_packet_connect{will_topic = WillTopic} = ConnPkt, PState)
             {error, ?RC_TOPIC_NAME_INVALID}
     end.
 
-check_will_acl(_ConnPkt, #pstate{enable_acl = EnableAcl})
-  when not EnableAcl ->
+check_will_acl(_ConnPkt, #pstate{enable_acl = EnableAcl}) when not EnableAcl ->
     ok;
 check_will_acl(#mqtt_packet_connect{will_topic = WillTopic}, PState) ->
     case emqx_access_control:check_acl(credentials(PState), publish, WillTopic) of
         allow -> ok;
         deny ->
-            ?LOG(warning, "Will message (to ~s) validation failed, acl denied", [WillTopic]),
+            ?LOG(warning, "Cannot publish will message to ~p for acl denied", [WillTopic]),
             {error, ?RC_NOT_AUTHORIZED}
     end.
 
@@ -825,7 +843,7 @@ check_publish(Packet, PState) ->
                      fun check_pub_acl/2], Packet, PState).
 
 check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, retain = Retain},
-                            variable = #mqtt_packet_publish{ properties = _Properties}},
+                            variable = #mqtt_packet_publish{properties = _Properties}},
                #pstate{zone = Zone}) ->
     emqx_mqtt_caps:check_pub(Zone, #{qos => QoS, retain => Retain}).
 
@@ -892,15 +910,15 @@ inc_stats(Type, Stats = #{pkt := PktCnt, msg := MsgCnt}) ->
                                          false -> MsgCnt
                                      end}.
 
-shutdown(_Reason, #pstate{client_id = undefined}) ->
+terminate(_Reason, #pstate{client_id = undefined}) ->
     ok;
-shutdown(_Reason, #pstate{connected = false}) ->
+terminate(_Reason, #pstate{connected = false}) ->
     ok;
-shutdown(conflict, _PState) ->
+terminate(conflict, _PState) ->
     ok;
-shutdown(discard, _PState) ->
+terminate(discard, _PState) ->
     ok;
-shutdown(Reason, PState) ->
+terminate(Reason, PState) ->
     ?LOG(info, "Shutdown for ~p", [Reason]),
     emqx_hooks:run('client.disconnected', [credentials(PState), Reason]).
 

+ 205 - 114
src/emqx_session.erl

@@ -71,8 +71,11 @@
           %% Clean Start Flag
           clean_start = false :: boolean(),
 
-          %% Client Binding: local | remote
-          binding = local :: local | remote,
+          %% Conn Binding: local | remote
+          %% binding = local :: local | remote,
+
+          %% Deliver fun
+          deliver_fun :: function(),
 
           %% ClientId: Identifier of Session
           client_id :: binary(),
@@ -157,6 +160,8 @@
 
 -export_type([attr/0]).
 
+-define(DEFAULT_BATCH_N, 1000).
+
 %% @doc Start a session proc.
 -spec(start_link(SessAttrs :: map()) -> {ok, pid()}).
 start_link(SessAttrs) ->
@@ -196,13 +201,13 @@ attrs(SPid) when is_pid(SPid) ->
     gen_server:call(SPid, attrs, infinity);
 
 attrs(#state{clean_start = CleanStart,
-             binding = Binding,
              client_id = ClientId,
+             conn_pid  = ConnPid,
              username = Username,
              expiry_interval = ExpiryInterval,
              created_at = CreatedAt}) ->
     [{clean_start, CleanStart},
-     {binding, Binding},
+     {binding, binding(ConnPid)},
      {client_id, ClientId},
      {username, Username},
      {expiry_interval, ExpiryInterval div 1000},
@@ -249,19 +254,19 @@ subscribe(SPid, PacketId, Properties, TopicFilters) ->
 
 %% @doc Called by connection processes when publishing messages
 -spec(publish(spid(), emqx_mqtt_types:packet_id(), emqx_types:message())
-      -> emqx_types:deliver_results() | {error, term()}).
+      -> {ok, emqx_types:deliver_results()} | {error, term()}).
 publish(_SPid, _PacketId, Msg = #message{qos = ?QOS_0}) ->
     %% Publish QoS0 message directly
-    emqx_broker:publish(Msg);
+    {ok, emqx_broker:publish(Msg)};
 
 publish(_SPid, _PacketId, Msg = #message{qos = ?QOS_1}) ->
     %% Publish QoS1 message directly
-    emqx_broker:publish(Msg);
+    {ok, emqx_broker:publish(Msg)};
 
 publish(SPid, PacketId, Msg = #message{qos = ?QOS_2, timestamp = Ts}) ->
     %% Register QoS2 message packet ID (and timestamp) to session, then publish
     case gen_server:call(SPid, {register_publish_packet_id, PacketId, Ts}, infinity) of
-        ok -> emqx_broker:publish(Msg);
+        ok -> {ok, emqx_broker:publish(Msg)};
         {error, Reason} -> {error, Reason}
     end.
 
@@ -342,7 +347,7 @@ init([Parent, #{zone            := Zone,
     IdleTimout = get_env(Zone, idle_timeout, 30000),
     State = #state{idle_timeout      = IdleTimout,
                    clean_start       = CleanStart,
-                   binding           = binding(ConnPid),
+                   deliver_fun       = deliver_fun(ConnPid),
                    client_id         = ClientId,
                    username          = Username,
                    conn_pid          = ConnPid,
@@ -376,9 +381,18 @@ init_mqueue(Zone) ->
                        default_priority => get_env(Zone, mqueue_default_priority)
                       }).
 
+binding(undefined) -> undefined;
 binding(ConnPid) ->
     case node(ConnPid) =:= node() of true -> local; false -> remote end.
 
+deliver_fun(ConnPid) when node(ConnPid) == node() ->
+    fun(Packet) -> ConnPid ! {deliver, Packet}, ok end;
+deliver_fun(ConnPid) ->
+    Node = node(ConnPid),
+    fun(Packet) ->
+            emqx_rpc:cast(Node, erlang, send, [ConnPid, {deliver, Packet}])
+    end.
+
 handle_call(info, _From, State) ->
     reply(info(State), State);
 
@@ -539,7 +553,7 @@ handle_cast({resume, #{conn_pid        := ConnPid,
     true = link(ConnPid),
 
     State1 = State#state{conn_pid         = ConnPid,
-                         binding          = binding(ConnPid),
+                         deliver_fun      = deliver_fun(ConnPid),
                          old_conn_pid     = OldConnPid,
                          clean_start      = false,
                          retry_timer      = undefined,
@@ -566,25 +580,11 @@ handle_cast(Msg, State) ->
     emqx_logger:error("[Session] unexpected cast: ~p", [Msg]),
     {noreply, State}.
 
-%% Batch dispatch
+handle_info({dispatch, Topic, Msg}, State) when is_record(Msg, message) ->
+    handle_dispatch([{Topic, Msg}], State);
+
 handle_info({dispatch, Topic, Msgs}, State) when is_list(Msgs) ->
-    noreply(lists:foldl(
-              fun(Msg, St) ->
-                  element(2, handle_info({dispatch, Topic, Msg}, St))
-              end, State, Msgs));
-
-%% Dispatch message
-handle_info({dispatch, Topic, Msg = #message{}}, State) ->
-    case emqx_shared_sub:is_ack_required(Msg) andalso not has_connection(State) of
-        true ->
-            %% Require ack, but we do not have connection
-            %% negative ack the message so it can try the next subscriber in the group
-            ok = emqx_shared_sub:nack_no_connection(Msg),
-            {noreply, State};
-        false ->
-            NewState = handle_dispatch(Topic, Msg, State),
-            noreply(ensure_stats_timer(maybe_gc({1, msg_size(Msg)}, NewState)))
-    end;
+    handle_dispatch([{Topic, Msg} || Msg <- Msgs], State);
 
 %% Do nothing if the client has been disconnected.
 handle_info({timeout, Timer, retry_delivery}, State = #state{conn_pid = undefined, retry_timer = Timer}) ->
@@ -684,18 +684,11 @@ maybe_shutdown(Pid, Reason) ->
 %% Internal functions
 %%------------------------------------------------------------------------------
 
-has_connection(#state{conn_pid = Pid}) ->
+is_connection_alive(#state{conn_pid = Pid}) ->
     is_pid(Pid) andalso is_process_alive(Pid).
 
-handle_dispatch(Topic, Msg, State = #state{subscriptions = SubMap}) ->
-    case maps:find(Topic, SubMap) of
-        {ok, #{nl := Nl, qos := QoS, rap := Rap, subid := SubId}} ->
-            run_dispatch_steps([{nl, Nl}, {qos, QoS}, {rap, Rap}, {subid, SubId}], Msg, State);
-        {ok, #{nl := Nl, qos := QoS, rap := Rap}} ->
-            run_dispatch_steps([{nl, Nl}, {qos, QoS}, {rap, Rap}], Msg, State);
-        error ->
-            dispatch(emqx_message:unset_flag(dup, Msg), State)
-    end.
+%%------------------------------------------------------------------------------
+%% Suback and unsuback
 
 suback(_From, undefined, _ReasonCodes) ->
     ignore;
@@ -722,7 +715,6 @@ kick(ClientId, OldConnPid, ConnPid) ->
 
 %%------------------------------------------------------------------------------
 %% Replay or Retry Delivery
-%%------------------------------------------------------------------------------
 
 %% Redeliver at once if force is true
 retry_delivery(Force, State = #state{inflight = Inflight}) ->
@@ -766,6 +758,7 @@ retry_delivery(Force, [{Type, Msg0, Ts} | Msgs], Now,
 %%------------------------------------------------------------------------------
 %% Send Will Message
 %%------------------------------------------------------------------------------
+
 send_willmsg(undefined) ->
     ignore;
 send_willmsg(WillMsg) ->
@@ -801,64 +794,156 @@ expire_awaiting_rel([{PacketId, Ts} | More], Now,
 
 is_awaiting_full(#state{max_awaiting_rel = 0}) ->
     false;
-is_awaiting_full(#state{awaiting_rel = AwaitingRel, max_awaiting_rel = MaxLen}) ->
+is_awaiting_full(#state{awaiting_rel = AwaitingRel,
+                        max_awaiting_rel = MaxLen}) ->
     maps:size(AwaitingRel) >= MaxLen.
 
 %%------------------------------------------------------------------------------
-%% Dispatch Messages
+%% Dispatch messages
 %%------------------------------------------------------------------------------
 
-run_dispatch_steps([], Msg, State) ->
-    dispatch(Msg, State);
-run_dispatch_steps([{nl, 1}|_Steps], #message{from = ClientId}, State = #state{client_id = ClientId}) ->
-    State;
-run_dispatch_steps([{nl, _}|Steps], Msg, State) ->
-    run_dispatch_steps(Steps, Msg, State);
-run_dispatch_steps([{qos, SubQoS}|Steps], Msg0 = #message{qos = PubQoS}, State = #state{upgrade_qos = false}) ->
-    %% Ack immediately if a shared dispatch QoS is downgraded to 0
-    Msg = case SubQoS =:= ?QOS_0 of
-              true -> emqx_shared_sub:maybe_ack(Msg0);
-              false -> Msg0
-          end,
-    run_dispatch_steps(Steps, Msg#message{qos = min(SubQoS, PubQoS)}, State);
-run_dispatch_steps([{qos, SubQoS}|Steps], Msg = #message{qos = PubQoS}, State = #state{upgrade_qos = true}) ->
-    run_dispatch_steps(Steps, Msg#message{qos = max(SubQoS, PubQoS)}, State);
-run_dispatch_steps([{rap, _Rap}|Steps], Msg = #message{flags = Flags, headers = #{retained := true}}, State = #state{}) ->
-    run_dispatch_steps(Steps, Msg#message{flags = maps:put(retain, true, Flags)}, State);
-run_dispatch_steps([{rap, 0}|Steps], Msg = #message{flags = Flags}, State = #state{}) ->
-    run_dispatch_steps(Steps, Msg#message{flags = maps:put(retain, false, Flags)}, State);
-run_dispatch_steps([{rap, _}|Steps], Msg, State) ->
-    run_dispatch_steps(Steps, Msg, State);
-run_dispatch_steps([{subid, SubId}|Steps], Msg, State) ->
-    run_dispatch_steps(Steps, emqx_message:set_header('Subscription-Identifier', SubId, Msg), State).
+handle_dispatch(Msgs, State = #state{inflight = Inflight, subscriptions = SubMap}) ->
+    %% Drain the mailbox and batch deliver
+    Msgs1 = drain_m(batch_n(Inflight), Msgs),
+    %% Ack the messages for shared subscription
+    Msgs2 = maybe_ack_shared(Msgs1, State),
+    %% Process suboptions
+    Msgs3 = lists:foldr(
+              fun({Topic, Msg}, Acc) ->
+                      SubOpts = find_subopts(Topic, SubMap),
+                      case process_subopts(SubOpts, Msg, State) of
+                          {ok, Msg1} -> [Msg1|Acc];
+                          ignore -> Acc
+                      end
+              end, [], Msgs2),
+    NState = batch_process(Msgs3, State),
+    noreply(ensure_stats_timer(NState)).
+
+batch_n(Inflight) ->
+    case emqx_inflight:max_size(Inflight) of
+        0 -> ?DEFAULT_BATCH_N;
+        Sz -> Sz - emqx_inflight:size(Inflight)
+    end.
+
+drain_m(Cnt, Msgs) when Cnt =< 0 ->
+    lists:reverse(Msgs);
+drain_m(Cnt, Msgs) ->
+    receive
+        {dispatch, Topic, Msg} ->
+            drain_m(Cnt-1, [{Topic, Msg}|Msgs])
+    after 0 ->
+        lists:reverse(Msgs)
+    end.
+
+%% Ack or nack the messages of shared subscription?
+maybe_ack_shared(Msgs, State) when is_list(Msgs) ->
+    lists:foldr(
+      fun({Topic, Msg}, Acc) ->
+            case maybe_ack_shared(Msg, State) of
+                ok -> Acc;
+                Msg1 -> [{Topic, Msg1}|Acc]
+            end
+      end, [], Msgs);
+
+maybe_ack_shared(Msg, State) ->
+    case emqx_shared_sub:is_ack_required(Msg) of
+        true -> do_ack_shared(Msg, State);
+        false -> Msg
+    end.
+
+do_ack_shared(Msg, State = #state{inflight = Inflight}) ->
+    case {is_connection_alive(State),
+          emqx_inflight:is_full(Inflight)} of
+        {false, _} ->
+            %% Require ack, but we do not have connection
+            %% negative ack the message so it can try the next subscriber in the group
+            emqx_shared_sub:nack_no_connection(Msg);
+        {_, true} ->
+            emqx_shared_sub:maybe_nack_dropped(Msg);
+         _ ->
+            %% Ack QoS1/QoS2 messages when message is delivered to connection.
+            %% NOTE: NOT to wait for PUBACK because:
+            %% The sender is monitoring this session process,
+            %% if the message is delivered to client but connection or session crashes,
+            %% sender will try to dispatch the message to the next shared subscriber.
+            %% This violates spec as QoS2 messages are not allowed to be sent to more
+            %% than one member in the group.
+            emqx_shared_sub:maybe_ack(Msg)
+    end.
+
+process_subopts([], Msg, _State) ->
+    {ok, Msg};
+process_subopts([{nl, 1}|_Opts], #message{from = ClientId}, #state{client_id = ClientId}) ->
+    ignore;
+process_subopts([{nl, _}|Opts], Msg, State) ->
+    process_subopts(Opts, Msg, State);
+process_subopts([{qos, SubQoS}|Opts], Msg = #message{qos = PubQoS}, State = #state{upgrade_qos = false}) ->
+    process_subopts(Opts, Msg#message{qos = min(SubQoS, PubQoS)}, State);
+process_subopts([{qos, SubQoS}|Opts], Msg = #message{qos = PubQoS}, State = #state{upgrade_qos = true}) ->
+    process_subopts(Opts, Msg#message{qos = max(SubQoS, PubQoS)}, State);
+process_subopts([{rap, _Rap}|Opts], Msg = #message{flags = Flags, headers = #{retained := true}}, State = #state{}) ->
+    process_subopts(Opts, Msg#message{flags = maps:put(retain, true, Flags)}, State);
+process_subopts([{rap, 0}|Opts], Msg = #message{flags = Flags}, State = #state{}) ->
+    process_subopts(Opts, Msg#message{flags = maps:put(retain, false, Flags)}, State);
+process_subopts([{rap, _}|Opts], Msg, State) ->
+    process_subopts(Opts, Msg, State);
+process_subopts([{subid, SubId}|Opts], Msg, State) ->
+    process_subopts(Opts, emqx_message:set_header('Subscription-Identifier', SubId, Msg), State).
+
+find_subopts(Topic, SubMap) ->
+    case maps:find(Topic, SubMap) of
+        {ok, #{nl := Nl, qos := QoS, rap := Rap, subid := SubId}} ->
+            [{nl, Nl}, {qos, QoS}, {rap, Rap}, {subid, SubId}];
+        {ok, #{nl := Nl, qos := QoS, rap := Rap}} ->
+            [{nl, Nl}, {qos, QoS}, {rap, Rap}];
+        error -> []
+    end.
+
+batch_process(Msgs, State) ->
+    {ok, Publishes, NState} = process_msgs(Msgs, [], State),
+    ok = batch_deliver(Publishes, NState),
+    maybe_gc(msg_cnt(Msgs), NState).
+
+process_msgs([], Publishes, State) ->
+    {ok, lists:reverse(Publishes), State};
+
+process_msgs([Msg|Msgs], Publishes, State) ->
+    case process_msg(Msg, State) of
+        {ok, Publish, NState} ->
+            process_msgs(Msgs, [Publish|Publishes], NState);
+        {ignore, NState} ->
+            process_msgs(Msgs, Publishes, NState)
+    end.
 
 %% Enqueue message if the client has been disconnected
-dispatch(Msg, State = #state{client_id = ClientId, username = Username, conn_pid = undefined}) ->
-    case emqx_hooks:run('message.dropped', [#{client_id => ClientId, username => Username}, Msg]) of
-        ok -> enqueue_msg(Msg, State);
-        stop -> State
-    end;
+process_msg(Msg, State = #state{conn_pid = undefined}) ->
+    {ignore, enqueue_msg(Msg, State)};
 
-%% Deliver qos0 message directly to client
-dispatch(Msg = #message{qos = ?QOS_0} = Msg, State) ->
-    ok = deliver(undefined, Msg, State),
-    State;
+%% Prepare the qos0 message delivery
+process_msg(Msg = #message{qos = ?QOS_0}, State) ->
+    {ok, {publish, undefined, Msg}, State};
 
-dispatch(Msg = #message{qos = QoS} = Msg,
-         State = #state{next_pkt_id = PacketId, inflight = Inflight})
+process_msg(Msg = #message{qos = QoS},
+            State = #state{next_pkt_id = PacketId, inflight = Inflight})
     when QoS =:= ?QOS_1 orelse QoS =:= ?QOS_2 ->
     case emqx_inflight:is_full(Inflight) of
         true ->
-            enqueue_msg(Msg, State);
+            {ignore, enqueue_msg(Msg, State)};
         false ->
-            ok = deliver(PacketId, Msg, State),
-            await(PacketId, Msg, next_pkt_id(State))
+            Publish = {publish, PacketId, Msg},
+            NState = await(PacketId, Msg, State),
+            {ok, Publish, next_pkt_id(NState)}
     end.
 
-enqueue_msg(Msg, State = #state{mqueue = Q}) ->
+enqueue_msg(Msg, State = #state{mqueue = Q, client_id = ClientId, username = Username}) ->
     emqx_pd:update_counter(enqueue_stats, 1),
     {Dropped, NewQ} = emqx_mqueue:in(Msg, Q),
-    Dropped =/= undefined andalso emqx_shared_sub:maybe_nack_dropped(Dropped),
+    if
+        Dropped =/= undefined ->
+            SessProps = #{client_id => ClientId, username => Username},
+            emqx_hooks:run('message.dropped', [SessProps, Msg]);
+        true -> ok
+    end,
     State#state{mqueue = NewQ}.
 
 %%------------------------------------------------------------------------------
@@ -866,28 +951,22 @@ enqueue_msg(Msg, State = #state{mqueue = Q}) ->
 %%------------------------------------------------------------------------------
 
 redeliver({PacketId, Msg = #message{qos = QoS}}, State) ->
-    deliver(PacketId, if QoS =:= ?QOS_2 -> Msg;
-                         true -> emqx_message:set_flag(dup, Msg)
-                      end, State);
+    Msg1 = if
+               QoS =:= ?QOS_2 -> Msg;
+               true -> emqx_message:set_flag(dup, Msg)
+           end,
+    do_deliver(PacketId, Msg1, State);
 
-redeliver({pubrel, PacketId}, #state{conn_pid = ConnPid}) ->
-    ConnPid ! {deliver, {pubrel, PacketId}}.
+redeliver({pubrel, PacketId}, #state{deliver_fun = DeliverFun}) ->
+    DeliverFun({pubrel, PacketId}).
 
-deliver(PacketId, Msg, State) ->
+do_deliver(PacketId, Msg, #state{deliver_fun = DeliverFun}) ->
     emqx_pd:update_counter(deliver_stats, 1),
-    %% Ack QoS1/QoS2 messages when message is delivered to connection.
-    %% NOTE: NOT to wait for PUBACK because:
-    %% The sender is monitoring this session process,
-    %% if the message is delivered to client but connection or session crashes,
-    %% sender will try to dispatch the message to the next shared subscriber.
-    %% This violates spec as QoS2 messages are not allowed to be sent to more
-    %% than one member in the group.
-    do_deliver(PacketId, emqx_shared_sub:maybe_ack(Msg), State).
-
-do_deliver(PacketId, Msg, #state{conn_pid = ConnPid, binding = local}) ->
-    ConnPid ! {deliver, {publish, PacketId, Msg}}, ok;
-do_deliver(PacketId, Msg, #state{conn_pid = ConnPid, binding = remote}) ->
-    emqx_rpc:cast(node(ConnPid), erlang, send, [ConnPid, {deliver, {publish, PacketId, Msg}}]).
+    DeliverFun({publish, PacketId, Msg}).
+
+batch_deliver(Publishes, #state{deliver_fun = DeliverFun}) ->
+    emqx_pd:update_counter(deliver_stats, length(Publishes)),
+    DeliverFun(Publishes).
 
 %%------------------------------------------------------------------------------
 %% Awaiting ACK for QoS1/QoS2 Messages
@@ -932,26 +1011,31 @@ acked(pubcomp, PacketId, State = #state{inflight = Inflight}) ->
 dequeue(State = #state{conn_pid = undefined}) ->
     State;
 
-dequeue(State = #state{inflight = Inflight}) ->
-    case emqx_inflight:is_full(Inflight) of
-        true  -> State;
-        false -> dequeue2(State)
+dequeue(State = #state{inflight = Inflight, mqueue = Q}) ->
+    case emqx_mqueue:is_empty(Q)
+         orelse emqx_inflight:is_full(Inflight) of
+        true -> State;
+        false ->
+            {Msgs, Q1} = drain_q(batch_n(Inflight), [], Q),
+            batch_process(lists:reverse(Msgs), State#state{mqueue = Q1})
     end.
 
-dequeue2(State = #state{mqueue = Q}) ->
+drain_q(Cnt, Msgs, Q) when Cnt =< 0 ->
+    {Msgs, Q};
+
+drain_q(Cnt, Msgs, Q) ->
     case emqx_mqueue:out(Q) of
-        {empty, _Q} -> State;
+        {empty, _Q} -> {Msgs, Q};
         {{value, Msg}, Q1} ->
-            %% Dequeue more
-            dequeue(dispatch(Msg, State#state{mqueue = Q1}))
+            drain_q(Cnt-1, [Msg|Msgs], Q1)
     end.
 
 %%------------------------------------------------------------------------------
 %% Ensure timers
 
-ensure_await_rel_timer(State = #state{await_rel_timer = undefined, await_rel_timeout = Timeout}) ->
+ensure_await_rel_timer(State = #state{await_rel_timer = undefined,
+                                      await_rel_timeout = Timeout}) ->
     ensure_await_rel_timer(Timeout, State);
-
 ensure_await_rel_timer(State) ->
     State.
 
@@ -960,7 +1044,8 @@ ensure_await_rel_timer(Timeout, State = #state{await_rel_timer = undefined}) ->
 ensure_await_rel_timer(_Timeout, State) ->
     State.
 
-ensure_retry_timer(State = #state{retry_timer = undefined, retry_interval = Interval}) ->
+ensure_retry_timer(State = #state{retry_timer = undefined,
+                                  retry_interval = Interval}) ->
     ensure_retry_timer(Interval, State);
 ensure_retry_timer(State) ->
     State.
@@ -970,7 +1055,8 @@ ensure_retry_timer(Interval, State = #state{retry_timer = undefined}) ->
 ensure_retry_timer(_Timeout, State) ->
     State.
 
-ensure_expire_timer(State = #state{expiry_interval = Interval}) when Interval > 0 andalso Interval =/= 16#ffffffff ->
+ensure_expire_timer(State = #state{expiry_interval = Interval})
+  when Interval > 0 andalso Interval =/= 16#ffffffff ->
     State#state{expiry_timer = emqx_misc:start_timer(Interval * 1000, expired)};
 ensure_expire_timer(State) ->
     State.
@@ -997,15 +1083,20 @@ next_pkt_id(State = #state{next_pkt_id = 16#FFFF}) ->
 next_pkt_id(State = #state{next_pkt_id = Id}) ->
     State#state{next_pkt_id = Id + 1}.
 
+%%------------------------------------------------------------------------------
+%% Maybe GC
+
+msg_cnt(Msgs) ->
+    lists:foldl(fun(Msg, {Cnt, Oct}) ->
+                        {Cnt+1, Oct+msg_size(Msg)}
+                end, {0, 0}, Msgs).
+
 %% Take only the payload size into account, add other fields if necessary
 msg_size(#message{payload = Payload}) -> payload_size(Payload).
 
 %% Payload should be binary(), but not 100% sure. Need dialyzer!
 payload_size(Payload) -> erlang:iolist_size(Payload).
 
-%%------------------------------------------------------------------------------
-%% Maybe GC
-
 maybe_gc(_, State = #state{gc_state = undefined}) ->
     State;
 maybe_gc({Cnt, Oct}, State = #state{gc_state = GCSt}) ->

+ 17 - 20
src/emqx_ws_connection.erl

@@ -18,7 +18,8 @@
 -include("emqx_mqtt.hrl").
 -include("logger.hrl").
 
--export([info/1, attrs/1]).
+-export([info/1]).
+-export([attrs/1]).
 -export([stats/1]).
 -export([kick/1]).
 -export([session/1]).
@@ -37,7 +38,7 @@
           sockname,
           idle_timeout,
           proto_state,
-          parser_state,
+          parse_state,
           keepalive,
           enable_stats,
           stats_timer,
@@ -128,24 +129,21 @@ websocket_init(#state{request = Req, options = Options}) ->
                                       sockname => Sockname,
                                       peercert => Peercert,
                                       sendfun  => send_fun(self())}, Options),
-    ParserState = emqx_protocol:parser(ProtoState),
+    ParseState = emqx_protocol:parser(ProtoState),
     Zone = proplists:get_value(zone, Options),
     EnableStats = emqx_zone:get_env(Zone, enable_stats, true),
     IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000),
-
     emqx_logger:set_metadata_peername(esockd_net:format(Peername)),
     {ok, #state{peername     = Peername,
                 sockname     = Sockname,
-                parser_state = ParserState,
+                parse_state  = ParseState,
                 proto_state  = ProtoState,
                 enable_stats = EnableStats,
                 idle_timeout = IdleTimout}}.
 
 send_fun(WsPid) ->
-    fun(Packet, Options) ->
-        Data = emqx_frame:serialize(Packet, Options),
+    fun(Data) ->
         BinSize = iolist_size(Data),
-        emqx_metrics:trans(inc, 'bytes/sent', BinSize),
         emqx_pd:update_counter(send_cnt, 1),
         emqx_pd:update_counter(send_oct, BinSize),
         WsPid ! {binary, iolist_to_binary(Data)},
@@ -159,15 +157,15 @@ websocket_handle({binary, <<>>}, State) ->
     {ok, ensure_stats_timer(State)};
 websocket_handle({binary, [<<>>]}, State) ->
     {ok, ensure_stats_timer(State)};
-websocket_handle({binary, Data}, State = #state{parser_state = ParserState,
-                                                proto_state  = ProtoState}) ->
+websocket_handle({binary, Data}, State = #state{parse_state = ParseState,
+                                                proto_state = ProtoState}) ->
     ?LOG(debug, "RECV ~p", [Data]),
     BinSize = iolist_size(Data),
     emqx_pd:update_counter(recv_oct, BinSize),
     emqx_metrics:trans(inc, 'bytes/received', BinSize),
-    try emqx_frame:parse(iolist_to_binary(Data), ParserState) of
-        {more, ParserState1} ->
-            {ok, State#state{parser_state = ParserState1}};
+    try emqx_frame:parse(iolist_to_binary(Data), ParseState) of
+        {more, ParseState1} ->
+            {ok, State#state{parse_state = ParseState1}};
         {ok, Packet, Rest} ->
             emqx_metrics:received(Packet),
             emqx_pd:update_counter(recv_cnt, 1),
@@ -248,10 +246,10 @@ websocket_info({keepalive, check}, State = #state{keepalive = KeepAlive}) ->
         {ok, KeepAlive1} ->
             {ok, State#state{keepalive = KeepAlive1}};
         {error, timeout} ->
-            ?LOG(debug, "Keepalive Timeout!", []),
+            ?LOG(debug, "Keepalive Timeout!"),
             shutdown(keepalive_timeout, State);
         {error, Error} ->
-            ?LOG(warning, "Keepalive error - ~p", [Error]),
+            ?LOG(error, "Keepalive error - ~p", [Error]),
             shutdown(keepalive_error, State)
     end;
 
@@ -277,15 +275,14 @@ terminate(SockError, _Req, #state{keepalive   = Keepalive,
                                   proto_state = ProtoState,
                                   shutdown    = Shutdown}) ->
 
-    ?LOG(debug, "Terminated for ~p, sockerror: ~p",
-           [Shutdown, SockError]),
+    ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Shutdown, SockError]),
     emqx_keepalive:cancel(Keepalive),
     case {ProtoState, Shutdown} of
         {undefined, _} -> ok;
         {_, {shutdown, Reason}} ->
-            emqx_protocol:shutdown(Reason, ProtoState);
+            emqx_protocol:terminate(Reason, ProtoState);
         {_, Error} ->
-            emqx_protocol:shutdown(Error, ProtoState)
+            emqx_protocol:terminate(Error, ProtoState)
     end.
 
 %%------------------------------------------------------------------------------
@@ -293,7 +290,7 @@ terminate(SockError, _Req, #state{keepalive   = Keepalive,
 %%------------------------------------------------------------------------------
 
 reset_parser(State = #state{proto_state = ProtoState}) ->
-    State#state{parser_state = emqx_protocol:parser(ProtoState)}.
+    State#state{parse_state = emqx_protocol:parser(ProtoState)}.
 
 ensure_stats_timer(State = #state{enable_stats = true,
                                   stats_timer  = undefined,

+ 3 - 2
test/emqx_connection_SUITE.erl

@@ -96,6 +96,7 @@ t_connect_api(_Config) ->
     ?STATS = emqx_connection:stats(CPid),
     ?ATTRS = emqx_connection:attrs(CPid),
     ?INFO = emqx_connection:info(CPid),
-    SessionPid = emqx_connection:session(CPid),
-    true = is_pid(SessionPid),
+    SPid = emqx_connection:session(CPid),
+    true = is_pid(SPid),
     emqx_client:disconnect(T1).
+

+ 12 - 15
test/emqx_mqtt_packet_SUITE.erl

@@ -1,18 +1,15 @@
-%%%===================================================================
-%%% Copyright (c) 2013-2019 EMQ Inc. All rights reserved.
-%%%
-%%% Licensed under the Apache License, Version 2.0 (the "License");
-%%% you may not use this file except in compliance with the License.
-%%% You may obtain a copy of the License at
-%%%
-%%%     http://www.apache.org/licenses/LICENSE-2.0
-%%%
-%%% Unless required by applicable law or agreed to in writing, software
-%%% distributed under the License is distributed on an "AS IS" BASIS,
-%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-%%% See the License for the specific language governing permissions and
-%%% limitations under the License.
-%%%===================================================================
+%% Copyright (c) 2013-2019 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
 
 -module(emqx_mqtt_packet_SUITE).
 

+ 2 - 1
test/emqx_pool_SUITE.erl

@@ -62,7 +62,7 @@ async_submit_mfa(_Config) ->
     emqx_pool:async_submit(fun ?MODULE:test_mfa/0, []).
 
 async_submit_crash(_) ->
-    emqx_pool:async_submit(fun() -> error(test) end).
+    emqx_pool:async_submit(fun() -> error(unexpected_error) end).
 
 t_unexpected(_) ->
     Pid = emqx_pool:worker(),
@@ -73,3 +73,4 @@ t_unexpected(_) ->
 
 test_mfa() ->
     lists:foldl(fun(X, Sum) -> X + Sum end, 0, [1,2,3,4,5]).
+

+ 3 - 5
test/emqx_protocol_SUITE.erl

@@ -1,5 +1,4 @@
-%%--------------------------------------------------------------------
-%% Copyright (c) 2013-2019 EMQ Enterprise, Inc. (http://emqtt.io)
+%% Copyright (c) 2013-2019 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.
@@ -12,7 +11,6 @@
 %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 %% See the License for the specific language governing permissions and
 %% limitations under the License.
-%%--------------------------------------------------------------------
 
 -module(emqx_protocol_SUITE).
 
@@ -574,9 +572,9 @@ acl_deny_action_ct(_) ->
 acl_deny_action_eunit(_) ->
     PState = ?TEST_PSTATE(?MQTT_PROTO_V5, #{msg => 0, pkt => 0}),
     CodeName = emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ?MQTT_PROTO_V5),
-    {error, CodeName, NEWPSTATE1} = emqx_protocol:process_packet(?PUBLISH_PACKET(?QOS_1, <<"acl_deny_action">>, 1, <<"payload">>), PState),
+    {error, CodeName, NEWPSTATE1} = emqx_protocol:process(?PUBLISH_PACKET(?QOS_1, <<"acl_deny_action">>, 1, <<"payload">>), PState),
     ?assertEqual(#{pkt => 1, msg => 0}, NEWPSTATE1#pstate.send_stats),
-    {error, CodeName, NEWPSTATE2} = emqx_protocol:process_packet(?PUBLISH_PACKET(?QOS_2, <<"acl_deny_action">>, 2, <<"payload">>), PState),
+    {error, CodeName, NEWPSTATE2} = emqx_protocol:process(?PUBLISH_PACKET(?QOS_2, <<"acl_deny_action">>, 2, <<"payload">>), PState),
     ?assertEqual(#{pkt => 1, msg => 0}, NEWPSTATE2#pstate.send_stats).
 
 will_topic_check(_) ->

+ 2 - 3
test/emqx_session_SUITE.erl

@@ -1,4 +1,3 @@
-
 %% Copyright (c) 2013-2019 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
@@ -45,7 +44,7 @@ ignore_loop(_Config) ->
     application:set_env(emqx, mqtt_ignore_loop_deliver, false).
 
 t_session_all(_) ->
-    emqx_zone:set_env(internal, idle_timeout, 100),
+    emqx_zone:set_env(internal, idle_timeout, 1000),
     ClientId = <<"ClientId">>,
     {ok, ConnPid} = emqx_mock_client:start_link(ClientId),
     {ok, SPid} = emqx_mock_client:open_session(ConnPid, ClientId, internal),
@@ -56,7 +55,7 @@ t_session_all(_) ->
     [{<<"topic">>, _}] = emqx:subscriptions(SPid),
     emqx_session:publish(SPid, 1, Message1),
     timer:sleep(200),
-    {publish, 1, _} = emqx_mock_client:get_last_message(ConnPid),
+    [{publish, 1, _}] = emqx_mock_client:get_last_message(ConnPid),
     Attrs = emqx_session:attrs(SPid),
     Info = emqx_session:info(SPid),
     Stats = emqx_session:stats(SPid),

+ 5 - 5
test/emqx_shared_sub_SUITE.erl

@@ -59,7 +59,7 @@ t_random_basic(_) ->
     PacketId = 1,
     emqx_session:publish(SPid, PacketId, Message1),
     ?wait(case emqx_mock_client:get_last_message(ConnPid) of
-              {publish, 1, _} -> true;
+              [{publish, 1, _}] -> true;
               Other -> Other
           end, 1000),
     emqx_session:pubrec(SPid, PacketId, reasoncode),
@@ -105,7 +105,7 @@ t_no_connection_nack(_) ->
         fun(PacketId, ConnPid) ->
                 Payload = MkPayload(PacketId),
                 case emqx_mock_client:get_last_message(ConnPid) of
-                    {publish, _, #message{payload = Payload}} ->
+                    [{publish, _, #message{payload = Payload}}] ->
                         CasePid ! {Ref, PacketId, ConnPid},
                         true;
                     _Other ->
@@ -176,7 +176,7 @@ t_not_so_sticky(_) ->
     ?wait(subscribed(<<"group1">>, <<"foo/bar">>, SPid1), 1000),
     emqx_session:publish(SPid1, 1, Message1),
     ?wait(case emqx_mock_client:get_last_message(ConnPid1) of
-              {publish, _, #message{payload = <<"hello1">>}} -> true;
+              [{publish, _, #message{payload = <<"hello1">>}}] -> true;
               Other -> Other
           end, 1000),
     emqx_mock_client:close_session(ConnPid1),
@@ -185,7 +185,7 @@ t_not_so_sticky(_) ->
     ?wait(subscribed(<<"group1">>, <<"foo/#">>, SPid2), 1000),
     emqx_session:publish(SPid2, 2, Message2),
     ?wait(case emqx_mock_client:get_last_message(ConnPid2) of
-              {publish, _, #message{payload = <<"hello2">>}} -> true;
+              [{publish, _, #message{payload = <<"hello2">>}}] -> true;
               Other -> Other
           end, 1000),
     emqx_mock_client:close_session(ConnPid2),
@@ -240,7 +240,7 @@ test_two_messages(Strategy, WithAck) ->
 last_message(_ExpectedPayload, []) -> <<"not yet?">>;
 last_message(ExpectedPayload, [Pid | Pids]) ->
     case emqx_mock_client:get_last_message(Pid) of
-        {publish, _, #message{payload = ExpectedPayload}} -> {true, Pid};
+        [{publish, _, #message{payload = ExpectedPayload}}] -> {true, Pid};
         _Other -> last_message(ExpectedPayload, Pids)
     end.