Sfoglia il codice sorgente

Merge pull request #14138 from savonarola/1031-optimize-session-heartbeat

perf(dssess): avoid periodical last_alive_at bumping in ds sessions
Ilia Averianov 1 anno fa
parent
commit
16ae8a9e4a

+ 1 - 0
apps/emqx/include/emqx_durable_session_metadata.hrl

@@ -22,6 +22,7 @@
 %% Session metadata keys:
 -define(created_at, created_at).
 -define(last_alive_at, last_alive_at).
+-define(node_epoch_id, node_epoch_id).
 -define(expiry_interval, expiry_interval).
 %% Unique integer used to create unique identities:
 -define(last_id, last_id).

+ 2 - 2
apps/emqx/src/emqx_cm_sup.erl

@@ -52,7 +52,7 @@ init([]) ->
     Registry = child_spec(emqx_cm_registry, 5000, worker),
     RegistryKeeper = child_spec(emqx_cm_registry_keeper, 5000, worker),
     Manager = child_spec(emqx_cm, 5000, worker),
-    DSSessionGCSup = child_spec(emqx_persistent_session_ds_sup, infinity, supervisor),
+    DSSessionSup = child_spec(emqx_persistent_session_ds_sup, infinity, supervisor),
     DSSessionBookkeeper = child_spec(emqx_persistent_session_bookkeeper, 5_000, worker),
     Children =
         [
@@ -63,7 +63,7 @@ init([]) ->
             Registry,
             RegistryKeeper,
             Manager,
-            DSSessionGCSup,
+            DSSessionSup,
             DSSessionBookkeeper
         ],
     {ok, {SupFlags, Children}}.

+ 79 - 39
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -199,14 +199,14 @@
 
 -define(TIMER_PULL, timer_pull).
 -define(TIMER_PUSH, timer_push).
--define(TIMER_BUMP_LAST_ALIVE_AT, timer_bump_last_alive_at).
+-define(TIMER_COMMIT, timer_commit).
 -define(TIMER_RETRY_REPLAY, timer_retry_replay).
 -define(TIMER_SHARED_SUB, timer_shared_sub).
 
 -type timer() ::
     ?TIMER_PULL
     | ?TIMER_PUSH
-    | ?TIMER_BUMP_LAST_ALIVE_AT
+    | ?TIMER_COMMIT
     | ?TIMER_RETRY_REPLAY
     | ?TIMER_SHARED_SUB.
 
@@ -234,7 +234,7 @@
     %% Timers:
     ?TIMER_PULL := timer_state(),
     ?TIMER_PUSH := timer_state(),
-    ?TIMER_BUMP_LAST_ALIVE_AT := timer_state(),
+    ?TIMER_COMMIT := timer_state(),
     ?TIMER_RETRY_REPLAY := timer_state(),
     ?TIMER_SHARED_SUB := timer_state()
 }.
