Przeglądaj źródła

Fix(connect): fix the race condition for openning session

- Remove the register_channel/1,2 functions
JianBo He 5 lat temu
rodzic
commit
3fb82f7234
2 zmienionych plików z 125 dodań i 50 usunięć
  1. 39 24
      src/emqx_cm.erl
  2. 86 26
      test/emqx_cm_SUITE.erl

+ 39 - 24
src/emqx_cm.erl

@@ -27,9 +27,7 @@
 
 -export([start_link/0]).
 
--export([ register_channel/1
-        , register_channel/2
-        , register_channel/3
+-export([ register_channel/3
         , unregister_channel/1
         ]).
 
@@ -45,6 +43,8 @@
         , set_chan_stats/2
         ]).
 
+-export([get_chann_conn_mod/2]).
+
 -export([ open_session/3
         , discard_session/1
         , discard_session/2
@@ -100,28 +100,29 @@ start_link() ->
 %% API
 %%--------------------------------------------------------------------
 
-%% @doc Register a channel.
--spec(register_channel(emqx_types:clientid()) -> ok).
-register_channel(ClientId) ->
-    register_channel(ClientId, self()).
-
-%% @doc Register a channel with pid.
--spec(register_channel(emqx_types:clientid(), chan_pid()) -> ok).
-register_channel(ClientId, ChanPid) when is_pid(ChanPid) ->
-    Chan = {ClientId, ChanPid},
-    true = ets:insert(?CHAN_TAB, Chan),
-    true = ets:insert(?CHAN_CONN_TAB, Chan),
-    ok = emqx_cm_registry:register_channel(Chan),
-    cast({registered, Chan}).
-
 %% @doc Register a channel with info and stats.
 -spec(register_channel(emqx_types:clientid(),
                        emqx_types:infos(),
                        emqx_types:stats()) -> ok).
-register_channel(ClientId, Info, Stats) ->
+register_channel(ClientId, Info = #{conninfo := ConnInfo}, Stats) ->
     Chan = {ClientId, ChanPid = self()},
     true = ets:insert(?CHAN_INFO_TAB, {Chan, Info, Stats}),
-    register_channel(ClientId, ChanPid).
+    register_channel(ClientId, ChanPid, ConnInfo);
+
+%% @private
+%% @doc Register a channel with pid and conn_mod.
+%%
+%% There is a Race-Condition on one node or cluster when many connections
+%% login to Broker with the same clientid. We should register it and save
+%% the conn_mod first for taking up the clientid access right.
+%%
+%% Note that: It should be called on a lock transaction
+register_channel(ClientId, ChanPid, #{conn_mod := ConnMod}) when is_pid(ChanPid) ->
+    Chan = {ClientId, ChanPid},
+    true = ets:insert(?CHAN_TAB, Chan),
+    true = ets:insert(?CHAN_CONN_TAB, {Chan, ConnMod}),
+    ok = emqx_cm_registry:register_channel(Chan),
+    cast({registered, Chan}).
 
 %% @doc Unregister a channel.
 -spec(unregister_channel(emqx_types:clientid()) -> ok).
@@ -132,7 +133,7 @@ unregister_channel(ClientId) when is_binary(ClientId) ->
 %% @private
 do_unregister_channel(Chan) ->
     ok = emqx_cm_registry:unregister_channel(Chan),
-    true = ets:delete_object(?CHAN_CONN_TAB, Chan),
+    true = ets:delete(?CHAN_CONN_TAB, Chan),
     true = ets:delete(?CHAN_INFO_TAB, Chan),
     ets:delete_object(?CHAN_TAB, Chan).
 
@@ -206,24 +207,29 @@ set_chan_stats(ClientId, ChanPid, Stats) ->
                 pendings => list()}}
        | {error, Reason :: term()}).
 open_session(true, ClientInfo = #{clientid := ClientId}, ConnInfo) ->
+    Self = self(),
     CleanStart = fun(_) ->
                      ok = discard_session(ClientId),
                      Session = create_session(ClientInfo, ConnInfo),
+                     register_channel(ClientId, Self, ConnInfo),
                      {ok, #{session => Session, present => false}}
                  end,
     emqx_cm_locker:trans(ClientId, CleanStart);
 
 open_session(false, ClientInfo = #{clientid := ClientId}, ConnInfo) ->
+    Self = self(),
     ResumeStart = fun(_) ->
                       case takeover_session(ClientId) of
                           {ok, ConnMod, ChanPid, Session} ->
                               ok = emqx_session:resume(ClientInfo, Session),
                               Pendings = ConnMod:call(ChanPid, {takeover, 'end'}),
+                              register_channel(ClientId, Self, ConnInfo),
                               {ok, #{session  => Session,
                                      present  => true,
                                      pendings => Pendings}};
                           {error, not_found} ->
                               Session = create_session(ClientInfo, ConnInfo),
+                              register_channel(ClientId, Self, ConnInfo),
                               {ok, #{session => Session, present => false}}
                       end
                   end,
@@ -253,8 +259,8 @@ takeover_session(ClientId) ->
     end.
 
 takeover_session(ClientId, ChanPid) when node(ChanPid) == node() ->
-    case get_chan_info(ClientId, ChanPid) of
-        #{conninfo := #{conn_mod := ConnMod}} ->
+    case get_chann_conn_mod(ClientId, ChanPid) of
+        ConnMod when is_atom(ConnMod) ->
             Session = ConnMod:call(ChanPid, {takeover, 'begin'}),
             {ok, ConnMod, ChanPid, Session};
         undefined ->
@@ -284,8 +290,8 @@ discard_session(ClientId) when is_binary(ClientId) ->
     end.
 
 discard_session(ClientId, ChanPid) when node(ChanPid) == node() ->
-    case get_chan_info(ClientId, ChanPid) of
-        #{conninfo := #{conn_mod := ConnMod}} ->
+    case get_chann_conn_mod(ClientId, ChanPid) of
+        ConnMod when is_atom(ConnMod) ->
             ConnMod:call(ChanPid, discard);
         undefined -> ok
     end;
@@ -418,3 +424,12 @@ update_stats({Tab, Stat, MaxStat}) ->
         Size -> emqx_stats:setstat(Stat, MaxStat, Size)
     end.
 
+get_chann_conn_mod(ClientId, ChanPid) when node(ChanPid) == node() ->
+    Chan = {ClientId, ChanPid},
+    try [ConnMod] = ets:lookup_element(?CHAN_CONN_TAB, Chan, 2), ConnMod
+    catch
+        error:badarg -> undefined
+    end;
+get_chann_conn_mod(ClientId, ChanPid) ->
+    rpc_call(node(ChanPid), get_chann_conn_mod, [ClientId, ChanPid]).
+

+ 86 - 26
test/emqx_cm_SUITE.erl

@@ -23,6 +23,13 @@
 -include_lib("eunit/include/eunit.hrl").
 
 -define(CM, emqx_cm).
+-define(ChanInfo,#{conninfo =>
+                   #{socktype => tcp,
+                     peername => {{127,0,0,1}, 5000},
+                     sockname => {{127,0,0,1}, 1883},
+                     peercert => nossl,
+                     conn_mod => emqx_connection,
+                     receive_maximum => 100}}).
 
 %%--------------------------------------------------------------------
 %% CT callbacks
@@ -43,13 +50,13 @@ end_per_suite(_Config) ->
 %%--------------------------------------------------------------------
 
 t_reg_unreg_channel(_) ->
-    ok = emqx_cm:register_channel(<<"clientid">>),
+    ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []),
     ?assertEqual([self()], emqx_cm:lookup_channels(<<"clientid">>)),
     ok = emqx_cm:unregister_channel(<<"clientid">>),
     ?assertEqual([], emqx_cm:lookup_channels(<<"clientid">>)).
 
 t_get_set_chan_info(_) ->
-    Info = #{proto_ver => 4, proto_name => <<"MQTT">>},
+    Info = ?ChanInfo,
     ok = emqx_cm:register_channel(<<"clientid">>, Info, []),
     ?assertEqual(Info, emqx_cm:get_chan_info(<<"clientid">>)),
     Info1 = Info#{proto_ver => 5},
@@ -60,7 +67,7 @@ t_get_set_chan_info(_) ->
 
 t_get_set_chan_stats(_) ->
     Stats = [{recv_oct, 10}, {send_oct, 8}],
-    ok = emqx_cm:register_channel(<<"clientid">>, #{}, Stats),
+    ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, Stats),
     ?assertEqual(Stats, emqx_cm:get_chan_stats(<<"clientid">>)),
     Stats1 = [{recv_oct, 10}|Stats],
     true = emqx_cm:set_chan_stats(<<"clientid">>, Stats1),
@@ -69,27 +76,89 @@ t_get_set_chan_stats(_) ->
     ?assertEqual(undefined, emqx_cm:get_chan_stats(<<"clientid">>)).
 
 t_open_session(_) ->
+    ok = meck:new(emqx_connection, [passthrough, no_history]),
+    ok = meck:expect(emqx_connection, call, fun(_, _) -> ok end),
+
     ClientInfo = #{zone => external,
                    clientid => <<"clientid">>,
                    username => <<"username">>,
                    peerhost => {127,0,0,1}},
