Jelajahi Sumber

Improve mechanism of waiting for session to expire

zhouzb 6 tahun lalu
induk
melakukan
a9dd94b2b5
4 mengubah file dengan 114 tambahan dan 97 penghapusan
  1. 36 32
      src/emqx_channel.erl
  2. 36 35
      src/emqx_connection.erl
  3. 41 29
      src/emqx_ws_connection.erl
  4. 1 1
      test/emqx_channel_SUITE.erl

+ 36 - 32
src/emqx_channel.erl

@@ -328,32 +328,25 @@ handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters),
 handle_in(?PACKET(?PINGREQ), Channel) ->
     {ok, ?PACKET(?PINGRESP), Channel};
 
-handle_in(?DISCONNECT_PACKET(RC, Properties), Channel = #channel{session = Session, protocol = Protocol}) ->
+handle_in(?DISCONNECT_PACKET(ReasonCode, Properties), Channel = #channel{session = Session, protocol = Protocol}) ->
     OldInterval = emqx_session:info(expiry_interval, Session),
     Interval = get_property('Session-Expiry-Interval', Properties, OldInterval),
     case OldInterval =:= 0 andalso Interval =/= OldInterval of
         true ->
             handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel);
         false ->
-            Channel1 = case RC of
-                           ?RC_SUCCESS -> Channel#channel{protocol = emqx_protocol:clear_will_msg(Protocol)};
-                           _ -> Channel
-                       end,
-            Channel2 = Channel1#channel{session = emqx_session:update_expiry_interval(Interval, Session)},
-            case Interval of
-                ?UINT_MAX ->
-                    {ok, ensure_timer(will_timer, Channel2)};
-                Int when Int > 0 ->
-                    {ok, ensure_timer([will_timer, expire_timer], Channel2)};
-                _Other ->
-                    Reason = case RC of
-                                 ?RC_SUCCESS -> normal;
-                                 _ ->
-                                     Ver = emqx_protocol:info(proto_ver, Protocol),
-                                     emqx_reason_codes:name(RC, Ver)
-                             end,
-                    {stop, {shutdown, Reason}, Channel2}
-            end
+            Reason = case ReasonCode of
+                         ?RC_SUCCESS -> normal;
+                         _ ->
+                             ProtoVer = emqx_protocol:info(proto_ver, Protocol),
+                             emqx_reason_codes:name(ReasonCode, ProtoVer)
+                     end,
+            {wait_session_expire, {shutdown, Reason},
+             Channel#channel{session = emqx_session:update_expiry_interval(Interval, Session),
+                             protocol = case ReasonCode of
+                                            ?RC_SUCCESS -> emqx_protocol:clear_will_msg(Protocol);
+                                            _ -> Protocol
+                                        end}}
     end;
 
 handle_in(?AUTH_PACKET(), Channel) ->
@@ -362,7 +355,7 @@ handle_in(?AUTH_PACKET(), Channel) ->
 
 handle_in(Packet, Channel) ->
     ?LOG(error, "Unexpected incoming: ~p", [Packet]),
-    {stop, {shutdown, unexpected_incoming_packet}, Channel}.
+    handle_out({disconnect, ?RC_MALFORMED_PACKET}, Channel).
 
 %%--------------------------------------------------------------------
 %% Process Connect
