Sfoglia il codice sorgente

Improve the channel design

Feng Lee 6 anni fa
parent
commit
4974eab20e

+ 1 - 1
src/emqx_broker.erl

@@ -286,7 +286,7 @@ dispatch(Topic, Delivery = #delivery{message = Msg, results = Results}) ->
 dispatch(SubPid, Topic, Msg) when is_pid(SubPid) ->
     case erlang:is_process_alive(SubPid) of
         true ->
-            SubPid ! {dispatch, Topic, Msg},
+            SubPid ! {deliver, Topic, Msg},
             1;
         false -> 0
     end;

+ 107 - 80
src/emqx_connection.erl

@@ -14,6 +14,7 @@
 %% limitations under the License.
 %%--------------------------------------------------------------------
 
+%% MQTT TCP/SSL Channel
 -module(emqx_channel).
 
 -behaviour(gen_statem).
@@ -21,6 +22,7 @@
 -include("emqx.hrl").
 -include("emqx_mqtt.hrl").
 -include("logger.hrl").
+-include("types.hrl").
 
 -logger_header("[Channel]").
 
@@ -32,16 +34,10 @@
         , stats/1
         ]).
 
--export([ kick/1
-        , discard/1
-        , takeover/1
-        ]).
-
--export([session/1]).
-
 %% gen_statem callbacks
 -export([ idle/3
         , connected/3
+        , disconnected/3
         ]).
 
 -export([ init/1
@@ -51,28 +47,32 @@
         ]).
 
 -record(state, {
-          transport,
-          socket,
-          peername,
-          sockname,
-          conn_state,
-          active_n,
-          proto_state,
-          parse_state,
-          gc_state,
-          keepalive,
-          rate_limit,
-          pub_limit,
-          limit_timer,
-          enable_stats,
-          stats_timer,
-          idle_timeout
+          transport    :: esockd:transport(),
+          socket       :: esockd:sock(),
+          peername     :: {inet:ip_address(), inet:port_number()},
+          sockname     :: {inet:ip_address(), inet:port_number()},
+          conn_state   :: running | blocked,
+          active_n     :: pos_integer(),
+          rate_limit   :: maybe(esockd_rate_limit:bucket()),
+          pub_limit    :: maybe(esockd_rate_limit:bucket()),
+          limit_timer  :: maybe(reference()),
+          serializer   :: emqx_frame:serializer(), %% TODO: remove it later.
+          parse_state  :: emqx_frame:parse_state(),
+          proto_state  :: emqx_protocol:protocol(),
+          gc_state     :: emqx_gc:gc_state(),
+          keepalive    :: maybe(reference()),
+          enable_stats :: boolean(),
+          stats_timer  :: maybe(reference()),
+          idle_timeout :: timeout()
          }).
 
 -define(ACTIVE_N, 100).
 -define(HANDLE(T, C, D), handle((T), (C), (D))).
+-define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]).
 -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]).
 
+-spec(start_link(esockd:transport(), esockd:sock(), proplists:proplist())
+      -> {ok, pid()}).
 start_link(Transport, Socket, Options) ->
     {ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}.
 
@@ -126,28 +126,13 @@ attrs(#state{peername = Peername,
 stats(CPid) when is_pid(CPid) ->
     call(CPid, stats);
 
-stats(#state{transport = Transport,
-             socket = Socket,
-             proto_state = ProtoState}) ->
+stats(#state{transport = Transport, socket = Socket}) ->
     SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of
                     {ok, Ss}   -> Ss;
                     {error, _} -> []
                 end,
-    lists:append([SockStats,
-                  emqx_misc:proc_stats(),
-                  emqx_protocol:stats(ProtoState)]).
-
-kick(CPid) ->
-    call(CPid, kick).
-
-discard(CPid) ->
-    call(CPid, discard).
-
-takeover(CPid) ->
-    call(CPid, takeover).
-
-session(CPid) ->
-    call(CPid, session).
+    ChanStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS],
+    lists:append([SockStats, ChanStats, emqx_misc:proc_stats()]).
 
 call(CPid, Req) ->
     gen_statem:call(CPid, Req, infinity).
@@ -166,23 +151,15 @@ init({Transport, RawSocket, Options}) ->
     RateLimit = init_limiter(proplists:get_value(rate_limit, Options)),
     PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)),
     ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N),
