Просмотр исходного кода

refactor(cm): avoid deep indirection in `emqx_session_mem`

Andrew Mayorov 2 лет назад
Родитель
Сommit
2dae8020ec
3 измененных файлов с 45 добавлено и 42 удалено
  1. 22 15
      apps/emqx/src/emqx_cm.erl
  2. 10 11
      apps/emqx/src/emqx_session_mem.erl
  3. 13 16
      apps/emqx/test/emqx_cm_SUITE.erl

+ 22 - 15
apps/emqx/src/emqx_cm.erl

@@ -52,7 +52,8 @@
     open_session/3,
     discard_session/1,
     discard_session/2,
-    takeover_channel_session/2,
+    takeover_session_begin/1,
+    takeover_session_end/1,
     kick_session/1,
     kick_session/2
 ]).
@@ -118,6 +119,8 @@
     _Stats :: emqx_types:stats()
 }.
 
+-type takeover_state() :: {_ConnMod :: module(), _ChanPid :: pid()}.
+
 -define(CHAN_STATS, [
     {?CHAN_TAB, 'channels.count', 'channels.max'},
     {?CHAN_TAB, 'sessions.count', 'sessions.max'},
@@ -289,28 +292,32 @@ create_register_session(ClientInfo = #{clientid := ClientId}, ConnInfo, ChanPid)
     {ok, #{session => Session, present => false}}.
 
 %% @doc Try to takeover a session from existing channel.
-%% Naming is wierd, because `takeover_session/2` is an RPC target and cannot be renamed.
--spec takeover_channel_session(emqx_types:clientid(), _TODO) ->
-    {ok, emqx_session:session(), _ReplayContext} | none | {error, _Reason}.
-takeover_channel_session(ClientId, OnTakeover) ->
-    takeover_channel_session(ClientId, pick_channel(ClientId), OnTakeover).
+-spec takeover_session_begin(emqx_types:clientid()) ->
+    {ok, emqx_session_mem:session(), takeover_state()} | none.
+takeover_session_begin(ClientId) ->
+    takeover_session_begin(ClientId, pick_channel(ClientId)).
 
-takeover_channel_session(ClientId, ChanPid, OnTakeover) when is_pid(ChanPid) ->
+takeover_session_begin(ClientId, ChanPid) when is_pid(ChanPid) ->
     case takeover_session(ClientId, ChanPid) of
         {living, ConnMod, Session} ->
-            Session1 = OnTakeover(Session),
-            case wrap_rpc(emqx_cm_proto_v2:takeover_finish(ConnMod, ChanPid)) of
-                {ok, Pendings} ->
-                    {ok, Session1, Pendings};
-                {error, _} = Error ->
-                    Error
-            end;
+            {ok, Session, {ConnMod, ChanPid}};
         none ->
             none
     end;
-takeover_channel_session(_ClientId, undefined, _OnTakeover) ->
+takeover_session_begin(_ClientId, undefined) ->
     none.
 
+%% @doc Conclude the session takeover process.
+-spec takeover_session_end(takeover_state()) ->
+    {ok, _ReplayContext} | {error, _Reason}.
+takeover_session_end({ConnMod, ChanPid}) ->
+    case wrap_rpc(emqx_cm_proto_v2:takeover_finish(ConnMod, ChanPid)) of
+        {ok, Pendings} ->
+            {ok, Pendings};
+        {error, _} = Error ->
+            Error
+    end.
+
 -spec pick_channel(emqx_types:clientid()) ->
     maybe(pid()).
 pick_channel(ClientId) ->

+ 10 - 11
apps/emqx/src/emqx_session_mem.erl

@@ -196,17 +196,16 @@ destroy(_Session) ->
 -spec open(clientinfo(), emqx_types:conninfo()) ->
     {true, session(), replayctx()} | false.
 open(ClientInfo = #{clientid := ClientId}, _ConnInfo) ->
-    case
-        emqx_cm:takeover_channel_session(
-            ClientId,
-            fun(Session) -> resume(ClientInfo, Session) end
-        )
-    of
-        {ok, Session, Pendings} ->
-            clean_session(ClientInfo, Session, Pendings);
-        {error, _} ->
-            % TODO log error?
-            false;
+    case emqx_cm:takeover_session_begin(ClientId) of
+        {ok, SessionRemote, TakeoverState} ->
+            Session = resume(ClientInfo, SessionRemote),
+            case emqx_cm:takeover_session_end(TakeoverState) of
+                {ok, Pendings} ->
+                    clean_session(ClientInfo, Session, Pendings);
+                {error, _} ->
+                    % TODO log error?
+                    false
+            end;
         none ->
             false
     end.

+ 13 - 16
apps/emqx/test/emqx_cm_SUITE.erl

@@ -321,7 +321,7 @@ test_stepdown_session(Action, Reason) ->
             discard ->
                 emqx_cm:discard_session(ClientId);
             {takeover, _} ->
-                none = emqx_cm:takeover_channel_session(ClientId, fun ident/1),
+                none = emqx_cm:takeover_session_begin(ClientId),
                 ok
         end,
     case Reason =:= timeout orelse Reason =:= noproc of
@@ -381,10 +381,11 @@ t_discard_session_race(_) ->
 
 t_takeover_session(_) ->
     #{conninfo := ConnInfo} = ?ChanInfo,
-    none = emqx_cm:takeover_channel_session(<<"clientid">>, fun ident/1),
+    ClientId = <<"clientid">>,
+    none = emqx_cm:takeover_session_begin(ClientId),
     Parent = self(),
-    erlang:spawn_link(fun() ->
-        ok = emqx_cm:register_channel(<<"clientid">>, self(), ConnInfo),
+    ChanPid = erlang:spawn_link(fun() ->
+        ok = emqx_cm:register_channel(ClientId, self(), ConnInfo),
         Parent ! registered,
         receive
             {'$gen_call', From1, {takeover, 'begin'}} ->
@@ -398,16 +399,17 @@ t_takeover_session(_) ->
     receive
         registered -> ok
     end,
-    {ok, test, []} = emqx_cm:takeover_channel_session(<<"clientid">>, fun ident/1),
-    emqx_cm:unregister_channel(<<"clientid">>).
+    {ok, test, State = {emqx_connection, ChanPid}} = emqx_cm:takeover_session_begin(ClientId),
+    {ok, []} = emqx_cm:takeover_session_end(State),
+    emqx_cm:unregister_channel(ClientId).
 
 t_takeover_session_process_gone(_) ->
     #{conninfo := ConnInfo} = ?ChanInfo,
     ClientIDTcp = <<"clientidTCP">>,
     ClientIDWs = <<"clientidWs">>,
     ClientIDRpc = <<"clientidRPC">>,
-    none = emqx_cm:takeover_channel_session(ClientIDTcp, fun ident/1),
-    none = emqx_cm:takeover_channel_session(ClientIDWs, fun ident/1),
+    none = emqx_cm:takeover_session_begin(ClientIDTcp),
+    none = emqx_cm:takeover_session_begin(ClientIDWs),
     meck:new(emqx_connection, [passthrough, no_history]),
     meck:expect(
         emqx_connection,
@@ -420,7 +422,7 @@ t_takeover_session_process_gone(_) ->
         end
     ),
     ok = emqx_cm:register_channel(ClientIDTcp, self(), ConnInfo),
-    none = emqx_cm:takeover_channel_session(ClientIDTcp, fun ident/1),
+    none = emqx_cm:takeover_session_begin(ClientIDTcp),
     meck:expect(
         emqx_connection,
         call,
@@ -432,7 +434,7 @@ t_takeover_session_process_gone(_) ->
         end
     ),
     ok = emqx_cm:register_channel(ClientIDWs, self(), ConnInfo),
-    none = emqx_cm:takeover_channel_session(ClientIDWs, fun ident/1),
+    none = emqx_cm:takeover_session_begin(ClientIDWs),
     meck:expect(
         emqx_connection,
         call,
@@ -444,7 +446,7 @@ t_takeover_session_process_gone(_) ->
         end
     ),
     ok = emqx_cm:register_channel(ClientIDRpc, self(), ConnInfo),
-    none = emqx_cm:takeover_channel_session(ClientIDRpc, fun ident/1),
+    none = emqx_cm:takeover_session_begin(ClientIDRpc),
     emqx_cm:unregister_channel(ClientIDTcp),
     emqx_cm:unregister_channel(ClientIDWs),
     emqx_cm:unregister_channel(ClientIDRpc),
@@ -463,8 +465,3 @@ t_message(_) ->
     ?CM ! testing,
     gen_server:cast(?CM, testing),
     gen_server:call(?CM, testing).
-
-%%
-
-ident(V) ->
-    V.