-    ConnInfo = #{peername => {{127,0,0,1}, 5000},
+    ConnInfo = #{socktype => tcp,
+                 peername => {{127,0,0,1}, 5000},
+                 sockname => {{127,0,0,1}, 1883},
+                 peercert => nossl,
+                 conn_mod => emqx_connection,
                  receive_maximum => 100},
     {ok, #{session := Session1, present := false}}
         = emqx_cm:open_session(true, ClientInfo, ConnInfo),
     ?assertEqual(100, emqx_session:info(inflight_max, Session1)),
     {ok, #{session := Session2, present := false}}
-        = emqx_cm:open_session(false, ClientInfo, ConnInfo),
-    ?assertEqual(100, emqx_session:info(inflight_max, Session2)).
+        = emqx_cm:open_session(true, ClientInfo, ConnInfo),
+    ?assertEqual(100, emqx_session:info(inflight_max, Session2)),
+
+    emqx_cm:unregister_channel(<<"clientid">>),
+    ok = meck:unload(emqx_connection).
+
+t_open_session_race_condition(_) ->
+    ClientInfo = #{zone => external,
+                   clientid => <<"clientid">>,
+                   username => <<"username">>,
+                   peerhost => {127,0,0,1}},
+    ConnInfo = #{socktype => tcp,
+                 peername => {{127,0,0,1}, 5000},
+                 sockname => {{127,0,0,1}, 1883},
+                 peercert => nossl,
+                 conn_mod => emqx_connection,
+                 receive_maximum => 100},
+
+    Parent = self(),
+    OpenASession = fun() ->
+                     timer:sleep(rand:uniform(100)),
+                     OpenR = (emqx_cm:open_session(true, ClientInfo, ConnInfo)),
+                     Parent ! OpenR,
+                     case OpenR of
+                         {ok, _} ->
+                             receive
+                                 {'$gen_call', From, discard} ->
+                                     gen_server:reply(From, ok), ok
+                             end;
+                         {error, Reason} ->
+                             exit(Reason)
+                     end
+                   end,
+    [spawn(
+      fun() ->
+              spawn(OpenASession),
+              spawn(OpenASession)
+      end) || _ <- lists:seq(1, 1000)],
+
+    WaitingRecv = fun _Wr(N1, N2, 0) ->
+                          {N1, N2};
+                      _Wr(N1, N2, Rest) ->
+                          receive
+                              {ok, _} -> _Wr(N1+1, N2, Rest-1);
+                              {error, _} -> _Wr(N1, N2+1, Rest-1)
+                          end
+                  end,
+
+    ct:pal("Race condition status: ~p~n", [WaitingRecv(0, 0, 2000)]),
+
+    ?assertEqual(1, ets:info(emqx_channel, size)),
+    ?assertEqual(1, ets:info(emqx_channel_conn, size)),
+    ?assertEqual(1, ets:info(emqx_channel_registry, size)),
+
+    [Pid] = emqx_cm:lookup_channels(<<"clientid">>),
+    exit(Pid, kill), timer:sleep(100),
+    ?assertEqual([], emqx_cm:lookup_channels(<<"clientid">>)).
 
 t_discard_session(_) ->
     ok = meck:new(emqx_connection, [passthrough, no_history]),
     ok = meck:expect(emqx_connection, call, fun(_, _) -> ok end),
     ok = emqx_cm:discard_session(<<"clientid">>),
