Sfoglia il codice sorgente

fix issue #292 - async sub/unsub

Feng 10 anni fa
parent
commit
d5a400c308
4 ha cambiato i file con 160 aggiunte e 95 eliminazioni
  1. 42 21
      src/emqttd_client.erl
  2. 6 11
      src/emqttd_protocol.erl
  3. 51 45
      src/emqttd_session.erl
  4. 61 18
      src/emqttd_ws_client.erl

+ 42 - 21
src/emqttd_client.erl

@@ -34,7 +34,10 @@
 -include("emqttd_protocol.hrl").
 
 %% API Function Exports
--export([start_link/2, session/1, info/1, kick/1, subscribe/2]).
+-export([start_link/2, session/1, info/1, kick/1]).
+
+%% SUB/UNSUB Asynchronously
+-export([subscribe/2, unsubscribe/2]).
 
 -behaviour(gen_server).
 
@@ -59,7 +62,7 @@ start_link(SockArgs, MqttEnv) ->
     {ok, proc_lib:spawn_link(?MODULE, init, [[SockArgs, MqttEnv]])}.
 
 session(CPid) ->
-    gen_server:call(CPid, session).
+    gen_server:call(CPid, session, infinity).
 
 info(CPid) ->
     gen_server:call(CPid, info, infinity).
@@ -70,6 +73,9 @@ kick(CPid) ->
 subscribe(CPid, TopicTable) ->
     gen_server:cast(CPid, {subscribe, TopicTable}).
 
+unsubscribe(CPid, Topics) ->
+    gen_server:cast(CPid, {unsubscribe, Topics}).
+
 init([SockArgs = {Transport, Sock, _SockFun}, MqttEnv]) ->
     % Transform if ssl.
     {ok, NewSock} = esockd_connection:accept(SockArgs),