-    SendFun = fun(Packet, Opts) ->
-                      Data = emqx_frame:serialize(Packet, Opts),
-                      case Transport:async_send(Socket, Data) of
-                          ok -> {ok, Data};
-                          {error, Reason} ->
-                              {error, Reason}
-                      end
-              end,
+    MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE),
+    ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}),
     ProtoState = emqx_protocol:init(#{peername => Peername,
                                       sockname => Sockname,
                                       peercert => Peercert,
-                                      sendfun  => SendFun,
                                       conn_mod => ?MODULE}, Options),
-    MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE),
-    ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}),
     GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false),
     GcState = emqx_gc:init(GcPolicy),
+    ok = emqx_misc:init_proc_mng_policy(Zone),
     EnableStats = emqx_zone:get_env(Zone, enable_stats, true),
     IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000),
     State = #state{transport    = Transport,
@@ -192,13 +169,12 @@ init({Transport, RawSocket, Options}) ->
                    active_n     = ActiveN,
                    rate_limit   = RateLimit,
                    pub_limit    = PubLimit,
-                   proto_state  = ProtoState,
                    parse_state  = ParseState,
+                   proto_state  = ProtoState,
                    gc_state     = GcState,
                    enable_stats = EnableStats,
                    idle_timeout = IdleTimout
                   },
-    ok = emqx_misc:init_proc_mng_policy(Zone),
     gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}],
                           idle, State, self(), [IdleTimout]).
 
@@ -218,12 +194,17 @@ idle(enter, _, State) ->
     keep_state_and_data;
 
 idle(timeout, _Timeout, State) ->
-    {stop, idle_timeout, State};
+    stop(idle_timeout, State);
+
+idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnVar)}, State) ->
+    #mqtt_packet_connect{proto_ver = ProtoVer} = ConnVar,
+    Serializer = emqx_frame:init_serializer(#{version => ProtoVer}),
+    NState = State#state{serializer = Serializer},
+    handle_incoming(Packet, fun(St) -> {next_state, connected, St} end, NState);
 
 idle(cast, {incoming, Packet}, State) ->
-    handle_incoming(Packet, fun(NState) ->
-                                    {next_state, connected, NState}
-                            end, State);
+    ?LOG(warning, "Unexpected incoming: ~p", [Packet]),
+    shutdown(unexpected_incoming_packet, State);
 
 idle(EventType, Content, State) ->
     ?HANDLE(EventType, Content, State).
@@ -235,18 +216,23 @@ connected(enter, _, _State) ->
     %% What to do?
     keep_state_and_data;
 
-%% Handle Input
+connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) ->
+    ?LOG(warning, "Unexpected connect: ~p", [Packet]),
+    shutdown(unexpected_incoming_connect, State);
+
 connected(cast, {incoming, Packet = ?PACKET(Type)}, State) ->
     ok = emqx_metrics:inc_recv(Packet),
     (Type == ?PUBLISH) andalso emqx_pd:update_counter(incoming_pubs, 1),
-    handle_incoming(Packet, fun(NState) -> {keep_state, NState} end, State);
+    handle_incoming(Packet, fun(St) -> {keep_state, St} end, State);
 
-%% Handle Output
-connected(info, {deliver, PubOrAck}, State = #state{proto_state = ProtoState}) ->
-    case emqx_protocol:deliver(PubOrAck, ProtoState) of
+%% Handle delivery
+connected(info, Devliery = {deliver, _Topic, Msg}, State = #state{proto_state = ProtoState}) ->
+    case emqx_protocol:handle_out(Devliery, ProtoState) of
         {ok, NProtoState} ->
+            {keep_state, State#state{proto_state = NProtoState}};
+        {ok, Packet, NProtoState} ->
             NState = State#state{proto_state = NProtoState},
-            {keep_state, maybe_gc(PubOrAck, NState)};
+            handle_outgoing(Packet, fun(St) -> {keep_state, St} end, NState);
         {error, Reason} ->
             shutdown(Reason, State)
     end;
@@ -281,6 +267,16 @@ connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) ->
 connected(EventType, Content, State) ->
     ?HANDLE(EventType, Content, State).
 
+%%--------------------------------------------------------------------
+%% Disconnected State
+
+disconnected(enter, _, _State) ->
+    %% TODO: What to do?
+    keep_state_and_data;
+
+disconnected(EventType, Content, State) ->
+    ?HANDLE(EventType, Content, State).
+
 %% Handle call
 handle({call, From}, info, State) ->
     reply(From, info(State), State);
@@ -299,9 +295,6 @@ handle({call, From}, discard, State) ->
     ok = gen_statem:reply(From, ok),
     shutdown(discard, State);
 
-handle({call, From}, session, State = #state{proto_state = ProtoState}) ->
-    reply(From, emqx_protocol:session(ProtoState), State);
-
 handle({call, From}, Req, State) ->
     ?LOG(error, "Unexpected call: ~p", [Req]),
     reply(From, ignored, State);
@@ -312,7 +305,8 @@ handle(cast, Msg, State) ->
     {keep_state, State};
 
 %% Handle Incoming
-handle(info, {Inet, _Sock, Data}, State) when Inet == tcp; Inet == ssl ->
+handle(info, {Inet, _Sock, Data}, State) when Inet == tcp;
+                                              Inet == ssl ->
     Oct = iolist_size(Data),
     ?LOG(debug, "RECV ~p", [Data]),
     emqx_pd:update_counter(incoming_bytes, Oct),
@@ -350,7 +344,7 @@ handle(info, {inet_reply, _Sock, {error, Reason}}, State) ->
 handle(info, {timeout, Timer, emit_stats},
        State = #state{stats_timer = Timer,
                       proto_state = ProtoState,
-                      gc_state = GcState}) ->
+                      gc_state    = GcState}) ->
     ClientId = emqx_protocol:client_id(ProtoState),
     emqx_cm:set_conn_stats(ClientId, stats(State)),
     NState = State#state{stats_timer = undefined},
