Browse Source

Merge pull request #6040 from HJianBo/refactor-stomp-gw

Refactor stomp gw
tigercl 4 years ago
parent
commit
e79085c259

+ 1 - 1
apps/emqx_stomp/src/emqx_stomp.app.src

@@ -1,6 +1,6 @@
 {application, emqx_stomp,
  [{description, "EMQ X Stomp Protocol Plugin"},
-  {vsn, "4.3.0"}, % strict semver, bump manually!
+  {vsn, "4.3.1"}, % strict semver, bump manually!
   {modules, []},
   {registered, [emqx_stomp_sup]},
   {applications, [kernel,stdlib]},

+ 8 - 0
apps/emqx_stomp/src/emqx_stomp.appup.src

@@ -0,0 +1,8 @@
+%% -*- mode: erlang -*-
+{"4.3.1",
+  [{"4.3.0",
+    [{restart_application,emqx_stomp}]},
+   {<<".*">>,[]}],
+  [{"4.3.0",
+     [{restart_application,emqx_stomp}]},
+   {<<".*">>,[]}]}.

+ 352 - 89
apps/emqx_stomp/src/emqx_stomp_connection.erl

@@ -20,9 +20,15 @@
 
 -include("emqx_stomp.hrl").
 -include_lib("emqx/include/logger.hrl").
+-include_lib("emqx/include/types.hrl").
+-include_lib("snabbkaffe/include/snabbkaffe.hrl").
 
 -logger_header("[Stomp-Conn]").
 
+-import(emqx_misc,
+        [ start_timer/2
+        ]).
+
 -export([ start_link/3
         , info/1
         ]).
@@ -37,56 +43,175 @@
         ]).
 
 %% for protocol
--export([send/4, heartbeat/2]).
-
--record(state, {transport, socket, peername, conn_name, conn_state,
-                await_recv, rate_limit, parser, pstate,
-                proto_env, heartbeat}).
-
--define(INFO_KEYS, [peername, await_recv, conn_state]).
--define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]).
+-export([send/4, heartbeat/2, statfun/3]).
+
+%% for mgmt
+-export([call/2, call/3]).
+
+-record(state, {
+          %% TCP/TLS Transport
+          transport :: esockd:transport(),
+          %% TCP/TLS Socket
+          socket :: esockd:socket(),
+          %% Peername of the connection
+          peername :: emqx_types:peername(),
+          %% Sockname of the connection
+          sockname :: emqx_types:peername(),
+          %% Sock State
+          sockstate :: emqx_types:sockstate(),
+          %% The {active, N} option
+          active_n :: pos_integer(),
+          %% Limiter
+          limiter :: maybe(emqx_limiter:limiter()),
+          %% Limit Timer
+          limit_timer :: maybe(reference()),
+          %% GC State
+          gc_state :: maybe(emqx_gc:gc_state()),
+          %% Stats Timer
+          stats_timer :: disabled | maybe(reference()),
+          %% Parser State
+          parser :: emqx_stomp_frame:parser(),
+          %% Protocol State
+          pstate :: emqx_stomp_protocol:pstate(),
+          %% XXX: some common confs
+          proto_env :: list()
+         }).
+
+-type(state() :: #state{}).
+
+-define(DEFAULT_GC_POLICY, #{bytes => 16777216, count => 16000}).
+-define(DEFAULT_OOM_POLICY, #{ max_heap_size => 8388608,
+                               message_queue_len => 10000}).
+
+-define(ACTIVE_N, 100).
+-define(IDLE_TIMEOUT, 30000).
+-define(INFO_KEYS,  [socktype, peername, sockname, sockstate, active_n]).
+-define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]).
+-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]).
+
+-define(ENABLED(X), (X =/= undefined)).
+
+-dialyzer({nowarn_function, [ ensure_stats_timer/2
+                            ]}).
+
+-dialyzer({no_return, [ init/1
+                      , init_state/3
+                      ]}).
 
 start_link(Transport, Sock, ProtoEnv) ->
     {ok, proc_lib:spawn_link(?MODULE, init, [[Transport, Sock, ProtoEnv]])}.
 
