Просмотр исходного кода

Merge pull request #12024 from thalesmg/ds-session-expiry-m-20231124

feat(ds): session expiry
Thales Macedo Garitezi 2 лет назад
Родитель
Сommit
26e59f9508

+ 84 - 2
apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl

@@ -9,6 +9,7 @@
 -include_lib("stdlib/include/assert.hrl").
 -include_lib("common_test/include/ct.hrl").
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
+-include_lib("emqx/include/asserts.hrl").
 -include_lib("emqx/include/emqx_mqtt.hrl").
 
 -import(emqx_common_test_helpers, [on_exit/1]).
@@ -221,9 +222,10 @@ t_session_subscription_idempotency(Config) ->
         end,
         fun(Trace) ->
             ct:pal("trace:\n  ~p", [Trace]),
+            ConnInfo = #{},
             ?assertMatch(
                 #{subscriptions := #{SubTopicFilter := #{}}},
-                erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId])
+                erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId, ConnInfo])
             )
         end
     ),
@@ -294,9 +296,10 @@ t_session_unsubscription_idempotency(Config) ->
         end,
         fun(Trace) ->
             ct:pal("trace:\n  ~p", [Trace]),
+            ConnInfo = #{},
             ?assertMatch(
                 #{subscriptions := Subs = #{}} when map_size(Subs) =:= 0,
-                erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId])
+                erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId, ConnInfo])
             ),
             ok
         end
@@ -387,3 +390,82 @@ do_t_session_discard(Params) ->
         end
     ),
     ok.