-    ok = emqx_cm:register_channel(<<"clientid">>),
+    ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []),
     ok = emqx_cm:discard_session(<<"clientid">>),
     ok = emqx_cm:unregister_channel(<<"clientid">>),
-    ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []),
+    ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []),
     ok = emqx_cm:discard_session(<<"clientid">>),
     ok = meck:expect(emqx_connection, call, fun(_, _) -> error(testing) end),
     ok = emqx_cm:discard_session(<<"clientid">>),
@@ -97,35 +166,26 @@ t_discard_session(_) ->
     ok = meck:unload(emqx_connection).
 
 t_takeover_session(_) ->
-    ok = meck:new(emqx_connection, [passthrough, no_history]),
-    ok = meck:expect(emqx_connection, call, fun(_, _) -> test end),
-    {error, not_found} = emqx_cm:takeover_session(<<"clientid">>),
-    ok = emqx_cm:register_channel(<<"clientid">>),
     {error, not_found} = emqx_cm:takeover_session(<<"clientid">>),
-    ok = emqx_cm:unregister_channel(<<"clientid">>),
-    ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []),
-    Pid = self(),
-    {ok, emqx_connection, Pid, test} = emqx_cm:takeover_session(<<"clientid">>),
     erlang:spawn(fun() ->
-                     ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []),
-                     timer:sleep(1000)
+                     ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []),
+                     receive
+                         {'$gen_call', From, {takeover, 'begin'}} ->
+                             gen_server:reply(From, test), ok
+                     end
                  end),
-    ct:sleep(100),
+    timer:sleep(100),
     {ok, emqx_connection, _, test} = emqx_cm:takeover_session(<<"clientid">>),
-    ok = emqx_cm:unregister_channel(<<"clientid">>),
-    ok = meck:unload(emqx_connection).
+    emqx_cm:unregister_channel(<<"clientid">>).
 
 t_kick_session(_) ->
     ok = meck:new(emqx_connection, [passthrough, no_history]),
     ok = meck:expect(emqx_connection, call, fun(_, _) -> test end),
     {error, not_found} = emqx_cm:kick_session(<<"clientid">>),
-    ok = emqx_cm:register_channel(<<"clientid">>),
-    {error, not_found} = emqx_cm:kick_session(<<"clientid">>),
-    ok = emqx_cm:unregister_channel(<<"clientid">>),
-    ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []),
+    ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []),
     test = emqx_cm:kick_session(<<"clientid">>),
     erlang:spawn(fun() ->
-                     ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []),
+                     ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []),
                      timer:sleep(1000)
                  end),
     ct:sleep(100),