@@ -599,10 +592,10 @@ handle_out({disconnect, ReasonCode}, Channel = #channel{protocol = Protocol}) ->
         ?MQTT_PROTO_V5 ->
             Reason = emqx_reason_codes:name(ReasonCode),
             Packet = ?DISCONNECT_PACKET(ReasonCode),
-            {stop, {shutdown, Reason}, Packet, Channel};
+            {wait_session_expire, {shutdown, Reason}, Packet, Channel};
         ProtoVer ->
             Reason = emqx_reason_codes:name(ReasonCode, ProtoVer),
-            {stop, {shutdown, Reason}, Channel}
+            {wait_session_expire, {shutdown, Reason}, Channel}
     end;
 
 handle_out({Type, Data}, Channel) ->
@@ -674,18 +667,26 @@ handle_info({unsubscribe, TopicFilters}, Channel = #channel{client = Client}) ->
 handle_info(disconnected, Channel = #channel{connected = undefined}) ->
     shutdown(closed, Channel);
 
+handle_info(disconnected, Channel = #channel{connected = false}) ->
+    {ok, Channel};
+
 handle_info(disconnected, Channel = #channel{protocol = Protocol,
                                              session  = Session}) ->
-    %% TODO: Why handle will_msg here?
-    publish_will_msg(emqx_protocol:info(will_msg, Protocol)),
-    NChannel = Channel#channel{protocol = emqx_protocol:clear_will_msg(Protocol)},
-    Interval = emqx_session:info(expiry_interval, Session),
-    case Interval of
+    Channel1 = ensure_disconnected(Channel),
+    Channel2 = case timer:seconds(emqx_protocol:info(will_delay_interval, Protocol)) of
+                   0 ->
+                       publish_will_msg(emqx_protocol:info(will_msg, Protocol)),
+                       Channel1#channel{protocol = emqx_protocol:clear_will_msg(Protocol)};
+                   _ ->
+                       ensure_timer(will_timer, Channel1)
+               end,
+    case emqx_session:info(expiry_interval, Session) of
         ?UINT_MAX ->
-            {ok, ensure_disconnected(NChannel)};
+            {ok, Channel2};
         Int when Int > 0 ->
-            {ok, ensure_timer(expire_timer, ensure_disconnected(NChannel))};
-        _Other -> shutdown(closed, NChannel)
+            {ok, ensure_timer(expire_timer, Channel2)};
+        _Other ->
+            shutdown(closed, Channel2)
     end;
 
 handle_info(Info, Channel) ->
@@ -715,7 +716,7 @@ timeout(TRef, {keepalive, StatVal}, Channel = #channel{keepalive = Keepalive,
             NChannel = Channel#channel{keepalive = NKeepalive},
             {ok, reset_timer(alive_timer, NChannel)};
         {error, timeout} ->
-            {stop, {shutdown, keepalive_timeout}, Channel}
+            {wait_session_expire, {shutdown, keepalive_timeout}, Channel}
     end;
 
 timeout(TRef, retry_delivery, Channel = #channel{session = Session,
@@ -804,6 +805,9 @@ interval(will_timer, #channel{protocol = Protocol}) ->
 
 terminate(normal, #channel{client = Client}) ->
     ok = emqx_hooks:run('client.disconnected', [Client, normal]);
+terminate({shutdown, Reason}, #channel{client = Client})
+    when Reason =:= kicked orelse Reason =:= discarded orelse Reason =:= takeovered ->
+    ok = emqx_hooks:run('client.disconnected', [Client, Reason]);
 terminate(Reason, #channel{client = Client,
                            protocol = Protocol
                           }) ->

+ 36 - 35
src/emqx_connection.erl

@@ -224,10 +224,13 @@ idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) ->
     SuccFun = fun(NewSt) -> {next_state, connected, NewSt} end,
     handle_incoming(Packet, SuccFun, NState);
 
-idle(cast, {incoming, Packet}, State) ->
+idle(cast, {incoming, Packet}, State) when is_record(Packet, mqtt_packet) ->
     ?LOG(warning, "Unexpected incoming: ~p", [Packet]),
     shutdown(unexpected_incoming_packet, State);
 
+idle(cast, {incoming, {error, Reason}}, State) ->
+    shutdown(Reason, State);
+
 idle(EventType, Content, State) ->
     ?HANDLE(EventType, Content, State).
 
@@ -241,6 +244,17 @@ connected(enter, _PrevSt, State) ->
 connected(cast, {incoming, Packet}, State) when is_record(Packet, mqtt_packet) ->
     handle_incoming(Packet, fun keep_state/1, State);
 
+connected(cast, {incoming, {error, Reason}}, State = #connection{chan_state = ChanState}) ->
+    case emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState) of
+        {wait_session_expire, _, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            {next_state, disconnected, State#connection{chan_state= NChanState}};
+        {wait_session_expire, _, OutPackets, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            NState = State#connection{chan_state= NChanState},
+            {next_state, disconnected, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)}
+    end;
+
 connected(info, Deliver = {deliver, _Topic, _Msg}, State) ->
     handle_deliver(emqx_misc:drain_deliver([Deliver]), State);
 
@@ -408,8 +422,7 @@ process_incoming(Data, State) ->
 process_incoming(<<>>, Packets, State) ->
     {keep_state, State, next_incoming_events(Packets)};
 
-process_incoming(Data, Packets, State = #connection{parse_state = ParseState,
-                                                    chan_state  = ChanState}) ->
+process_incoming(Data, Packets, State = #connection{parse_state = ParseState}) ->
     try emqx_frame:parse(Data, ParseState) of
         {more, NParseState} ->
             NState = State#connection{parse_state = NParseState},
@@ -418,32 +431,16 @@ process_incoming(Data, Packets, State = #connection{parse_state = ParseState,
             NState = State#connection{parse_state = NParseState},
             process_incoming(Rest, [Packet|Packets], NState);
         {error, Reason} ->
-            shutdown(Reason, State)
+            {keep_state, State, next_incoming_events({error, Reason})}
     catch
         error:Reason:Stk ->
-            ?LOG(error, "Parse failed for ~p~nStacktrace:~p~nError data:~p", [Reason, Stk, Data]),
-            Result = 
-                case emqx_channel:info(connected, ChanState) of
-                    undefined ->
-                        emqx_channel:handle_out({connack, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState);
-                    true ->
-                        emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState);
-                    _ ->
-                        ignore
-                end,
-            case Result of
-                {stop, Reason0, OutPackets, NChanState} ->
-                    Shutdown = fun(NewSt) -> stop(Reason0, NewSt) end,
-                    NState = State#connection{chan_state = NChanState},
-                    handle_outgoing(OutPackets, Shutdown, NState);
-                {stop, Reason0, NChanState} ->
-                    stop(Reason0, State#connection{chan_state = NChanState});
-                ignore ->
-                    keep_state(State)
-            end
+            ?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nError data:~p", [Reason, Stk, Data]),
+            {keep_state, State, next_incoming_events({error, Reason})}
     end.
 
 -compile({inline, [next_incoming_events/1]}).
+next_incoming_events({error, Reason}) ->
+    [next_event(cast, {incoming, {error, Reason}})];
 next_incoming_events(Packets) ->
     [next_event(cast, {incoming, Packet}) || Packet <- Packets].
 
@@ -459,14 +456,19 @@ handle_incoming(Packet = ?PACKET(Type), SuccFun,
         {ok, NChanState} ->
             SuccFun(State#connection{chan_state= NChanState});
         {ok, OutPackets, NChanState} ->
-            handle_outgoing(OutPackets, SuccFun,
-                            State#connection{chan_state = NChanState});
+            handle_outgoing(OutPackets, SuccFun, State#connection{chan_state = NChanState});
+        {wait_session_expire, Reason, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            {next_state, disconnected, State#connection{chan_state = NChanState}};
+        {wait_session_expire, Reason, OutPackets, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            NState = State#connection{chan_state= NChanState},
+            {next_state, disconnected, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)};
         {stop, Reason, NChanState} ->
             stop(Reason, State#connection{chan_state = NChanState});
         {stop, Reason, OutPackets, NChanState} ->
-            Shutdown = fun(NewSt) -> stop(Reason, NewSt) end,
-            NState = State#connection{chan_state = NChanState},
-            handle_outgoing(OutPackets, Shutdown, NState)
+            NState = State#connection{chan_state= NChanState},
+            stop(Reason, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState))
     end.
 
 %%-------------------------------------------------------------------
@@ -477,10 +479,7 @@ handle_deliver(Delivers, State = #connection{chan_state = ChanState}) ->
         {ok, NChanState} ->
             keep_state(State#connection{chan_state = NChanState});
         {ok, Packets, NChanState} ->
-            NState = State#connection{chan_state = NChanState},
-            handle_outgoing(Packets, fun keep_state/1, NState);
-        {stop, Reason, NChanState} ->
-            stop(Reason, State#connection{chan_state = NChanState})
+            handle_outgoing(Packets, fun keep_state/1, State#connection{chan_state = NChanState})
     end.
 
 %%--------------------------------------------------------------------
@@ -534,8 +533,10 @@ handle_timeout(TRef, Msg, State = #connection{chan_state = ChanState}) ->
         {ok, NChanState} ->
             keep_state(State#connection{chan_state = NChanState});
         {ok, Packets, NChanState} ->
-            handle_outgoing(Packets, fun keep_state/1,
-                            State#connection{chan_state = NChanState});
+            handle_outgoing(Packets, fun keep_state/1, State#connection{chan_state = NChanState});
+        {wait_session_expire, Reason, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            {next_state, disconnected, State#connection{chan_state = NChanState}};
         {stop, Reason, NChanState} ->
             stop(Reason, State#connection{chan_state = NChanState})
     end.

+ 41 - 29
src/emqx_ws_connection.erl

@@ -254,6 +254,22 @@ websocket_info({cast, Msg}, State = #ws_connection{chan_state = ChanState}) ->
             stop(Reason, State#ws_connection{chan_state = NChanState})
     end;
 
+websocket_info({incoming, {error, Reason}}, State = #ws_connection{fsm_state = idle}) ->
+    stop({shutdown, Reason}, State);
+
+websocket_info({incoming, {error, Reason}}, State = #ws_connection{fsm_state = connected, chan_state = ChanState}) ->
+    case emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState) of
+        {wait_session_expire, _, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            disconnected(State#ws_connection{chan_state= NChanState});
+        {wait_session_expire, _, OutPackets, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            disconnected(enqueue(OutPackets, State#ws_connection{chan_state = NChanState}))
+    end;
+
+websocket_info({incoming, {error, _Reason}}, State = #ws_connection{fsm_state = disconnected}) ->
+    reply(State);
+    
 websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)},
                 State = #ws_connection{fsm_state = idle}) ->
     #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt,
@@ -276,9 +292,7 @@ websocket_info(Deliver = {deliver, _Topic, _Msg},
         {ok, NChanState} ->
             reply(State#ws_connection{chan_state = NChanState});
         {ok, Packets, NChanState} ->
-            reply(enqueue(Packets, State#ws_connection{chan_state = NChanState}));
-        {stop, Reason, NChanState} ->
-            stop(Reason, State#ws_connection{chan_state = NChanState})
+            reply(enqueue(Packets, State#ws_connection{chan_state = NChanState}))
     end;
 
 websocket_info({timeout, TRef, keepalive}, State) when is_reference(TRef) ->
@@ -307,8 +321,7 @@ websocket_info(Info, State = #ws_connection{chan_state = ChanState}) ->
 
 terminate(SockError, _Req, #ws_connection{chan_state  = ChanState,
                                           stop_reason = Reason}) ->
-    ?LOG(debug, "Terminated for ~p, sockerror: ~p",
-         [Reason, SockError]),
+    ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Reason, SockError]),
     emqx_channel:terminate(Reason, ChanState).
 
 %%--------------------------------------------------------------------
@@ -318,6 +331,12 @@ connected(State = #ws_connection{chan_state = ChanState}) ->
     ok = emqx_channel:handle_cast({register, attrs(State), stats(State)}, ChanState),
     reply(State#ws_connection{fsm_state = connected}).
 
+%%--------------------------------------------------------------------
+%% Disconnected callback
+
+disconnected(State) ->
+    reply(State#ws_connection{fsm_state = disconnected}).
+
 %%--------------------------------------------------------------------
 %% Handle timeout
 
@@ -328,6 +347,9 @@ handle_timeout(TRef, Msg, State = #ws_connection{chan_state = ChanState}) ->
         {ok, Packets, NChanState} ->
             NState = State#ws_connection{chan_state = NChanState},
             reply(enqueue(Packets, NState));
+        {wait_session_expire, Reason, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            disconnected(State#ws_connection{chan_state = NChanState});
         {stop, Reason, NChanState} ->
             stop(Reason, State#ws_connection{chan_state = NChanState})
     end.
@@ -347,29 +369,13 @@ process_incoming(Data, State = #ws_connection{parse_state = ParseState,
             self() ! {incoming, Packet},
             process_incoming(Rest, State#ws_connection{parse_state = NParseState});
         {error, Reason} ->
-            ?LOG(error, "Frame error: ~p", [Reason]),
-            stop(Reason, State)
+            self() ! {incoming, {error, Reason}},
+            {ok, State}
     catch
         error:Reason:Stk ->
-            ?LOG(error, "Parse failed for ~p~nStacktrace:~p~nFrame data: ~p", [Reason, Stk, Data]),
-            Result = 
-                case emqx_channel:info(connected, ChanState) of
-                    undefined ->
-                        emqx_channel:handle_out({connack, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState);
-                    true ->
-                        emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState);
-                    _ ->
-                        ignore
-                end,
-            case Result of
-                {stop, Reason0, OutPackets, NChanState} ->
-                    NState = State#ws_connection{chan_state = NChanState},
-                    stop(Reason0, enqueue(OutPackets, NState));
-                {stop, Reason0, NChanState} ->
-                    stop(Reason0, State#ws_connection{chan_state = NChanState});
-                ignore ->
-                    {ok, State}
-            end
+            ?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nFrame data: ~p", [Reason, Stk, Data]),
+            self() ! {incoming, {error, Reason}},
+            {ok, State}
     end.
 
 %%--------------------------------------------------------------------
@@ -386,11 +392,17 @@ handle_incoming(Packet = ?PACKET(Type), SuccFun,
         {ok, OutPackets, NChanState} ->
             NState = State#ws_connection{chan_state= NChanState},
             SuccFun(enqueue(OutPackets, NState));
+        {wait_session_expire, Reason, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            disconnected(State#ws_connection{chan_state = NChanState});
+        {wait_session_expire, Reason, OutPackets, NChanState} ->
+            ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]),
+            disconnected(enqueue(OutPackets, State#ws_connection{chan_state = NChanState}));
         {stop, Reason, NChanState} ->
-            stop(Reason, State#ws_connection{chan_state= NChanState});
-        {stop, Reason, OutPacket, NChanState} ->
+            stop(Reason, State#ws_connection{chan_state = NChanState});
+        {stop, Reason, OutPackets, NChanState} ->
             NState = State#ws_connection{chan_state= NChanState},
-            stop(Reason, enqueue(OutPacket, NState))
+            stop(Reason, enqueue(OutPackets, NState))
     end.
 
 %%--------------------------------------------------------------------

+ 1 - 1
test/emqx_channel_SUITE.erl

@@ -144,7 +144,7 @@ t_handle_pingreq(_) ->
 t_handle_disconnect(_) ->
     with_channel(
       fun(Channel) ->
-              {stop, {shutdown, normal}, Channel1} = handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel),
+              {wait_session_expire, {shutdown, normal}, Channel1} = handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel),
               ?assertMatch(#{will_msg := undefined}, emqx_channel:info(protocol, Channel1))
       end).