@@ -390,15 +384,9 @@ terminate(Reason, _StateName, #state{transport = Transport,
                                      keepalive = KeepAlive,
                                      proto_state = ProtoState}) ->
     ?LOG(debug, "Terminated for ~p", [Reason]),
-    Transport:fast_close(Socket),
-    emqx_keepalive:cancel(KeepAlive),
-    case {ProtoState, Reason} of
-        {undefined, _} -> ok;
-        {_, {shutdown, Error}} ->
-            emqx_protocol:terminate(Error, ProtoState);
-        {_, Reason} ->
-            emqx_protocol:terminate(Reason, ProtoState)
-    end.
+    ok = Transport:fast_close(Socket),
+    ok = emqx_keepalive:cancel(KeepAlive),
+    emqx_protocol:terminate(Reason, ProtoState).
 
 %%--------------------------------------------------------------------
 %% Process incoming data
@@ -431,10 +419,16 @@ next_events(Packet) ->
 %%--------------------------------------------------------------------
 %% Handle incoming packet
 
-handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) ->
-    case emqx_protocol:received(Packet, ProtoState) of
+handle_incoming(Packet = ?PACKET(Type), SuccFun,
+                State = #state{proto_state = ProtoState}) ->
+    _ = inc_incoming_stats(Type),
+    ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]),
+    case emqx_protocol:handle_in(Packet, ProtoState) of
         {ok, NProtoState} ->
             SuccFun(State#state{proto_state = NProtoState});
+        {ok, OutPacket, NProtoState} ->
+            handle_outgoing(OutPacket, SuccFun,
+                            State#state{proto_state = NProtoState});
         {error, Reason} ->
             shutdown(Reason, State);
         {error, Reason, NProtoState} ->
@@ -443,6 +437,22 @@ handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) ->
             stop(Error, State#state{proto_state = NProtoState})
     end.
 
+%%--------------------------------------------------------------------
+%% Handle outgoing packet
+
+handle_outgoing(Packet = ?PACKET(Type), SuccFun,
+                State = #state{transport = Transport,
+                               socket = Socket,
+                               serializer = Serializer}) ->
+    _ = inc_outgoing_stats(Type),
+    ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]),
+    Data = Serializer(Packet),
+    case Transport:async_send(Socket, Data) of
+        ok -> SuccFun(State);
+        {error, Reason} ->
+            shutdown(Reason, State)
+    end.
+
 %%--------------------------------------------------------------------
 %% Ensure rate limit
 