-info(CPid) ->
-    gen_server:call(CPid, info, infinity).
-
-init([Transport, Sock, ProtoEnv]) ->
-    process_flag(trap_exit, true),
-    case Transport:wait(Sock) of
-        {ok, NewSock} ->
-            {ok, Peername} = Transport:ensure_ok_or_exit(peername, [NewSock]),
-            ConnName = esockd:format(Peername),
-            SendFun = {fun ?MODULE:send/4, [Transport, Sock, self()]},
-            HrtBtFun = {fun ?MODULE:heartbeat/2, [Transport, Sock]},
-            Parser = emqx_stomp_frame:init_parer_state(ProtoEnv),
-            PState = emqx_stomp_protocol:init(#{peername => Peername,
-                                                sendfun => SendFun,
-                                                heartfun => HrtBtFun}, ProtoEnv),
-            RateLimit = init_rate_limit(proplists:get_value(rate_limit, ProtoEnv)),
-            State = run_socket(#state{transport   = Transport,
-                                      socket      = NewSock,
-                                      peername    = Peername,
-                                      conn_name   = ConnName,
-                                      conn_state  = running,
-                                      await_recv  = false,
-                                      rate_limit  = RateLimit,
-                                      parser      = Parser,
-                                      proto_env   = ProtoEnv,
-                                      pstate      = PState}),
-            emqx_logger:set_metadata_peername(esockd:format(Peername)),
-            gen_server:enter_loop(?MODULE, [{hibernate_after, 5000}], State, 20000);
+-spec info(pid()|state()) -> emqx_types:infos().
+info(CPid) when is_pid(CPid) ->
+    call(CPid, info);
+info(State = #state{pstate = PState}) ->
+    ChanInfo = emqx_stomp_protocol:info(PState),
+    SockInfo = maps:from_list(
+                 info(?INFO_KEYS, State)),
+    ChanInfo#{sockinfo => SockInfo}.
+
+info(Keys, State) when is_list(Keys) ->
+    [{Key, info(Key, State)} || Key <- Keys];
+info(socktype, #state{transport = Transport, socket = Socket}) ->
+    Transport:type(Socket);
+info(peername, #state{peername = Peername}) ->
+    Peername;
+info(sockname, #state{sockname = Sockname}) ->
+    Sockname;
+info(sockstate, #state{sockstate = SockSt}) ->
+    SockSt;
+info(active_n, #state{active_n = ActiveN}) ->
+    ActiveN.
+
+-spec stats(pid()|state()) -> emqx_types:stats().
+stats(CPid) when is_pid(CPid) ->
+    call(CPid, stats);
+stats(#state{transport = Transport,
+             socket    = Socket,
+             pstate    = PState}) ->
+    SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of
+                    {ok, Ss}   -> Ss;
+                    {error, _} -> []
+                end,
+    ConnStats = emqx_pd:get_counters(?CONN_STATS),
+    ChanStats = emqx_stomp_protocol:stats(PState),
+    ProcStats = emqx_misc:proc_stats(),
+    lists:append([SockStats, ConnStats, ChanStats, ProcStats]).
+
+call(Pid, Req) ->
+    call(Pid, Req, infinity).
+call(Pid, Req, Timeout) ->
+    gen_server:call(Pid, Req, Timeout).
+
+init([Transport, RawSocket, ProtoEnv]) ->
+    case Transport:wait(RawSocket) of
+        {ok, Socket} ->
+            init_state(Transport, Socket, ProtoEnv);
         {error, Reason} ->
-            {stop, Reason}
+            ok = Transport:fast_close(RawSocket),
+            exit_on_sock_error(Reason)
     end.
 
-init_rate_limit(undefined) ->
-    undefined;
-init_rate_limit({Rate, Burst}) ->
-    esockd_rate_limit:new(Rate, Burst).
+init_state(Transport, Socket, ProtoEnv) ->
+    {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
+    {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
+
+    SendFun = {fun ?MODULE:send/4, [Transport, Socket, self()]},
+    StatFun = {fun ?MODULE:statfun/3, [Transport, Socket]},
+    HrtBtFun = {fun ?MODULE:heartbeat/2, [Transport, Socket]},
+    Parser = emqx_stomp_frame:init_parer_state(ProtoEnv),
+
+    ActiveN = proplists:get_value(active_n, ProtoEnv, ?ACTIVE_N),
+    GcState = emqx_gc:init(?DEFAULT_GC_POLICY),
+
+    Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]),
+    ConnInfo = #{socktype => Transport:type(Socket),
+                 peername => Peername,
+                 sockname => Sockname,
+                 peercert => Peercert,
+                 statfun  => StatFun,
+                 sendfun  => SendFun,
+                 heartfun => HrtBtFun,
+                 conn_mod => ?MODULE
+                },
+    PState = emqx_stomp_protocol:init(ConnInfo, ProtoEnv),
+    State = #state{transport = Transport,
+                   socket    = Socket,
+                   peername  = Peername,
+                   sockname  = Sockname,
+                   sockstate = idle,
+                   active_n  = ActiveN,
+                   limiter   = undefined,
+                   parser    = Parser,
+                   proto_env = ProtoEnv,
+                   gc_state  = GcState,
+                   pstate    = PState},
+    case activate_socket(State) of
+        {ok, NState} ->
+            emqx_logger:set_metadata_peername(
+              esockd:format(Peername)),
+            gen_server:enter_loop(
+              ?MODULE, [{hibernate_after, 5000}], NState, 20000);
+        {error, Reason} ->
+            ok = Transport:fast_close(Socket),
+            exit_on_sock_error(Reason)
+    end.
 
-send(Data, Transport, Sock, ConnPid) ->
+-spec exit_on_sock_error(any()) -> no_return().
+exit_on_sock_error(Reason) when Reason =:= einval;
+                                Reason =:= enotconn;
+                                Reason =:= closed ->
+    erlang:exit(normal);
+exit_on_sock_error(timeout) ->
+    erlang:exit({shutdown, ssl_upgrade_timeout});
+exit_on_sock_error(Reason) ->
+    erlang:exit({shutdown, Reason}).
+
+send(Frame, Transport, Sock, ConnPid) ->
+    ?LOG(info, "SEND Frame: ~s", [emqx_stomp_frame:format(Frame)]),
+    ok = inc_outgoing_stats(Frame),
+    Data = emqx_stomp_frame:serialize(Frame),
+    ?LOG(debug, "SEND ~p", [Data]),
     try Transport:async_send(Sock, Data) of
         ok -> ok;
         {error, Reason} -> ConnPid ! {shutdown, Reason}
@@ -95,23 +220,27 @@ send(Data, Transport, Sock, ConnPid) ->
     end.
 
 heartbeat(Transport, Sock) ->
+    ?LOG(debug, "SEND heartbeat: \\n"),
     Transport:send(Sock, <<$\n>>).
 
-handle_call(info, _From, State = #state{transport   = Transport,
-                                        socket      = Sock,
-                                        peername    = Peername,
-                                        await_recv  = AwaitRecv,
-                                        conn_state  = ConnState,
-                                        pstate      = PState}) ->
-    ClientInfo = [{peername,  Peername}, {await_recv, AwaitRecv},
-                  {conn_state, ConnState}],
-    ProtoInfo  = emqx_stomp_protocol:info(PState),
-    case Transport:getstat(Sock, ?SOCK_STATS) of
-        {ok, SockStats} ->
-            {reply, lists:append([ClientInfo, ProtoInfo, SockStats]), State};
-        {error, Reason} ->
-            {stop, Reason, lists:append([ClientInfo, ProtoInfo]), State}
-    end;
+statfun(Stat, Transport, Sock) ->
+    case Transport:getstat(Sock, [Stat]) of
+        {ok, [{Stat, Val}]} -> {ok, Val};
+        {error, Error}      -> {error, Error}
+    end.
+
+handle_call(info, _From, State) ->
+    {reply, info(State), State};
+
+handle_call(stats, _From, State) ->
+    {reply, stats(State), State};
+
+handle_call(discard, _From, State) ->
+    %% TODO: send the DISCONNECT packet?
+    shutdown_and_reply(discared, ok, State);
+
+handle_call(kick, _From, State) ->
+    shutdown_and_reply(kicked, ok, State);
 
 handle_call(Req, _From, State) ->
     ?LOG(error, "unexpected request: ~p", [Req]),
@@ -121,6 +250,13 @@ handle_cast(Msg, State) ->
     ?LOG(error, "unexpected msg: ~p", [Msg]),
     noreply(State).
 
+handle_info({event, Name}, State = #state{pstate = PState})
+  when Name == connected;
+       Name == updated ->
+    ClientId = emqx_stomp_protocol:info(clientid, PState),
+    emqx_cm:insert_channel_info(ClientId, info(State), stats(State)),
+    noreply(State);
+
 handle_info(timeout, State) ->
     shutdown(idle_timeout, State);
 
@@ -141,26 +277,73 @@ handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming;
             shutdown({sock_error, Reason}, State)
     end;
 
-handle_info({timeout, TRef, TMsg}, State) ->
-    with_proto(timeout, [TRef, TMsg], State);
-
-handle_info({'EXIT', HbProc, Error}, State = #state{heartbeat = HbProc}) ->
-    stop(Error, State);
+handle_info({timeout, _TRef, limit_timeout}, State) ->
+    NState = State#state{sockstate   = idle,
+                         limit_timer = undefined
+                        },
+    handle_info(activate_socket, NState);
 
-handle_info(activate_sock, State) ->
-    noreply(run_socket(State#state{conn_state = running}));
+handle_info({timeout, _TRef, emit_stats},
+            State = #state{pstate = PState}) ->
+    ClientId = emqx_stomp_protocol:info(clientid, PState),
+    emqx_cm:set_chan_stats(ClientId, stats(State)),
+    noreply(State#state{stats_timer = undefined});
 
-handle_info({inet_async, _Sock, _Ref, {ok, Bytes}}, State) ->
-    ?LOG(debug, "RECV ~p", [Bytes]),
-    received(Bytes, rate_limit(size(Bytes), State#state{await_recv = false}));
+handle_info({timeout, TRef, TMsg}, State) ->
+    with_proto(timeout, [TRef, TMsg], State);
 
-handle_info({inet_async, _Sock, _Ref, {error, Reason}}, State) ->
-    shutdown(Reason, State);
+handle_info(activate_socket, State = #state{sockstate = OldSst}) ->
+    case activate_socket(State) of
+        {ok, NState = #state{sockstate = NewSst}} ->
+            case OldSst =/= NewSst of
+                true -> {ok, {event, NewSst}, NState};
+                false -> {ok, NState}
+            end;
+        {error, Reason} ->
+            handle_info({sock_error, Reason}, State)
+    end;
 
 handle_info({inet_reply, _Ref, ok}, State) ->
     noreply(State);
 
+handle_info({Inet, _Sock, Data}, State) when Inet == tcp; Inet == ssl ->
+    ?LOG(debug, "RECV ~0p", [Data]),
+    Oct = iolist_size(Data),
+    inc_counter(incoming_bytes, Oct),
+    ok = emqx_metrics:inc('bytes.received', Oct),
+    received(Data, ensure_stats_timer(?IDLE_TIMEOUT, State));
+
+handle_info({Passive, _Sock}, State)
+  when Passive == tcp_passive; Passive == ssl_passive ->
+    %% In Stats
+    Pubs = emqx_pd:reset_counter(incoming_pubs),
+    Bytes = emqx_pd:reset_counter(incoming_bytes),
+    InStats = #{cnt => Pubs, oct => Bytes},
+    %% Ensure Rate Limit
+    NState = ensure_rate_limit(InStats, State),
+    %% Run GC and Check OOM
+    NState1 = check_oom(run_gc(InStats, NState)),
+    handle_info(activate_socket, NState1);
+
+handle_info({Error, _Sock, Reason}, State)
+  when Error == tcp_error; Error == ssl_error ->
+    handle_info({sock_error, Reason}, State);
+
+handle_info({Closed, _Sock}, State)
+  when Closed == tcp_closed; Closed == ssl_closed ->
+    handle_info({sock_closed, Closed}, close_socket(State));
+
 handle_info({inet_reply, _Sock, {error, Reason}}, State) ->
+    handle_info({sock_error, Reason}, State);
+
+handle_info({sock_error, Reason}, State) ->
+    case Reason =/= closed andalso Reason =/= einval of
+        true -> ?LOG(warning, "socket_error: ~p", [Reason]);
+        false -> ok
+    end,
+    handle_info({sock_closed, Reason}, close_socket(State));
+
+handle_info({sock_closed, Reason}, State) ->
     shutdown(Reason, State);
 
 handle_info({deliver, _Topic, Msg}, State = #state{pstate = PState}) ->
@@ -172,8 +355,7 @@ handle_info({deliver, _Topic, Msg}, State = #state{pstate = PState}) ->
                                  end});
 
 handle_info(Info, State) ->
-    ?LOG(error, "Unexpected info: ~p", [Info]),
-    noreply(State).
+    with_proto(handle_info, [Info], State).
 
 terminate(Reason, #state{transport = Transport,
                          socket    = Sock,
@@ -197,6 +379,8 @@ code_change(_OldVsn, State, _Extra) ->
 
 with_proto(Fun, Args, State = #state{pstate = PState}) ->
     case erlang:apply(emqx_stomp_protocol, Fun, Args ++ [PState]) of
+        ok ->
+            noreply(State);
         {ok, NPState} ->
             noreply(State#state{pstate = NPState});
         {F, Reason, NPState} when F == stop;
@@ -208,13 +392,14 @@ with_proto(Fun, Args, State = #state{pstate = PState}) ->
 received(<<>>, State) ->
     noreply(State);
 
-received(Bytes, State = #state{parser   = Parser,
+received(Bytes, State = #state{parser = Parser,
                                pstate = PState}) ->
     try emqx_stomp_frame:parse(Bytes, Parser) of
         {more, NewParser} ->
             noreply(State#state{parser = NewParser});
         {ok, Frame, Rest} ->
             ?LOG(info, "RECV Frame: ~s", [emqx_stomp_frame:format(Frame)]),
+            ok = inc_incoming_stats(Frame),
             case emqx_stomp_protocol:received(Frame, PState) of
                 {ok, PState1}           ->
                     received(Rest, reset_parser(State#state{pstate = PState1}));
@@ -237,25 +422,97 @@ received(Bytes, State = #state{parser   = Parser,
 reset_parser(State = #state{proto_env = ProtoEnv}) ->
     State#state{parser = emqx_stomp_frame:init_parer_state(ProtoEnv)}.
 
-rate_limit(_Size, State = #state{rate_limit = undefined}) ->
-    run_socket(State);
-rate_limit(Size, State = #state{rate_limit = Rl}) ->
-    case esockd_rate_limit:check(Size, Rl) of
-        {0, Rl1} ->
-            run_socket(State#state{conn_state = running, rate_limit = Rl1});
-        {Pause, Rl1} ->
-            ?LOG(error, "Rate limiter pause for ~p", [Pause]),
-            erlang:send_after(Pause, self(), activate_sock),
-            State#state{conn_state = blocked, rate_limit = Rl1}
+activate_socket(State = #state{sockstate = closed}) ->
+    {ok, State};
+activate_socket(State = #state{sockstate = blocked}) ->
+    {ok, State};
+activate_socket(State = #state{transport = Transport,
+                               socket    = Socket,
+                               active_n  = N}) ->
+    case Transport:setopts(Socket, [{active, N}]) of
+        ok -> {ok, State#state{sockstate = running}};
+        Error -> Error
+    end.
+
+close_socket(State = #state{sockstate = closed}) -> State;
+close_socket(State = #state{transport = Transport, socket = Socket}) ->
+    ok = Transport:fast_close(Socket),
+    State#state{sockstate = closed}.
+
+%%--------------------------------------------------------------------
+%% Inc incoming/outgoing stats
+
+inc_incoming_stats(#stomp_frame{command = Cmd}) ->
+    inc_counter(recv_pkt, 1),
+    case Cmd of
+        <<"SEND">> ->
+            inc_counter(recv_msg, 1),
+            inc_counter(incoming_pubs, 1),
+            emqx_metrics:inc('messages.received'),
+            emqx_metrics:inc('messages.qos1.received');
+        _ ->
+            ok
+    end,
+    emqx_metrics:inc('packets.received').
+
+inc_outgoing_stats(#stomp_frame{command = Cmd}) ->
+    inc_counter(send_pkt, 1),
+    case Cmd of
+        <<"MESSAGE">> ->
+            inc_counter(send_msg, 1),
+            inc_counter(outgoing_pubs, 1),
+            emqx_metrics:inc('messages.sent'),
+            emqx_metrics:inc('messages.qos1.sent');
+        _ ->
+            ok
+    end,
+    emqx_metrics:inc('packets.sent').
+
+%%--------------------------------------------------------------------
+%% Ensure rate limit
+
+ensure_rate_limit(Stats, State = #state{limiter = Limiter}) ->
+    case ?ENABLED(Limiter) andalso emqx_limiter:check(Stats, Limiter) of
+        false -> State;
+        {ok, Limiter1} ->
+            State#state{limiter = Limiter1};
+        {pause, Time, Limiter1} ->
+            ?LOG(warning, "Pause ~pms due to rate limit", [Time]),
+            TRef = start_timer(Time, limit_timeout),
+            State#state{sockstate   = blocked,
+                        limiter     = Limiter1,
+                        limit_timer = TRef
+                       }
     end.
 
-run_socket(State = #state{conn_state = blocked}) ->
-    State;
-run_socket(State = #state{await_recv = true}) ->
-    State;
-run_socket(State = #state{transport = Transport, socket = Sock}) ->
-    Transport:async_recv(Sock, 0, infinity),
-    State#state{await_recv = true}.
+%%--------------------------------------------------------------------
+%% Run GC and Check OOM
+
+run_gc(Stats, State = #state{gc_state = GcSt}) ->
+    case ?ENABLED(GcSt) andalso emqx_gc:run(Stats, GcSt) of
+        false -> State;
+        {_IsGC, GcSt1} ->
+            State#state{gc_state = GcSt1}
+    end.
+
+check_oom(State) ->
+    OomPolicy = ?DEFAULT_OOM_POLICY,
+    ?tp(debug, check_oom, #{policy => OomPolicy}),
+    case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of
+        {shutdown, Reason} ->
+            %% triggers terminate/2 callback immediately
+            erlang:exit({shutdown, Reason});
+        _Other ->
+            ok
+    end,
+    State.
+
+%%--------------------------------------------------------------------
+%% Ensure/cancel stats timer
+
+ensure_stats_timer(Timeout, State = #state{stats_timer = undefined}) ->
+    State#state{stats_timer = start_timer(Timeout, emit_stats)};
+ensure_stats_timer(_Timeout, State) -> State.
 
 getstat(Stat, #state{transport = Transport, socket = Sock}) ->
     case Transport:getstat(Sock, [Stat]) of
@@ -272,3 +529,9 @@ stop(Reason, State) ->
 shutdown(Reason, State) ->
     stop({shutdown, Reason}, State).
 
+shutdown_and_reply(Reason, Reply, State) ->
+    {stop, {shutdown, Reason}, Reply, State}.
+
+inc_counter(Key, Inc) ->
+    _ = emqx_pd:inc_counter(Key, Inc),
+    ok.

+ 7 - 0
apps/emqx_stomp/src/emqx_stomp_frame.erl

@@ -126,6 +126,13 @@ parse(Bytes, #{phase := body, len := Len, state := State}) ->
 
 parse(Bytes, Parser = #{pre := Pre}) ->
     parse(<<Pre/binary, Bytes/binary>>, maps:without([pre], Parser));
+parse(<<?CR, Rest/binary>>, Parser = #{phase := none}) ->
+    parse(Rest, Parser);
+parse(<<?LF, Rest/binary>>, Parser = #{phase := none}) ->
+    case byte_size(Rest) of
+        0 -> {more, Parser};
+        _ -> parse(Rest, Parser)
+    end;
 parse(<<?CR, ?LF, Rest/binary>>, #{phase := Phase, state := State}) ->
     parse(Phase, <<?LF, Rest/binary>>, State);
 parse(<<?CR>>, Parser) ->

+ 21 - 9
apps/emqx_stomp/src/emqx_stomp_heartbeat.erl

@@ -23,9 +23,10 @@
         , check/3
         , info/1
         , interval/2
+        , reset/3
         ]).
 
--record(heartbeater, {interval, statval, repeat}).
+-record(heartbeater, {interval, statval, repeat, repeat_max}).
 
 -type name() :: incoming | outgoing.
 
@@ -33,7 +34,6 @@
                        outgoing => #heartbeater{}
                       }.
 
-
 %%--------------------------------------------------------------------
 %% APIs
 %%--------------------------------------------------------------------
@@ -43,19 +43,23 @@ init({0, 0}) ->
     #{};
 init({Cx, Cy}) ->
     maps:filter(fun(_, V) -> V /= undefined end,
-      #{incoming => heartbeater(Cx),
-        outgoing => heartbeater(Cy)
+      #{incoming => heartbeater(incoming, Cx),
+        outgoing => heartbeater(outgoing, Cy)
        }).
 
-heartbeater(0) ->
+heartbeater(_, 0) ->
     undefined;
-heartbeater(I) ->
+heartbeater(N, I) ->
     #heartbeater{
        interval = I,
        statval = 0,
-       repeat = 0
+       repeat = 0,
+       repeat_max = repeat_max(N)
       }.
 
+repeat_max(incoming) -> 1;
+repeat_max(outgoing) -> 0.
+
 -spec check(name(), pos_integer(), heartbeat())
     -> {ok, heartbeat()}
      | {error, timeout}.
@@ -68,11 +72,12 @@ check(Name, NewVal, HrtBt) ->
     end.
 
 check(NewVal, HrtBter = #heartbeater{statval = OldVal,
-                                     repeat = Repeat}) ->
+                                     repeat = Repeat,
+                                     repeat_max = Max}) ->
     if
         NewVal =/= OldVal ->
             {ok, HrtBter#heartbeater{statval = NewVal, repeat = 0}};
-        Repeat < 1 ->
+        Repeat < Max ->
             {ok, HrtBter#heartbeater{repeat = Repeat + 1}};
         true -> {error, timeout}
     end.
@@ -90,3 +95,10 @@ interval(Type, HrtBt) ->
         undefined -> undefined;
         #heartbeater{interval = Intv} -> Intv
     end.
+
+reset(Type, StatVal, HrtBt) ->
+    case maps:get(Type, HrtBt, undefined) of
+        undefined -> HrtBt;
+        HrtBter ->
+            HrtBt#{Type => HrtBter#heartbeater{statval = StatVal, repeat = 0}}
+    end.

+ 409 - 86
apps/emqx_stomp/src/emqx_stomp_protocol.erl

@@ -20,6 +20,7 @@
 -include("emqx_stomp.hrl").
 
 -include_lib("emqx/include/emqx.hrl").
+-include_lib("emqx/include/types.hrl").
 -include_lib("emqx/include/logger.hrl").
 -include_lib("emqx/include/emqx_mqtt.hrl").
 
@@ -30,6 +31,8 @@
 %% API
 -export([ init/2
         , info/1
+        , info/2
+        , stats/1
         ]).
 
 -export([ received/2
@@ -38,6 +41,9 @@
         , timeout/3
         ]).
 
+-export([ handle_info/2
+        ]).
+
 %% for trans callback
 -export([ handle_recv_send_frame/2
         , handle_recv_ack_frame/2
@@ -45,21 +51,37 @@
         ]).
 
 -record(pstate, {
-          peername,
-          heartfun,
-          sendfun,
+          %% Stomp ConnInfo
+          conninfo :: emqx_types:conninfo(),
+          %% Stomp ClientInfo
+          clientinfo :: emqx_types:clientinfo(),
+          %% Stomp Heartbeats
+          heart_beats :: maybe(emqx_stomp_hearbeat:heartbeat()),
+          %% Stomp Connection State
           connected = false,
-          proto_ver,
-          proto_name,
-          heart_beats,
-          login,
-          allow_anonymous,
-          default_user,
-          subscriptions = [],
+          %% Timers
           timers :: #{atom() => disable | undefined | reference()},
-          transaction :: #{binary() => list()}
+          %% Transaction
+          transaction :: #{binary() => list()},
+          %% Subscriptions
+          subscriptions = #{},
+          %% Send function
+          sendfun :: {function(), list()},
+          %% Heartbeat function
+          heartfun :: {function(), list()},
+          %% Get Socket stat function
+          statfun :: {function(), list()},
+          %% The confs for the connection
+          %% TODO: put these configs into a public mem?
+          allow_anonymous :: maybe(boolean()),
+          default_user :: maybe(list())
          }).
 
+-define(DEFAULT_SUB_ACK, <<"auto">>).
+
+-define(INCOMING_TIMER_BACKOFF, 1.25).
+-define(OUTCOMING_TIMER_BACKOFF, 0.75).
+
 -define(TIMER_TABLE, #{
           incoming_timer => incoming,
           outgoing_timer => outgoing,
@@ -68,34 +90,135 @@
 
 -define(TRANS_TIMEOUT, 60000).
 
+-define(INFO_KEYS, [conninfo, conn_state, clientinfo, session, will_msg]).
+
+-define(STATS_KEYS, [subscriptions_cnt,
+                     subscriptions_max,
+                     inflight_cnt,
+                     inflight_max,
+                     mqueue_len,
+                     mqueue_max,
+                     mqueue_dropped,
+                     next_pkt_id,
+                     awaiting_rel_cnt,
+                     awaiting_rel_max
+                    ]).
+
+-dialyzer({nowarn_function, [ check_acl/3
+                            , init/2
+                            ]}).
+
 -type(pstate() :: #pstate{}).
 
 %% @doc Init protocol
-init(#{peername := Peername,
-       sendfun := SendFun,
-       heartfun := HeartFun}, Env) ->
-    AllowAnonymous = get_value(allow_anonymous, Env, false),
-    DefaultUser = get_value(default_user, Env),
-	#pstate{peername = Peername,
-                 heartfun = HeartFun,
-                 sendfun = SendFun,
-                 timers = #{},
-                 transaction = #{},
-                 allow_anonymous = AllowAnonymous,
-                 default_user = DefaultUser}.
-
-info(#pstate{connected     = Connected,
-                  proto_ver     = ProtoVer,
-                  proto_name    = ProtoName,
-                  heart_beats   = Heartbeats,
-                  login         = Login,
-                  subscriptions = Subscriptions}) ->
-    [{connected, Connected},
-     {proto_ver, ProtoVer},
-     {proto_name, ProtoName},
-     {heart_beats, Heartbeats},
-     {login, Login},
-     {subscriptions, Subscriptions}].
+init(ConnInfo = #{peername := {PeerHost, _Port},
+                  sockname := {_Host, SockPort},
+                  statfun  := StatFun,
+                  sendfun  := SendFun,
+                  heartfun := HeartFun}, Opts) ->
+
+    NConnInfo = default_conninfo(ConnInfo),
+
+    ClientInfo = #{zone => undefined,
+                   protocol => stomp,
+                   peerhost => PeerHost,
+                   sockport => SockPort,
+                   clientid => undefined,
+                   username => undefined,
+                   mountpoint => undefined, %% XXX: not supported now
+                   is_bridge => false,
+                   is_superuser => false
+                  },
+
+    AllowAnonymous = get_value(allow_anonymous, Opts, false),
+    DefaultUser = get_value(default_user, Opts),
+
+	#pstate{
+       conninfo = NConnInfo,
+       clientinfo = ClientInfo,
+       heartfun = HeartFun,
+       sendfun = SendFun,
+       statfun = StatFun,
+       timers = #{},
+       transaction = #{},
+       allow_anonymous = AllowAnonymous,
+       default_user = DefaultUser
+      }.
+
+default_conninfo(ConnInfo) ->
+    NConnInfo = maps:without([sendfun, heartfun], ConnInfo),
+    NConnInfo#{
+      proto_name => <<"STOMP">>,
+      proto_ver => <<"1.2">>,
+      clean_start => true,
+      clientid => undefined,
+      username => undefined,
+      conn_props => [],
+      connected => false,
+      connected_at => undefined,
+      keepalive => undefined,
+      receive_maximum => 0,
+      expiry_interval => 0
+     }.
+
+-spec info(pstate()) -> emqx_types:infos().
+info(State) ->
+    maps:from_list(info(?INFO_KEYS, State)).
+
+-spec info(list(atom())|atom(), pstate()) -> term().
+info(Keys, State) when is_list(Keys) ->
+    [{Key, info(Key, State)} || Key <- Keys];
+info(conninfo, #pstate{conninfo = ConnInfo}) ->
+    ConnInfo;
+info(socktype, #pstate{conninfo = ConnInfo}) ->
+    maps:get(socktype, ConnInfo, undefined);
+info(peername, #pstate{conninfo = ConnInfo}) ->
+    maps:get(peername, ConnInfo, undefined);
+info(sockname, #pstate{conninfo = ConnInfo}) ->
+    maps:get(sockname, ConnInfo, undefined);
+info(proto_name, #pstate{conninfo = ConnInfo}) ->
+    maps:get(proto_name, ConnInfo, undefined);
+info(proto_ver, #pstate{conninfo = ConnInfo}) ->
+    maps:get(proto_ver, ConnInfo, undefined);
+info(connected_at, #pstate{conninfo = ConnInfo}) ->
+    maps:get(connected_at, ConnInfo, undefined);
+info(clientinfo, #pstate{clientinfo = ClientInfo}) ->
+    ClientInfo;
+info(zone, _) ->
+    undefined;
+info(clientid, #pstate{clientinfo = ClientInfo}) ->
+    maps:get(clientid, ClientInfo, undefined);
+info(username, #pstate{clientinfo = ClientInfo}) ->
+    maps:get(username, ClientInfo, undefined);
+info(session, State) ->
+    session_info(State);
+info(conn_state, #pstate{connected = true}) ->
+    connected;
+info(conn_state, _) ->
+    disconnected;
+info(will_msg, _) ->
+    undefined.
+
+session_info(#pstate{conninfo = ConnInfo, subscriptions = Subs}) ->
+    #{subscriptions => Subs,
+      upgrade_qos => false,
+      retry_interval => 0,
+      await_rel_timeout => 0,
+      created_at => maps:get(connected_at, ConnInfo, 0)
+     }.
+
+-spec stats(pstate()) -> emqx_types:stats().
+stats(#pstate{subscriptions = Subs}) ->
+    [{subscriptions_cnt, maps:size(Subs)},
+     {subscriptions_max, 0},
+     {inflight_cnt, 0},
+     {inflight_max, 0},
+     {mqueue_len, 0},
+     {mqueue_max, 0},
+     {mqueue_dropped, 0},
+     {next_pkt_id, 0},
+     {awaiting_rel_cnt, 0},
+     {awaiting_rel_max, 0}].
 
 -spec(received(stomp_frame(), pstate())
     -> {ok, pstate()}
@@ -105,20 +228,49 @@ received(Frame = #stomp_frame{command = <<"STOMP">>}, State) ->
     received(Frame#stomp_frame{command = <<"CONNECT">>}, State);
 
 received(#stomp_frame{command = <<"CONNECT">>, headers = Headers},
-         State = #pstate{connected = false, allow_anonymous = AllowAnonymous, default_user = DefaultUser}) ->
+         State = #pstate{connected = false}) ->
     case negotiate_version(header(<<"accept-version">>, Headers)) of
         {ok, Version} ->
             Login = header(<<"login">>, Headers),
             Passc = header(<<"passcode">>, Headers),
-            case check_login(Login, Passc, AllowAnonymous, DefaultUser) of
+            case check_login(Login, Passc,
+                             allow_anonymous(State),
+                             default_user(State)
+                            ) of
                 true ->
-                    emqx_logger:set_metadata_clientid(Login),
-
-                    Heartbeats = parse_heartbeats(header(<<"heart-beat">>, Headers, <<"0,0">>)),
-                    NState = start_heartbeart_timer(Heartbeats, State#pstate{connected = true,
-                                                                                  proto_ver = Version, login = Login}),
-                    send(connected_frame([{<<"version">>, Version},
-                                          {<<"heart-beat">>, reverse_heartbeats(Heartbeats)}]), NState);
+                    Heartbeats = parse_heartbeats(
+                                   header(<<"heart-beat">>, Headers, <<"0,0">>)),
+                    ClientId = emqx_guid:to_base62(emqx_guid:gen()),
+                    emqx_logger:set_metadata_clientid(ClientId),
+                    ConnInfo = State#pstate.conninfo,
+                    ClitInfo = State#pstate.clientinfo,
+                    NConnInfo = ConnInfo#{
+                                  proto_ver => Version,
+                                  clientid => ClientId,
+                                  keepalive => element(1, Heartbeats) div 1000,
+                                  username => Login
+                                 },
+                    NClitInfo = ClitInfo#{
+                                  clientid => ClientId,
+                                  username => Login
+                                 },
+
+                    ConnPid = self(),
+                    _ = emqx_cm_locker:trans(ClientId, fun(_) ->
+                        emqx_cm:discard_session(ClientId),
+                        emqx_cm:register_channel(ClientId, ConnPid, NConnInfo)
+                    end),
+                    NState = start_heartbeart_timer(
+                               Heartbeats,
+                               State#pstate{
+                                 conninfo = NConnInfo,
+                                 clientinfo = NClitInfo}
+                              ),
+                    ConnectedFrame = connected_frame(
+                                       [{<<"version">>, Version},
+                                        {<<"heart-beat">>, reverse_heartbeats(Heartbeats)}
+                                       ]),
+                    send(ConnectedFrame, ensure_connected(NState));
                 false ->
                     _ = send(error_frame(undefined, <<"Login or passcode error!">>), State),
                     {error, login_or_passcode_error, State}
@@ -130,6 +282,7 @@ received(#stomp_frame{command = <<"CONNECT">>, headers = Headers},
     end;
 
 received(#stomp_frame{command = <<"CONNECT">>}, State = #pstate{connected = true}) ->
+    ?LOG(error, "Received CONNECT frame on connected=true state"),
     {error, unexpected_connect, State};
 
 received(Frame = #stomp_frame{command = <<"SEND">>, headers = Headers}, State) ->
@@ -139,31 +292,51 @@ received(Frame = #stomp_frame{command = <<"SEND">>, headers = Headers}, State) -
     end;
 
 received(#stomp_frame{command = <<"SUBSCRIBE">>, headers = Headers},
-            State = #pstate{subscriptions = Subscriptions}) ->
+            State = #pstate{subscriptions = Subs}) ->
     Id    = header(<<"id">>, Headers),
     Topic = header(<<"destination">>, Headers),
-    Ack   = header(<<"ack">>, Headers, <<"auto">>),
-    {ok, State1} = case lists:keyfind(Id, 1, Subscriptions) of
-                       {Id, Topic, Ack} ->
-                           {ok, State};
-                       false ->
-                           emqx_broker:subscribe(Topic),
-                           {ok, State#pstate{subscriptions = [{Id, Topic, Ack}|Subscriptions]}}
-                   end,
-    maybe_send_receipt(receipt_id(Headers), State1);
+    Ack   = header(<<"ack">>, Headers, ?DEFAULT_SUB_ACK),
+
+    case find_sub_by_id(Id, Subs) of
+        {Topic, #{sub_props := #{id := Id}}} ->
+            ?LOG(info, "Subscription has established: ~s", [Topic]),
+            maybe_send_receipt(receipt_id(Headers), State);
+        {InuseTopic, #{sub_props := #{id := InuseId}}} ->
+            ?LOG(info, "Subscription id ~p inused by topic: ~s, "
+                       "request topic: ~s", [InuseId, InuseTopic, Topic]),
+            send(error_frame(receipt_id(Headers),
+                             ["Request sub-id ", Id, " inused "]), State);
+        undefined ->
+            case check_acl(subscribe, Topic, State) of
+                allow ->
+                    ClientInfo = State#pstate.clientinfo,
+
+                    [{TopicFilter, SubOpts}] = parse_topic_filters(
+                                                 [{Topic, ?DEFAULT_SUBOPTS}
+                                               ]),
+                    NSubOpts = SubOpts#{sub_props => #{id => Id, ack => Ack}},
+                    _ = run_hooks('client.subscribe',
+                                  [ClientInfo, _SubProps = #{}],
+                                  [{TopicFilter, NSubOpts}]),
+                    NState = do_subscribe(TopicFilter, NSubOpts, State),
+                    maybe_send_receipt(receipt_id(Headers), NState)
+            end
+    end;
 
 received(#stomp_frame{command = <<"UNSUBSCRIBE">>, headers = Headers},
-            State = #pstate{subscriptions = Subscriptions}) ->
+            State = #pstate{subscriptions = Subs, clientinfo = ClientInfo}) ->
     Id = header(<<"id">>, Headers),
-
-    {ok, State1} = case lists:keyfind(Id, 1, Subscriptions) of
-                       {Id, Topic, _Ack} ->
-                           ok = emqx_broker:unsubscribe(Topic),
-                           {ok, State#pstate{subscriptions = lists:keydelete(Id, 1, Subscriptions)}};
-                       false ->
-                           {ok, State}
-                   end,
-    maybe_send_receipt(receipt_id(Headers), State1);
+    {ok, NState} = case find_sub_by_id(Id, Subs) of
+            {Topic, #{sub_props := #{id := Id}}} ->
+                _ = run_hooks('client.unsubscribe',
+                              [ClientInfo, #{}],
+                              [{Topic, #{}}]),
+                State1 = do_unsubscribe(Topic, ?DEFAULT_SUBOPTS, State),
+                {ok, State1};
+            undefined ->
+                {ok, State}
+        end,
+    maybe_send_receipt(receipt_id(Headers), NState);
 
 %% ACK
 %% id:12345
@@ -239,10 +412,15 @@ received(#stomp_frame{command = <<"DISCONNECT">>, headers = Headers}, State) ->
     _ = maybe_send_receipt(receipt_id(Headers), State),
     {stop, normal, State}.
 
-send(Msg = #message{topic = Topic, headers = Headers, payload = Payload},
-     State = #pstate{subscriptions = Subscriptions}) ->
-    case lists:keyfind(Topic, 2, Subscriptions) of
-        {Id, Topic, Ack} ->
+send(Msg0 = #message{},
+     State = #pstate{clientinfo = ClientInfo, subscriptions = Subs}) ->
+    ok = emqx_metrics:inc('messages.delivered'),
+    Msg = emqx_hooks:run_fold('message.delivered', [ClientInfo], Msg0),
+    #message{topic = Topic,
+             headers = Headers,
+             payload = Payload} = Msg,
+    case find_sub_by_topic(Topic, Subs) of
+        {Topic, #{sub_props := #{id := Id, ack := Ack}}} ->
             Headers0 = [{<<"subscription">>, Id},
                         {<<"message-id">>, next_msgid()},
                         {<<"destination">>, Topic},
@@ -256,19 +434,21 @@ send(Msg = #message{topic = Topic, headers = Headers, payload = Payload},
             Frame = #stomp_frame{command = <<"MESSAGE">>,
                                  headers = Headers1 ++ maps:get(stomp_headers, Headers, []),
                                  body = Payload},
+
+
             send(Frame, State);
-        false ->
+        undefined ->
             ?LOG(error, "Stomp dropped: ~p", [Msg]),
             {error, dropped, State}
     end;
 
-send(Frame, State = #pstate{sendfun = {Fun, Args}}) ->
-    ?LOG(info, "SEND Frame: ~s", [emqx_stomp_frame:format(Frame)]),
-    Data = emqx_stomp_frame:serialize(Frame),
-    ?LOG(debug, "SEND ~p", [Data]),
-    erlang:apply(Fun, [Data] ++ Args),
+send(Frame, State = #pstate{sendfun = {Fun, Args}}) when is_record(Frame, stomp_frame) ->
+    erlang:apply(Fun, [Frame] ++ Args),
     {ok, State}.
 
+shutdown(Reason, State = #pstate{connected = true}) ->
+    _ = ensure_disconnected(Reason, State),
+    ok;
 shutdown(_Reason, _State) ->
     ok.
 
@@ -283,11 +463,18 @@ timeout(_TRef, {incoming, NewVal},
 
 timeout(_TRef, {outgoing, NewVal},
         State = #pstate{heart_beats = HrtBt,
-                             heartfun = {Fun, Args}}) ->
+                        statfun = {StatFun, StatArgs},
+                        heartfun = {Fun, Args}}) ->
     case emqx_stomp_heartbeat:check(outgoing, NewVal, HrtBt) of
         {error, timeout} ->
             _ = erlang:apply(Fun, Args),
-            {ok, State};
+            case erlang:apply(StatFun, [send_oct] ++ StatArgs) of
+                {ok, NewVal2} ->
+                    NHrtBt = emqx_stomp_heartbeat:reset(outgoing, NewVal2, HrtBt),
+                    {ok, reset_timer(outgoing_timer, State#pstate{heart_beats = NHrtBt})};
+                {error, Reason} ->
+                    {shutdown, {error, {get_stats_error, Reason}}, State}
+            end;
         {ok, NHrtBt} ->
             {ok, reset_timer(outgoing_timer, State#pstate{heart_beats = NHrtBt})}
     end;
@@ -297,6 +484,28 @@ timeout(_TRef, clean_trans, State = #pstate{transaction = Trans}) ->
     NTrans = maps:filter(fun(_, {Ts, _}) -> Ts + ?TRANS_TIMEOUT < Now end, Trans),
     {ok, ensure_clean_trans_timer(State#pstate{transaction = NTrans})}.
 
+
+-spec(handle_info(Info :: term(), pstate())
+      -> ok | {ok, pstate()} | {shutdown, Reason :: term(), pstate()}).
+
+handle_info({subscribe, TopicFilters}, State) ->
+    NState = lists:foldl(
+        fun({TopicFilter, SubOpts}, StateAcc = #pstate{subscriptions = Subs}) ->
+            NSubOpts = enrich_sub_opts(SubOpts, Subs),
+            do_subscribe(TopicFilter, NSubOpts, StateAcc)
+        end, State, parse_topic_filters(TopicFilters)),
+    {ok, NState};
+
+handle_info({unsubscribe, TopicFilters}, State) ->
+    NState = lists:foldl(fun({TopicFilter, SubOpts}, StateAcc) ->
+                do_unsubscribe(TopicFilter, SubOpts, StateAcc)
+             end, State, parse_topic_filters(TopicFilters)),
+    {ok, NState};
+
+handle_info(Info, State) ->
+    ?LOG(warning, "Unexpected info ~p", [Info]),
+    {ok, State}.
+
 negotiate_version(undefined) ->
     {ok, <<"1.0">>};
 negotiate_version(Accepts) ->
@@ -312,13 +521,15 @@ negotiate_version(Ver, [AcceptVer|_]) when Ver >= AcceptVer ->
 negotiate_version(Ver, [_|T]) ->
     negotiate_version(Ver, T).
 
-check_login(undefined, _, AllowAnonymous, _) ->
+check_login(Login, _, AllowAnonymous, _)
+  when Login == <<>>;
+       Login == undefined ->
     AllowAnonymous;
 check_login(_, _, _, undefined) ->
     false;
 check_login(Login, Passcode, _, DefaultUser) ->
-    case {list_to_binary(get_value(login, DefaultUser)),
-          list_to_binary(get_value(passcode, DefaultUser))} of
+    case {iolist_to_binary(get_value(login, DefaultUser)),
+          iolist_to_binary(get_value(passcode, DefaultUser))} of
         {Login, Passcode} -> true;
         {_,     _       } -> false
     end.
@@ -396,11 +607,18 @@ receipt_id(Headers) ->
 
 handle_recv_send_frame(#stomp_frame{command = <<"SEND">>, headers = Headers, body = Body}, State) ->
     Topic = header(<<"destination">>, Headers),
-    _ = maybe_send_receipt(receipt_id(Headers), State),
-    _ = emqx_broker:publish(
-        make_mqtt_message(Topic, Headers, iolist_to_binary(Body))
-    ),
-    State.
+    case check_acl(publish, Topic, State) of
+        allow ->
+            _ = maybe_send_receipt(receipt_id(Headers), State),
+            _ = emqx_broker:publish(
+                make_mqtt_message(Topic, Headers, iolist_to_binary(Body))
+            ),
+            State;
+        deny ->
+            ErrFrame = error_frame(receipt_id(Headers), <<"Not Authorized">>),
+            {ok, NState} = send(ErrFrame, State),
+            NState
+    end.
 
 handle_recv_ack_frame(#stomp_frame{command = <<"ACK">>, headers = Headers}, State) ->
     Id = header(<<"id">>, Headers),
@@ -431,7 +649,111 @@ reverse_heartbeats({Cx, Cy}) ->
 start_heartbeart_timer(Heartbeats, State) ->
     ensure_timer(
       [incoming_timer, outgoing_timer],
-      State#pstate{heart_beats = emqx_stomp_heartbeat:init(Heartbeats)}).
+      State#pstate{heart_beats = emqx_stomp_heartbeat:init(backoff(Heartbeats))}).
+
+backoff({Cx, Cy}) ->
+    {erlang:ceil(Cx * ?INCOMING_TIMER_BACKOFF),
+     erlang:ceil(Cy * ?OUTCOMING_TIMER_BACKOFF)}.
+
+%%--------------------------------------------------------------------
+%% pub & sub helpers
+
+parse_topic_filters(TopicFilters) ->
+    lists:map(fun emqx_topic:parse/1, TopicFilters).
+
+check_acl(PubSub, Topic, State = #pstate{clientinfo = ClientInfo}) ->
+    case is_acl_enabled(State) andalso
+         emqx_access_control:check_acl(ClientInfo, PubSub, Topic) of
+        false -> allow;
+        Res   -> Res
+    end.
+
+do_subscribe(TopicFilter, SubOpts,
+             State = #pstate{clientinfo = ClientInfo, subscriptions = Subs}) ->
+    ClientId = maps:get(clientid, ClientInfo),
+    _ = emqx_broker:subscribe(TopicFilter, ClientId),
+    NSubOpts = SubOpts#{is_new => true},
+    _ = run_hooks('session.subscribed',
+                  [ClientInfo, TopicFilter, NSubOpts]),
+    send_event_to_self(updated),
+    State#pstate{subscriptions = maps:put(TopicFilter, SubOpts, Subs)}.
+
+do_unsubscribe(TopicFilter, SubOpts,
+               State = #pstate{clientinfo = ClientInfo, subscriptions = Subs}) ->
+    ok = emqx_broker:unsubscribe(TopicFilter),
+    _ = run_hooks('session.unsubscribe',
+                  [ClientInfo, TopicFilter, SubOpts]),
+    send_event_to_self(updated),
+    State#pstate{subscriptions = maps:remove(TopicFilter, Subs)}.
+
+find_sub_by_topic(Topic, Subs) ->
+    case maps:get(Topic, Subs, undefined) of
+        undefined -> undefined;
+        SubOpts -> {Topic, SubOpts}
+    end.
+
+find_sub_by_id(Id, Subs) ->
+    Found = maps:filter(fun(_, SubOpts) ->
+               %% FIXME: datatype??
+               maps:get(id, maps:get(sub_props, SubOpts, #{}), -1) == Id
+            end, Subs),
+    case maps:to_list(Found) of
+        [] -> undefined;
+        [Sub|_] -> Sub
+    end.
+
+is_acl_enabled(_) ->
+    %% TODO: configs from somewhere
+    true.
+
+%% automaticly fill the next sub-id and ack if sub-id is absent
+enrich_sub_opts(SubOpts0, Subs) ->
+    SubOpts = maps:merge(?DEFAULT_SUBOPTS, SubOpts0),
+    SubProps = maps:get(sub_props, SubOpts, #{}),
+    SubOpts#{sub_props =>
+             maps:merge(#{id => next_sub_id(Subs),
+                          ack => ?DEFAULT_SUB_ACK}, SubProps)}.
+
+next_sub_id(Subs) ->
+    Ids = maps:fold(fun(_, SubOpts, Acc) ->
+        [binary_to_integer(
+           maps:get(id, maps:get(sub_props, SubOpts, #{}), <<"0">>)) | Acc]
+    end, [], Subs),
+    integer_to_binary(lists:max(Ids) + 1).
+
+%%--------------------------------------------------------------------
+%% helpers
+
+default_user(#pstate{default_user = DefaultUser}) ->
+    DefaultUser.
+allow_anonymous(#pstate{allow_anonymous = AllowAnonymous}) ->
+    AllowAnonymous.
+
+ensure_connected(State = #pstate{conninfo = ConnInfo,
+                                 clientinfo = ClientInfo}) ->
+    NConnInfo = ConnInfo#{
+                  connected => true,
+                  connected_at => erlang:system_time(millisecond)
+                 },
+    send_event_to_self(connected),
+    ok = run_hooks('client.connected', [ClientInfo, NConnInfo]),
+    State#pstate{conninfo  = NConnInfo,
+                 connected = true
+                }.
+
+ensure_disconnected(Reason, State = #pstate{conninfo = ConnInfo, clientinfo = ClientInfo}) ->
+    NConnInfo = ConnInfo#{disconnected_at => erlang:system_time(millisecond)},
+    ok = run_hooks('client.disconnected', [ClientInfo, Reason, NConnInfo]),
+    State#pstate{conninfo = NConnInfo, connected = false}.
+
+send_event_to_self(Name) ->
+    self() ! {event, Name}, ok.
+
+run_hooks(Name, Args) ->
+    emqx_hooks:run(Name, Args).
+
+run_hooks(Name, Args, Acc) ->
+    emqx_hooks:run_fold(Name, Args, Acc).
 
 %%--------------------------------------------------------------------
 %% Timer
@@ -466,3 +788,4 @@ interval(outgoing_timer, #pstate{heart_beats = HrtBt}) ->
     emqx_stomp_heartbeat:interval(outgoing, HrtBt);
 interval(clean_trans_timer, _) ->
     ?TRANS_TIMEOUT.
+

+ 1 - 1
apps/emqx_stomp/test/emqx_stomp_SUITE.erl

@@ -100,7 +100,7 @@ t_heartbeat(_) ->
                                                      {<<"host">>, <<"127.0.0.1:61613">>},
                                                      {<<"login">>, <<"guest">>},
                                                      {<<"passcode">>, <<"guest">>},
-                                                     {<<"heart-beat">>, <<"1000,800">>}])),
+                                                     {<<"heart-beat">>, <<"1000,2000">>}])),
                         {ok, Data} = gen_tcp:recv(Sock, 0),
                         {ok, #stomp_frame{command = <<"CONNECTED">>,
                                           headers = _,

+ 1 - 2
apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl

@@ -35,8 +35,7 @@ t_check_1(_) ->
     {ok, HrtBt1} = emqx_stomp_heartbeat:check(incoming, 0, HrtBt),
     {error, timeout} = emqx_stomp_heartbeat:check(incoming, 0, HrtBt1),
 
-    {ok, HrtBt2} = emqx_stomp_heartbeat:check(outgoing, 0, HrtBt1),
-    {error, timeout} = emqx_stomp_heartbeat:check(outgoing, 0, HrtBt2),
+    {error, timeout} = emqx_stomp_heartbeat:check(outgoing, 0, HrtBt1),
     ok.
 
 t_check_2(_) ->

+ 4 - 1
src/emqx_types.erl

@@ -94,7 +94,10 @@
 -type(ver() :: ?MQTT_PROTO_V3
              | ?MQTT_PROTO_V4
              | ?MQTT_PROTO_V5
-             | non_neg_integer()).
+             | non_neg_integer()
+             %% Some non-MQTT versions of protocol may be a binary type
+             | binary()
+             ).
 
 -type(qos() :: ?QOS_0 | ?QOS_1 | ?QOS_2).
 -type(qos_name() :: qos0 | at_most_once |