Feng 10 yıl önce
ebeveyn
işleme
a1d778b081
2 değiştirilmiş dosya ile 59 ekleme ve 62 silme
  1. 4 0
      apps/emqtt/include/emqtt.hrl
  2. 55 62
      apps/emqttd/src/emqttd_session.erl

+ 4 - 0
apps/emqtt/include/emqtt.hrl

@@ -58,6 +58,10 @@
 -record(mqtt_message, {
     %% topic is first for message may be retained
     topic           :: binary(),
+    %% clientid from
+    from            :: binary() | atom(),
+    %% sender pid ??
+    sender          :: pid(),
     qos    = ?QOS_0 :: mqtt_qos(),
     retain = false  :: boolean(),
     dup    = false  :: boolean(),

+ 55 - 62
apps/emqttd/src/emqttd_session.erl

@@ -225,7 +225,7 @@ puback(Session = #session{clientid = ClientId, awaiting_ack = Awaiting}, {?PUBAC
     Session#session{awaiting_ack = maps:remove(PacketId, Awaiting)};
 
 puback(SessPid, {?PUBACK, PacketId}) when is_pid(SessPid) ->
-    gen_server:cast(SessPid, {puback, PacketId});
+    gen_server:cast(SessPid, {puback, {?PUBACK, PacketId});
 
 %% PUBREC
 puback(Session = #session{clientid = ClientId, 
@@ -239,7 +239,7 @@ puback(Session = #session{clientid = ClientId,
                       awaiting_comp  = maps:put(PacketId, true, AwaitingComp)};
 
 puback(SessPid, {?PUBREC, PacketId}) when is_pid(SessPid) ->
-        gen_server:cast(SessPid, {pubrec, PacketId}), SessPid;
+    gen_server:cast(SessPid, {puback, {?PUBREC, PacketId});
 
 %% PUBREL
 puback(Session = #session{clientid = ClientId, awaiting_rel = Awaiting}, {?PUBREL, PacketId}) ->
@@ -253,7 +253,7 @@ puback(Session = #session{clientid = ClientId, awaiting_rel = Awaiting}, {?PUBRE
     Session#session{awaiting_rel = maps:remove(PacketId, Awaiting)};
 
 puback(SessPid, {?PUBREL, PacketId}) when is_pid(SessPid) ->
-    cast(SessPid, {pubrel, PacketId});
+    gen_server:cast(SessPid, {puback, {?PUBREL, PacketId});
 
 %% PUBCOMP
 puback(Session = #session{clientid = ClientId, 
@@ -265,7 +265,9 @@ puback(Session = #session{clientid = ClientId,
     Session#session{awaiting_comp = maps:remove(PacketId, AwaitingComp)};
 
 puback(SessPid, {?PUBCOMP, PacketId}) when is_pid(SessPid) ->
-    cast(SessPid, {pubcomp, PacketId}).
+    gen_server:cast(SessPid, {puback, {?PUBCOMP, PacketId});
+
+wait_ack
 
 timeout(awaiting_rel, MsgId, Session = #session{clientid = ClientId, awaiting_rel = Awaiting}) ->
     case maps:find(MsgId, Awaiting) of
@@ -440,48 +442,34 @@ handle_cast({resume, ClientId, ClientPid}, State = #session{
         end, emqttd_queue:all(Queue)),
 
     {noreply, State#session{client_pid   = ClientPid,
-                                  msg_queue    = emqttd_queue:clear(Queue),
-                                  expire_timer = undefined}, hibernate};
-
-handle_cast({publish, ClientId, {?QOS_2, Message}}, State) ->
-    NewState = publish(State, ClientId, {?QOS_2, Message}),
-    {noreply, NewState};
-
-handle_cast({puback, PacketId}, State) ->
-    NewState = puback(State, {?PUBACK, PacketId}),
-    {noreply, NewState};
-
-handle_cast({pubrec, PacketId}, State) ->
-    NewState = puback(State, {?PUBREC, PacketId}),
-    {noreply, NewState};
+                            msg_queue    = emqttd_queue:clear(Queue),
+                            expire_timer = undefined}, hibernate};
 
-handle_cast({pubrel, PacketId}, State) ->
-    NewState = puback(State, {?PUBREL, PacketId}),
-    {noreply, NewState};
+handle_cast({publish, ClientId, {?QOS_2, Message}}, Session) ->
+    {noreply, publish(Session, ClientId, {?QOS_2, Message})};
 
-handle_cast({pubcomp, PacketId}, State) ->
-    NewState = puback(State, {?PUBCOMP, PacketId}),
-    {noreply, NewState};
+handle_cast({puback, {PubAck, PacketId}, Session) ->
+    {noreply, puback(Session, {PubAck, PacketId})};
 
-handle_cast({destroy, ClientId}, State = #session{clientid = ClientId}) ->
+handle_cast({destroy, ClientId}, Session = #session{clientid = ClientId}) ->
     lager:warning("Session ~s destroyed", [ClientId]),
-    {stop, normal, State};
+    {stop, normal, Session};
 
 handle_cast(Msg, State) ->
     lager:critical("Unexpected Msg: ~p, State: ~p", [Msg, State]), 
     {noreply, State}.
 
-handle_info({dispatch, {_From, Messages}}, State) when is_list(Messages) ->
+handle_info({dispatch, {_From, Messages}}, Session) when is_list(Messages) ->
     F = fun(Message, S) -> dispatch(Message, S) end,
-    {noreply, lists:foldl(F, State, Messages)};
+    {noreply, lists:foldl(F, Session, Messages)};
 
 handle_info({dispatch, {_From, Message}}, State) ->
     {noreply, dispatch(Message, State)};
 
-handle_info({'EXIT', ClientPid, Reason}, State = #session{clientid = ClientId,
-                                                          client_pid = ClientPid}) ->
+handle_info({'EXIT', ClientPid, Reason}, Session = #session{clientid = ClientId,
+                                                            client_pid = ClientPid}) ->
     lager:info("Session: client ~s@~p exited for ~p", [ClientId, ClientPid, Reason]),
-    {noreply, start_expire_timer(State#session{client_pid = undefined})};
+    {noreply, start_expire_timer(Session#session{client_pid = undefined})};
 
 handle_info({'EXIT', ClientPid0, _Reason}, State = #session{client_pid = ClientPid}) ->
     lager:error("Unexpected Client EXIT: pid=~p, pid(state): ~p", [ClientPid0, ClientPid]),
@@ -491,51 +479,55 @@ handle_info(session_expired, State = #session{clientid = ClientId}) ->
     lager:warning("Session ~s expired!", [ClientId]),
     {stop, {shutdown, expired}, State};
 
-handle_info({timeout, awaiting_rel, MsgId}, SessState) ->
-    NewState = timeout(awaiting_rel, MsgId, SessState),
-    {noreply, NewState};
+handle_info({timeout, awaiting_rel, MsgId}, Session) ->
+    {noreply, timeout(awaiting_rel, MsgId, Session)};
 
-handle_info(Info, State) ->
-    lager:critical("Unexpected Info: ~p, State: ~p", [Info, State]),
-    {noreply, State}.
+handle_info(Info, Session) ->
+    lager:critical("Unexpected Info: ~p, Session: ~p", [Info, Session]),
+    {noreply, Session}.
 
-terminate(_Reason, _State) ->
+terminate(_Reason, _Session) ->
     ok.
 
-code_change(_OldVsn, State, _Extra) ->
-    {ok, State}.
-
-
-
+code_change(_OldVsn, Session, _Extra) ->
+    {ok, Session}.
 
 %%%=============================================================================
-%%% Internal functions
+%%% Dispatch message from broker -> client.
 %%%=============================================================================
 
-%% client is offline
+%% queued the message if client is offline
 dispatch(Msg, Session = #session{client_pid = undefined}) ->
     queue(Msg, Session);
 
-%% dispatch qos0 directly
+%% dispatch qos0 directly to client process
 dispatch(Msg = #mqtt_message{qos = ?QOS_0}, Session = #session{client_pid = ClientPid}) ->
     ClientPid ! {dispatch, {self(), Msg}}, Session;
 
-%% queue if inflight_queue is full
-dispatch(Msg = #mqtt_message{qos = Qos}, Session = #session{inflight_window = InflightWin,
-                                                              inflight_queue  = InflightQ})
-        when (Qos > ?QOS_0) andalso (length(InflightQ) >= InflightWin) ->
-    %%TODO: set alarms
-    lager:error([{clientid, ClientId}], "Session ~s inflight_queue is full!", [ClientId]),
-    queue(Msg, Session);
-
-%% dispatch and await ack
-dispatch(Msg = #mqtt_message{qos = Qos}, Session = #session{client_pid = ClientPid})
+%% dispatch qos1/2 messages and wait for puback
+dispatch(Msg = #mqtt_message{qos = Qos}, Session = #session{clientid = ClientId,
+                                                            message_id = MsgId,
+                                                            pending_queue = Q,
+                                                            inflight_window = Win})
     when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) ->
-    %% assign msgid and await
-    {NewMsg, NewState} = await_ack(Msg, Session),
-    ClientPid ! {dispatch, {self(), NewMsg}},
 
-queue(Msg, Session = #session{pending_queue = Queue}) ->
+    case emqttd_mqwin:is_full(InflightWin) of
+        true  ->
+            lager:error("Session ~s inflight window is full!", [ClientId]),
+            Session#session{pending_queue = emqttd_mqueue:in(Msg, Q)};
+        false ->
+            Msg1 = Msg#mqtt_message{msgid = MsgId},
+            Msg2 =
+            if
+                Qos =:= ?QOS_2 -> Msg1#mqtt_message{dup = false};
+                true -> Msg1
+            end,
+            ClientPid ! {dispatch, {self(), Msg2}},
+            NewWin = emqttd_mqwin:in(Msg2, Win),
+            await_ack(Msg2, next_msgid(Session#session{inflight_window = NewWin}))
+    end.
+
+queue(Msg, Session = #session{pending_queue= Queue}) ->
     Session#session{pending_queue = emqttd_mqueue:in(Msg, Queue)}.
 
 next_msgid(State = #session{message_id = 16#ffff}) ->
@@ -544,8 +536,9 @@ next_msgid(State = #session{message_id = 16#ffff}) ->
 next_msgid(State = #session{message_id = MsgId}) ->
     State#session{message_id = MsgId + 1}.
 
-start_expire_timer(State = #session{expires = Expires, expire_timer = OldTimer}) ->
+start_expire_timer(Session = #session{expired_after = Expires,
+                                      expired_timer = OldTimer}) ->
     emqttd_util:cancel_timer(OldTimer),
     Timer = erlang:send_after(Expires * 1000, self(), session_expired),
-    State#session{expire_timer = Timer}.
+    Session#session{expired_timer = Timer}.