Преглед изворни кода

fix(persistent_session): Make sure to discard expired sessions on reconnect

Tobias Lindahl пре 4 година
родитељ
комит
fd71bc50ab

+ 14 - 13
apps/emqx/src/emqx_cm.erl

@@ -242,7 +242,13 @@ open_session(false, ClientInfo = #{clientid := ClientId}, ConnInfo) ->
                               {ok, #{session  => Session1,
                                      present  => true,
                                      pendings => Pendings}};
-                          {error, not_found} ->
+                          {expired, OldSession} ->
+                              _ = emqx_persistent_session:discard(ClientId, OldSession),
+                              Session = create_session(ClientInfo, ConnInfo),
+                              Session1 = emqx_persistent_session:persist(ClientInfo, ConnInfo, Session),
+                              register_channel(ClientId, Self, ConnInfo),
+                              {ok, #{session => Session1, present => false}};
+                          none ->
                               Session = create_session(ClientInfo, ConnInfo),
                               Session1 = emqx_persistent_session:persist(ClientInfo, ConnInfo, Session),
                               register_channel(ClientId, Self, ConnInfo),
@@ -282,17 +288,15 @@ get_mqtt_conf(Zone, Key) ->
     emqx_config:get_zone_conf(Zone, [mqtt, Key]).
 
 %% @doc Try to takeover a session.
--spec(takeover_session(emqx_types:clientid())
-      -> {error, term()}
-       | {living, atom(), pid(), emqx_session:session()}
-       | {persistent, emqx_session:session()}).
+-spec takeover_session(emqx_types:clientid()) ->
+          none
+        | {living, atom(), pid(), emqx_session:session()}
+        | {persistent, emqx_session:session()}
+        | {expired, emqx_session:session()}.
 takeover_session(ClientId) ->
     case lookup_channels(ClientId) of
         [] ->
-            case emqx_persistent_session:lookup(ClientId) of
-                [] -> {error, not_found};
-                [Session] -> {persistent, Session}
-            end;
+            emqx_persistent_session:lookup(ClientId);
         [ChanPid] ->
             takeover_session(ClientId, ChanPid);
         ChanPids ->
@@ -307,10 +311,7 @@ takeover_session(ClientId) ->
 takeover_session(ClientId, ChanPid) when node(ChanPid) == node() ->
     case get_chann_conn_mod(ClientId, ChanPid) of
         undefined ->
-            case emqx_persistent_session:lookup(ClientId) of
-                [] -> {error, not_found};
-                [Session] -> {persistent, Session}
-            end;
+            emqx_persistent_session:lookup(ClientId);
         ConnMod when is_atom(ConnMod) ->
             Session = ConnMod:call(ChanPid, {takeover, 'begin'}, ?T_TAKEOVER),
             {living, ConnMod, ChanPid, Session}

+ 6 - 7
apps/emqx/src/emqx_persistent_session.erl

@@ -181,20 +181,19 @@ timestamp_from_conninfo(ConnInfo) ->
 
 lookup(ClientID) when is_binary(ClientID) ->
     case lookup_session_store(ClientID) of
-        none -> [];
+        none -> none;
         {value, #session_store{session = S} = SS} ->
             case persistent_session_status(SS) of
-                not_persistent -> []; %% For completeness. Should not happen
-                expired        -> [];
-                persistent     -> [S]
+                expired        -> {expired, S};
+                persistent     -> {persistent, S}
             end
     end.
 
 -spec discard_if_present(binary()) -> 'ok'.
 discard_if_present(ClientID) ->
     case lookup(ClientID) of
-        [] -> ok;
-        [Session] ->
+        none -> ok;
+        {Tag, Session} when Tag =:= persistent; Tag =:= expired ->
             _ = discard(ClientID, Session),
             ok
     end.
@@ -354,7 +353,7 @@ do_mark_as_delivered(_SessionID, []) ->
 -spec pending(emqx_session:sessionID()) ->
           [{emqx_types:message(), STopic :: binary()}].
 pending(SessionID) ->
-    pending(SessionID, []).
+    pending_messages_in_db(SessionID, []).
 
 -spec pending(emqx_session:sessionID(), MarkerIDs :: [emqx_guid:guid()]) ->
           [{emqx_types:message(), STopic :: binary()}].

+ 1 - 1
apps/emqx/test/emqx_cm_SUITE.erl

@@ -221,7 +221,7 @@ t_discard_session_race(_) ->
 
 t_takeover_session(_) ->
     #{conninfo := ConnInfo} = ?ChanInfo,
-    {error, not_found} = emqx_cm:takeover_session(<<"clientid">>),
+    none = emqx_cm:takeover_session(<<"clientid">>),
     erlang:spawn_link(fun() ->
         ok = emqx_cm:register_channel(<<"clientid">>, self(), ConnInfo),
         receive

+ 65 - 4
apps/emqx/test/emqx_persistent_session_SUITE.erl

@@ -81,7 +81,8 @@ init_per_group(persistent_store_enabled, Config) ->
                                    (Other) -> meck:passthrough([Other])
                                 end),
     emqx_common_test_helpers:start_apps([], fun set_special_confs/1),
-    Config;
+    ?assertEqual(true, emqx_persistent_session:is_store_enabled()),
+    [{persistent_store_enabled, true}|Config];
 init_per_group(persistent_store_disabled, Config) ->
     %% Start Apps
     emqx_common_test_helpers:boot_modules(all),
@@ -90,7 +91,8 @@ init_per_group(persistent_store_disabled, Config) ->
                                    (Other) -> meck:passthrough([Other])
                                 end),
     emqx_common_test_helpers:start_apps([], fun set_special_confs/1),
-    Config;
+    ?assertEqual(false, emqx_persistent_session:is_store_enabled()),
+    [{persistent_store_enabled, false}|Config];
 init_per_group(Group, Config) when Group == tcp; Group == tcp_snabbkaffe ->
     [ {port, 1883}, {conn_fun, connect}| Config];
 init_per_group(Group, Config) when Group == quic; Group == quic_snabbkaffe ->
@@ -382,30 +384,89 @@ t_persist_on_disconnect(Config) ->
     ?assertEqual(0, client_info(session_present, Client2)),
     ok = emqtt:disconnect(Client2).
 
+wait_for_pending(SId) ->
+    wait_for_pending(SId, 100).
+
+wait_for_pending(_SId, 0) ->
+    error(exhausted_wait_for_pending);
+wait_for_pending(SId, N) ->
+    case emqx_persistent_session:pending(SId) of
+        [] -> timer:sleep(1), wait_for_pending(SId, N - 1);
+        [_|_] = Pending -> Pending
+    end.
+
 t_process_dies_session_expires(Config) ->
     %% Emulate an error in the connect process,
     %% or that the node of the process goes down.
     %% A persistent session should eventually expire.
     ConnFun = ?config(conn_fun, Config),
     ClientId = ?config(client_id, Config),
+    Topic = ?config(topic, Config),
+    STopic = ?config(stopic, Config),
+    Payload = <<"test">>,
     {ok, Client1} = emqtt:start_link([ {proto_ver, v5},
                                        {clientid, ClientId},
                                        {properties, #{'Session-Expiry-Interval' => 1}},
                                        {clean_start, true}
                                      | Config]),
     {ok, _} = emqtt:ConnFun(Client1),
+    {ok, _, [2]} = emqtt:subscribe(Client1, STopic, qos2),
     ok = emqtt:disconnect(Client1),
 
     maybe_kill_connection_process(ClientId, Config),
 
+    ok = publish(Topic, [Payload], Config),
+
+    SessionId =
+        case ?config(persistent_store_enabled, Config) of
+            false -> undefined;
+            true ->
+                %% The session should not be marked as expired.
+                {Tag, Session} = emqx_persistent_session:lookup(ClientId),
+                ?assertEqual(persistent, Tag),
+                SId = emqx_session:info(id, Session),
+                case ?config(kill_connection_process, Config) of
+                    true ->
+                        %% The session should have a pending message
+                        ?assertMatch([_], wait_for_pending(SId));
+                    false ->
+                        skip
+                end,
+                SId
+        end,
+
     timer:sleep(1100),
 
+    %% The session should now be marked as expired.
+    case (?config(kill_connection_process, Config) andalso
+          ?config(persistent_store_enabled, Config)) of
+        true  -> ?assertMatch({expired, _}, emqx_persistent_session:lookup(ClientId));
+        false -> skip
+    end,
+
     {ok, Client2} = emqtt:start_link([ {proto_ver, v5},
                                        {clientid, ClientId},
+                                       {properties, #{'Session-Expiry-Interval' => 30}},
                                        {clean_start, false}
                                      | Config]),
     {ok, _} = emqtt:ConnFun(Client2),
     ?assertEqual(0, client_info(session_present, Client2)),
+
+    case (?config(kill_connection_process, Config) andalso
+          ?config(persistent_store_enabled, Config)) of
+        true ->
+            %% The session should be a fresh one
+            {persistent, NewSession} = emqx_persistent_session:lookup(ClientId),
+            ?assertNotEqual(SessionId, emqx_session:info(id, NewSession)),
+            %% The old session should now either be marked as abandoned or already be garbage collected.
+            ?assertMatch([], emqx_persistent_session:pending(SessionId));
+        false ->
+            skip
+    end,
+
+    %% We should not receive the pending message
+    ?assertEqual([], receive_messages(1)),
+
     emqtt:disconnect(Client2).
 
 t_publish_while_client_is_gone(Config) ->
@@ -520,7 +581,7 @@ t_unsubscribe(Config) ->
     {ok, _, [2]} = emqtt:subscribe(Client, STopic, qos2),
     case emqx_persistent_session:is_store_enabled() of
         true ->
-            [Session] = emqx_persistent_session:lookup(ClientId),
+            {persistent, Session} = emqx_persistent_session:lookup(ClientId),
             SessionID = emqx_session:info(id, Session),
             SessionIDs = [SId || #route{dest = SId} <- emqx_session_router:match_routes(Topic)],
             ?assert(lists:member(SessionID, SessionIDs)),
@@ -582,7 +643,7 @@ t_lost_messages_because_of_gc(init, Config) ->
             OldRetain = emqx_config:get(?msg_retain, Retain),
             emqx_config:put(?msg_retain, Retain),
             [{retain, Retain}, {old_retain, OldRetain}|Config];
-        false -> {skip, only_relevant_with_store}
+        false -> {skip, only_relevant_with_store_and_kill_process}
     end;
 t_lost_messages_because_of_gc('end', Config) ->
     OldRetain = ?config(old_retain, Config),