Parcourir la source

fix(session): respect existing session even if expiry interval = 0

If the original connection had Session-Expiry-Interval > 0, and the
new connection set Session-Expiry-Interval = 0, the MQTTv5 spec says
that (supposedly) we still have to continue with the existing session
(if it hasn't expired yet).

Co-Authored-By: Thales Macedo Garitezi <thalesmg@gmail.com>
Andrew Mayorov il y a 2 ans
Parent
commit
a2ddd9d5f5

+ 2 - 2
apps/emqx/integration_test/emqx_ds_SUITE.erl

@@ -245,8 +245,8 @@ t_session_subscription_idempotency(Config) ->
             ?assertEqual([{ClientId, SubTopicFilterWords}], get_all_iterator_refs(Node1)),
             ?assertMatch({ok, [_]}, get_all_iterator_ids(Node1)),
             ?assertMatch(
-                {_IsNew = false, #{}, #{SubTopicFilterWords := #{}}},
-                erpc:call(Node1, emqx_ds, session_open, [ClientId, #{}])
+                {ok, #{}, #{SubTopicFilterWords := #{}}},
+                erpc:call(Node1, emqx_ds, session_open, [ClientId])
             )
         end
     ),

+ 27 - 21
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -24,7 +24,7 @@
 %% Session API
 -export([
     create/3,
-    open/3,
+    open/2,
     destroy/1
 ]).
 
@@ -98,12 +98,11 @@
     session().
 create(#{clientid := ClientID}, _ConnInfo, Conf) ->
     % TODO: expiration
-    {true, Session} = open_session(ClientID, Conf),
-    Session.
+    ensure_session(ClientID, Conf).
 
--spec open(clientinfo(), conninfo(), emqx_session:conf()) ->
-    {_IsPresent :: true, session(), []} | {_IsPresent :: false, session()}.
-open(#{clientid := ClientID}, _ConnInfo, Conf) ->
+-spec open(clientinfo(), conninfo()) ->
+    {_IsPresent :: true, session(), []} | false.
+open(#{clientid := ClientID}, _ConnInfo) ->
     %% NOTE
     %% The fact that we need to concern about discarding all live channels here
     %% is essentially a consequence of the in-memory session design, where we
@@ -111,24 +110,31 @@ open(#{clientid := ClientID}, _ConnInfo, Conf) ->
     %% somehow isolate those idling not-yet-expired sessions into a separate process
     %% space, and move this call back into `emqx_cm` where it belongs.
     ok = emqx_cm:discard_session(ClientID),
-    {IsNew, Session} = open_session(ClientID, Conf),
-    IsPresent = not IsNew,
-    case IsPresent of
-        true ->
-            {IsPresent, Session, []};
+    case open_session(ClientID) of
+        Session = #{} ->
+            {true, Session, []};
         false ->
-            {IsPresent, Session}
+            false
     end.
 
-open_session(ClientID, Conf) ->
-    {IsNew, Session, Iterators} = emqx_ds:session_open(ClientID, Conf),
-    {IsNew, Session#{
-        iterators => maps:fold(
-            fun(Topic, Iterator, Acc) -> Acc#{emqx_topic:join(Topic) => Iterator} end,
-            #{},
-            Iterators
-        )
-    }}.
+ensure_session(ClientID, Conf) ->
+    {ok, Session, #{}} = emqx_ds:session_ensure_new(ClientID, Conf),
+    Session#{iterators => #{}}.
+
+open_session(ClientID) ->
+    case emqx_ds:session_open(ClientID) of
+        {ok, Session, Iterators} ->
+            Session#{iterators => prep_iterators(Iterators)};
+        false ->
+            false
+    end.
+
+prep_iterators(Iterators) ->
+    maps:fold(
+        fun(Topic, Iterator, Acc) -> Acc#{emqx_topic:join(Topic) => Iterator} end,
+        #{},
+        Iterators
+    ).
 
 -spec destroy(session() | clientinfo()) -> ok.
 destroy(#{id := ClientID}) ->

+ 53 - 19
apps/emqx/src/emqx_session.erl

@@ -156,6 +156,15 @@
 
 -define(IMPL(S), (get_impl_mod(S))).
 
+%%--------------------------------------------------------------------
+%% Behaviour
+%% -------------------------------------------------------------------
+
+-callback create(clientinfo(), conninfo(), conf()) ->
+    t().
+-callback open(clientinfo(), conninfo()) ->
+    {_IsPresent :: true, t(), _ReplayContext} | false.
+
 %%--------------------------------------------------------------------
 %% Create a Session
 %%--------------------------------------------------------------------
@@ -167,7 +176,11 @@ create(ClientInfo, ConnInfo) ->
 
 create(ClientInfo, ConnInfo, Conf) ->
     % FIXME error conditions
-    Session = (choose_impl_mod(ConnInfo)):create(ClientInfo, ConnInfo, Conf),
+    create(choose_impl_mod(ConnInfo), ClientInfo, ConnInfo, Conf).
+
+create(Mod, ClientInfo, ConnInfo, Conf) ->
+    % FIXME error conditions
+    Session = Mod:create(ClientInfo, ConnInfo, Conf),
     ok = emqx_metrics:inc('session.created'),
     ok = emqx_hooks:run('session.created', [ClientInfo, info(Session)]),
     Session.
@@ -176,17 +189,29 @@ create(ClientInfo, ConnInfo, Conf) ->
     {_IsPresent :: true, t(), _ReplayContext} | {_IsPresent :: false, t()}.
 open(ClientInfo, ConnInfo) ->
     Conf = get_session_conf(ClientInfo, ConnInfo),
-    case (choose_impl_mod(ConnInfo)):open(ClientInfo, ConnInfo, Conf) of
-        {_IsPresent = true, Session, ReplayContext} ->
-            {true, Session, ReplayContext};
-        {_IsPresent = false, NewSession} ->
-            ok = emqx_metrics:inc('session.created'),
-            ok = emqx_hooks:run('session.created', [ClientInfo, info(NewSession)]),
-            {false, NewSession};
-        _IsPresent = false ->
-            {false, create(ClientInfo, ConnInfo, Conf)}
+    Mods = [Default | _] = choose_impl_candidates(ConnInfo),
+    %% NOTE
+    %% Try to look the existing session up in session stores corresponding to the given
+    %% `Mods` in order, starting from the last one.
+    case try_open(Mods, ClientInfo, ConnInfo) of
+        {_IsPresent = true, _, _} = Present ->
+            Present;
+        false ->
+            %% NOTE
+            %% Nothing was found, create a new session with the `Default` implementation.
+            {false, create(Default, ClientInfo, ConnInfo, Conf)}
     end.
 
+try_open([Mod | Rest], ClientInfo, ConnInfo) ->
+    case try_open(Rest, ClientInfo, ConnInfo) of
+        {_IsPresent = true, _, _} = Present ->
+            Present;
+        false ->
+            Mod:open(ClientInfo, ConnInfo)
+    end;
+try_open([], _ClientInfo, _ConnInfo) ->
+    false.
+
 -spec get_session_conf(clientinfo(), conninfo()) -> conf().
 get_session_conf(
     #{zone := Zone},
@@ -527,15 +552,24 @@ get_impl_mod(Session) when ?IS_SESSION_IMPL_DS(Session) ->
     emqx_persistent_session_ds.
 
 -spec choose_impl_mod(conninfo()) -> module().
-choose_impl_mod(#{expiry_interval := 0}) ->
-    emqx_session_mem;
-choose_impl_mod(#{expiry_interval := EI}) when EI > 0 ->
-    case emqx_persistent_message:is_store_enabled() of
-        true ->
-            emqx_persistent_session_ds;
-        false ->
-            emqx_session_mem
-    end.
+choose_impl_mod(#{expiry_interval := EI}) ->
+    hd(choose_impl_candidates(EI, emqx_persistent_message:is_store_enabled())).
+
+-spec choose_impl_candidates(conninfo()) -> [module()].
+choose_impl_candidates(#{expiry_interval := EI}) ->
+    choose_impl_candidates(EI, emqx_persistent_message:is_store_enabled()).
+
+choose_impl_candidates(_, _IsPSStoreEnabled = false) ->
+    [emqx_session_mem];
+choose_impl_candidates(0, _IsPSStoreEnabled = true) ->
+    %% NOTE
+    %% If ExpiryInterval is 0, the natural choice is `emqx_session_mem`. Yet we still
+    %% need to look the existing session up in the `emqx_persistent_session_ds` store
+    %% first, because previous connection may have set ExpiryInterval to a non-zero
+    %% value.
+    [emqx_session_mem, emqx_persistent_session_ds];
+choose_impl_candidates(EI, _IsPSStoreEnabled = true) when EI > 0 ->
+    [emqx_persistent_session_ds].
 
 -compile({inline, [run_hook/2]}).
 run_hook(Name, Args) ->

+ 3 - 3
apps/emqx/src/emqx_session_mem.erl

@@ -57,7 +57,7 @@
 
 -export([
     create/3,
-    open/3,
+    open/2,
     destroy/1
 ]).
 
@@ -193,9 +193,9 @@ destroy(_Session) ->
 %% Open a (possibly existing) Session
 %%--------------------------------------------------------------------
 
--spec open(clientinfo(), conninfo(), emqx_session:conf()) ->
+-spec open(clientinfo(), conninfo()) ->
     {_IsPresent :: true, session(), replayctx()} | _IsPresent :: false.
-open(ClientInfo = #{clientid := ClientId}, _ConnInfo, _Conf) ->
+open(ClientInfo = #{clientid := ClientId}, _ConnInfo) ->
     case emqx_cm:takeover_session_begin(ClientId) of
         {ok, SessionRemote, TakeoverState} ->
             Session = resume(ClientInfo, SessionRemote),

+ 22 - 3
apps/emqx/test/emqx_persistent_session_SUITE.erl

@@ -50,13 +50,14 @@ all() ->
 
 groups() ->
     TCs = emqx_common_test_helpers:all(?MODULE),
+    TCsNonGeneric = [t_choose_impl],
     [
         {persistent_store_disabled, [{group, no_kill_connection_process}]},
         {persistent_store_ds, [{group, no_kill_connection_process}]},
         {no_kill_connection_process, [], [{group, tcp}, {group, quic}, {group, ws}]},
         {tcp, [], TCs},
-        {quic, [], TCs},
-        {ws, [], TCs}
+        {quic, [], TCs -- TCsNonGeneric},
+        {ws, [], TCs -- TCsNonGeneric}
     ].
 
 init_per_group(persistent_store_disabled, Config) ->
@@ -276,6 +277,25 @@ do_publish(Payload, PublishFun, WaitForUnregister) ->
 %% Test Cases
 %%--------------------------------------------------------------------
 
+t_choose_impl(Config) ->
+    ClientId = ?config(client_id, Config),
+    ConnFun = ?config(conn_fun, Config),
+    {ok, Client} = emqtt:start_link([
+        {clientid, ClientId},
+        {proto_ver, v5},
+        {properties, #{'Session-Expiry-Interval' => 30}}
+        | Config
+    ]),
+    {ok, _} = emqtt:ConnFun(Client),
+    [ChanPid] = emqx_cm:lookup_channels(ClientId),
+    ?assertEqual(
+        case ?config(persistent_store, Config) of
+            false -> emqx_session_mem;
+            ds -> emqx_persistent_session_ds
+        end,
+        emqx_connection:info({channel, {session, impl}}, sys:get_state(ChanPid))
+    ).
+
 t_connect_discards_existing_client(Config) ->
     ClientId = ?config(client_id, Config),
     ConnFun = ?config(conn_fun, Config),
@@ -372,7 +392,6 @@ t_assigned_clientid_persistent_session(Config) ->
     {ok, Client2} = emqtt:start_link([
         {clientid, AssignedClientId},
         {proto_ver, v5},
-        {properties, #{'Session-Expiry-Interval' => 30}},
         {clean_start, false}
         | Config
     ]),

+ 22 - 10
apps/emqx_durable_storage/src/emqx_ds.erl

@@ -26,7 +26,8 @@
 -export([iterator_update/2, iterator_next/1, iterator_stats/0]).
 %%   Session:
 -export([
-    session_open/2,
+    session_open/1,
+    session_ensure_new/2,
     session_drop/1,
     session_suspend/1,
     session_add_iterator/3,
@@ -148,28 +149,36 @@ message_stats() ->
 %%--------------------------------------------------------------------------------
 
 %% @doc Called when a client connects. This function looks up a
-%% session or creates a new one if previous one couldn't be found.
+%% session or returns `false` if previous one couldn't be found.
 %%
 %% This function also spawns replay agents for each iterator.
 %%
 %% Note: session API doesn't handle session takeovers, it's the job of
 %% the broker.
--spec session_open(session_id(), _Props :: map()) ->
-    {_New :: boolean(), session(), iterators()}.
-session_open(SessionId, Props) ->
+-spec session_open(session_id()) ->
+    {ok, session(), iterators()} | false.
+session_open(SessionId) ->
     transaction(fun() ->
         case mnesia:read(?SESSION_TAB, SessionId, write) of
             [Record = #session{}] ->
                 Session = export_record(Record),
                 IteratorRefs = session_read_iterators(SessionId),
                 Iterators = export_iterators(IteratorRefs),
-                {false, Session, Iterators};
+                {ok, Session, Iterators};
             [] ->
-                Session = export_record(session_create(SessionId, Props)),
-                {true, Session, #{}}
+                false
         end
     end).
 
+-spec session_ensure_new(session_id(), _Props :: map()) ->
+    {ok, session(), iterators()}.
+session_ensure_new(SessionId, Props) ->
+    transaction(fun() ->
+        ok = session_drop_iterators(SessionId),
+        Session = export_record(session_create(SessionId, Props)),
+        {ok, Session, #{}}
+    end).
+
 session_create(SessionId, Props) ->
     Session = #session{
         id = SessionId,
@@ -186,11 +195,14 @@ session_create(SessionId, Props) ->
 session_drop(DSSessionId) ->
     transaction(fun() ->
         %% TODO: ensure all iterators from this clientid are closed?
-        IteratorRefs = session_read_iterators(DSSessionId),
-        ok = lists:foreach(fun session_del_iterator/1, IteratorRefs),
+        ok = session_drop_iterators(DSSessionId),
         ok = mnesia:delete(?SESSION_TAB, DSSessionId, write)
     end).
 
+session_drop_iterators(DSSessionId) ->
+    IteratorRefs = session_read_iterators(DSSessionId),
+    ok = lists:foreach(fun session_del_iterator/1, IteratorRefs).
+
 %% @doc Called when a client disconnects. This function terminates all
 %% active processes related to the session.
 -spec session_suspend(session_id()) -> ok | {error, session_not_found}.