Просмотр исходного кода

fix(gw): add takeover_session/3 for cm_proto_v1

JianBo He 4 лет назад
Родитель
Сommit
63ef00a208

+ 8 - 8
apps/emqx_gateway/src/emqx_gateway_cm.erl

@@ -78,6 +78,7 @@
         , do_get_chan_stats/3
         , do_set_chan_stats/4
         , do_kick_session/4
+        , do_takeover_session/3
         , do_get_chann_conn_mod/3
         , do_call/4
         , do_call/5
@@ -301,7 +302,7 @@ open_session(GwName, true = _CleanStart, ClientInfo, ConnInfo, CreateSessionFun,
     Self = self(),
     ClientId = maps:get(clientid, ClientInfo),
     Fun = fun(_) ->
-                  ok = discard_session(GwName, ClientId),
+                  _ = discard_session(GwName, ClientId),
                   Session = create_session(GwName,
                                            ClientInfo,
                                            ConnInfo,
@@ -394,7 +395,7 @@ takeover_session(GwName, ClientId) ->
                             , chan_pids => ChanPids
                             }),
             lists:foreach(fun(StalePid) ->
-                                  catch discard_session(ClientId, StalePid)
+                                  catch discard_session(GwName, ClientId, StalePid)
                           end, StalePids),
             do_takeover_session(GwName, ClientId, ChanPid)
     end.
@@ -415,21 +416,20 @@ do_takeover_session(GwName, ClientId, ChanPid) ->
     wrap_rpc(emqx_gateway_cm_proto_v1:takeover_session(GwName, ClientId, ChanPid)).
 
 %% @doc Discard all the sessions identified by the ClientId.
--spec discard_session(GwName :: gateway_name(), binary()) -> ok.
+-spec discard_session(GwName :: gateway_name(), binary()) -> ok | {error, not_found}.
 discard_session(GwName, ClientId) when is_binary(ClientId) ->
     case lookup_channels(GwName, ClientId) of
-        [] -> ok;
+        [] -> {error, not_found};
         ChanPids -> lists:foreach(fun(Pid) -> discard_session(GwName, ClientId, Pid) end, ChanPids)
     end.
 
 discard_session(GwName, ClientId, ChanPid) ->
     kick_session(GwName, discard, ClientId, ChanPid).
 
--spec kick_session(gateway_name(), emqx_types:clientid()) -> ok.
-
+-spec kick_session(gateway_name(), emqx_types:clientid()) -> ok | {error, not_found}.
 kick_session(GwName, ClientId) ->
     case lookup_channels(GwName, ClientId) of
-        [] -> ok;
+        [] -> {error, not_found};
         ChanPids ->
             ChanPids > 1 andalso begin
                 ?SLOG(warning, #{ msg => "more_than_one_channel_found"
@@ -438,7 +438,7 @@ kick_session(GwName, ClientId) ->
                       #{clientid => ClientId})
             end,
             lists:foreach(fun(Pid) ->
-                kick_session(GwName, ClientId, Pid)
+                _ = kick_session(GwName, ClientId, Pid)
             end, ChanPids)
     end.
 

+ 7 - 0
apps/emqx_gateway/src/proto/emqx_gateway_cm_proto_v1.erl

@@ -27,6 +27,7 @@
         , kick_session/4
         , get_chann_conn_mod/3
         , lookup_by_clientid/3
+        , takeover_session/3
         , call/4
         , call/5
         , cast/4
@@ -81,6 +82,12 @@ get_chann_conn_mod(GwName, ClientId, ChanPid) ->
     rpc:call(node(ChanPid), emqx_gateway_cm, do_get_chann_conn_mod,
              [GwName, ClientId, ChanPid]).
 
+-spec takeover_session(emqx_gateway_cm:gateway_name(),
+                       emqx_types:clientid(),
+                       pid()) -> boolean() | {badrpc, _}.
+takeover_session(GwName, ClientId, ChanPid) ->
+    rpc:call(node(ChanPid), emqx_gateway_cm, do_takeover_session, [GwName, ClientId, ChanPid]).
+
 -spec call(emqx_gateway_cm:gateway_name(),
            emqx_types:clientid(),
            pid(),

+ 5 - 4
apps/emqx_gateway/test/emqx_gateway_cm_SUITE.erl

@@ -61,9 +61,10 @@ end_per_testcase(_TestCase, Conf) ->
 %%--------------------------------------------------------------------
 
 t_open_session(_) ->
-    {error, not_supported_now} = emqx_gateway_cm:open_session(
-                                 ?GWNAME, false, clientinfo(), conninfo(),
-                                 fun(_, _) -> #{} end),
+    {ok, #{present := false,
+           session := #{}}} = emqx_gateway_cm:open_session(
+                                ?GWNAME, false, clientinfo(), conninfo(),
+                                fun(_, _) -> #{} end),
 
     {ok, SessionRes} = emqx_gateway_cm:open_session(
                          ?GWNAME, true, clientinfo(), conninfo(),
@@ -189,7 +190,7 @@ t_kick_session(_) ->
 
     ok = emqx_gateway_cm:kick_session(?GWNAME, ?CLIENTID),
 
-    receive discard -> ok
+    receive kick -> ok
     after 100 -> ?assert(false, "waiting discard msg timeout")
     end,
     receive