+
+t_session_expiration1(Config) ->
+    ClientId = atom_to_binary(?FUNCTION_NAME),
+    Opts = #{
+        clientid => ClientId,
+        sequence => [
+            {#{clean_start => false, properties => #{'Session-Expiry-Interval' => 30}}, #{}},
+            {#{clean_start => false, properties => #{'Session-Expiry-Interval' => 1}}, #{}},
+            {#{clean_start => false, properties => #{'Session-Expiry-Interval' => 30}}, #{}}
+        ]
+    },
+    do_t_session_expiration(Config, Opts).
+
+t_session_expiration2(Config) ->
+    ClientId = atom_to_binary(?FUNCTION_NAME),
+    Opts = #{
+        clientid => ClientId,
+        sequence => [
+            {#{clean_start => false, properties => #{'Session-Expiry-Interval' => 30}}, #{}},
+            {#{clean_start => false, properties => #{'Session-Expiry-Interval' => 30}}, #{
+                'Session-Expiry-Interval' => 1
+            }},
+            {#{clean_start => false, properties => #{'Session-Expiry-Interval' => 30}}, #{}}
+        ]
+    },
+    do_t_session_expiration(Config, Opts).
+
+do_t_session_expiration(_Config, Opts) ->
+    #{
+        clientid := ClientId,
+        sequence := [
+            {FirstConn, FirstDisconn},
+            {SecondConn, SecondDisconn},
+            {ThirdConn, ThirdDisconn}
+        ]
+    } = Opts,
+    CommonParams = #{proto_ver => v5, clientid => ClientId},
+    ?check_trace(
+        begin
+            Topic = <<"some/topic">>,
+            Params0 = maps:merge(CommonParams, FirstConn),
+            Client0 = start_client(Params0),
+            {ok, _} = emqtt:connect(Client0),
+            {ok, _, [?RC_GRANTED_QOS_2]} = emqtt:subscribe(Client0, Topic, ?QOS_2),
+            Subs0 = emqx_persistent_session_ds:list_all_subscriptions(),
+            ?assertEqual(1, map_size(Subs0), #{subs => Subs0}),
+            Info0 = maps:from_list(emqtt:info(Client0)),
+            ?assertEqual(0, maps:get(session_present, Info0), #{info => Info0}),
+            emqtt:disconnect(Client0, ?RC_NORMAL_DISCONNECTION, FirstDisconn),
+
+            Params1 = maps:merge(CommonParams, SecondConn),
+            Client1 = start_client(Params1),
+            {ok, _} = emqtt:connect(Client1),
+            Info1 = maps:from_list(emqtt:info(Client1)),
+            ?assertEqual(1, maps:get(session_present, Info1), #{info => Info1}),
+            Subs1 = emqtt:subscriptions(Client1),
+            ?assertEqual([], Subs1),
+            emqtt:disconnect(Client1, ?RC_NORMAL_DISCONNECTION, SecondDisconn),
+
+            ct:sleep(1_500),
+
+            Params2 = maps:merge(CommonParams, ThirdConn),
+            Client2 = start_client(Params2),
+            {ok, _} = emqtt:connect(Client2),
+            Info2 = maps:from_list(emqtt:info(Client2)),
+            ?assertEqual(0, maps:get(session_present, Info2), #{info => Info2}),
+            Subs2 = emqtt:subscriptions(Client2),
+            ?assertEqual([], Subs2),
+            emqtt:publish(Client2, Topic, <<"payload">>),
+            ?assertNotReceive({publish, #{topic := Topic}}),
+            %% ensure subscriptions are absent from table.
+            ?assertEqual(#{}, emqx_persistent_session_ds:list_all_subscriptions()),
+            emqtt:disconnect(Client2, ?RC_NORMAL_DISCONNECTION, ThirdDisconn),
+
+            ok
+        end,
+        []
+    ),
+    ok.

+ 4 - 2
apps/emqx/src/emqx_channel.erl

@@ -1204,12 +1204,13 @@ handle_info(
         #channel{
             conn_state = ConnState,
             clientinfo = ClientInfo,
+            conninfo = ConnInfo,
             session = Session
         }
 ) when
     ConnState =:= connected orelse ConnState =:= reauthenticating
 ->
-    {Intent, Session1} = emqx_session:disconnect(ClientInfo, Session),
+    {Intent, Session1} = emqx_session:disconnect(ClientInfo, ConnInfo, Session),
     Channel1 = ensure_disconnected(Reason, maybe_publish_will_msg(Channel)),
     Channel2 = Channel1#channel{session = Session1},
     case maybe_shutdown(Reason, Intent, Channel2) of
@@ -1321,7 +1322,8 @@ handle_timeout(
         {ok, Replies, NSession} ->
             handle_out(publish, Replies, Channel#channel{session = NSession})
     end;
-handle_timeout(_TRef, expire_session, Channel) ->
+handle_timeout(_TRef, expire_session, Channel = #channel{session = Session}) ->
+    ok = emqx_session:destroy(Session),
     shutdown(expired, Channel);
 handle_timeout(
     _TRef,

+ 115 - 47
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -56,7 +56,7 @@
     deliver/3,
     replay/3,
     handle_timeout/3,
-    disconnect/1,
+    disconnect/2,
     terminate/2
 ]).
 
@@ -74,7 +74,7 @@
 
 -ifdef(TEST).
 -export([
-    session_open/1,
+    session_open/2,
     list_all_sessions/0,
     list_all_subscriptions/0,
     list_all_streams/0,
@@ -98,22 +98,26 @@
     id := id(),
     %% When the session was created
     created_at := timestamp(),
-    %% When the session should expire
-    expires_at := timestamp() | never,
+    %% When the client was last considered alive
+    last_alive_at := timestamp(),
     %% Client’s Subscriptions.
     subscriptions := #{topic_filter() => subscription()},
     %% Inflight messages
     inflight := emqx_persistent_message_ds_replayer:inflight(),
     %% Receive maximum
     receive_maximum := pos_integer(),
+    %% Connection Info
+    conninfo := emqx_types:conninfo(),
     %%
     props := map()
 }.
 
 -type timestamp() :: emqx_utils_calendar:epoch_millisecond().
+-type millisecond() :: non_neg_integer().
 -type clientinfo() :: emqx_types:clientinfo().
 -type conninfo() :: emqx_session:conninfo().
 -type replies() :: emqx_session:replies().
+-type timer() :: pull | get_streams | bump_last_alive_at.
 
 -define(STATS_KEYS, [
     subscriptions_cnt,
@@ -123,6 +127,12 @@
     next_pkt_id
 ]).
 
+-define(IS_EXPIRED(NOW_MS, LAST_ALIVE_AT, EI),
+    (is_number(LAST_ALIVE_AT) andalso
+        is_number(EI) andalso
+        (NOW_MS >= LAST_ALIVE_AT + EI))
+).
+
 -export_type([id/0]).
 
 %%
@@ -144,26 +154,24 @@ open(#{clientid := ClientID} = _ClientInfo, ConnInfo) ->
     %% somehow isolate those idling not-yet-expired sessions into a separate process
     %% space, and move this call back into `emqx_cm` where it belongs.
     ok = emqx_cm:discard_session(ClientID),
-    case maps:get(clean_start, ConnInfo, false) of
+    case session_open(ClientID, ConnInfo) of
+        Session0 = #{} ->
+            ensure_timers(),
+            ReceiveMaximum = receive_maximum(ConnInfo),
+            Session = Session0#{receive_maximum => ReceiveMaximum},
+            {true, Session, []};
         false ->
-            case session_open(ClientID) of
-                Session0 = #{} ->
-                    ensure_timers(),
-                    ReceiveMaximum = receive_maximum(ConnInfo),
-                    Session = Session0#{receive_maximum => ReceiveMaximum},
-                    {true, Session, []};
-                false ->
-                    false
-            end;
-        true ->
-            session_drop(ClientID),
             false
     end.
 
 ensure_session(ClientID, ConnInfo, Conf) ->
-    Session = session_ensure_new(ClientID, Conf),
+    Session = session_ensure_new(ClientID, ConnInfo, Conf),
     ReceiveMaximum = receive_maximum(ConnInfo),
-    Session#{subscriptions => #{}, receive_maximum => ReceiveMaximum}.
+    Session#{
+        conninfo => ConnInfo,
+        receive_maximum => ReceiveMaximum,
+        subscriptions => #{}
+    }.
 
 -spec destroy(session() | clientinfo()) -> ok.
 destroy(#{id := ClientID}) ->
@@ -389,6 +397,11 @@ handle_timeout(
 handle_timeout(_ClientInfo, get_streams, Session) ->
     renew_streams(Session),
     ensure_timer(get_streams),
+    {ok, [], Session};
+handle_timeout(_ClientInfo, bump_last_alive_at, Session0) ->
+    NowMS = now_ms(),
+    Session = session_set_last_alive_at_trans(Session0, NowMS),
+    ensure_timer(bump_last_alive_at),
     {ok, [], Session}.
 
 -spec replay(clientinfo(), [], session()) ->
@@ -399,8 +412,9 @@ replay(_ClientInfo, [], Session = #{inflight := Inflight0}) ->
 
 %%--------------------------------------------------------------------
 
--spec disconnect(session()) -> {shutdown, session()}.
-disconnect(Session = #{}) ->
+-spec disconnect(session(), emqx_types:conninfo()) -> {shutdown, session()}.
+disconnect(Session0, ConnInfo) ->
+    Session = session_set_last_alive_at_trans(Session0, ConnInfo, now_ms()),
     {shutdown, Session}.
 
 -spec terminate(Reason :: term(), session()) -> ok.
@@ -530,47 +544,84 @@ storage() ->
 %%
 %% Note: session API doesn't handle session takeovers, it's the job of
 %% the broker.
--spec session_open(id()) ->
+-spec session_open(id(), emqx_types:conninfo()) ->
     session() | false.
-session_open(SessionId) ->
-    ro_transaction(fun() ->
+session_open(SessionId, NewConnInfo) ->
+    NowMS = now_ms(),
+    transaction(fun() ->
         case mnesia:read(?SESSION_TAB, SessionId, write) of
-            [Record = #session{}] ->
-                Session = export_session(Record),
-                DSSubs = session_read_subscriptions(SessionId),
-                Subscriptions = export_subscriptions(DSSubs),
-                Inflight = emqx_persistent_message_ds_replayer:open(SessionId),
-                Session#{
-                    subscriptions => Subscriptions,
-                    inflight => Inflight
-                };
-            [] ->
+            [Record0 = #session{last_alive_at = LastAliveAt, conninfo = ConnInfo}] ->
+                EI = expiry_interval(ConnInfo),
+                case ?IS_EXPIRED(NowMS, LastAliveAt, EI) of
+                    true ->
+                        session_drop(SessionId),
+                        false;
+                    false ->
+                        %% new connection being established
+                        Record1 = Record0#session{conninfo = NewConnInfo},
+                        Record = session_set_last_alive_at(Record1, NowMS),
+                        Session = export_session(Record),
+                        DSSubs = session_read_subscriptions(SessionId),
+                        Subscriptions = export_subscriptions(DSSubs),
+                        Inflight = emqx_persistent_message_ds_replayer:open(SessionId),
+                        Session#{
+                            conninfo => NewConnInfo,
+                            inflight => Inflight,
+                            subscriptions => Subscriptions
+                        }
+                end;
+            _ ->
                 false
         end
     end).
 
--spec session_ensure_new(id(), _Props :: map()) ->
+-spec session_ensure_new(id(), emqx_types:conninfo(), _Props :: map()) ->
     session().
-session_ensure_new(SessionId, Props) ->
+session_ensure_new(SessionId, ConnInfo, Props) ->
     transaction(fun() ->
         ok = session_drop_subscriptions(SessionId),
-        Session = export_session(session_create(SessionId, Props)),
+        Session = export_session(session_create(SessionId, ConnInfo, Props)),
         Session#{
             subscriptions => #{},
             inflight => emqx_persistent_message_ds_replayer:new()
         }
     end).
 
-session_create(SessionId, Props) ->
+session_create(SessionId, ConnInfo, Props) ->
     Session = #session{
         id = SessionId,
-        created_at = erlang:system_time(millisecond),
-        expires_at = never,
+        created_at = now_ms(),
+        last_alive_at = now_ms(),
+        conninfo = ConnInfo,
         props = Props
     },
     ok = mnesia:write(?SESSION_TAB, Session, write),
     Session.
 
+session_set_last_alive_at_trans(Session, LastAliveAt) ->
+    #{conninfo := ConnInfo} = Session,
+    session_set_last_alive_at_trans(Session, ConnInfo, LastAliveAt).
+
+session_set_last_alive_at_trans(Session, NewConnInfo, LastAliveAt) ->
+    #{id := SessionId} = Session,
+    transaction(fun() ->
+        case mnesia:read(?SESSION_TAB, SessionId, write) of
+            [#session{} = SessionRecord0] ->
+                SessionRecord = SessionRecord0#session{conninfo = NewConnInfo},
+                _ = session_set_last_alive_at(SessionRecord, LastAliveAt),
+                ok;
+            _ ->
+                %% log and crash?
+                ok
+        end
+    end),
+    Session#{conninfo := NewConnInfo, last_alive_at := LastAliveAt}.
+
+session_set_last_alive_at(SessionRecord0, LastAliveAt) ->
+    SessionRecord = SessionRecord0#session{last_alive_at = LastAliveAt},
+    ok = mnesia:write(?SESSION_TAB, SessionRecord, write),
+    SessionRecord.
+
 %% @doc Called when a client reconnects with `clean session=true' or
 %% during session GC
 -spec session_drop(id()) -> ok.
@@ -673,7 +724,7 @@ session_read_pubranges(DSSessionId, LockKind) ->
 new_subscription_id(DSSessionId, TopicFilter) ->
     %% Note: here we use _milliseconds_ to match with the timestamp
     %% field of `#message' record.
-    NowMS = erlang:system_time(millisecond),
+    NowMS = now_ms(),
     DSSubId = {DSSessionId, TopicFilter},
     {DSSubId, NowMS}.
 
@@ -681,6 +732,9 @@ new_subscription_id(DSSessionId, TopicFilter) ->
 subscription_id_to_topic_filter({_DSSessionId, TopicFilter}) ->
     TopicFilter.
 
+now_ms() ->
+    erlang:system_time(millisecond).
+
 %%--------------------------------------------------------------------
 %% RPC targets (v1)
 %%--------------------------------------------------------------------
@@ -781,8 +835,13 @@ session_drop_pubranges(DSSessionId) ->
 %%--------------------------------------------------------------------------------
 
 transaction(Fun) ->
-    {atomic, Res} = mria:transaction(?DS_MRIA_SHARD, Fun),
-    Res.
+    case mnesia:is_transaction() of
+        true ->
+            Fun();
+        false ->
+            {atomic, Res} = mria:transaction(?DS_MRIA_SHARD, Fun),
+            Res
+    end.
 
 ro_transaction(Fun) ->
     {atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
@@ -800,7 +859,7 @@ export_subscriptions(DSSubs) ->
     ).
 
 export_session(#session{} = Record) ->
-    export_record(Record, #session.id, [id, created_at, expires_at, props], #{}).
+    export_record(Record, #session.id, [id, created_at, last_alive_at, conninfo, props], #{}).
 
 export_subscription(#ds_sub{} = Record) ->
     export_record(Record, #ds_sub.start_time, [start_time, props, extra], #{}).
@@ -814,13 +873,17 @@ export_record(_, _, [], Acc) ->
 %% effects. Add `CBM:init' callback to the session behavior?
 ensure_timers() ->
     ensure_timer(pull),
-    ensure_timer(get_streams).
+    ensure_timer(get_streams),
+    ensure_timer(bump_last_alive_at).
 
--spec ensure_timer(pull | get_streams) -> ok.
+-spec ensure_timer(timer()) -> ok.
+ensure_timer(bump_last_alive_at = Type) ->
+    BumpInterval = emqx_config:get([session_persistence, last_alive_update_interval]),
+    ensure_timer(Type, BumpInterval);
 ensure_timer(Type) ->
     ensure_timer(Type, 100).
 
--spec ensure_timer(pull | get_streams, non_neg_integer()) -> ok.
+-spec ensure_timer(timer(), non_neg_integer()) -> ok.
 ensure_timer(Type, Timeout) ->
     _ = emqx_utils:start_timer(Timeout, {emqx_session, Type}),
     ok.
@@ -832,11 +895,16 @@ receive_maximum(ConnInfo) ->
     %% indicates that it's optional.
     maps:get(receive_maximum, ConnInfo, 65_535).
 
+-spec expiry_interval(conninfo()) -> millisecond().
+expiry_interval(ConnInfo) ->
+    maps:get(expiry_interval, ConnInfo, 0).
+
 -ifdef(TEST).
 list_all_sessions() ->
     DSSessionIds = mnesia:dirty_all_keys(?SESSION_TAB),
+    ConnInfo = #{},
     Sessions = lists:map(
-        fun(SessionID) -> {SessionID, session_open(SessionID)} end,
+        fun(SessionID) -> {SessionID, session_open(SessionID, ConnInfo)} end,
         DSSessionIds
     ),
     maps:from_list(Sessions).

+ 2 - 1
apps/emqx/src/emqx_persistent_session_ds.hrl

@@ -73,7 +73,8 @@
     id :: emqx_persistent_session_ds:id(),
     %% creation time
     created_at :: _Millisecond :: non_neg_integer(),
-    expires_at = never :: _Millisecond :: non_neg_integer() | never,
+    last_alive_at :: _Millisecond :: non_neg_integer(),
+    conninfo :: emqx_types:conninfo(),
     %% for future usage
     props = #{} :: map()
 }).

+ 8 - 0
apps/emqx/src/emqx_schema.erl

@@ -1781,6 +1781,14 @@ fields("session_persistence") ->
                     desc => ?DESC(session_ds_idle_poll_interval)
                 }
             )},
+        {"last_alive_update_interval",
+            sc(
+                timeout_duration(),
+                #{
+                    default => <<"5000ms">>,
+                    desc => ?DESC(session_ds_last_alive_update_interval)
+                }
+            )},
         {"force_persistence",
             sc(
                 boolean(),

+ 4 - 4
apps/emqx/src/emqx_session.erl

@@ -84,7 +84,7 @@
 -export([
     deliver/3,
     handle_timeout/3,
-    disconnect/2,
+    disconnect/3,
     terminate/3
 ]).
 
@@ -503,10 +503,10 @@ cancel_timer(Name, Timers) ->
 
 %%--------------------------------------------------------------------
 
--spec disconnect(clientinfo(), t()) ->
+-spec disconnect(clientinfo(), eqmx_types:conninfo(), t()) ->
     {idle | shutdown, t()}.
-disconnect(_ClientInfo, Session) ->
-    ?IMPL(Session):disconnect(Session).
+disconnect(_ClientInfo, ConnInfo, Session) ->
+    ?IMPL(Session):disconnect(Session, ConnInfo).
 
 -spec terminate(clientinfo(), Reason :: term(), t()) ->
     ok.

+ 3 - 3
apps/emqx/src/emqx_session_mem.erl

@@ -87,7 +87,7 @@
     deliver/3,
     replay/3,
     handle_timeout/3,
-    disconnect/1,
+    disconnect/2,
     terminate/2
 ]).
 
@@ -725,8 +725,8 @@ append(L1, L2) -> L1 ++ L2.
 
 %%--------------------------------------------------------------------
 
--spec disconnect(session()) -> {idle, session()}.
-disconnect(Session = #session{}) ->
+-spec disconnect(session(), emqx_types:conninfo()) -> {idle, session()}.
+disconnect(Session = #session{}, _ConnInfo) ->
     % TODO: isolate expiry timer / timeout handling here?
     {idle, Session}.
 

+ 39 - 2
apps/emqx/test/emqx_persistent_session_SUITE.erl

@@ -347,8 +347,6 @@ t_connect_discards_existing_client(Config) ->
     end.
 
 %% [MQTT-3.1.2-23]
-t_connect_session_expiry_interval(init, Config) -> skip_ds_tc(Config);
-t_connect_session_expiry_interval('end', _Config) -> ok.
 t_connect_session_expiry_interval(Config) ->
     ConnFun = ?config(conn_fun, Config),
     Topic = ?config(topic, Config),
@@ -356,6 +354,45 @@ t_connect_session_expiry_interval(Config) ->
     Payload = <<"test message">>,
     ClientId = ?config(client_id, Config),
 
+    {ok, Client1} = emqtt:start_link([
+        {clientid, ClientId},
+        {proto_ver, v5},
+        {properties, #{'Session-Expiry-Interval' => 30}}
+        | Config
+    ]),
+    {ok, _} = emqtt:ConnFun(Client1),
+    {ok, _, [?RC_GRANTED_QOS_1]} = emqtt:subscribe(Client1, STopic, ?QOS_1),
+    ok = emqtt:disconnect(Client1),
+
+    maybe_kill_connection_process(ClientId, Config),
+
+    publish(Topic, Payload, ?QOS_1),
+
+    {ok, Client2} = emqtt:start_link([
+        {clientid, ClientId},
+        {proto_ver, v5},
+        {properties, #{'Session-Expiry-Interval' => 30}},
+        {clean_start, false}
+        | Config
+    ]),
+    {ok, _} = emqtt:ConnFun(Client2),
+    [Msg | _] = receive_messages(1),
+    ?assertEqual({ok, iolist_to_binary(Topic)}, maps:find(topic, Msg)),
+    ?assertEqual({ok, iolist_to_binary(Payload)}, maps:find(payload, Msg)),
+    ?assertEqual({ok, ?QOS_1}, maps:find(qos, Msg)),
+    ok = emqtt:disconnect(Client2).
+
+%% [MQTT-3.1.2-23]
+%% TODO: un-skip after QoS 2 support is implemented in DS.
+t_connect_session_expiry_interval_qos2(init, Config) -> skip_ds_tc(Config);
+t_connect_session_expiry_interval_qos2('end', _Config) -> ok.
+t_connect_session_expiry_interval_qos2(Config) ->
+    ConnFun = ?config(conn_fun, Config),
+    Topic = ?config(topic, Config),
+    STopic = ?config(stopic, Config),
+    Payload = <<"test message">>,
+    ClientId = ?config(client_id, Config),
+
     {ok, Client1} = emqtt:start_link([
         {clientid, ClientId},
         {proto_ver, v5},