Browse Source

Improve the connection, ws_connection and channel modules

Feng Lee 6 years ago
parent
commit
cce0dbd3cf
5 changed files with 425 additions and 413 deletions
  1. 87 99
      src/emqx_channel.erl
  2. 191 165
      src/emqx_connection.erl
  3. 129 138
      src/emqx_ws_connection.erl
  4. 17 11
      test/emqx_channel_SUITE.erl
  5. 1 0
      test/emqx_pool_SUITE.erl

+ 87 - 99
src/emqx_channel.erl

@@ -40,14 +40,9 @@
         , handle_call/2
         , handle_info/2
         , handle_timeout/3
-        , disconnect/2
         , terminate/2
         ]).
 
--export([ received/2
-        , sent/2
-        ]).
-
 -import(emqx_misc,
         [ run_fold/3
         , pipeline/3
@@ -75,8 +70,8 @@
           pub_stats :: emqx_types:stats(),
           %% Timers
           timers :: #{atom() => disabled | maybe(reference())},
-          %% Fsm State
-          state :: fsm_state(),
+          %% Conn State
+          conn_state :: conn_state(),
           %% GC State
           gc_state :: maybe(emqx_gc:gc_state()),
           %% Takeover
@@ -89,13 +84,7 @@
 
 -opaque(channel() :: #channel{}).
 
--type(fsm_state() :: #{state_name := initialized
-                                   | connecting
-                                   | connected
-                                   | disconnected,
-                       connected_at := pos_integer(),
-                       disconnected_at := pos_integer()
-                      }).
+-type(conn_state() :: idle | connecting | connected | disconnected).
 
 -type(action() :: {enter, connected | disconnected}
                 | {close, Reason :: atom()}
@@ -113,7 +102,7 @@
           will_timer   => will_message
          }).
 
--define(ATTR_KEYS, [conninfo, clientinfo, state, session]).
+-define(ATTR_KEYS, [conninfo, clientinfo, session, conn_state]).
 
 -define(INFO_KEYS, ?ATTR_KEYS ++ [keepalive, will_msg, topic_aliases,
                                   alias_maximum, gc_state]).
@@ -136,8 +125,8 @@ info(clientinfo, #channel{clientinfo = ClientInfo}) ->
     ClientInfo;
 info(session, #channel{session = Session}) ->
     maybe_apply(fun emqx_session:info/1, Session);
-info(state, #channel{state = State}) ->
-    State;
+info(conn_state, #channel{conn_state = ConnState}) ->
+    ConnState;
 info(keepalive, #channel{keepalive = Keepalive}) ->
     maybe_apply(fun emqx_keepalive:info/1, Keepalive);
 info(topic_aliases, #channel{topic_aliases = Aliases}) ->
@@ -211,7 +200,7 @@ init(ConnInfo = #{peername := {PeerHost, _Port}}, Options) ->
              clientinfo = ClientInfo,
              pub_stats  = #{},
              timers     = #{stats_timer => StatsTimer},
-             state      = #{state_name => initialized},
+             conn_state = idle,
              gc_state   = init_gc_state(Zone),
              takeover   = false,
              resuming   = false,
@@ -228,12 +217,16 @@ init_gc_state(Zone) ->
 %% Handle incoming packet
 %%--------------------------------------------------------------------
 
--spec(handle_in(emqx_types:packet(), channel())
+-spec(handle_in(Bytes :: pos_integer() | emqx_types:packet(), channel())
       -> {ok, channel()}
        | {ok, output(), channel()}
        | {stop, Reason :: term(), channel()}
        | {stop, Reason :: term(), output(), channel()}).
-handle_in(?CONNECT_PACKET(_), Channel = #channel{state = #{state_name := connected}}) ->
+handle_in(Bytes, Channel) when is_integer(Bytes) ->
+    NChannel = maybe_gc_and_check_oom(Bytes, Channel),
+    {ok, ensure_timer(stats_timer, NChannel)};
+
+handle_in(?CONNECT_PACKET(_), Channel = #channel{conn_state = connected}) ->
      handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel);
 
 handle_in(?CONNECT_PACKET(ConnPkt), Channel) ->
@@ -347,7 +340,7 @@ handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters),
     end;
 
 handle_in(?PACKET(?PINGREQ), Channel) ->
-    {ok, Channel, {outgoing, ?PACKET(?PINGRESP)}};
+    {ok, ?PACKET(?PINGRESP), Channel};
 
 handle_in(?DISCONNECT_PACKET(ReasonCode, Properties), Channel = #channel{conninfo = ConnInfo}) ->
     #{proto_ver := ProtoVer, expiry_interval := OldInterval} = ConnInfo,
@@ -371,19 +364,19 @@ handle_in(?DISCONNECT_PACKET(ReasonCode, Properties), Channel = #channel{conninf
 handle_in(?AUTH_PACKET(), Channel) ->
     handle_out({disconnect, ?RC_IMPLEMENTATION_SPECIFIC_ERROR}, Channel);
 
-handle_in({frame_error, Reason}, Channel = #channel{state = FsmState}) ->
-    case FsmState of
-        #{state_name := initialized} ->
-            {stop, {shutdown, Reason}, Channel};
-        #{state_name := connecting} ->
-            Packet = ?CONNACK_PACKET(?RC_MALFORMED_PACKET),
-            {stop, {shutdown, Reason}, Packet, Channel};
-        #{state_name := connected} ->
-            handle_out({disconnect, ?RC_MALFORMED_PACKET}, Channel);
-        #{state_name := disconnected} ->
-            ?LOG(error, "Unexpected frame error: ~p", [Reason]),
-            {ok, Channel}
-    end;
+handle_in({frame_error, Reason}, Channel = #channel{conn_state = idle}) ->
+    {stop, {shutdown, Reason}, Channel};
+
+handle_in({frame_error, Reason}, Channel = #channel{conn_state = connecting}) ->
+    Packet = ?CONNACK_PACKET(?RC_MALFORMED_PACKET),
+    {stop, {shutdown, Reason}, Packet, Channel};
+
+handle_in({frame_error, _Reason}, Channel = #channel{conn_state = connected}) ->
+    handle_out({disconnect, ?RC_MALFORMED_PACKET}, Channel);
+
+handle_in({frame_error, Reason}, Channel = #channel{conn_state = disconnected}) ->
+    ?LOG(error, "Unexpected frame error: ~p", [Reason]),
+    {ok, Channel};
 
 handle_in(Packet, Channel) ->
     ?LOG(error, "Unexpected incoming: ~p", [Packet]),
@@ -534,31 +527,59 @@ do_unsubscribe(TopicFilter, _SubOpts, Channel =
 %% Handle outgoing packet
 %%--------------------------------------------------------------------
 
-%% TODO: RunFold or Pipeline
+-spec(handle_out(integer()|term(), channel())
+      -> {ok, channel()}
+       | {ok, output(), channel()}
+       | {stop, Reason :: term(), channel()}
+       | {stop, Reason :: term(), output(), channel()}).
+handle_out(Bytes, Channel) when is_integer(Bytes) ->
+    NChannel = maybe_gc_and_check_oom(Bytes, Channel),
+    {ok, ensure_timer(stats_timer, NChannel)};
+
+handle_out(Delivers, Channel = #channel{conn_state = disconnected,
+                                        session = Session})
+  when is_list(Delivers) ->
+    NSession = emqx_session:enqueue(Delivers, Session),
+    {ok, Channel#channel{session = NSession}};
+
+handle_out(Delivers, Channel = #channel{takeover = true,
+                                        pendings = Pendings})
+  when is_list(Delivers) ->
+    {ok, Channel#channel{pendings = lists:append(Pendings, Delivers)}};
+
+handle_out(Delivers, Channel = #channel{session = Session}) when is_list(Delivers) ->
+    case emqx_session:deliver(Delivers, Session) of
+        {ok, Publishes, NSession} ->
+            NChannel = Channel#channel{session = NSession},
+            handle_out({publish, Publishes}, ensure_timer(retry_timer, NChannel));
+        {ok, NSession} ->
+            {ok, Channel#channel{session = NSession}}
+    end;
+
 handle_out({connack, ?RC_SUCCESS, SP, ConnPkt},
            Channel = #channel{conninfo   = ConnInfo,
-                              clientinfo = ClientInfo,
-                              state      = FsmState}) ->
+                              clientinfo = ClientInfo}) ->
     AckProps = run_fold([fun enrich_caps/2,
                          fun enrich_server_keepalive/2,
                          fun enrich_assigned_clientid/2], #{}, Channel),
-    FsmState1 = FsmState#{state_name => connected,
-                          connected_at => erlang:system_time(second)
-                         },
-    Channel1 = Channel#channel{state = FsmState1,
+    ConnInfo1 = ConnInfo#{connected_at => erlang:system_time(second)},
+    Channel1 = Channel#channel{conninfo = ConnInfo1,
                                will_msg = emqx_packet:will_msg(ConnPkt),
+                               conn_state = connected,
                                alias_maximum = init_alias_maximum(ConnPkt, ClientInfo)
                               },
     Channel2 = ensure_keepalive(AckProps, Channel1),
     ok = emqx_hooks:run('client.connected', [ClientInfo, ?RC_SUCCESS, ConnInfo]),
     AckPacket = ?CONNACK_PACKET(?RC_SUCCESS, SP, AckProps),
     case maybe_resume_session(Channel2) of
-        ignore -> {ok, AckPacket, Channel2};
+        ignore ->
+            Output = [{outgoing, AckPacket}, {enter, connected}],
+            {ok, Output, Channel2};
         {ok, Publishes, NSession} ->
             Channel3 = Channel2#channel{session  = NSession,
                                         resuming = false,
                                         pendings = []},
-            {ok, Packets, _} = handle_out({publish, Publishes}, Channel3),
+            {ok, {outgoing, Packets}, _} = handle_out({publish, Publishes}, Channel3),
             Output = [{outgoing, [AckPacket|Packets]}, {enter, connected}],
             {ok, Output, Channel3}
     end;
@@ -573,24 +594,6 @@ handle_out({connack, ReasonCode, _ConnPkt}, Channel = #channel{conninfo = ConnIn
     Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer),
     {stop, {shutdown, Reason}, ?CONNACK_PACKET(ReasonCode1), Channel};
 
-handle_out({deliver, Delivers}, Channel = #channel{state   = #{state_name := disconnected},
-                                                   session = Session}) ->
-    NSession = emqx_session:enqueue(Delivers, Session),
-    {ok, Channel#channel{session = NSession}};
-
-handle_out({deliver, Delivers}, Channel = #channel{takeover = true,
-                                                   pendings = Pendings}) ->
-    {ok, Channel#channel{pendings = lists:append(Pendings, Delivers)}};
-
-handle_out({deliver, Delivers}, Channel = #channel{session = Session}) ->
-    case emqx_session:deliver(Delivers, Session) of
-        {ok, Publishes, NSession} ->
-            NChannel = Channel#channel{session = NSession},
-            handle_out({publish, Publishes}, ensure_timer(retry_timer, NChannel));
-        {ok, NSession} ->
-            {ok, Channel#channel{session = NSession}}
-    end;
-
 handle_out({publish, Publishes}, Channel) when is_list(Publishes) ->
     Packets = lists:foldl(
                 fun(Publish, Acc) ->
@@ -679,28 +682,33 @@ handle_out({Type, Data}, Channel) ->
 %% Handle call
 %%--------------------------------------------------------------------
 
+-spec(handle_call(Req :: term(), channel())
+      -> {reply, Reply :: term(), channel()}
+       | {stop, Reason :: term(), Reply :: term(), channel()}).
 handle_call(kick, Channel) ->
     {stop, {shutdown, kicked}, ok, Channel};
 
-handle_call(discard, Channel = #channel{state = #{state_name := connected}}) ->
+handle_call(discard, Channel = #channel{conn_state = connected}) ->
     Packet = ?DISCONNECT_PACKET(?RC_SESSION_TAKEN_OVER),
-    {stop, {shutdown, discarded}, Packet, ok, Channel};
-handle_call(discard, Channel = #channel{state = #{state_name := disconnected}}) ->
+    {stop, {shutdown, discarded}, ok, Packet, Channel};
+
+handle_call(discard, Channel = #channel{conn_state = disconnected}) ->
     {stop, {shutdown, discarded}, ok, Channel};
 
 %% Session Takeover
 handle_call({takeover, 'begin'}, Channel = #channel{session = Session}) ->
-    {ok, Session, Channel#channel{takeover = true}};
+    {reply, Session, Channel#channel{takeover = true}};
 
 handle_call({takeover, 'end'}, Channel = #channel{session  = Session,
                                                   pendings = Pendings}) ->
     ok = emqx_session:takeover(Session),
-    AllPendings = lists:append(emqx_misc:drain_deliver(), Pendings),
+    Delivers = emqx_misc:drain_deliver(),
+    AllPendings = lists:append(Delivers, Pendings),
     {stop, {shutdown, takeovered}, AllPendings, Channel};
 
 handle_call(Req, Channel) ->
     ?LOG(error, "Unexpected call: ~p", [Req]),
-    {ok, ignored, Channel}.
+    {reply, ignored, Channel}.
 
 %%--------------------------------------------------------------------
 %% Handle Info
@@ -727,26 +735,21 @@ handle_info({register, Attrs, Stats}, #channel{clientinfo = #{clientid := Client
     emqx_cm:set_chan_attrs(ClientId, Attrs),
     emqx_cm:set_chan_stats(ClientId, Stats);
 
-%%TODO: Fixme later
-%%handle_info(disconnected, Channel = #channel{connected = undefined}) ->
-%%    shutdown(closed, Channel);
-
-handle_info(disconnected, Channel = #channel{state = #{state_name := disconnected}}) ->
+handle_info({sock_closed, _Reason}, Channel = #channel{conn_state = disconnected}) ->
     {ok, Channel};
 
-handle_info(disconnected, Channel = #channel{conninfo = #{expiry_interval := ExpiryInterval},
-                                             clientinfo = ClientInfo = #{zone := Zone},
-                                             will_msg = WillMsg}) ->
+handle_info({sock_closed, _Reason}, Channel = #channel{conninfo = ConnInfo,
+                                                       clientinfo = ClientInfo = #{zone := Zone},
+                                                       will_msg = WillMsg}) ->
     emqx_zone:enable_flapping_detect(Zone) andalso emqx_flapping:detect(ClientInfo),
-    Channel1 = ensure_disconnected(Channel),
+    ConnInfo1 = ConnInfo#{disconnected_at => erlang:system_time(second)},
+    Channel1 = Channel#channel{conninfo = ConnInfo1, conn_state = disconnected},
     Channel2 = case timer:seconds(will_delay_interval(WillMsg)) of
-                   0 ->
-                       publish_will_msg(WillMsg),
-                       Channel1#channel{will_msg = undefined};
-                   _ ->
-                       ensure_timer(will_timer, Channel1)
+                   0 -> publish_will_msg(WillMsg),
+                        Channel1#channel{will_msg = undefined};
+                   _ -> ensure_timer(will_timer, Channel1)
                end,
-    case ExpiryInterval of
+    case maps:get(expiry_interval, ConnInfo) of
         ?UINT_MAX ->
             {ok, Channel2};
         Int when Int > 0 ->
@@ -757,6 +760,7 @@ handle_info(disconnected, Channel = #channel{conninfo = #{expiry_interval := Exp
 
 handle_info(Info, Channel) ->
     ?LOG(error, "Unexpected info: ~p~n", [Info]),
+    error(unexpected_info),
     {ok, Channel}.
 
 %%--------------------------------------------------------------------
@@ -870,14 +874,11 @@ will_delay_interval(undefined) -> 0;
 will_delay_interval(WillMsg) ->
     emqx_message:get_header('Will-Delay-Interval', WillMsg, 0).
 
-%% TODO: Implement later.
-disconnect(_Reason, Channel) -> {ok, Channel}.
-
 %%--------------------------------------------------------------------
 %% Terminate
 %%--------------------------------------------------------------------
 
-terminate(_, #channel{state = #{state_name := initialized}}) ->
+terminate(_, #channel{conn_state = idle}) ->
     ok;
 terminate(normal, #channel{conninfo = ConnInfo, clientinfo = ClientInfo}) ->
     ok = emqx_hooks:run('client.disconnected', [ClientInfo, normal, ConnInfo]);
@@ -888,14 +889,6 @@ terminate(Reason, #channel{conninfo = ConnInfo, clientinfo = ClientInfo, will_ms
     publish_will_msg(WillMsg),
     ok = emqx_hooks:run('client.disconnected', [ClientInfo, Reason, ConnInfo]).
 
--spec(received(pos_integer(), channel()) -> channel()).
-received(Oct, Channel) ->
-    ensure_timer(stats_timer, maybe_gc_and_check_oom(Oct, Channel)).
-
--spec(sent(pos_integer(), channel()) -> channel()).
-sent(Oct, Channel) ->
-    ensure_timer(stats_timer, maybe_gc_and_check_oom(Oct, Channel)).
-
 %% TODO: Improve will msg:)
 publish_will_msg(undefined) ->
     ok;
@@ -1153,11 +1146,6 @@ init_alias_maximum(#mqtt_packet_connect{proto_ver  = ?MQTT_PROTO_V5,
       inbound  => emqx_mqtt_caps:get_caps(Zone, max_topic_alias, 0)};
 init_alias_maximum(_ConnPkt, _ClientInfo) -> undefined.
 
-ensure_disconnected(Channel = #channel{state = FsmState}) ->
-    Channel#channel{state = FsmState#{state_name := disconnected,
-                                      disconnected_at => erlang:system_time(second)
-                                     }}.
-
 ensure_keepalive(#{'Server-Keep-Alive' := Interval}, Channel) ->
     ensure_keepalive_timer(Interval, Channel);
 ensure_keepalive(_AckProps, Channel = #channel{conninfo = ConnInfo}) ->

+ 191 - 165
src/emqx_connection.erl

@@ -14,6 +14,7 @@
 %% limitations under the License.
 %%--------------------------------------------------------------------
 
+%% MQTT/TCP Connection
 -module(emqx_connection).
 
 -include("emqx.hrl").
@@ -25,7 +26,6 @@
 
 %% API
 -export([ start_link/3
-        , call/2
         , stop/1
         ]).
 
@@ -33,6 +33,9 @@
         , stats/1
         ]).
 
+-export([call/2]).
+
+%% callback
 -export([init/4]).
 
 %% Sys callbacks
@@ -56,10 +59,10 @@
           peername :: emqx_types:peername(),
           %% Sockname of the connection
           sockname :: emqx_types:peername(),
+          %% Sock state
+          sockstate :: emqx_types:sockstate(),
           %% The {active, N} option
           active_n :: pos_integer(),
-          %% The active state
-          active_st :: idle | running | blocked | closed,
           %% Publish Limit
           pub_limit :: maybe(esockd_rate_limit:bucket()),
           %% Rate Limit
@@ -71,7 +74,7 @@
           %% Serialize function
           serialize :: emqx_frame:serialize_fun(),
           %% Channel State
-          chan_state :: emqx_channel:channel(),
+          channel :: emqx_channel:channel(),
           %% Idle timer
           idle_timer :: reference()
         }).
@@ -79,7 +82,7 @@
 -type(state() :: #state{}).
 
 -define(ACTIVE_N, 100).
--define(INFO_KEYS, [socktype, peername, sockname, active_n, active_state,
+-define(INFO_KEYS, [socktype, peername, sockname, sockstate, active_n,
                     pub_limit, rate_limit]).
 -define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]).
 -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]).
@@ -98,8 +101,8 @@ start_link(Transport, Socket, Options) ->
 -spec(info(pid()|state()) -> emqx_types:infos()).
 info(CPid) when is_pid(CPid) ->
     call(CPid, info);
-info(State = #state{chan_state = ChanState}) ->
-    ChanInfo = emqx_channel:info(ChanState),
+info(State = #state{channel = Channel}) ->
+    ChanInfo = emqx_channel:info(Channel),
     SockInfo = maps:from_list(info(?INFO_KEYS, State)),
     maps:merge(ChanInfo, #{sockinfo => SockInfo}).
 
@@ -111,16 +114,16 @@ info(peername, #state{peername = Peername}) ->
     Peername;
 info(sockname, #state{sockname = Sockname}) ->
     Sockname;
+info(sockstate, #state{sockstate = SockSt}) ->
+    SockSt;
 info(active_n, #state{active_n = ActiveN}) ->
     ActiveN;
-info(active_st, #state{active_st= ActiveSt}) ->
-    ActiveSt;
 info(pub_limit, #state{pub_limit = PubLimit}) ->
     limit_info(PubLimit);
 info(rate_limit, #state{rate_limit = RateLimit}) ->
     limit_info(RateLimit);
-info(chan_state, #state{chan_state = ChanState}) ->
-    emqx_channel:info(ChanState).
+info(channel, #state{channel = Channel}) ->
+    emqx_channel:info(Channel).
 
 limit_info(Limit) ->
     emqx_misc:maybe_apply(fun esockd_rate_limit:info/1, Limit).
@@ -129,15 +132,15 @@ limit_info(Limit) ->
 -spec(stats(pid()|state()) -> emqx_types:stats()).
 stats(CPid) when is_pid(CPid) ->
     call(CPid, stats);
-stats(#state{transport  = Transport,
-             socket     = Socket,
-             chan_state = ChanState}) ->
+stats(#state{transport = Transport,
+             socket    = Socket,
+             channel   = Channel}) ->
     SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of
                     {ok, Ss}   -> Ss;
                     {error, _} -> []
                 end,
     ConnStats = emqx_pd:get_counters(?CONN_STATS),
-    ChanStats = emqx_channel:stats(ChanState),
+    ChanStats = emqx_channel:stats(Channel),
     ProcStats = emqx_misc:proc_stats(),
     lists:append([SockStats, ConnStats, ChanStats, ProcStats]).
 
@@ -152,7 +155,23 @@ stop(Pid) ->
 %%--------------------------------------------------------------------
 
 init(Parent, Transport, RawSocket, Options) ->
-    {ok, Socket} = Transport:wait(RawSocket),
+    case Transport:wait(RawSocket) of
+        {ok, Socket} ->
+            do_init(Parent, Transport, Socket, Options);
+        {error, Reason} when Reason =:= enotconn;
+                             Reason =:= einval;
+                             Reason =:= closed ->
+            Transport:fast_close(RawSocket),
+            exit(normal);
+        {error, timeout} ->
+            Transport:fast_close(RawSocket),
+            exit({shutdown, ssl_upgrade_timeout});
+        {error, Reason} ->
+            Transport:fast_close(RawSocket),
+            exit(Reason)
+    end.
+
+do_init(Parent, Transport, Socket, Options) ->
     {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
     {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
     emqx_logger:set_metadata_peername(esockd_net:format(Peername)),
@@ -170,27 +189,32 @@ init(Parent, Transport, RawSocket, Options) ->
     FrameOpts = emqx_zone:frame_options(Zone),
     ParseState = emqx_frame:initial_parse_state(FrameOpts),
     Serialize = emqx_frame:serialize_fun(),
-    ChanState = emqx_channel:init(ConnInfo, Options),
+    Channel = emqx_channel:init(ConnInfo, Options),
     IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000),
     IdleTimer = emqx_misc:start_timer(IdleTimout, idle_timeout),
     HibAfterTimeout = emqx_zone:get_env(Zone, hibernate_after, IdleTimout*2),
-    State = #state{parent       = Parent,
-                   transport    = Transport,
-                   socket       = Socket,
-                   peername     = Peername,
-                   sockname     = Sockname,
-                   active_n     = ActiveN,
-                   active_st    = idle,
-                   pub_limit    = PubLimit,
-                   rate_limit   = RateLimit,
-                   parse_state  = ParseState,
-                   serialize    = Serialize,
-                   chan_state   = ChanState,
-                   idle_timer   = IdleTimer
+    State = #state{parent      = Parent,
+                   transport   = Transport,
+                   socket      = Socket,
+                   peername    = Peername,
+                   sockname    = Sockname,
+                   sockstate   = idle,
+                   active_n    = ActiveN,
+                   pub_limit   = PubLimit,
+                   rate_limit  = RateLimit,
+                   parse_state = ParseState,
+                   serialize   = Serialize,
+                   channel     = Channel,
+                   idle_timer  = IdleTimer
                   },
     case activate_socket(State) of
         {ok, NState} ->
             recvloop(NState, #{hibernate_after => HibAfterTimeout});
+        {error, Reason} when Reason =:= einval;
+                             Reason =:= enotconn;
+                             Reason =:= closed ->
+            Transport:fast_close(Socket),
+            exit(normal);
         {error, Reason} ->
             Transport:fast_close(Socket),
             erlang:exit({shutdown, Reason})
@@ -208,7 +232,8 @@ recvloop(State = #state{parent = Parent},
          Options = #{hibernate_after := HibAfterTimeout}) ->
     receive
         {system, From, Request} ->
-            sys:handle_system_msg(Request, From, Parent, ?MODULE, [], {State, Options});
+            sys:handle_system_msg(Request, From, Parent,
+                                  ?MODULE, [], {State, Options});
         {'EXIT', Parent, Reason} ->
             terminate(Reason, State);
         Msg ->
@@ -230,6 +255,7 @@ wakeup_from_hib(State, Options) ->
 
 process_msg([], State, Options) ->
     recvloop(State, Options);
+
 process_msg([Msg|More], State, Options) ->
     case catch handle_msg(Msg, State) of
         ok ->
@@ -246,11 +272,6 @@ process_msg([Msg|More], State, Options) ->
             terminate(Reason, State)
     end.
 
--compile({inline, [append_msg/2]}).
-append_msg(NextMsgs, L) when is_list(NextMsgs) ->
-    lists:append(NextMsgs, L);
-append_msg(NextMsg, L) -> [NextMsg|L].
-
 %%--------------------------------------------------------------------
 %% Handle a Msg
 
@@ -261,51 +282,37 @@ handle_msg({'$gen_call', From, Req}, State) ->
             {ok, NState};
         {stop, Reason, Reply, NState} ->
             gen_server:reply(From, Reply),
-            {stop, Reason, NState}
+            stop(Reason, NState)
     end;
 
-%% Handle incoming data
-handle_msg({Inet, _Sock, Data}, State = #state{chan_state = ChanState})
+handle_msg({Inet, _Sock, Data}, State = #state{channel = Channel})
   when Inet == tcp; Inet == ssl ->
     ?LOG(debug, "RECV ~p", [Data]),
     Oct = iolist_size(Data),
     emqx_pd:update_counter(incoming_bytes, Oct),
     ok = emqx_metrics:inc('bytes.received', Oct),
-    NChanState = emqx_channel:received(Oct, ChanState),
-    State1 = State#state{chan_state = NChanState},
-    {Packets, State2} = parse_incoming(Data, State1),
-    {ok, next_incoming_msgs(Packets), State2};
+    {ok, NChannel} = emqx_channel:handle_in(Oct, Channel),
+    process_incoming(Data, State#state{channel = NChannel});
 
-%% Handle incoming packets
 handle_msg({incoming, Packet = ?CONNECT_PACKET(ConnPkt)},
            State = #state{idle_timer = IdleTimer}) ->
     ok = emqx_misc:cancel_timer(IdleTimer),
-    NState = State#state{serialize  = emqx_frame:serialize_fun(ConnPkt),
+    Serialize = emqx_frame:serialize_fun(ConnPkt),
+    NState = State#state{serialize  = Serialize,
                          idle_timer = undefined
                         },
     handle_incoming(Packet, NState);
 
-handle_msg({incoming, Packet}, State) when is_record(Packet, mqtt_packet) ->
+handle_msg({incoming, Packet}, State) ->
     handle_incoming(Packet, State);
 
-handle_msg({enter, connected}, State = #state{active_n   = ActiveN,
-                                              active_st  = ActiveSt,
-                                              chan_state = ChanState
-                                             }) ->
-    ChanAttrs = emqx_channel:attrs(ChanState),
-    SockAttrs = #{active_n  => ActiveN,
-                  active_st => ActiveSt
-                 },
-    Attrs = maps:merge(ChanAttrs, #{sockinfo => SockAttrs}),
-    emqx_channel:handle_info({register, Attrs, stats(State)}, ChanState);
-
 handle_msg({Error, _Sock, Reason}, State)
   when Error == tcp_error; Error == ssl_error ->
-    handle_sockerr(Reason, State);
+    handle_info({sock_error, Reason}, State);
 
 handle_msg({Closed, _Sock}, State)
   when Closed == tcp_closed; Closed == ssl_closed ->
-    socket_closed(Closed, State);
+    handle_info(sock_closed, State);
 
 handle_msg({Passive, _Sock}, State)
   when Passive == tcp_passive; Passive == ssl_passive ->
@@ -314,73 +321,67 @@ handle_msg({Passive, _Sock}, State)
     case activate_socket(NState) of
         {ok, NState} -> {ok, NState};
         {error, Reason} ->
-            handle_sockerr(Reason, State)
+            {ok, {sock_error, Reason}, NState}
     end;
 
 %% Rate limit timer expired.
 handle_msg(activate_socket, State) ->
-    NState = State#state{active_st   = idle,
+    NState = State#state{sockstate   = idle,
                          limit_timer = undefined
                         },
     case activate_socket(NState) of
         {ok, NState} -> {ok, NState};
         {error, Reason} ->
-            handle_sockerr(Reason, State)
+            {ok, {sock_error, Reason}, State}
     end;
 
 handle_msg(Deliver = {deliver, _Topic, _Msg},
-           State = #state{chan_state = ChanState}) ->
+           State = #state{channel = Channel}) ->
     Delivers = emqx_misc:drain_deliver([Deliver]),
-    Result = emqx_channel:handle_out({deliver, Delivers}, ChanState),
-    handle_chan_return(Result, State);
+    Result = emqx_channel:handle_out(Delivers, Channel),
+    handle_return(Result, State);
 
 handle_msg({outgoing, Packets}, State) ->
-    handle_outgoing(Packets, State);
+    {ok, handle_outgoing(Packets, State)};
 
 %% something sent
 handle_msg({inet_reply, _Sock, ok}, _State) ->
     ok;
 
 handle_msg({inet_reply, _Sock, {error, Reason}}, State) ->
-    handle_sockerr(Reason, State);
+    handle_info({sock_error, Reason}, State);
 
-handle_msg({timeout, TRef, TMsg}, State) when is_reference(TRef) ->
+handle_msg({timeout, TRef, TMsg}, State) ->
     handle_timeout(TRef, TMsg, State);
 
 handle_msg(Shutdown = {shutdown, _Reason}, State) ->
-    {stop, Shutdown, State};
-
-handle_msg(Msg, State = #state{chan_state = ChanState}) ->
-    case emqx_channel:handle_info(Msg, ChanState) of
-        {ok, NChanState} ->
-            {ok, State#state{chan_state = NChanState}};
-        {stop, Reason, NChanState} ->
-            {stop, Reason, State#state{chan_state = NChanState}}
-    end.
+    stop(Shutdown, State);
+
+handle_msg(Msg, State) -> handle_info(Msg, State).
 
 %%--------------------------------------------------------------------
 %% Terminate
 
-terminate(Reason, #state{transport  = Transport,
-                         socket     = Socket,
-                         active_st  = ActiveSt,
-                         chan_state = ChanState}) ->
+terminate(Reason, #state{transport = Transport,
+                         socket    = Socket,
+                         sockstate = SockSt,
+                         channel   = Channel}) ->
     ?LOG(debug, "Terminated for ~p", [Reason]),
-    ActiveSt =:= closed orelse Transport:fast_close(Socket),
-    emqx_channel:terminate(Reason, ChanState),
+    SockSt =:= closed orelse Transport:fast_close(Socket),
+    emqx_channel:terminate(Reason, Channel),
     exit(Reason).
 
 %%--------------------------------------------------------------------
 %% Sys callbacks
 
 system_continue(_Parent, _Deb, {State, Options}) ->
-	recvloop(State, Options).
+    recvloop(State, Options).
 
 system_terminate(Reason, _Parent, _Deb, {State, _}) ->
-	terminate(Reason, State).
+    terminate(Reason, State).
 
 system_code_change(Misc, _, _, _) ->
-	{ok, Misc}.
+    {ok, Misc}.
 
 system_get_state({State, _Options}) ->
     {ok, State}.
@@ -394,24 +395,23 @@ handle_call(_From, info, State) ->
 handle_call(_From, stats, State) ->
     {reply, stats(State), State};
 
-%% TODO: the handle_outgoing is not right ...
-handle_call(_From, Req, State = #state{chan_state = ChanState}) ->
-    case emqx_channel:handle_call(Req, ChanState) of
-        {ok, Reply, NChanState} ->
-            {reply, Reply, State#state{chan_state = NChanState}};
-        {stop, Reason, Reply, NChanState} ->
-            {stop, Reason, Reply, State#state{chan_state = NChanState}};
-        {stop, Reason, Packet, Reply, NChanState} ->
-            State1 = State#state{chan_state = NChanState},
-            {ok, State2} = handle_outgoing(Packet, State1),
-            {stop, Reason, Reply, State2}
+handle_call(_From, Req, State = #state{channel = Channel}) ->
+    case emqx_channel:handle_call(Req, Channel) of
+        {reply, Reply, NChannel} ->
+            {reply, Reply, State#state{channel = NChannel}};
+        {stop, Reason, Reply, NChannel} ->
+            {stop, Reason, Reply, State#state{channel = NChannel}};
+        {stop, Reason, Reply, OutPacket, NChannel} ->
+            NState = State#state{channel = NChannel},
+            NState1 = handle_outgoing(OutPacket, NState),
+            {stop, Reason, Reply, NState1}
     end.
 
 %%--------------------------------------------------------------------
 %% Handle timeout
 
 handle_timeout(TRef, idle_timeout, State = #state{idle_timer = TRef}) ->
-    {stop, idle_timeout, State};
+    stop(idle_timeout, State);
 
 handle_timeout(TRef, emit_stats, State) ->
     handle_timeout(TRef, {emit_stats, stats(State)}, State);
@@ -422,16 +422,21 @@ handle_timeout(TRef, keepalive, State = #state{transport = Transport,
         {ok, [{recv_oct, RecvOct}]} ->
             handle_timeout(TRef, {keepalive, RecvOct}, State);
         {error, Reason} ->
-            handle_sockerr(Reason, State)
+            handle_info({sockerr, Reason}, State)
     end;
 
-handle_timeout(TRef, Msg, State = #state{chan_state = ChanState}) ->
-    Result = emqx_channel:handle_timeout(TRef, Msg, ChanState),
-    handle_chan_return(Result, State).
+handle_timeout(TRef, Msg, State = #state{channel = Channel}) ->
+    handle_return(emqx_channel:handle_timeout(TRef, Msg, Channel), State).
 
 %%--------------------------------------------------------------------
-%% Parse incoming data.
+%% Process/Parse incoming data.
 
+-compile({inline, [process_incoming/2]}).
+process_incoming(Data, State) ->
+    {Packets, NState} = parse_incoming(Data, State),
+    {ok, next_incoming_msgs(Packets), NState}.
+
+-compile({inline, [parse_incoming/2]}).
 parse_incoming(Data, State) ->
     parse_incoming(Data, [], State).
 
@@ -460,30 +465,30 @@ next_incoming_msgs(Packets) ->
 %%--------------------------------------------------------------------
 %% Handle incoming packet
 
-handle_incoming(Packet = ?PACKET(Type), State = #state{chan_state = ChanState}) ->
+handle_incoming(Packet = ?PACKET(Type), State = #state{channel = Channel}) ->
     _ = inc_incoming_stats(Type),
     ok = emqx_metrics:inc_recv(Packet),
     ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]),
-    Result = emqx_channel:handle_in(Packet, ChanState),
-    handle_chan_return(Result, State);
-
-handle_incoming(FrameError = {frame_error, _Reason}, State = #state{chan_state = ChanState}) ->
-    Result = emqx_channel:handle_in(FrameError, ChanState),
-    handle_chan_return(Result, State).
-
-handle_chan_return({ok, NChanState}, State) ->
-    {ok, State#state{chan_state = NChanState}};
-handle_chan_return({ok, OutPacket, NChanState}, State)
-  when is_record(OutPacket, mqtt_packet) ->
-    {ok, {outgoing, OutPacket}, State#state{chan_state = NChanState}};
-handle_chan_return({ok, Actions, NChanState}, State) ->
-    {ok, Actions, State#state{chan_state = NChanState}};
-handle_chan_return({stop, Reason, NChanState}, State) ->
-    {stop, Reason, State#state{chan_state = NChanState}};
-handle_chan_return({stop, Reason, OutPackets, NChanState}, State) ->
-    NState = State#state{chan_state = NChanState},
-    {ok, NState1} = handle_outgoing(OutPackets, NState),
-    {stop, Reason, NState1}.
+    handle_return(emqx_channel:handle_in(Packet, Channel), State);
+
+handle_incoming(FrameError, State = #state{channel = Channel}) ->
+    handle_return(emqx_channel:handle_in(FrameError, Channel), State).
+
+%%--------------------------------------------------------------------
+%% Handle channel return
+
+handle_return(ok, State) ->
+    {ok, State};
+handle_return({ok, NChannel}, State) ->
+    {ok, State#state{channel = NChannel}};
+handle_return({ok, Replies, NChannel}, State) ->
+    {ok, next_msgs(Replies), State#state{channel = NChannel}};
+handle_return({stop, Reason, NChannel}, State) ->
+    stop(Reason, State#state{channel = NChannel});
+handle_return({stop, Reason, OutPacket, NChannel}, State) ->
+    NState = State#state{channel = NChannel},
+    NState1 = handle_outgoing(OutPacket, NState),
+    stop(Reason, NState1).
 
 %%--------------------------------------------------------------------
 %% Handle outgoing packets
@@ -510,70 +515,73 @@ serialize_and_inc_stats_fun(#state{serialize = Serialize}) ->
 %%--------------------------------------------------------------------
 %% Send data
 
-send(IoData, State = #state{transport  = Transport,
-                            socket     = Socket,
-                            chan_state = ChanState}) ->
+send(IoData, State = #state{transport = Transport,
+                            socket    = Socket,
+                            channel   = Channel}) ->
     Oct = iolist_size(IoData),
     ok = emqx_metrics:inc('bytes.sent', Oct),
     case Transport:async_send(Socket, IoData) of
         ok ->
-            NChanState = emqx_channel:sent(Oct, ChanState),
-            {ok, State#state{chan_state = NChanState}};
+            {ok, NChannel} = emqx_channel:handle_out(Oct, Channel),
+            State#state{channel = NChannel};
         Error = {error, _Reason} ->
             %% Simulate an inet_reply to postpone handling the error
-            self() ! {inet_reply, Socket, Error},
-            {ok, State}
+            self() ! {inet_reply, Socket, Error}, State
     end.
 
 %%--------------------------------------------------------------------
-%% Handle sockerr
+%% Handle Info
 
-handle_sockerr(_Reason, State = #state{active_st = closed}) ->
-    {ok, State};
+handle_info({enter, _}, State = #state{active_n  = ActiveN,
+                                       sockstate = SockSt,
+                                       channel   = Channel}) ->
+    ChanAttrs = emqx_channel:attrs(Channel),
+    SockAttrs = #{active_n  => ActiveN,
+                  sockstate => SockSt
+                 },
+    Attrs = maps:merge(ChanAttrs, #{sockinfo => SockAttrs}),
+    handle_info({register, Attrs, stats(State)}, State);
 
-handle_sockerr(Reason, State = #state{transport  = Transport,
-                                      socket     = Socket,
-                                      chan_state = ChanState}) ->
+handle_info({sockerr, _Reason}, #state{sockstate = closed}) -> ok;
+handle_info({sockerr, Reason}, State) ->
     ?LOG(debug, "Socket error: ~p", [Reason]),
-    ok = Transport:fast_close(Socket),
-    NState = State#state{active_st = closed},
-    case emqx_channel:handle_info({sockerr, Reason}, ChanState) of
-        {ok, NChanState} ->
-            {ok, NState#state{chan_state = NChanState}};
-        {stop, NChanState} ->
-            {stop, {shutdown, Reason}, NState#state{chan_state = NChanState}}
-    end.
+    handle_info({sock_closed, Reason}, close_socket(State));
 
-socket_closed(Closed, State = #state{transport  = Transport,
-                                     socket     = Socket,
-                                     chan_state = ChanState}) ->
-    ?LOG(debug, "Socket closed: ~p", [Closed]),
-    ok = Transport:fast_close(Socket),
-    NState = State#state{active_st = closed},
-    case emqx_channel:handle_info({sock_closed, Closed}, ChanState) of
-        {ok, NChanState} ->
-            {ok, NState#state{chan_state = NChanState}};
-        {stop, NChanState} ->
-            NState = NState#state{chan_state = NChanState},
-            {stop, {shutdown, Closed}, NState}
-    end.
+handle_info(sock_closed, #state{sockstate = closed}) -> ok;
+handle_info(sock_closed, State) ->
+    ?LOG(debug, "Socket closed"),
+    handle_info({sock_closed, closed}, close_socket(State));
+
+handle_info({close, Reason}, State) ->
+    ?LOG(debug, "Force close due to : ~p", [Reason]),
+    {ok, close_socket(State)};
+
+handle_info(Info, State = #state{channel = Channel}) ->
+    handle_return(emqx_channel:handle_info(Info, Channel), State).
 
 %%--------------------------------------------------------------------
 %% Activate Socket
 
 -compile({inline, [activate_socket/1]}).
-activate_socket(State = #state{active_st = closed}) ->
+activate_socket(State = #state{sockstate = closed}) ->
     {ok, State};
-activate_socket(State = #state{active_st = blocked}) ->
+activate_socket(State = #state{sockstate = blocked}) ->
     {ok, State};
 activate_socket(State = #state{transport = Transport,
                                socket    = Socket,
                                active_n  = N}) ->
     case Transport:setopts(Socket, [{active, N}]) of
-        ok -> {ok, State#state{active_st = running}};
+        ok -> {ok, State#state{sockstate = running}};
         Error -> Error
     end.
 
+%%--------------------------------------------------------------------
+%% Close Socket
+
+close_socket(State = #state{transport = Transport, socket = Socket}) ->
+    ok = Transport:fast_close(Socket),
+    State#state{sockstate = closed}.
+
 %%--------------------------------------------------------------------
 %% Ensure rate limit
 
@@ -595,7 +603,7 @@ ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) ->
         {Pause, Rl1} ->
             ?LOG(debug, "Pause ~pms due to rate limit", [Pause]),
             TRef = erlang:send_after(Pause, self(), activate_socket),
-            NState = State#state{active_st = blocked, limit_timer = TRef},
+            NState = State#state{sockstate = blocked, limit_timer = TRef},
             setelement(Pos, NState, Rl1)
     end.
 
@@ -612,10 +620,28 @@ inc_incoming_stats(Type) when is_integer(Type) ->
         true -> ok
     end.
 
-
 -compile({inline, [inc_outgoing_stats/1]}).
 inc_outgoing_stats(Type) ->
     emqx_pd:update_counter(send_pkt, 1),
-    (Type == ?PUBLISH)
-        andalso emqx_pd:update_counter(send_msg, 1).
+    (Type == ?PUBLISH) andalso emqx_pd:update_counter(send_msg, 1).
+
+%%--------------------------------------------------------------------
+%% Helper functions
+
+-compile({inline, [append_msg/2]}).
+append_msg(Msgs, Q) when is_list(Msgs) ->
+    lists:append(Msgs, Q);
+append_msg(Msg, Q) -> [Msg|Q].
+
+-compile({inline, [next_msgs/1]}).
+next_msgs(Packet) when is_record(Packet, mqtt_packet) ->
+    {outgoing, Packet};
+next_msgs(Action) when is_tuple(Action) ->
+    Action;
+next_msgs(Actions) when is_list(Actions) ->
+    Actions.
+
+-compile({inline, [stop/2]}).
+stop(Reason, State) ->
+    {stop, Reason, State}.
 

+ 129 - 138
src/emqx_ws_connection.erl

@@ -14,7 +14,7 @@
 %% limitations under the License.
 %%--------------------------------------------------------------------
 
-%% MQTT WebSocket Connection
+%% MQTT/WS Connection
 -module(emqx_ws_connection).
 
 -include("emqx.hrl").
@@ -22,8 +22,9 @@
 -include("logger.hrl").
 -include("types.hrl").
 
--logger_header("[WsConnection]").
+-logger_header("[MQTT/WS]").
 
+%% API
 -export([ info/1
         , stats/1
         ]).
@@ -35,6 +36,7 @@
         , websocket_init/1
         , websocket_handle/2
         , websocket_info/2
+        , websocket_close/2
         , terminate/3
         ]).
 
@@ -43,14 +45,14 @@
           peername :: emqx_types:peername(),
           %% Sockname of the ws connection
           sockname :: emqx_types:peername(),
-          %% Conn state
-          conn_state :: idle | connected | disconnected,
+          %% Sock state
+          sockstate :: emqx_types:sockstate(),
           %% Parser State
           parse_state :: emqx_frame:parse_state(),
           %% Serialize function
           serialize :: emqx_frame:serialize_fun(),
-          %% Channel State
-          chan_state :: emqx_channel:channel(),
+          %% Channel
+          channel :: emqx_channel:channel(),
           %% Out Pending Packets
           pendings :: list(emqx_types:packet()),
           %% The stop reason
@@ -59,7 +61,7 @@
 
 -type(state() :: #state{}).
 
--define(INFO_KEYS, [socktype, peername, sockname, active_state]).
+-define(INFO_KEYS, [socktype, peername, sockname, sockstate]).
 -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]).
 -define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]).
 
@@ -70,8 +72,8 @@
 -spec(info(pid()|state()) -> emqx_types:infos()).
 info(WsPid) when is_pid(WsPid) ->
     call(WsPid, info);
-info(WsConn = #state{chan_state = ChanState}) ->
-    ChanInfo = emqx_channel:info(ChanState),
+info(WsConn = #state{channel = Channel}) ->
+    ChanInfo = emqx_channel:info(Channel),
     SockInfo = maps:from_list(info(?INFO_KEYS, WsConn)),
     maps:merge(ChanInfo, #{sockinfo => SockInfo}).
 
@@ -83,18 +85,18 @@ info(peername, #state{peername = Peername}) ->
     Peername;
 info(sockname, #state{sockname = Sockname}) ->
     Sockname;
-info(active_state, _State) ->
-    running;
-info(chan_state, #state{chan_state = ChanState}) ->
-    emqx_channel:info(ChanState).
+info(sockstate, #state{sockstate = SockSt}) ->
+    SockSt;
+info(channel, #state{channel = Channel}) ->
+    emqx_channel:info(Channel).
 
 -spec(stats(pid()|state()) -> emqx_types:stats()).
 stats(WsPid) when is_pid(WsPid) ->
     call(WsPid, stats);
-stats(#state{chan_state = ChanState}) ->
+stats(#state{channel = Channel}) ->
     SockStats = emqx_pd:get_counters(?SOCK_STATS),
     ConnStats = emqx_pd:get_counters(?CONN_STATS),
-    ChanStats = emqx_channel:stats(ChanState),
+    ChanStats = emqx_channel:stats(Channel),
     ProcStats = emqx_misc:proc_stats(),
     lists:append([SockStats, ConnStats, ChanStats, ProcStats]).
 
@@ -168,27 +170,26 @@ websocket_init([Req, Opts]) ->
     FrameOpts = emqx_zone:frame_options(Zone),
     ParseState = emqx_frame:initial_parse_state(FrameOpts),
     Serialize = emqx_frame:serialize_fun(),
-    ChanState = emqx_channel:init(ConnInfo, Opts),
+    Channel = emqx_channel:init(ConnInfo, Opts),
     emqx_logger:set_metadata_peername(esockd_net:format(Peername)),
     {ok, #state{peername    = Peername,
                 sockname    = Sockname,
-                conn_state  = idle,
+                sockstate   = idle,
                 parse_state = ParseState,
                 serialize   = Serialize,
-                chan_state  = ChanState,
+                channel     = Channel,
                 pendings    = []
                }}.
 
 websocket_handle({binary, Data}, State) when is_list(Data) ->
     websocket_handle({binary, iolist_to_binary(Data)}, State);
 
-websocket_handle({binary, Data}, State = #state{chan_state = ChanState}) ->
+websocket_handle({binary, Data}, State = #state{channel = Channel}) ->
     ?LOG(debug, "RECV ~p", [Data]),
     Oct = iolist_size(Data),
     ok = inc_recv_stats(1, Oct),
-    NChanState = emqx_channel:received(Oct, ChanState),
-    NState = State#state{chan_state = NChanState},
-    process_incoming(Data, NState);
+    {ok, NChannel} = emqx_channel:handle_in(Oct, Channel),
+    process_incoming(Data, State#state{channel = NChannel});
 
 %% Pings should be replied with pongs, cowboy does it automatically
 %% Pongs can be safely ignored. Clause here simply prevents crash.
@@ -203,56 +204,27 @@ websocket_handle({FrameType, _}, State) ->
     ?LOG(error, "Unexpected frame - ~p", [FrameType]),
     stop({shutdown, unexpected_ws_frame}, State).
 
-websocket_info({call, From, info}, State) ->
-    gen_server:reply(From, info(State)),
-    {ok, State};
+websocket_info({call, From, Req}, State) ->
+    handle_call(From, Req, State);
 
-websocket_info({call, From, stats}, State) ->
-    gen_server:reply(From, stats(State)),
-    {ok, State};
-
-websocket_info({call, From, state}, State) ->
-    gen_server:reply(From, State),
-    {ok, State};
-
-websocket_info({call, From, Req}, State = #state{chan_state = ChanState}) ->
-    case emqx_channel:handle_call(Req, ChanState) of
-        {ok, Reply, NChanState} ->
-            _ = gen_server:reply(From, Reply),
-            {ok, State#state{chan_state = NChanState}};
-        {stop, Reason, Reply, NChanState} ->
-            _ = gen_server:reply(From, Reply),
-            stop(Reason, State#state{chan_state = NChanState})
-    end;
-
-websocket_info({cast, Msg}, State = #state{chan_state = ChanState}) ->
-    case emqx_channel:handle_info(Msg, ChanState) of
-        ok -> {ok, State};
-        {ok, NChanState} ->
-            {ok, State#state{chan_state = NChanState}};
-        {stop, Reason, NChanState} ->
-            stop(Reason, State#state{chan_state = NChanState})
-    end;
+websocket_info({cast, Msg}, State = #state{channel = Channel}) ->
+    handle_return(emqx_channel:handle_info(Msg, Channel), State);
 
 websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) ->
-    NState = State#state{serialize = emqx_frame:serialize_fun(ConnPkt)},
-    handle_incoming(Packet, fun connected/1, NState);
-
-websocket_info({incoming, Packet}, State) when is_record(Packet, mqtt_packet) ->
-    handle_incoming(Packet, fun reply/1, State);
+    Serialize = emqx_frame:serialize_fun(ConnPkt),
+    NState = State#state{sockstate = running,
+                         serialize = Serialize
+                        },
+    handle_incoming(Packet, NState);
 
-websocket_info({incoming, FrameError = {frame_error, _Reason}}, State) ->
-    handle_incoming(FrameError, State);
+websocket_info({incoming, Packet}, State) ->
+    handle_incoming(Packet, State);
 
 websocket_info(Deliver = {deliver, _Topic, _Msg},
-               State = #state{chan_state = ChanState}) ->
+               State = #state{channel = Channel}) ->
     Delivers = emqx_misc:drain_deliver([Deliver]),
-    case emqx_channel:handle_out({deliver, Delivers}, ChanState) of
-        {ok, NChanState} ->
-            reply(State#state{chan_state = NChanState});
-        {ok, Packets, NChanState} ->
-            reply(enqueue(Packets, State#state{chan_state = NChanState}))
-    end;
+    Result = emqx_channel:handle_out(Delivers, Channel),
+    handle_return(Result, State);
 
 websocket_info({timeout, TRef, keepalive}, State) when is_reference(TRef) ->
     RecvOct = emqx_pd:get_counter(recv_oct),
@@ -264,60 +236,70 @@ websocket_info({timeout, TRef, emit_stats}, State) when is_reference(TRef) ->
 websocket_info({timeout, TRef, Msg}, State) when is_reference(TRef) ->
     handle_timeout(TRef, Msg, State);
 
+websocket_info({close, Reason}, State) ->
+    stop({shutdown, Reason}, State);
+
 websocket_info({shutdown, Reason}, State) ->
     stop({shutdown, Reason}, State);
 
 websocket_info({stop, Reason}, State) ->
     stop(Reason, State);
 
-websocket_info(Info, State = #state{chan_state = ChanState}) ->
-    case emqx_channel:handle_info(Info, ChanState) of
-        {ok, NChanState} ->
-            {ok, State#state{chan_state = NChanState}};
-        {stop, Reason, NChanState} ->
-            stop(Reason, State#state{chan_state = NChanState})
-    end.
+websocket_info(Info, State) ->
+    handle_info(Info, State).
 
-terminate(SockError, _Req, #state{chan_state  = ChanState,
+websocket_close(Reason, State) ->
+    ?LOG(debug, "WebSocket closed due to ~p~n", [Reason]),
+    handle_info({sock_closed, Reason}, State).
+
+terminate(SockError, _Req, #state{channel = Channel,
                                   stop_reason = Reason}) ->
     ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Reason, SockError]),
-    emqx_channel:terminate(Reason, ChanState).
+    emqx_channel:terminate(Reason, Channel).
 
 %%--------------------------------------------------------------------
-%% Connected callback
+%% Handle call
 
-connected(State = #state{chan_state = ChanState}) ->
-    ChanAttrs = emqx_channel:attrs(ChanState),
-    SockAttrs = #{active_state => running},
-    Attrs = maps:merge(ChanAttrs, #{sockinfo => SockAttrs}),
-    ok = emqx_channel:handle_info({register, Attrs, stats(State)}, ChanState),
-    reply(State#state{conn_state = connected}).
+handle_call(From, info, State) ->
+    gen_server:reply(From, info(State)),
+    {ok, State};
 
-%%--------------------------------------------------------------------
-%% Close
+handle_call(From, stats, State) ->
+    gen_server:reply(From, stats(State)),
+    {ok, State};
 
-close(Reason, State) ->
-    ?LOG(warning, "Closed for ~p", [Reason]),
-    reply(State#state{conn_state = disconnected}).
+handle_call(From, Req, State = #state{channel = Channel}) ->
+    case emqx_channel:handle_call(Req, Channel) of
+        {reply, Reply, NChannel} ->
+            _ = gen_server:reply(From, Reply),
+            {ok, State#state{channel = NChannel}};
+        {stop, Reason, Reply, NChannel} ->
+            _ = gen_server:reply(From, Reply),
+            stop(Reason, State#state{channel = NChannel});
+        {stop, Reason, Reply, OutPacket, NChannel} ->
+            gen_server:reply(From, Reply),
+            NState = State#state{channel = NChannel},
+            stop(Reason, enqueue(OutPacket, NState))
+    end.
 
 %%--------------------------------------------------------------------
 %% Handle timeout
 
-handle_timeout(TRef, Msg, State = #state{chan_state = ChanState}) ->
-    case emqx_channel:handle_timeout(TRef, Msg, ChanState) of
-        {ok, NChanState} ->
-            {ok, State#state{chan_state = NChanState}};
-        {ok, Packets, NChanState} ->
-            NState = State#state{chan_state = NChanState},
-            reply(enqueue(Packets, NState));
-        {close, Reason, NChanState} ->
-            close(Reason, State#state{chan_state = NChanState});
-        {close, Reason, OutPackets, NChanState} ->
-            NState = State#state{chan_state= NChanState},
-            close(Reason, enqueue(OutPackets, NState));
-        {stop, Reason, NChanState} ->
-            stop(Reason, State#state{chan_state = NChanState})
-    end.
+handle_timeout(TRef, Msg, State = #state{channel = Channel}) ->
+    handle_return(emqx_channel:handle_timeout(TRef, Msg, Channel), State).
+
+%%--------------------------------------------------------------------
+%% Handle Info
+
+handle_info({enter, _}, State = #state{channel = Channel}) ->
+    ChanAttrs = emqx_channel:attrs(Channel),
+    SockAttrs = maps:from_list(info(?INFO_KEYS, State)),
+    Attrs = maps:merge(ChanAttrs, #{sockinfo => SockAttrs}),
+    ok = emqx_channel:handle_info({register, Attrs, stats(State)}, Channel),
+    reply(State);
+
+handle_info(Info, State = #state{channel = Channel}) ->
+    handle_return(emqx_channel:handle_info(Info, Channel), State).
 
 %%--------------------------------------------------------------------
 %% Process incoming data
@@ -343,48 +325,39 @@ process_incoming(Data, State = #state{parse_state = ParseState}) ->
 %%--------------------------------------------------------------------
 %% Handle incoming packets
 
-handle_incoming(Packet = ?PACKET(Type), SuccFun,
-                State = #state{chan_state = ChanState}) ->
+handle_incoming(Packet = ?PACKET(Type), State = #state{channel = Channel}) ->
     _ = inc_incoming_stats(Type),
     _ = emqx_metrics:inc_recv(Packet),
     ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]),
-    case emqx_channel:handle_in(Packet, ChanState) of
-        {ok, NChanState} ->
-            SuccFun(State#state{chan_state= NChanState});
-        {ok, OutPackets, NChanState} ->
-            NState = State#state{chan_state= NChanState},
-            SuccFun(enqueue(OutPackets, NState));
-        {close, Reason, NChanState} ->
-            close(Reason, State#state{chan_state = NChanState});
-        {close, Reason, OutPackets, NChanState} ->
-            NState = State#state{chan_state= NChanState},
-            close(Reason, enqueue(OutPackets, NState));
-        {stop, Reason, NChanState} ->
-            stop(Reason, State#state{chan_state = NChanState});
-        {stop, Reason, OutPackets, NChanState} ->
-            NState = State#state{chan_state= NChanState},
-            stop(Reason, enqueue(OutPackets, NState))
-    end.
+    handle_return(emqx_channel:handle_in(Packet, Channel), State);
 
-handle_incoming(FrameError = {frame_error, _Reason},
-                State = #state{chan_state = ChanState}) ->
-    case emqx_channel:handle_in(FrameError, ChanState) of
-        {stop, Reason, NChanState} ->
-            stop(Reason, State#state{chan_state = NChanState});
-        {stop, Reason, OutPackets, NChanState} ->
-            NState = State#state{chan_state = NChanState},
-            stop(Reason, enqueue(OutPackets, NState))
-    end.
+handle_incoming(FrameError, State = #state{channel = Channel}) ->
+    handle_return(emqx_channel:handle_in(FrameError, Channel), State).
+
+%%--------------------------------------------------------------------
+%% Handle channel return
+
+handle_return(ok, State) ->
+    reply(State);
+handle_return({ok, NChannel}, State) ->
+    reply(State#state{channel= NChannel});
+handle_return({ok, Replies, NChannel}, State) ->
+    reply(Replies, State#state{channel= NChannel});
+handle_return({stop, Reason, NChannel}, State) ->
+    stop(Reason, State#state{channel = NChannel});
+handle_return({stop, Reason, OutPacket, NChannel}, State) ->
+    NState = State#state{channel = NChannel},
+    stop(Reason, enqueue(OutPacket, NState)).
 
 %%--------------------------------------------------------------------
 %% Handle outgoing packets
 
-handle_outgoing(Packets, State = #state{chan_state = ChanState}) ->
+handle_outgoing(Packets, State = #state{channel = Channel}) ->
     IoData = lists:map(serialize_and_inc_stats_fun(State), Packets),
     Oct = iolist_size(IoData),
     ok = inc_sent_stats(length(Packets), Oct),
-    NChanState = emqx_channel:sent(Oct, ChanState),
-    {{binary, IoData}, State#state{chan_state = NChanState}}.
+    {ok, NChannel} = emqx_channel:handle_out(Oct, Channel),
+    {{binary, IoData}, State#state{channel = NChannel}}.
 
 %% TODO: Duplicated with emqx_channel:serialize_and_inc_stats_fun/1
 serialize_and_inc_stats_fun(#state{serialize = Serialize}) ->
@@ -433,7 +406,25 @@ inc_sent_stats(Cnt, Oct) ->
 %%--------------------------------------------------------------------
 %% Reply or Stop
 
--compile({inline, [reply/1]}).
+reply(Packet, State) when is_record(Packet, mqtt_packet) ->
+    reply(enqueue(Packet, State));
+reply({outgoing, Packets}, State) ->
+    reply(enqueue(Packets, State));
+reply(Close = {close, _Reason}, State) ->
+    self() ! Close,
+    reply(State);
+
+reply([], State) ->
+    reply(State);
+reply([Packet|More], State) when is_record(Packet, mqtt_packet) ->
+    reply(More, enqueue(Packet, State));
+reply([{outgoing, Packets}|More], State) ->
+    reply(More, enqueue(Packets, State));
+reply([Other|More], State) ->
+    self() ! Other,
+    reply(More, State).
+
+-compile({inline, [reply/1, enqueue/2]}).
 
 reply(State = #state{pendings = []}) ->
     {ok, State};
@@ -441,6 +432,11 @@ reply(State = #state{pendings = Pendings}) ->
     {Reply, NState} = handle_outgoing(Pendings, State),
     {reply, Reply, NState#state{pendings = []}}.
 
+enqueue(Packet, State) when is_record(Packet, mqtt_packet) ->
+    enqueue([Packet], State);
+enqueue(Packets, State = #state{pendings = Pendings}) ->
+    State#state{pendings = lists:append(Pendings, Packets)}.
+
 stop(Reason, State = #state{pendings = []}) ->
     {stop, State#state{stop_reason = Reason}};
 stop(Reason, State = #state{pendings = Pendings}) ->
@@ -448,8 +444,3 @@ stop(Reason, State = #state{pendings = Pendings}) ->
     State2 = State1#state{pendings = [], stop_reason = Reason},
     {reply, [Reply, close], State2}.
 
-enqueue(Packet, State) when is_record(Packet, mqtt_packet) ->
-    enqueue([Packet], State);
-enqueue(Packets, State = #state{pendings = Pendings}) ->
-    State#state{pendings = lists:append(Pendings, Packets)}.
-

+ 17 - 11
test/emqx_channel_SUITE.erl

@@ -64,17 +64,18 @@ t_handle_connect(_) ->
                  is_bridge   = false,
                  clean_start = true,
                  keepalive   = 30,
-                 properties  = #{},
+                 properties  = undefined,
                  clientid    = <<"clientid">>,
                  username    = <<"username">>,
                  password    = <<"passwd">>
                 },
     with_channel(
       fun(Channel) ->
-              {ok, ?CONNACK_PACKET(?RC_SUCCESS), Channel1}
-                = handle_in(?CONNECT_PACKET(ConnPkt), Channel),
-              #{clientid := ClientId, username := Username}
-                = emqx_channel:info(clientinfo, Channel1),
+              ConnAck = ?CONNACK_PACKET(?RC_SUCCESS, 0, #{}),
+              ExpectedOutput = [{outgoing, ConnAck},{enter, connected}],
+              {ok, Output, Channel1} = handle_in(?CONNECT_PACKET(ConnPkt), Channel),
+              ?assertEqual(ExpectedOutput, Output),
+              #{clientid := ClientId, username := Username} = emqx_channel:info(clientinfo, Channel1),
               ?assertEqual(<<"clientid">>, ClientId),
               ?assertEqual(<<"username">>, Username)
       end).
@@ -180,7 +181,7 @@ t_handle_in_auth(_) ->
 %%--------------------------------------------------------------------
 
 t_handle_deliver(_) ->
-    with_channel(
+    with_connected_channel(
       fun(Channel) ->
               TopicFilters = [{<<"+">>, ?DEFAULT_SUBOPTS#{qos => ?QOS_2}}],
               {ok, ?SUBACK_PACKET(1, [?QOS_2]), Channel1}
@@ -188,7 +189,7 @@ t_handle_deliver(_) ->
               Msg0 = emqx_message:make(<<"clientx">>, ?QOS_0, <<"t0">>, <<"qos0">>),
               Msg1 = emqx_message:make(<<"clientx">>, ?QOS_1, <<"t1">>, <<"qos1">>),
               Delivers = [{deliver, <<"+">>, Msg0}, {deliver, <<"+">>, Msg1}],
-              {ok, Packets, _Ch} = emqx_channel:handle_out({deliver, Delivers}, Channel1),
+              {ok, {outgoing, Packets}, _Ch} = emqx_channel:handle_out(Delivers, Channel1),
               ?assertEqual([?QOS_0, ?QOS_1], [emqx_packet:qos(Pkt)|| Pkt <- Packets])
       end).
 
@@ -206,10 +207,9 @@ t_handle_out_connack(_) ->
                 },
     with_channel(
       fun(Channel) ->
-              {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, _), _}
+              {ok, [{outgoing, ?CONNACK_PACKET(?RC_SUCCESS, SP, _)}, {enter, connected}], _Chan}
                 = handle_out({connack, ?RC_SUCCESS, 0, ConnPkt}, Channel),
-              {stop, {shutdown, not_authorized},
-               ?CONNACK_PACKET(?RC_NOT_AUTHORIZED), _}
+              {stop, {shutdown, not_authorized}, ?CONNACK_PACKET(?RC_NOT_AUTHORIZED), _}
                 = handle_out({connack, ?RC_NOT_AUTHORIZED, ConnPkt}, Channel)
       end).
 
@@ -220,7 +220,7 @@ t_handle_out_publish(_) ->
               Pub1 = {publish, 1, emqx_message:make(<<"c">>, ?QOS_1, <<"t">>, <<"qos1">>)},
               {ok, ?PUBLISH_PACKET(?QOS_0), Channel} = handle_out(Pub0, Channel),
               {ok, ?PUBLISH_PACKET(?QOS_1), Channel} = handle_out(Pub1, Channel),
-              {ok, Packets, Channel1} = handle_out({publish, [Pub0, Pub1]}, Channel),
+              {ok, {outgoing, Packets}, Channel1} = handle_out({publish, [Pub0, Pub1]}, Channel),
               ?assertEqual(2, length(Packets)),
               ?assertEqual(#{publish_out => 2}, emqx_channel:info(pub_stats, Channel1))
       end).
@@ -304,6 +304,12 @@ t_terminate(_) ->
 %% Helper functions
 %%--------------------------------------------------------------------
 
+with_connected_channel(TestFun) ->
+    with_channel(
+      fun(Channel) ->
+          TestFun(emqx_channel:set_field(conn_state, connected, Channel))
+      end).
+
 with_channel(TestFun) ->
     with_channel(#{}, TestFun).
 

+ 1 - 0
test/emqx_pool_SUITE.erl

@@ -39,6 +39,7 @@ groups() ->
     ].
 
 init_per_suite(Config) ->
+    ok = emqx_logger:set_log_level(emergency),
     application:ensure_all_started(gproc),
     Config.