@@ -465,6 +475,12 @@ ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) ->
            setelement(Pos, State#state{conn_state = blocked, limit_timer = TRef}, Rl1)
    end.
 
+%% start_keepalive(0, _PState) ->
+%%     ignore;
+%% start_keepalive(Secs, #pstate{zone = Zone}) when Secs > 0 ->
+%%     Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75),
+%%     self() ! {keepalive, start, round(Secs * Backoff)}.
+
 %%--------------------------------------------------------------------
 %% Activate socket
 
@@ -479,6 +495,17 @@ activate_socket(#state{transport = Transport, socket = Socket, active_n = N}) ->
             ok
     end.
 
+%%--------------------------------------------------------------------
+%% Inc incoming/outgoing stats
+
+inc_incoming_stats(Type) ->
+    emqx_pd:update_counter(recv_pkt, 1),
+    Type =:= ?PUBLISH andalso emqx_pd:update_counter(recv_msg, 1).
+
+inc_outgoing_stats(Type) ->
+    emqx_pd:update_counter(send_pkt, 1),
+    Type =:= ?PUBLISH andalso emqx_pd:update_counter(send_msg, 1).
+
 %%--------------------------------------------------------------------
 %% Ensure stats timer
 

+ 13 - 9
src/emqx_frame.erl

@@ -21,6 +21,7 @@
 
 -export([ initial_parse_state/0
         , initial_parse_state/1
+        , init_serializer/1
         ]).
 
 -export([ parse/1
@@ -29,22 +30,22 @@
         , serialize/2
         ]).
 
+-export_type([ options/0
+             , parse_state/0
+             , parse_result/0
+             ]).
+
 -type(options() :: #{max_size => 1..?MAX_PACKET_SIZE,
-                     version  => emqx_mqtt_types:version()
+                     version  => emqx_mqtt:version()
                     }).
 
 -opaque(parse_state() :: {none, options()} | {more, cont_fun()}).
 
 -opaque(parse_result() :: {ok, parse_state()}
-                        | {ok, emqx_mqtt_types:packet(), binary(), parse_state()}).
+                        | {ok, emqx_mqtt:packet(), binary(), parse_state()}).
 
 -type(cont_fun() :: fun((binary()) -> parse_result())).
 
--export_type([ options/0
-             , parse_state/0
-             , parse_result/0
-             ]).
-
 -define(none(Opts), {none, Opts}).
 -define(more(Cont), {more, Cont}).
 -define(DEFAULT_OPTIONS,
@@ -385,11 +386,14 @@ parse_binary_data(<<Len:16/big, Data:Len/binary, Rest/binary>>) ->
 %% Serialize MQTT Packet
 %%--------------------------------------------------------------------
 
--spec(serialize(emqx_mqtt_types:packet()) -> iodata()).
+init_serializer(Options) ->
+    fun(Packet) -> serialize(Packet, Options) end.
+
+-spec(serialize(emqx_mqtt:packet()) -> iodata()).
 serialize(Packet) ->
     serialize(Packet, ?DEFAULT_OPTIONS).
 
--spec(serialize(emqx_mqtt_types:packet(), options()) -> iodata()).
+-spec(serialize(emqx_mqtt:packet(), options()) -> iodata()).
 serialize(#mqtt_packet{header   = Header,
                        variable = Variable,
                        payload  = Payload}, Options) when is_map(Options) ->

+ 2 - 2
src/emqx_inflight.erl

@@ -33,6 +33,8 @@
         , window/1
         ]).
 
+-export_type([inflight/0]).
+
 -type(key() :: term()).
 
 -type(max_size() :: pos_integer()).
@@ -43,8 +45,6 @@
 
 -define(Inflight(MaxSize, Tree), {?MODULE, MaxSize, (Tree)}).
 
--export_type([inflight/0]).
-
 %%--------------------------------------------------------------------
 %% APIs
 %%--------------------------------------------------------------------

File diff suppressed because it is too large
+ 253 - 321
src/emqx_protocol.erl


File diff suppressed because it is too large
+ 313 - 576
src/emqx_session.erl


+ 5 - 5
src/emqx_shared_sub.erl

@@ -135,7 +135,7 @@ ack_enabled() ->
 
 do_dispatch(SubPid, Topic, Msg, _Type) when SubPid =:= self() ->
     %% Deadlock otherwise
-    _ = erlang:send(SubPid, {dispatch, Topic, Msg}),
+    _ = erlang:send(SubPid, {deliver, Topic, Msg}),
     ok;
 do_dispatch(SubPid, Topic, Msg, Type) ->
     dispatch_per_qos(SubPid, Topic, Msg, Type).
