Просмотр исходного кода

Merge pull request #7354 from HJianBo/send_disconnect_pkt_while_kicked_5

fix(channel): send DISCONNECT packet if connection has been kicked
JianBo He 3 лет назад
Родитель
Сommit
36e84ff8cd
2 измененных файлов с 68 добавлено и 11 удалено
  1. 34 10
      apps/emqx/src/emqx_channel.erl
  2. 34 1
      apps/emqx/test/emqx_channel_SUITE.erl

+ 34 - 10
apps/emqx/src/emqx_channel.erl

@@ -1141,9 +1141,31 @@ return_sub_unsub_ack(Packet, Channel) ->
     {reply, Reply :: term(), channel()}
     {reply, Reply :: term(), channel()}
     | {shutdown, Reason :: term(), Reply :: term(), channel()}
     | {shutdown, Reason :: term(), Reply :: term(), channel()}
     | {shutdown, Reason :: term(), Reply :: term(), emqx_types:packet(), channel()}.
     | {shutdown, Reason :: term(), Reply :: term(), emqx_types:packet(), channel()}.
-handle_call(kick, Channel) ->
-    Channel1 = ensure_disconnected(kicked, Channel),
-    disconnect_and_shutdown(kicked, ok, Channel1);
+handle_call(
+    kick,
+    Channel = #channel{
+        conn_state = ConnState,
+        will_msg = WillMsg,
+        conninfo = #{proto_ver := ProtoVer}
+    }
+) ->
+    (WillMsg =/= undefined) andalso publish_will_msg(WillMsg),
+    Channel1 =
+        case ConnState of
+            connected -> ensure_disconnected(kicked, Channel);
+            _ -> Channel
+        end,
+    case ProtoVer == ?MQTT_PROTO_V5 andalso ConnState == connected of
+        true ->
+            shutdown(
+                kicked,
+                ok,
+                ?DISCONNECT_PACKET(?RC_ADMINISTRATIVE_ACTION),
+                Channel1
+            );
+        _ ->
+            shutdown(kicked, ok, Channel1)
+    end;
 handle_call(discard, Channel) ->
 handle_call(discard, Channel) ->
     disconnect_and_shutdown(discarded, ok, Channel);
     disconnect_and_shutdown(discarded, ok, Channel);
 %% Session Takeover
 %% Session Takeover