@@ -107,9 +113,11 @@ handle_call(Req, _From, State = #state{peername = Peername}) ->
     lager:critical("Client ~s: unexpected request - ~p", [emqttd_net:format(Peername), Req]),
     {reply, {error, unsupported_request}, State}.    
 
-handle_cast({subscribe, TopicTable}, State = #state{proto_state = ProtoState}) ->
-    {ok, ProtoState1} = emqttd_protocol:handle({subscribe, TopicTable}, ProtoState),
-    noreply(State#state{proto_state = ProtoState1});
+handle_cast({subscribe, TopicTable}, State) ->
+    with_session(fun(SessPid) -> emqttd_session:subscribe(SessPid, TopicTable) end, State);
+
+handle_cast({unsubscribe, Topics}, State) ->
+    with_session(fun(SessPid) -> emqttd_session:unsubscribe(SessPid, Topics) end, State);
 
 handle_cast(Msg, State = #state{peername = Peername}) ->
     lager:critical("Client ~s: unexpected msg - ~p",[emqttd_net:format(Peername), Msg]),
@@ -149,17 +157,26 @@ handle_info({inet_reply, _Sock, {error, Reason}}, State = #state{peername = Peer
 
 handle_info({keepalive, start, TimeoutSec}, State = #state{transport = Transport, socket = Socket, peername = Peername}) ->
     lager:debug("Client ~s: Start KeepAlive with ~p seconds", [emqttd_net:format(Peername), TimeoutSec]),
-    KeepAlive = emqttd_keepalive:new({Transport, Socket}, TimeoutSec, {keepalive, timeout}),
+    StatFun = fun() ->
+            case Transport:getstat(Socket, [recv_oct]) of
+                {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct};
+                {error, Error} -> {error, Error}
+            end
+    end,
+    KeepAlive = emqttd_keepalive:start(StatFun, TimeoutSec, {keepalive, check}),
     noreply(State#state{keepalive = KeepAlive});
 
-handle_info({keepalive, timeout}, State = #state{peername = Peername, keepalive = KeepAlive}) ->
-    case emqttd_keepalive:resume(KeepAlive) of
-    timeout ->
+handle_info({keepalive, check}, State = #state{peername = Peername, keepalive = KeepAlive}) ->
+    case emqttd_keepalive:check(KeepAlive) of
+    {ok, KeepAlive1} ->
+        lager:debug("Client ~s: Keepalive Resumed", [emqttd_net:format(Peername)]),
+        noreply(State#state{keepalive = KeepAlive1});
+    {error, timeout} ->
         lager:debug("Client ~s: Keepalive Timeout!", [emqttd_net:format(Peername)]),
         stop({shutdown, keepalive_timeout}, State#state{keepalive = undefined});
-    {resumed, KeepAlive1} ->
-        lager:debug("Client ~s: Keepalive Resumed", [emqttd_net:format(Peername)]),
-        noreply(State#state{keepalive = KeepAlive1})
+    {error, Error} ->
+        lager:debug("Client ~s: Keepalive Error: ~p!", [emqttd_net:format(Peername), Error]),
+        stop({shutdown, keepalive_error}, State#state{keepalive = undefined})
     end;
 
 handle_info(Info, State = #state{peername = Peername}) ->
@@ -188,12 +205,20 @@ terminate(Reason, #state{peername = Peername,
 code_change(_OldVsn, State, _Extra) ->
     {ok, State}.
 
+%%%=============================================================================
+%%% Internal functions
+%%%=============================================================================
+
 noreply(State) ->
     {noreply, State, hibernate}.
-    
-%-------------------------------------------------------
-% receive and parse tcp data
-%-------------------------------------------------------
+
+stop(Reason, State) ->
+    {stop, Reason, State}.
+
+with_session(Fun, State = #state{proto_state = ProtoState}) ->
+    Fun(emqttd_protocol:session(ProtoState)), noreply(State).
+
+%% receive and parse tcp data
 received(<<>>, State) ->
     {noreply, State, hibernate};
 
@@ -244,12 +269,8 @@ control_throttle(State = #state{conn_state = Flow,
         {_,            _} -> run_socket(State)
     end.
 
-stop(Reason, State) ->
-    {stop, Reason, State}.
-
 received_stats(?PACKET(Type)) ->
-    emqttd_metrics:inc('packets/received'), 
-    inc(Type).
+    emqttd_metrics:inc('packets/received'), inc(Type).
 inc(?CONNECT) ->
     emqttd_metrics:inc('packets/connect');
 inc(?PUBLISH) ->

+ 6 - 11
src/emqttd_protocol.erl

@@ -239,16 +239,11 @@ handle(?SUBSCRIBE_PACKET(PacketId, TopicTable), State = #proto_state{client_id =
     case lists:member(deny, AllowDenies) of
         true ->
             %%TODO: return 128 QoS when deny... no need to SUBACK?
-            lager:error("SUBSCRIBE from '~s' Denied: ~p", [ClientId, TopicTable]),
-            {ok, State};
+            lager:error("SUBSCRIBE from '~s' Denied: ~p", [ClientId, TopicTable]);
         false ->
-            %%TODO: GrantedQos should be renamed.
-            {ok, GrantedQos} = emqttd_session:subscribe(Session, TopicTable),
-            send(?SUBACK_PACKET(PacketId, GrantedQos), State)
-    end;
-
-handle({subscribe, TopicTable}, State = #proto_state{session = Session}) ->
-    {ok, _GrantedQos} = emqttd_session:subscribe(Session, TopicTable),
+            Callback = fun(GrantedQos) -> send(?SUBACK_PACKET(PacketId, GrantedQos), State) end,
+            emqttd_session:subscribe(Session, TopicTable, Callback)
+    end,
     {ok, State};
 
 %% protect from empty topic list
@@ -256,7 +251,7 @@ handle(?UNSUBSCRIBE_PACKET(PacketId, []), State) ->
     send(?UNSUBACK_PACKET(PacketId), State);
 
 handle(?UNSUBSCRIBE_PACKET(PacketId, Topics), State = #proto_state{session = Session}) ->
-    ok = emqttd_session:unsubscribe(Session, Topics),
+    emqttd_session:unsubscribe(Session, Topics),
     send(?UNSUBACK_PACKET(PacketId), State);
 
 handle(?PACKET(?PINGREQ), State) ->
@@ -349,7 +344,7 @@ send_willmsg(ClientId, WillMsg) ->
 start_keepalive(0) -> ignore;
 
 start_keepalive(Sec) when Sec > 0 ->
-    self() ! {keepalive, start, round(Sec * 1.5)}.
+    self() ! {keepalive, start, round(Sec * 1.2)}.
 
 %%----------------------------------------------------------------------------
 %% Validate Packets

+ 51 - 45
src/emqttd_session.erl

@@ -59,7 +59,7 @@
 %% PubSub APIs
 -export([publish/2,
          puback/2, pubrec/2, pubrel/2, pubcomp/2,
-         subscribe/2, unsubscribe/2]).
+         subscribe/2, subscribe/3, unsubscribe/2]).
 
 -behaviour(gen_server2).
 
@@ -166,9 +166,13 @@ destroy(SessPid, ClientId) ->
 %% @doc Subscribe Topics
 %% @end
 %%------------------------------------------------------------------------------
--spec subscribe(pid(), [{binary(), mqtt_qos()}]) -> {ok, [mqtt_qos()]}.
+-spec subscribe(pid(), [{binary(), mqtt_qos()}]) -> ok.
 subscribe(SessPid, TopicTable) ->
-    gen_server2:call(SessPid, {subscribe, TopicTable}, ?PUBSUB_TIMEOUT).
+    subscribe(SessPid, TopicTable, fun(_) -> ok end).
+
+-spec subscribe(pid(), [{binary(), mqtt_qos()}], Callback :: fun()) -> ok.
+subscribe(SessPid, TopicTable, Callback) ->
+    gen_server2:cast(SessPid, {subscribe, TopicTable, Callback}).
 
 %%------------------------------------------------------------------------------
 %% @doc Publish message
@@ -213,7 +217,7 @@ pubcomp(SessPid, PktId) ->
 %%------------------------------------------------------------------------------
 -spec unsubscribe(pid(), [binary()]) -> ok.
 unsubscribe(SessPid, Topics) ->
-    gen_server2:call(SessPid, {unsubscribe, Topics}, ?PUBSUB_TIMEOUT).
+    gen_server2:cast(SessPid, {unsubscribe, Topics}).
 
 %%%=============================================================================
 %%% gen_server callbacks
@@ -247,26 +251,24 @@ init([CleanSess, ClientId, ClientPid]) ->
     {ok, start_collector(Session#session{client_mon = MRef}), hibernate}.
 
 prioritise_call(Msg, _From, _Len, _State) ->
-    case Msg of
-        {unsubscribe, _} -> 2;
-        {subscribe, _}   -> 1;
-        _                -> 0
-    end.
+    case Msg of _  -> 0 end.
 
 prioritise_cast(Msg, _Len, _State) ->
     case Msg of
-        {destroy, _}      -> 10;
-        {resume, _, _}    -> 9;
-        {pubrel,  _PktId} -> 8;
-        {pubcomp, _PktId} -> 8;
-        {pubrec,  _PktId} -> 8;
-        {puback,  _PktId} -> 7;
-        _                 -> 0
+        {destroy, _}        -> 10;
+        {resume, _, _}      -> 9;
+        {pubrel,  _PktId}   -> 8;
+        {pubcomp, _PktId}   -> 8;
+        {pubrec,  _PktId}   -> 8;
+        {puback,  _PktId}   -> 7;
+        {unsubscribe, _, _} -> 6;
+        {subscribe, _, _}   -> 5;
+        _                   -> 0
     end.
 
 prioritise_info(Msg, _Len, _State) ->
     case Msg of
-        {'DOWN', _, process, _, _} -> 10;
+        {'DOWN', _, _, _, _} -> 10;
         {'EXIT', _, _}  -> 10;
         session_expired -> 10;
         {timeout, _, _} -> 5;
@@ -275,17 +277,40 @@ prioritise_info(Msg, _Len, _State) ->
         _               -> 0
     end.
 
-handle_call({subscribe, TopicTable0}, _From, Session = #session{client_id = ClientId,
-                                                                subscriptions = Subscriptions}) ->
+handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, _From,
+                Session = #session{client_id         = ClientId,
+                                   awaiting_rel      = AwaitingRel,
+                                   await_rel_timeout = Timeout}) ->
+    case check_awaiting_rel(Session) of
+        true ->
+            TRef = timer(Timeout, {timeout, awaiting_rel, PktId}),
+            AwaitingRel1 = maps:put(PktId, {Msg, TRef}, AwaitingRel),
+            {reply, ok, Session#session{awaiting_rel = AwaitingRel1}};
+        false ->
+            lager:critical([{client, ClientId}], "Session(~s) dropped Qos2 message "
+                                "for too many awaiting_rel: ~p", [ClientId, Msg]),
+            {reply, {error, dropped}, Session}
+    end;
+
+handle_call(Req, _From, State) ->
+    lager:critical("Unexpected Request: ~p", [Req]),
+    {reply, ok, State}.
+
+handle_cast({subscribe, TopicTable0, Callback}, Session = #session{
+                client_id = ClientId, subscriptions = Subscriptions}) ->
 
-    case TopicTable0 -- Subscriptions of
+    TopicTable = emqttd_broker:foldl_hooks('client.subscribe', [ClientId], TopicTable0),
+
+    case TopicTable -- Subscriptions of
         [] ->
-            {reply, {ok, [Qos || {_, Qos} <- TopicTable0]}, Session};
+            catch Callback([Qos || {_, Qos} <- TopicTable]),
+            noreply(Session);
         _  ->
-            TopicTable = emqttd_broker:foldl_hooks('client.subscribe', [ClientId], TopicTable0),
             %% subscribe first and don't care if the subscriptions have been existed
             {ok, GrantedQos} = emqttd_pubsub:subscribe(TopicTable),
 
+            catch Callback(GrantedQos),
+
             emqttd_broker:foreach_hooks('client.subscribe.after', [ClientId, TopicTable]),
 
             lager:info([{client, ClientId}], "Session(~s): subscribe ~p, Granted QoS: ~p",
@@ -310,11 +335,11 @@ handle_call({subscribe, TopicTable0}, _From, Session = #session{client_id = Clie
                                     [{Topic, Qos} | Acc]
                             end
                         end, Subscriptions, TopicTable),
-            {reply, {ok, GrantedQos}, Session#session{subscriptions = Subscriptions1}}
+            noreply(Session#session{subscriptions = Subscriptions1})
     end;
 
-handle_call({unsubscribe, Topics0}, _From, Session = #session{client_id = ClientId,
-                                                             subscriptions = Subscriptions}) ->
+handle_cast({unsubscribe, Topics0}, Session = #session{client_id = ClientId,
+                                                       subscriptions = Subscriptions}) ->
 
     Topics = emqttd_broker:foldl_hooks('client.unsubscribe', [ClientId], Topics0),
 
@@ -333,26 +358,7 @@ handle_call({unsubscribe, Topics0}, _From, Session = #session{client_id = Client
                     end
                 end, Subscriptions, Topics),
 
-    {reply, ok, Session#session{subscriptions = Subscriptions1}};
-
-handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, _From, 
-            Session = #session{client_id = ClientId,
-                               awaiting_rel = AwaitingRel,
-                               await_rel_timeout = Timeout}) ->
-    case check_awaiting_rel(Session) of
-        true ->
-            TRef = timer(Timeout, {timeout, awaiting_rel, PktId}),
-            AwaitingRel1 = maps:put(PktId, {Msg, TRef}, AwaitingRel),
-            {reply, ok, Session#session{awaiting_rel = AwaitingRel1}};
-        false ->
-            lager:critical([{client, ClientId}], "Session(~s) dropped Qos2 message "
-                                "for too many awaiting_rel: ~p", [ClientId, Msg]),
-            {reply, {error, dropped}, Session}
-    end;
-
-handle_call(Req, _From, State) ->
-    lager:critical("Unexpected Request: ~p", [Req]),
-    {reply, ok, State}.
+    noreply(Session#session{subscriptions = Subscriptions1});
 
 handle_cast({destroy, ClientId}, Session = #session{client_id = ClientId}) ->
     lager:warning([{client, ClientId}], "Session(~s) destroyed", [ClientId]),

+ 61 - 18
src/emqttd_ws_client.erl

@@ -34,7 +34,10 @@
 -include("emqttd_protocol.hrl").
 
 %% API Exports
--export([start_link/1, ws_loop/3, subscribe/2]).
+-export([start_link/1, ws_loop/3, session/1, info/1, kick/1]).
+
+%% SUB/UNSUB Asynchronously
+-export([subscribe/2, unsubscribe/2]).
 
 -behaviour(gen_server).
 
@@ -61,9 +64,21 @@ start_link(Req) ->
                              packet_opts  = PktOpts,
                              parser       = emqttd_parser:new(PktOpts)}).
 
+session(CPid) ->
+    gen_server:call(CPid, session, infinity).
+
+info(CPid) ->
+    gen_server:call(CPid, info, infinity).
+
+kick(CPid) ->
+    gen_server:call(CPid, kick).
+
 subscribe(CPid, TopicTable) ->
     gen_server:cast(CPid, {subscribe, TopicTable}).
 
+unsubscribe(CPid, Topics) ->
+    gen_server:cast(CPid, {unsubscribe, Topics}).
+
 %%------------------------------------------------------------------------------
 %% @private
 %% @doc Start WebSocket client.
@@ -112,17 +127,30 @@ init([WsPid, Req, ReplyChannel, PktOpts]) ->
     ProtoState = emqttd_protocol:init(Peername, SendFun, [{ws_initial_headers, HeadersList}|PktOpts]),
     {ok, #client_state{ws_pid = WsPid, request = Req, proto_state = ProtoState}}.
 
+handle_call(session, _From, State = #client_state{proto_state = ProtoState}) ->
+    {reply, emqttd_protocol:session(ProtoState), State};
+
+handle_call(info, _From, State = #client_state{request = Req,
+                                               proto_state = ProtoState}) ->
+    {reply, [{websocket, true}, {peer, Req:get(peer)}
+             | emqttd_protocol:info(ProtoState)], State};
+
+handle_call(kick, _From, State) ->
+    {stop, {shutdown, kick}, ok, State};
+
 handle_call(_Req, _From, State) ->
     {reply, error, State}.
 
-handle_cast({subscribe, TopicTable}, State = #client_state{proto_state = ProtoState}) ->
-    {ok, ProtoState1} = emqttd_protocol:handle({subscribe, TopicTable}, ProtoState),
-    {noreply, State#client_state{proto_state = ProtoState1}, hibernate};
+handle_cast({subscribe, TopicTable}, State) ->
+    with_session(fun(SessPid) -> emqttd_session:subscribe(SessPid, TopicTable) end, State);
+
+handle_cast({unsubscribe, Topics}, State) ->
+    with_session(fun(SessPid) -> emqttd_session:unsubscribe(SessPid, Topics) end, State);
 
 handle_cast({received, Packet}, State = #client_state{proto_state = ProtoState}) ->
     case emqttd_protocol:received(Packet, ProtoState) of
     {ok, ProtoState1} ->
-        {noreply, State#client_state{proto_state = ProtoState1}};
+        noreply(State#client_state{proto_state = ProtoState1});
     {error, Error} ->
         lager:error("MQTT protocol error ~p", [Error]),
         stop({shutdown, Error}, State);
@@ -137,11 +165,11 @@ handle_cast(_Msg, State) ->
 
 handle_info({deliver, Message}, State = #client_state{proto_state = ProtoState}) ->
     {ok, ProtoState1} = emqttd_protocol:send(Message, ProtoState),
-    {noreply, State#client_state{proto_state = ProtoState1}};
+    noreply(State#client_state{proto_state = ProtoState1});
 
 handle_info({redeliver, {?PUBREL, PacketId}}, State = #client_state{proto_state = ProtoState}) ->
     {ok, ProtoState1} = emqttd_protocol:redeliver({?PUBREL, PacketId}, ProtoState),
-    {noreply, State#client_state{proto_state = ProtoState1}};
+    noreply(State#client_state{proto_state = ProtoState1});
 
 handle_info({stop, duplicate_id, _NewPid}, State = #client_state{proto_state = ProtoState}) ->
     lager:error("Shutdown for duplicate clientid: ~s", [emqttd_protocol:clientid(ProtoState)]), 
@@ -149,18 +177,27 @@ handle_info({stop, duplicate_id, _NewPid}, State = #client_state{proto_state = P
 
 handle_info({keepalive, start, TimeoutSec}, State = #client_state{request = Req}) ->
     lager:debug("Client(WebSocket) ~s: Start KeepAlive with ~p seconds", [Req:get(peer), TimeoutSec]),
-    KeepAlive = emqttd_keepalive:new({esockd_transport, Req:get(socket)},
-                                     TimeoutSec, {keepalive, timeout}),
-    {noreply, State#client_state{keepalive = KeepAlive}};
-
-handle_info({keepalive, timeout}, State = #client_state{request = Req, keepalive = KeepAlive}) ->
-    case emqttd_keepalive:resume(KeepAlive) of
-    timeout ->
+    Socket = Req:get(socket),
+    StatFun = fun() ->
+        case esockd_transport:getstat(Socket, [recv_oct]) of
+            {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct};
+            {error, Error}              -> {error, Error}
+        end
+    end,
+    KeepAlive = emqttd_keepalive:start(StatFun, TimeoutSec, {keepalive, check}),
+    noreply(State#client_state{keepalive = KeepAlive});
+
+handle_info({keepalive, check}, State = #client_state{request = Req, keepalive = KeepAlive}) ->
+    case emqttd_keepalive:check(KeepAlive) of
+    {ok, KeepAlive1} ->
+        lager:debug("Client(WebSocket) ~s: Keepalive Resumed", [Req:get(peer)]),
+        noreply(State#client_state{keepalive = KeepAlive1});
+    {error, timeout} ->
         lager:debug("Client(WebSocket) ~s: Keepalive Timeout!", [Req:get(peer)]),
         stop({shutdown, keepalive_timeout}, State#client_state{keepalive = undefined});
-    {resumed, KeepAlive1} ->
-        lager:debug("Client(WebSocket) ~s: Keepalive Resumed", [Req:get(peer)]),
-        {noreply, State#client_state{keepalive = KeepAlive1}}
+    {error, Error} ->
+        lager:debug("Client(WebSocket) ~s: Keepalive Error: ~p", [Req:get(peer), Error]),
+        stop({shutdown, keepalive_error}, State#client_state{keepalive = undefined})
     end;
 
 handle_info({'EXIT', WsPid, Reason}, State = #client_state{ws_pid = WsPid, proto_state = ProtoState}) ->
@@ -170,7 +207,7 @@ handle_info({'EXIT', WsPid, Reason}, State = #client_state{ws_pid = WsPid, proto
 
 handle_info(Info, State = #client_state{request = Req}) ->
     lager:critical("Client(WebSocket) ~s: Unexpected Info - ~p", [Req:get(peer), Info]),
-    {noreply, State}.
+    noreply(State).
 
 terminate(Reason, #client_state{proto_state = ProtoState, keepalive = KeepAlive}) ->
     lager:info("WebSocket client terminated: ~p", [Reason]),
@@ -189,6 +226,12 @@ code_change(_OldVsn, State, _Extra) ->
 %%% Internal functions
 %%%=============================================================================
 
+noreply(State) ->
+    {noreply, State, hibernate}.
+
 stop(Reason, State ) ->
     {stop, Reason, State}.
 
+with_session(Fun, State = #client_state{proto_state = ProtoState}) ->
+    Fun(emqttd_protocol:session(ProtoState)), noreply(State).
+