@@ -143,18 +143,18 @@ do_dispatch(SubPid, Topic, Msg, Type) ->
 %% return either 'ok' (when everything is fine) or 'error'
 dispatch_per_qos(SubPid, Topic, #message{qos = ?QOS_0} = Msg, _Type) ->
     %% For QoS 0 message, send it as regular dispatch
-    _ = erlang:send(SubPid, {dispatch, Topic, Msg}),
+    _ = erlang:send(SubPid, {deliver, Topic, Msg}),
     ok;
 dispatch_per_qos(SubPid, Topic, Msg, retry) ->
     %% Retry implies all subscribers nack:ed, send again without ack
-    _ = erlang:send(SubPid, {dispatch, Topic, Msg}),
+    _ = erlang:send(SubPid, {deliver, Topic, Msg}),
     ok;
 dispatch_per_qos(SubPid, Topic, Msg, fresh) ->
     case ack_enabled() of
         true ->
             dispatch_with_ack(SubPid, Topic, Msg);
         false ->
-            _ = erlang:send(SubPid, {dispatch, Topic, Msg}),
+            _ = erlang:send(SubPid, {deliver, Topic, Msg}),
             ok
     end.
 
@@ -162,7 +162,7 @@ dispatch_with_ack(SubPid, Topic, Msg) ->
     %% For QoS 1/2 message, expect an ack
     Ref = erlang:monitor(process, SubPid),
     Sender = self(),
-    _ = erlang:send(SubPid, {dispatch, Topic, with_ack_ref(Msg, {Sender, Ref})}),
+    _ = erlang:send(SubPid, {deliver, Topic, with_ack_ref(Msg, {Sender, Ref})}),
     Timeout = case Msg#message.qos of
                   ?QOS_1 -> timer:seconds(?SHARED_SUB_QOS1_DISPATCH_TIMEOUT_SECONDS);
                   ?QOS_2 -> infinity

+ 22 - 9
src/emqx_ws_connection.erl

@@ -14,6 +14,7 @@
 %% limitations under the License.
 %%--------------------------------------------------------------------
 
+%% MQTT WebSocket Channel
 -module(emqx_ws_channel).
 
 -include("emqx.hrl").
@@ -170,7 +171,8 @@ websocket_init(#state{request = Req, options = Options}) ->
                 parse_state  = ParseState,
                 proto_state  = ProtoState,
                 enable_stats = EnableStats,
-                idle_timeout = IdleTimout}}.
+                idle_timeout = IdleTimout
+               }}.
 
 send_fun(WsPid) ->
     fun(Packet, Options) ->
@@ -242,10 +244,13 @@ websocket_info({call, From, session}, State = #state{proto_state = ProtoState})
     gen_server:reply(From, emqx_protocol:session(ProtoState)),
     {ok, State};
 
-websocket_info({deliver, PubOrAck}, State = #state{proto_state = ProtoState}) ->
-    case emqx_protocol:deliver(PubOrAck, ProtoState) of
-        {ok, ProtoState1} ->
-            {ok, ensure_stats_timer(State#state{proto_state = ProtoState1})};
+websocket_info(Delivery, State = #state{proto_state = ProtoState})
+  when element(1, Delivery) =:= deliver ->
+    case emqx_protocol:handle_out(Delivery, ProtoState) of
+        {ok, NProtoState} ->
+            {ok, State#state{proto_state = NProtoState}};
+        {ok, Packet, NProtoState} ->
+            handle_outgoing(Packet, State#state{proto_state = NProtoState});
         {error, Reason} ->
             shutdown(Reason, State)
     end;
@@ -285,8 +290,8 @@ websocket_info({shutdown, conflict, {ClientId, NewPid}}, State) ->
     ?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]),
     shutdown(conflict, State);
 
-websocket_info({binary, Data}, State) ->
-    {reply, {binary, Data}, State};
+%% websocket_info({binary, Data}, State) ->
+%%    {reply, {binary, Data}, State};
 
 websocket_info({shutdown, Reason}, State) ->
     shutdown(Reason, State);
@@ -317,9 +322,12 @@ terminate(SockError, _Req, #state{keepalive   = Keepalive,
 %%--------------------------------------------------------------------
 
 handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) ->
-    case emqx_protocol:received(Packet, ProtoState) of
+    case emqx_protocol:handle_in(Packet, ProtoState) of
         {ok, NProtoState} ->
             SuccFun(State#state{proto_state = NProtoState});
+        {ok, OutPacket, NProtoState} ->
+            %% TODO: How to call SuccFun???
+            handle_outgoing(OutPacket, State#state{proto_state = NProtoState});
         {error, Reason} ->
             ?LOG(error, "Protocol error: ~p", [Reason]),
             shutdown(Reason, State);
@@ -329,7 +337,12 @@ handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) ->
             shutdown(Error, State#state{proto_state = NProtoState})
     end.
 
-
+handle_outgoing(Packet, State = #state{proto_state = _NProtoState}) ->
+    Data = emqx_frame:serialize(Packet), %% TODO:, Options),
+    BinSize = iolist_size(Data),
+    emqx_pd:update_counter(send_cnt, 1),
+    emqx_pd:update_counter(send_oct, BinSize),
+    {reply, {binary, Data}, ensure_stats_timer(State)}.
 
 ensure_stats_timer(State = #state{enable_stats = true,
                                   stats_timer  = undefined,