@@ -1220,7 +1242,7 @@ handle_info(
 ->
 ->
     emqx_config:get_zone_conf(Zone, [flapping_detect, enable]) andalso
     emqx_config:get_zone_conf(Zone, [flapping_detect, enable]) andalso
         emqx_flapping:detect(ClientInfo),
         emqx_flapping:detect(ClientInfo),
-    Channel1 = ensure_disconnected(Reason, mabye_publish_will_msg(Channel)),
+    Channel1 = ensure_disconnected(Reason, maybe_publish_will_msg(Channel)),
     case maybe_shutdown(Reason, Channel1) of
     case maybe_shutdown(Reason, Channel1) of
         {ok, Channel2} -> {ok, {event, disconnected}, Channel2};
         {ok, Channel2} -> {ok, {event, disconnected}, Channel2};
         Shutdown -> Shutdown
         Shutdown -> Shutdown
@@ -2044,9 +2066,9 @@ ensure_disconnected(
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Maybe Publish will msg
 %% Maybe Publish will msg
 
 
-mabye_publish_will_msg(Channel = #channel{will_msg = undefined}) ->
+maybe_publish_will_msg(Channel = #channel{will_msg = undefined}) ->
     Channel;
     Channel;
-mabye_publish_will_msg(Channel = #channel{will_msg = WillMsg}) ->
+maybe_publish_will_msg(Channel = #channel{will_msg = WillMsg}) ->
     case will_delay_interval(WillMsg) of
     case will_delay_interval(WillMsg) of
         0 ->
         0 ->
             ok = publish_will_msg(WillMsg),
             ok = publish_will_msg(WillMsg),
@@ -2056,10 +2078,13 @@ mabye_publish_will_msg(Channel = #channel{will_msg = WillMsg}) ->
     end.
     end.
 
 
 will_delay_interval(WillMsg) ->
 will_delay_interval(WillMsg) ->
-    maps:get('Will-Delay-Interval', emqx_message:get_header(properties, WillMsg), 0).
+    maps:get(
+        'Will-Delay-Interval',
+        emqx_message:get_header(properties, WillMsg, #{}),
+        0
+    ).
 
 
 publish_will_msg(Msg) ->
 publish_will_msg(Msg) ->
-    %% TODO check if we should discard result here
     _ = emqx_broker:publish(Msg),
     _ = emqx_broker:publish(Msg),
     ok.
     ok.
 
 
@@ -2070,8 +2095,7 @@ disconnect_reason(?RC_SUCCESS) -> normal;
 disconnect_reason(ReasonCode) -> emqx_reason_codes:name(ReasonCode).
 disconnect_reason(ReasonCode) -> emqx_reason_codes:name(ReasonCode).
 
 
 reason_code(takenover) -> ?RC_SESSION_TAKEN_OVER;
 reason_code(takenover) -> ?RC_SESSION_TAKEN_OVER;
-reason_code(discarded) -> ?RC_SESSION_TAKEN_OVER;
-reason_code(_) -> ?RC_NORMAL_DISCONNECTION.
+reason_code(discarded) -> ?RC_SESSION_TAKEN_OVER.
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Helper functions
 %% Helper functions

+ 34 - 1
apps/emqx/test/emqx_channel_SUITE.erl

@@ -907,7 +907,32 @@ t_handle_out_unexpected(_) ->
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 t_handle_call_kick(_) ->
 t_handle_call_kick(_) ->
-    {shutdown, kicked, ok, _Chan} = emqx_channel:handle_call(kick, channel()).
+    Channelv5 = channel(),
+    Channelv4 = v4(Channelv5),
+    {shutdown, kicked, ok, _} = emqx_channel:handle_call(kick, Channelv4),
+    {shutdown, kicked, ok, ?DISCONNECT_PACKET(?RC_ADMINISTRATIVE_ACTION), _} = emqx_channel:handle_call(
+        kick, Channelv5
+    ),
+
+    DisconnectedChannelv5 = channel(#{conn_state => disconnected}),
+    DisconnectedChannelv4 = v4(DisconnectedChannelv5),
+
+    {shutdown, kicked, ok, _} = emqx_channel:handle_call(kick, DisconnectedChannelv5),
+    {shutdown, kicked, ok, _} = emqx_channel:handle_call(kick, DisconnectedChannelv4).
+
+t_handle_kicked_publish_will_msg(_) ->
+    Self = self(),
+    ok = meck:expect(emqx_broker, publish, fun(M) -> Self ! {pub, M} end),
+
+    Msg = emqx_message:make(test, <<"will_topic">>, <<"will_payload">>),
+
+    {shutdown, kicked, ok, ?DISCONNECT_PACKET(?RC_ADMINISTRATIVE_ACTION), _} = emqx_channel:handle_call(
+        kick, channel(#{will_msg => Msg})
+    ),
+    receive
+        {pub, Msg} -> ok
+    after 200 -> exit(will_message_not_published)
+    end.
 
 
 t_handle_call_discard(_) ->
 t_handle_call_discard(_) ->
     Packet = ?DISCONNECT_PACKET(?RC_SESSION_TAKEN_OVER),
     Packet = ?DISCONNECT_PACKET(?RC_SESSION_TAKEN_OVER),
@@ -1243,3 +1268,11 @@ quota() ->
     emqx_limiter_container:get_limiter_by_names([message_routing], limiter_cfg()).
     emqx_limiter_container:get_limiter_by_names([message_routing], limiter_cfg()).
 
 
 limiter_cfg() -> #{message_routing => default}.
 limiter_cfg() -> #{message_routing => default}.
+
+v4(Channel) ->
+    ConnInfo = emqx_channel:info(conninfo, Channel),
+    emqx_channel:set_field(
+        conninfo,
+        maps:put(proto_ver, ?MQTT_PROTO_V4, ConnInfo),
+        Channel
+    ).