@@ -526,7 +526,7 @@ publish(
                 undefined ->
                     Results = emqx_broker:publish(Msg),
                     S = emqx_persistent_session_ds_state:put_awaiting_rel(PacketId, Ts, S0),
-                    {ok, Results, Session#{s := S}};
+                    {ok, Results, ensure_state_commit_timer(Session#{s := S})};
                 _Ts ->
                     {error, ?RC_PACKET_IDENTIFIER_IN_USE}
             end;
@@ -544,7 +544,7 @@ is_awaiting_full(#{s := S, props := Props}) ->
 -spec expire(emqx_types:clientinfo(), session()) ->
     {ok, [], timeout(), session()} | {ok, [], session()}.
 expire(ClientInfo, Session0 = #{props := Props}) ->
-    Session = #{s := S} = do_expire(ClientInfo, Session0),
+    Session = #{s := S} = ensure_state_commit_timer(do_expire(ClientInfo, Session0)),
     case emqx_persistent_session_ds_state:n_awaiting_rel(S) of
         0 ->
             {ok, [], Session};
@@ -606,7 +606,7 @@ puback(_ClientInfo, PacketId, Session0) ->
 pubrec(PacketId, Session0) ->
     case update_seqno(pubrec, PacketId, Session0) of
         {ok, Msg, Session} ->
-            {ok, Msg, Session};
+            {ok, Msg, ensure_state_commit_timer(Session)};
         Error = {error, _} ->
             Error
     end.
@@ -623,7 +623,7 @@ pubrel(PacketId, Session = #{s := S0}) ->
             {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND};
         _TS ->
             S = emqx_persistent_session_ds_state:del_awaiting_rel(PacketId, S0),
-            {ok, Session#{s := S}}
+            {ok, ensure_state_commit_timer(Session#{s := S})}
     end.
 
 %%--------------------------------------------------------------------
@@ -681,18 +681,12 @@ handle_timeout(ClientInfo, ?TIMER_PUSH, Session0) ->
             Timeout = get_config(ClientInfo, [idle_poll_interval]),
             PollOpts = #{timeout => Timeout},
             SchedS = emqx_persistent_session_ds_stream_scheduler:poll(PollOpts, SchedS0, S),
-            {ok, [], Session1#{stream_scheduler_s := SchedS}}
+            {ok, [], ensure_state_commit_timer(Session1#{stream_scheduler_s := SchedS})}
     end;
 handle_timeout(ClientInfo, ?TIMER_RETRY_REPLAY, Session0) ->
     Session = replay_streams(Session0, ClientInfo),
-    {ok, [], Session};
-handle_timeout(_ClientInfo, ?TIMER_BUMP_LAST_ALIVE_AT, Session0 = #{s := S0}) ->
-    S = bump_last_alive(S0),
-    Session = set_timer(
-        ?TIMER_BUMP_LAST_ALIVE_AT,
-        bump_interval(),
-        Session0#{s := S}
-    ),
+    {ok, [], ensure_state_commit_timer(Session)};
+handle_timeout(_ClientInfo, ?TIMER_COMMIT, Session) ->
     {ok, [], commit(Session)};
 handle_timeout(_ClientInfo, #req_sync{from = From, ref = Ref}, Session0) ->
     Session = commit(Session0),
@@ -728,7 +722,7 @@ handle_info(
     ?shared_sub_message(Msg), Session = #{s := S0, shared_sub_s := SharedSubS0}, _ClientInfo
 ) ->
     {S, SharedSubS} = emqx_persistent_session_ds_shared_subs:on_info(S0, SharedSubS0, Msg),
-    Session#{s := S, shared_sub_s := SharedSubS};
+    ensure_state_commit_timer(Session#{s := S, shared_sub_s := SharedSubS});
 handle_info(AsyncReply = #poll_reply{}, Session, ClientInfo) ->
     push_now(handle_ds_reply(AsyncReply, Session, ClientInfo));
 handle_info(#new_stream_event{subref = Ref}, Session, _ClientInfo) ->
@@ -751,14 +745,6 @@ handle_info(Msg, Session, _ClientInfo) ->
 shared_sub_opts(SessionId) ->
     #{session_id => SessionId}.
 
-bump_last_alive(S0) ->
-    %% Note: we take a pessimistic approach here and assume that the client will be alive
-    %% until the next bump timeout.  With this, we avoid garbage collecting this session
-    %% too early in case the session/connection/node crashes earlier without having time
-    %% to commit the time.
-    EstimatedLastAliveAt = now_ms() + bump_interval(),
-    emqx_persistent_session_ds_state:set_last_alive_at(EstimatedLastAliveAt, S0).
-
 -spec replay(clientinfo(), [], session()) ->
     {ok, replies(), session()}.
 replay(ClientInfo, [], Session0 = #{s := S0}) ->
@@ -908,9 +894,10 @@ disconnect(Session = #{id := Id, s := S0, shared_sub_s := SharedSubS0}, ConnInfo
     {shutdown, commit(Session#{s := S, shared_sub_s := SharedSubS})}.
 
 -spec terminate(Reason :: term(), session()) -> ok.
-terminate(_Reason, Session = #{id := Id}) ->
-    maybe_set_will_message_timer(Session),
-    _ = commit(Session),
+terminate(_Reason, Session = #{s := S0, id := Id}) ->
+    _ = maybe_set_will_message_timer(Session),
+    S = finalize_last_alive_at(S0),
+    _ = commit(Session#{s := S}),
     ?tp(debug, persistent_session_ds_terminate, #{id => Id}),
     ok.
 
@@ -1004,7 +991,7 @@ session_open(
     case emqx_persistent_session_ds_state:open(SessionId) of
         {ok, S0} ->
             EI = emqx_persistent_session_ds_state:get_expiry_interval(S0),
-            LastAliveAt = emqx_persistent_session_ds_state:get_last_alive_at(S0),
+            LastAliveAt = get_last_alive_at(S0),
             case NowMS >= LastAliveAt + EI of
                 true ->
                     session_drop(SessionId, expired),
@@ -1013,7 +1000,7 @@ session_open(
                     ?tp(open_session, #{ei => EI, now => NowMS, laa => LastAliveAt}),
                     %% New connection being established
                     S1 = emqx_persistent_session_ds_state:set_expiry_interval(EI, S0),
-                    S2 = emqx_persistent_session_ds_state:set_last_alive_at(NowMS, S1),
+                    S2 = init_last_alive_at(NowMS, S1),
                     S3 = emqx_persistent_session_ds_state:set_peername(
                         maps:get(peername, NewConnInfo), S2
                     ),
@@ -1040,7 +1027,7 @@ session_open(
                             new_stream_subs => NewStreamSubs,
                             ?TIMER_PULL => undefined,
                             ?TIMER_PUSH => undefined,
-                            ?TIMER_BUMP_LAST_ALIVE_AT => undefined,
+                            ?TIMER_COMMIT => undefined,
                             ?TIMER_RETRY_REPLAY => undefined,
                             ?TIMER_SHARED_SUB => undefined
                         }
@@ -1065,7 +1052,7 @@ session_ensure_new(
     Now = now_ms(),
     S0 = emqx_persistent_session_ds_state:create_new(Id),
     S1 = emqx_persistent_session_ds_state:set_expiry_interval(expiry_interval(ConnInfo), S0),
-    S2 = bump_last_alive(S1),
+    S2 = init_last_alive_at(S1),
     S3 = emqx_persistent_session_ds_state:set_created_at(Now, S2),
     S4 = lists:foldl(
         fun(Track, Acc) ->
@@ -1097,7 +1084,7 @@ session_ensure_new(
         replay => undefined,
         ?TIMER_PULL => undefined,
         ?TIMER_PUSH => undefined,
-        ?TIMER_BUMP_LAST_ALIVE_AT => undefined,
+        ?TIMER_COMMIT => undefined,
         ?TIMER_RETRY_REPLAY => undefined,
         ?TIMER_SHARED_SUB => undefined
     }.
@@ -1147,8 +1134,9 @@ do_ensure_all_iterators_closed(_DSSessionID) ->
 %% Normal replay:
 %%--------------------------------------------------------------------
 
-push_now(Session) ->
-    ensure_timer(?TIMER_PUSH, 0, Session).
+push_now(Session0) ->
+    Session1 = ensure_timer(?TIMER_PUSH, 0, Session0),
+    ensure_state_commit_timer(Session1).
 
 handle_ds_reply(AsyncReply, Session0 = #{s := S0, stream_scheduler_s := SchedS0}, ClientInfo) ->
     case emqx_persistent_session_ds_stream_scheduler:on_ds_reply(AsyncReply, S0, SchedS0) of
@@ -1434,7 +1422,7 @@ do_drain_buffer(Inflight0, S0, Acc) ->
 -spec post_init(session()) -> session().
 post_init(Session0) ->
     Session1 = renew_streams(all, Session0),
-    Session = set_timer(?TIMER_BUMP_LAST_ALIVE_AT, 100, Session1),
+    Session = set_timer(?TIMER_COMMIT, 100, Session1),
     maybe_set_shared_sub_timer(Session).
 
 %% This function triggers sending buffered packets to the client
@@ -1446,8 +1434,9 @@ post_init(Session0) ->
 %%
 %% - When the client releases a packet ID (via PUBACK or PUBCOMP)
 -spec pull_now(session()) -> session().
-pull_now(Session) ->
-    ensure_timer(?TIMER_PULL, 0, Session).
+pull_now(Session0) ->
+    Session1 = ensure_timer(?TIMER_PULL, 0, Session0),
+    ensure_state_commit_timer(Session1).
 
 -spec receive_maximum(conninfo()) -> pos_integer().
 receive_maximum(ConnInfo) ->
@@ -1466,6 +1455,9 @@ expiry_interval(ConnInfo) ->
 bump_interval() ->
     emqx_config:get([durable_sessions, heartbeat_interval]).
 
+commit_interval() ->
+    bump_interval().
+
 get_config(#{zone := Zone}, Key) ->
     emqx_config:get_zone_conf(Zone, [durable_sessions | Key]).
 
@@ -1701,7 +1693,7 @@ set_timer(Timer, Time, Session) ->
 
 commit(Session = #{s := S0}) ->
     S = emqx_persistent_session_ds_state:commit(S0),
-    Session#{s := S}.
+    cancel_state_commit_timer(Session#{s := S}).
 
 -spec maybe_set_shared_sub_timer(session()) -> session().
 maybe_set_shared_sub_timer(Session = #{s := S}) ->
@@ -1731,6 +1723,54 @@ has_shared_subs(S) ->
         has_shared_subs -> true
     end.
 
+%%--------------------------------------------------------------------
+%% Management of state commit timer
+%%--------------------------------------------------------------------
+
+-spec ensure_state_commit_timer(session()) -> session().
+ensure_state_commit_timer(#{s := S, ?TIMER_COMMIT := undefined} = Session) ->
+    case emqx_persistent_session_ds_state:is_dirty(S) of
+        true ->
+            set_timer(?TIMER_COMMIT, commit_interval(), Session);
+        false ->
+            Session
+    end;
+ensure_state_commit_timer(Session) ->
+    Session.
+
+-spec cancel_state_commit_timer(session()) -> session().
+cancel_state_commit_timer(#{?TIMER_COMMIT := TRef} = Session) ->
+    emqx_utils:cancel_timer(TRef),
+    Session#{?TIMER_COMMIT := undefined}.
+
+%%--------------------------------------------------------------------
+%% Management of the heartbeat
+%%--------------------------------------------------------------------
+
+init_last_alive_at(S) ->
+    init_last_alive_at(now_ms(), S).
+
+init_last_alive_at(NowMs, S0) ->
+    NodeEpochId = emqx_persistent_session_ds_node_heartbeat_worker:get_node_epoch_id(),
+    S1 = emqx_persistent_session_ds_state:set_node_epoch_id(NodeEpochId, S0),
+    emqx_persistent_session_ds_state:set_last_alive_at(NowMs + bump_interval(), S1).
+
+finalize_last_alive_at(S0) ->
+    S = emqx_persistent_session_ds_state:set_last_alive_at(now_ms(), S0),
+    emqx_persistent_session_ds_state:set_node_epoch_id(undefined, S).
+
+%% NOTE
+%% Here we ignore the case when:
+%% * the session is terminated abnormally, without running terminate callback,
+%% e.g. when the conection was brutally killed;
+%% * but its node and persistent session subsystem remained alive.
+%%
+%% In this case, the session's lifitime is prolonged till the node termination.
+get_last_alive_at(S) ->
+    LastAliveAt = emqx_persistent_session_ds_state:get_last_alive_at(S),
+    NodeEpochId = emqx_persistent_session_ds_state:get_node_epoch_id(S),
+    emqx_persistent_session_ds_gc_worker:session_last_alive_at(LastAliveAt, NodeEpochId).
+
 %%--------------------------------------------------------------------
 %% Tests
 %%--------------------------------------------------------------------

+ 70 - 21
apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_gc_worker.erl

@@ -18,8 +18,6 @@
 -behaviour(gen_server).
 
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
--include_lib("stdlib/include/qlc.hrl").
--include_lib("stdlib/include/ms_transform.hrl").
 
 -include("session_internals.hrl").
 
@@ -27,7 +25,8 @@
 -export([
     start_link/0,
     check_session/1,
-    check_session_after/2
+    check_session_after/2,
+    session_last_alive_at/2
 ]).
 
 %% `gen_server' API
@@ -60,6 +59,22 @@ check_session_after(SessionId, Time0) ->
     _ = erlang:send_after(Time, ?MODULE, #check_session{id = SessionId}),
     ok.
 
+-spec session_last_alive_at(
+    pos_integer(), emqx_persistent_session_ds_node_heartbeat_worker:epoch_id() | undefined
+) -> pos_integer().
+session_last_alive_at(LastAliveAt, undefined) ->
+    LastAliveAt;
+session_last_alive_at(LastAliveAt, NodeEpochId) ->
+    case emqx_persistent_session_ds_node_heartbeat_worker:get_last_alive_at(NodeEpochId) of
+        undefined ->
+            LastAliveAt;
+        NodeLastAliveAt ->
+            max(
+                LastAliveAt,
+                NodeLastAliveAt + emqx_config:get([durable_sessions, heartbeat_interval])
+            )
+    end.
+
 %%--------------------------------------------------------------------------------
 %% `gen_server' API
 %%--------------------------------------------------------------------------------
@@ -102,7 +117,7 @@ try_gc() ->
     CoreNodes = mria_membership:running_core_nodelist(),
     Res = global:trans(
         {?MODULE, self()},
-        fun() -> ?tp_span(debug, ds_session_gc, #{}, start_gc()) end,
+        fun() -> ?tp_span(debug, ds_session_gc, #{}, run_gc()) end,
         CoreNodes,
         %% Note: we set retries to 1 here because, in rare occasions, GC might start at the
         %% same time in more than one node, and each one will abort the other.  By allowing
@@ -123,40 +138,55 @@ try_gc() ->
 now_ms() ->
     erlang:system_time(millisecond).
 
-start_gc() ->
-    #{min_last_alive := MinLastAlive} = gc_context(),
-    gc_loop(MinLastAlive, emqx_persistent_session_ds_state:make_session_iterator()).
+run_gc() ->
+    NowMs = now_ms(),
+    SessionCounts0 = init_epoch_session_counters(NowMs),
+    SessionCounts = gc_loop(
+        gc_context(), SessionCounts0, emqx_persistent_session_ds_state:make_session_iterator()
+    ),
+    ok = clenup_inactive_epochs(SessionCounts).
 
 gc_context() ->
+    gc_context(now_ms()).
+
+gc_context(NowMs) ->
     GCInterval = emqx_config:get([durable_sessions, session_gc_interval]),
     BumpInterval = emqx_config:get([durable_sessions, heartbeat_interval]),
     SafetyMargin = BumpInterval * 3,
     #{
-        min_last_alive => now_ms() - SafetyMargin,
+        min_last_alive => NowMs - SafetyMargin,
         bump_interval => BumpInterval,
         gc_interval => GCInterval
     }.
 
-gc_loop(MinLastAlive, It0) ->
+gc_loop(GCContext, SessionCounts0, It0) ->
     GCBatchSize = emqx_config:get([durable_sessions, session_gc_batch_size]),
     case emqx_persistent_session_ds_state:session_iterator_next(It0, GCBatchSize) of
         {[], _It} ->
-            ok;
+            SessionCounts0;
         {Sessions, It} ->
-            [
-                do_gc(MinLastAlive, SessionId, Metadata)
-             || {SessionId, Metadata} <- Sessions
-            ],
-            gc_loop(MinLastAlive, It)
+            SessionCounts1 = lists:foldl(
+                fun({SessionId, Metadata}, SessionCountsAcc) ->
+                    do_gc(GCContext, SessionCountsAcc, SessionId, Metadata)
+                end,
+                SessionCounts0,
+                Sessions
+            ),
+            gc_loop(GCContext, SessionCounts1, It)
     end.
 
-do_gc(MinLastAlive, SessionId, Metadata) ->
+do_gc(GCContext, SessionId, Metadata) ->
+    do_gc(GCContext, _NoEpochCounters = #{}, SessionId, Metadata).
+
+do_gc(#{min_last_alive := MinLastAlive}, SessionCounts, SessionId, Metadata) ->
     #{
-        ?last_alive_at := LastAliveAt,
+        ?last_alive_at := SessionLastAliveAt,
+        ?node_epoch_id := NodeEpochId,
         ?expiry_interval := EI,
         ?will_message := MaybeWillMessage,
         ?clientinfo := ClientInfo
     } = Metadata,
+    LastAliveAt = session_last_alive_at(SessionLastAliveAt, NodeEpochId),
     IsExpired = LastAliveAt + EI < MinLastAlive,
     case
         should_send_will_message(
@@ -179,9 +209,10 @@ do_gc(MinLastAlive, SessionId, Metadata) ->
                 last_alive_at => LastAliveAt,
                 expiry_interval => EI,
                 min_last_alive => MinLastAlive
-            });
+            }),
+            SessionCounts;
         false ->
-            ok
+            inc_epoch_session_count(SessionCounts, NodeEpochId)
     end.
 
 should_send_will_message(undefined, _ClientInfo, _IsExpired, _LastAliveAt, _MinLastAlive) ->
@@ -205,8 +236,26 @@ should_send_will_message(WillMsg, ClientInfo, IsExpired, LastAliveAt, MinLastAli
 do_check_session(SessionId) ->
     case emqx_persistent_session_ds_state:print_session(SessionId) of
         #{metadata := Metadata} ->
-            #{min_last_alive := MinLastAlive} = gc_context(),
-            do_gc(MinLastAlive, SessionId, Metadata);
+            do_gc(gc_context(), SessionId, Metadata);
         _ ->
             ok
     end.
+
+init_epoch_session_counters(NowMs) ->
+    maps:from_keys(
+        emqx_persistent_session_ds_node_heartbeat_worker:inactive_epochs(NowMs), 0
+    ).
+
+inc_epoch_session_count(SessionCounts, NodeEpochId) when
+    is_map_key(NodeEpochId, SessionCounts)
+->
+    maps:update_with(NodeEpochId, fun(X) -> X + 1 end, 1, SessionCounts);
+inc_epoch_session_count(SessionCounts, _NodeEpochId) ->
+    SessionCounts.
+
+clenup_inactive_epochs(SessionCounts) ->
+    ?tp(debug, clenup_inactive_epochs, #{
+        session_counts => SessionCounts
+    }),
+    EmptyInactiveEpochIds = [NodeEpochId || {NodeEpochId, 0} <- maps:to_list(SessionCounts)],
+    ok = emqx_persistent_session_ds_node_heartbeat_worker:delete_epochs(EmptyInactiveEpochIds).

+ 172 - 0
apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_node_heartbeat_worker.erl

@@ -0,0 +1,172 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+-module(emqx_persistent_session_ds_node_heartbeat_worker).
+
+-behaviour(gen_server).
+
+-include("session_internals.hrl").
+-include_lib("snabbkaffe/include/snabbkaffe.hrl").
+-include_lib("stdlib/include/ms_transform.hrl").
+
+%% API
+-export([
+    create_tables/0,
+    start_link/0,
+    get_node_epoch_id/0,
+    get_last_alive_at/1,
+    inactive_epochs/1,
+    delete_epochs/1
+]).
+
+%% `gen_server' API
+-export([
+    init/1,
+    handle_call/3,
+    handle_cast/2,
+    handle_info/2,
+    terminate/2
+]).
+
+-export_type([epoch_id/0]).
+
+%% call/cast/info records
+-record(update_last_alive_at, {}).
+
+-define(epoch_id_pt_key, {?MODULE, epoch_id}).
+-define(node_epoch, node_epoch).
+-define(tab, ?node_epoch).
+
+-record(?node_epoch, {
+    epoch_id :: reference(),
+    node :: node(),
+    last_alive_at :: pos_integer()
+}).
+
+-type epoch_id() :: reference().
+
+%%--------------------------------------------------------------------------------
+%% API
+%%--------------------------------------------------------------------------------
+
+-spec start_link() -> {ok, pid()}.
+start_link() ->
+    ok = create_tables(),
+    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
+
+-spec get_node_epoch_id() -> epoch_id().
+get_node_epoch_id() ->
+    persistent_term:get(?epoch_id_pt_key).
+
+-spec get_last_alive_at(epoch_id()) -> pos_integer() | undefined.
+get_last_alive_at(EpochId) ->
+    case mnesia:dirty_read(?tab, EpochId) of
+        [] -> undefined;
+        [#?node_epoch{last_alive_at = LastAliveAt}] -> LastAliveAt
+    end.
+
+-spec inactive_epochs(integer()) -> [epoch_id()].
+inactive_epochs(NowMs) ->
+    DeadLine = NowMs - 2 * heartbeat_interval(),
+    Ms = ets:fun2ms(
+        fun(#?node_epoch{last_alive_at = LastAliveAt, epoch_id = EpochId}) when
+            LastAliveAt < DeadLine
+        ->
+            EpochId
+        end
+    ),
+    mnesia:dirty_select(?tab, Ms).
+
+-spec delete_epochs([epoch_id()]) -> ok.
+delete_epochs(EpochIds) ->
+    ?tp(debug, persistent_session_ds_node_heartbeat_delete_epochs, #{
+        epoch_ids => EpochIds
+    }),
+    mria:async_dirty(?DS_MRIA_SHARD, fun() ->
+        lists:foreach(
+            fun(EpochId) ->
+                mnesia:delete(?tab, EpochId, write)
+            end,
+            EpochIds
+        )
+    end).
+
+%%--------------------------------------------------------------------------------
+%% `gen_server' API
+%%--------------------------------------------------------------------------------
+
+init(_Opts) ->
+    erlang:process_flag(trap_exit, true),
+    ok = generate_node_epoch_id(),
+    ok = update_last_alive_at(),
+    ok = ensure_heartbeat_timer(),
+    State = #{},
+    {ok, State}.
+
+handle_call(_Call, _From, State) ->
+    {reply, {error, not_implemented}, State}.
+
+handle_cast(_Cast, State) ->
+    {noreply, State}.
+
+handle_info(#update_last_alive_at{}, State) ->
+    ok = update_last_alive_at(),
+    ok = ensure_heartbeat_timer(),
+    {noreply, State}.
+
+terminate(_Reason, _State) ->
+    ok = delete_last_alive_at(),
+    ok.
+
+%%--------------------------------------------------------------------------------
+%% Internal functions
+%%--------------------------------------------------------------------------------
+
+create_tables() ->
+    ok = mria:create_table(?tab, [
+        {rlog_shard, ?DS_MRIA_SHARD},
+        {type, set},
+        {storage, disc_copies},
+        {record_name, ?node_epoch},
+        {attributes, record_info(fields, ?node_epoch)}
+    ]),
+    mria:wait_for_tables([?tab]).
+
+generate_node_epoch_id() ->
+    EpochId = erlang:make_ref(),
+    persistent_term:put(?epoch_id_pt_key, EpochId),
+    ok.
+
+ensure_heartbeat_timer() ->
+    _ = erlang:send_after(heartbeat_interval(), self(), #update_last_alive_at{}),
+    ok.
+
+update_last_alive_at() ->
+    EpochId = get_node_epoch_id(),
+    LastAliveAt = now_ms() + heartbeat_interval(),
+    ok = mria:dirty_write(?tab, #?node_epoch{
+        epoch_id = EpochId, node = node(), last_alive_at = LastAliveAt
+    }),
+    ok.
+
+delete_last_alive_at() ->
+    EpochId = get_node_epoch_id(),
+    ok = mria:dirty_delete(?tab, EpochId).
+
+heartbeat_interval() ->
+    emqx_config:get([durable_sessions, heartbeat_interval]).
+
+now_ms() ->
+    erlang:system_time(millisecond).

+ 18 - 0
apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_state.erl

@@ -42,8 +42,10 @@
 -export([
     open/1, create_new/1, delete/1, commit/1, commit/2, format/1, print_session/1, list_sessions/0
 ]).
+-export([is_dirty/1]).
 -export([get_created_at/1, set_created_at/2]).
 -export([get_last_alive_at/1, set_last_alive_at/2]).
+-export([get_node_epoch_id/1, set_node_epoch_id/2]).
 -export([get_expiry_interval/1, set_expiry_interval/2]).
 -export([get_clientinfo/1, set_clientinfo/2]).
 -export([get_will_message/1, set_will_message/2, clear_will_message/1, clear_will_message_now/1]).
@@ -171,6 +173,7 @@
     #{
         ?created_at => emqx_persistent_session_ds:timestamp(),
         ?last_alive_at => emqx_persistent_session_ds:timestamp(),
+        ?node_epoch_id => emqx_persistent_session_ds_node_heartbeat_worker:epoch_id() | undefined,
         ?expiry_interval => non_neg_integer(),
         ?last_id => integer(),
         ?peername => emqx_types:peername(),
@@ -616,6 +619,10 @@ create_new(SessionId) ->
 
 %%
 
+-spec is_dirty(t()) -> boolean().
+is_dirty(#{?dirty := Dirty}) ->
+    Dirty.
+
 -spec get_created_at(t()) -> emqx_persistent_session_ds:timestamp() | undefined.
 get_created_at(Rec) ->
     get_meta(?created_at, Rec).
@@ -632,6 +639,17 @@ get_last_alive_at(Rec) ->
 set_last_alive_at(Val, Rec) ->
     set_meta(?last_alive_at, Val, Rec).
 
+-spec get_node_epoch_id(t()) ->
+    emqx_persistent_session_ds_node_heartbeat_worker:epoch_id() | undefined.
+get_node_epoch_id(Rec) ->
+    get_meta(?node_epoch_id, Rec).
+
+-spec set_node_epoch_id(
+    emqx_persistent_session_ds_node_heartbeat_worker:epoch_id() | undefined, t()
+) -> t().
+set_node_epoch_id(Val, Rec) ->
+    set_meta(?node_epoch_id, Val, Rec).
+
 -spec get_expiry_interval(t()) -> non_neg_integer() | undefined.
 get_expiry_interval(Rec) ->
     get_meta(?expiry_interval, Rec).

+ 6 - 3
apps/emqx/src/emqx_persistent_session_ds/emqx_persistent_session_ds_sup.erl

@@ -53,14 +53,17 @@ do_init(_Opts) ->
         period => 2,
         auto_shutdown => never
     },
-    CoreChildren = [
+    CoreNodeChildren = [
         worker(session_gc_worker, emqx_persistent_session_ds_gc_worker, []),
         worker(message_gc_worker, emqx_persistent_message_ds_gc_worker, [])
     ],
+    AnyNodeChildren = [
+        worker(node_heartbeat, emqx_persistent_session_ds_node_heartbeat_worker, [])
+    ],
     Children =
         case mria_rlog:role() of
-            core -> CoreChildren;
-            replicant -> []
+            core -> CoreNodeChildren ++ AnyNodeChildren;
+            replicant -> AnyNodeChildren
         end,
     {ok, {SupFlags, Children}}.
 

+ 98 - 3
apps/emqx/test/emqx_persistent_session_ds_SUITE.erl

@@ -68,13 +68,17 @@ init_per_testcase(TestCase, Config) when
         {work_dir, WorkDir}
         | Config
     ];
-init_per_testcase(t_session_gc = TestCase, Config) ->
+init_per_testcase(TestCase, Config) when
+    TestCase =:= t_session_gc;
+    TestCase =:= t_crashed_node_session_gc;
+    TestCase =:= t_last_alive_at_cleanup
+->
     Opts = #{
         n => 3,
         roles => [core, core, core],
         extra_emqx_conf =>
             "\n durable_sessions {"
-            "\n   heartbeat_interval = 500ms "
+            "\n   heartbeat_interval = 50ms "
             "\n   session_gc_interval = 1s "
             "\n   session_gc_batch_size = 2 "
             "\n }"
@@ -102,7 +106,9 @@ end_per_testcase(TestCase, Config) when
     TestCase =:= t_subscription_state_change;
     TestCase =:= t_session_gc;
     TestCase =:= t_storage_generations;
-    TestCase =:= t_new_stream_notifications
+    TestCase =:= t_new_stream_notifications;
+    TestCase =:= t_crashed_node_session_gc;
+    TestCase =:= t_last_alive_at_cleanup
 ->
     Nodes = ?config(nodes, Config),
     emqx_common_test_helpers:call_janitor(60_000),
@@ -846,6 +852,95 @@ t_session_gc(Config) ->
     ),
     ok.
 
+t_crashed_node_session_gc(Config) ->
+    [Node1, Node2 | _] = ?config(nodes, Config),
+    Port = get_mqtt_port(Node1, tcp),
+    ct:pal("Port: ~p", [Port]),
+
+    ?check_trace(
+        #{timetrap => 30_000},
+        begin
+            ClientId = <<"session_on_crashed_node">>,
+            Client = start_client(#{
+                clientid => ClientId,
+                port => Port,
+                properties => #{'Session-Expiry-Interval' => 1},
+                clean_start => false,
+                proto_ver => v5
+            }),
+            {ok, _} = emqtt:connect(Client),
+            ct:sleep(1500),
+            emqx_cth_peer:kill(Node1),
+
+            %% Much time has passed since the session last reported its alive time (on start).
+            %% Last alive time was not persisted on connection shutdown since we brutally killed the node.
+            %% However, the session should not be expired,
+            %% because session's last alive time should be bumped by the node's last_alive_at, and
+            %% the node only recently crashed.
+            erpc:call(Node2, emqx_persistent_session_ds_gc_worker, check_session, [ClientId]),
+            %%% Wait for possible async dirty session delete
+            ct:sleep(100),
+            ?assertMatch([_], list_all_sessions(Node2), sessions),
+
+            %% But finally the session has to expire since the connection
+            %% is not re-established.
+            ?assertMatch(
+                {ok, _},
+                ?block_until(
+                    #{
+                        ?snk_kind := ds_session_gc_cleaned,
+                        session_id := ClientId
+                    }
+                )
+            ),
+            %%% Wait for possible async dirty session delete
+            ct:sleep(100),
+            ?assertMatch([], list_all_sessions(Node2), sessions)
+        end,
+        []
+    ),
+    ok.
+
+t_last_alive_at_cleanup(Config) ->
+    [Node1 | _] = ?config(nodes, Config),
+    Port = get_mqtt_port(Node1, tcp),
+    ?check_trace(
+        #{timetrap => 5_000},
+        begin
+            NodeEpochId = erpc:call(
+                Node1,
+                emqx_persistent_session_ds_node_heartbeat_worker,
+                get_node_epoch_id,
+                []
+            ),
+            ClientId = <<"session_on_crashed_node">>,
+            Client = start_client(#{
+                clientid => ClientId,
+                port => Port,
+                properties => #{'Session-Expiry-Interval' => 1},
+                clean_start => false,
+                proto_ver => v5
+            }),
+            {ok, _} = emqtt:connect(Client),
+
+            %% Kill node making its lifetime epoch invalid.
+            emqx_cth_peer:kill(Node1),
+
+            %% Wait till the node's epoch is cleaned up.
+            ?assertMatch(
+                {ok, _},
+                ?block_until(
+                    #{
+                        ?snk_kind := persistent_session_ds_node_heartbeat_delete_epochs,
+                        epoch_ids := [NodeEpochId]
+                    }
+                )
+            )
+        end,
+        []
+    ),
+    ok.
+
 t_session_replay_retry(_Config) ->
     %% Verify that the session recovers smoothly from transient errors during
     %% replay.