Преглед на файлове

Merge pull request #11533 from thalesmg/ds-unsubscribe-m-20230828

feat(ds): close iterators when handling `UNSUBSCRIBE` packets
Thales Macedo Garitezi преди 2 години
родител
ревизия
954b73a9aa

+ 1 - 1
Makefile

@@ -296,7 +296,7 @@ $(foreach tt,$(ALL_ELIXIR_TGZS),$(eval $(call gen-elixir-tgz-target,$(tt))))
 
 .PHONY: fmt
 fmt: $(REBAR)
-	@$(SCRIPTS)/erlfmt -w '{apps,lib-ee}/*/{src,include,priv,test}/**/*.{erl,hrl,app.src,eterm}'
+	@$(SCRIPTS)/erlfmt -w '{apps,lib-ee}/*/{src,include,priv,test,integration_test}/**/*.{erl,hrl,app.src,eterm}'
 	@$(SCRIPTS)/erlfmt -w 'rebar.config.erl'
 	@mix format
 

+ 5 - 1
apps/emqx/include/emqx_session.hrl

@@ -49,7 +49,11 @@
     %% Awaiting PUBREL Timeout (Unit: millisecond)
     await_rel_timeout :: timeout(),
     %% Created at
-    created_at :: pos_integer()
+    created_at :: pos_integer(),
+    %% Topic filter to iterator ID mapping.
+    %% Note: we shouldn't serialize this when persisting sessions, as this information
+    %% also exists in the `?ITERATOR_REF_TAB' table.
+    iterators = #{} :: #{emqx_topic:topic() => emqx_ds:iterator_id()}
 }).
 
 -endif.

+ 135 - 5
apps/emqx/integration_test/emqx_ds_SUITE.erl

@@ -9,8 +9,10 @@
 -include_lib("stdlib/include/assert.hrl").
 -include_lib("common_test/include/ct.hrl").
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
+-include_lib("emqx/include/emqx_mqtt.hrl").
 
 -define(DS_SHARD, <<"local">>).
+-define(ITERATOR_REF_TAB, emqx_ds_iterator_ref).
 
 %%------------------------------------------------------------------------------
 %% CT boilerplate
@@ -31,9 +33,12 @@ end_per_suite(Config) ->
     emqx_cth_suite:stop(TCApps),
     ok.
 
-init_per_testcase(t_session_subscription_idempotency = TC, Config) ->
+init_per_testcase(TestCase, Config) when
+    TestCase =:= t_session_subscription_idempotency;
+    TestCase =:= t_session_unsubscription_idempotency
+->
     Cluster = cluster(#{n => 1}),
-    ClusterOpts = #{work_dir => emqx_cth_suite:work_dir(TC, Config)},
+    ClusterOpts = #{work_dir => emqx_cth_suite:work_dir(TestCase, Config)},
     NodeSpecs = emqx_cth_cluster:mk_nodespecs(Cluster, ClusterOpts),
     Nodes = emqx_cth_cluster:start(Cluster, ClusterOpts),
     [
@@ -46,7 +51,10 @@ init_per_testcase(t_session_subscription_idempotency = TC, Config) ->
 init_per_testcase(_TestCase, Config) ->
     Config.
 
-end_per_testcase(t_session_subscription_idempotency, Config) ->
+end_per_testcase(TestCase, Config) when
+    TestCase =:= t_session_subscription_idempotency;
+    TestCase =:= t_session_unsubscription_idempotency
+->
     Nodes = ?config(nodes, Config),
     ok = emqx_cth_cluster:stop(Nodes),
     ok;
@@ -91,12 +99,21 @@ get_mqtt_port(Node, Type) ->
     {_IP, Port} = erpc:call(Node, emqx_config, get, [[listeners, Type, default, bind]]),
     Port.
 
+get_all_iterator_refs(Node) ->
+    erpc:call(Node, mnesia, dirty_all_keys, [?ITERATOR_REF_TAB]).
+
 get_all_iterator_ids(Node) ->
     Fn = fun(K, _V, Acc) -> [K | Acc] end,
     erpc:call(Node, fun() ->
         emqx_ds_storage_layer:foldl_iterator_prefix(?DS_SHARD, <<>>, Fn, [])
     end).
 
+get_session_iterators(Node, ClientId) ->
+    erpc:call(Node, fun() ->
+        [ConnPid] = emqx_cm:lookup_channels(ClientId),
+        emqx_connection:info({channel, {session, iterators}}, sys:get_state(ConnPid))
+    end).
+
 wait_nodeup(Node) ->
     ?retry(
         _Sleep0 = 500,
@@ -159,6 +176,7 @@ t_session_subscription_idempotency(Config) ->
                 %% have to re-inject this so that we may stop the node succesfully at the
                 %% end....
                 ok = emqx_cth_cluster:set_node_opts(Node1, Node1Spec),
+                ok = snabbkaffe:forward_trace(Node1),
                 ct:pal("node ~p restarted", [Node1]),
                 ?tp(restarted_node, #{}),
                 ok
@@ -191,14 +209,20 @@ t_session_subscription_idempotency(Config) ->
             {ok, _} = emqtt:connect(Client1),
             ct:pal("subscribing 2"),
             {ok, _, [2]} = emqtt:subscribe(Client1, SubTopicFilter, qos2),
+            SessionIterators = get_session_iterators(Node1, ClientId),
 
             ok = emqtt:stop(Client1),
 
-            ok
+            #{session_iterators => SessionIterators}
         end,
-        fun(Trace) ->
+        fun(Res, Trace) ->
             ct:pal("trace:\n  ~p", [Trace]),
+            #{session_iterators := SessionIterators} = Res,
             %% Exactly one iterator should have been opened.
+            ?assertEqual(1, map_size(SessionIterators), #{iterators => SessionIterators}),
+            ?assertMatch(#{SubTopicFilter := _}, SessionIterators),
+            SubTopicFilterWords = emqx_topic:words(SubTopicFilter),
+            ?assertEqual([{ClientId, SubTopicFilterWords}], get_all_iterator_refs(Node1)),
             ?assertMatch({ok, [_]}, get_all_iterator_ids(Node1)),
             ?assertMatch(
                 {_IsNew = false, ClientId},
@@ -208,3 +232,109 @@ t_session_subscription_idempotency(Config) ->
         end
     ),
     ok.
+
+%% Check that we close the iterators before deleting the iterator id entry.
+t_session_unsubscription_idempotency(Config) ->
+    [Node1Spec | _] = ?config(node_specs, Config),
+    [Node1] = ?config(nodes, Config),
+    Port = get_mqtt_port(Node1, tcp),
+    SubTopicFilter = <<"t/+">>,
+    ClientId = <<"myclientid">>,
+    ?check_trace(
+        begin
+            ?force_ordering(
+                #{?snk_kind := persistent_session_ds_close_iterators, ?snk_span := {complete, _}},
+                _NEvents0 = 1,
+                #{?snk_kind := will_restart_node},
+                _Guard0 = true
+            ),
+            ?force_ordering(
+                #{?snk_kind := restarted_node},
+                _NEvents1 = 1,
+                #{?snk_kind := persistent_session_ds_iterator_delete, ?snk_span := start},
+                _Guard1 = true
+            ),
+
+            spawn_link(fun() ->
+                ?tp(will_restart_node, #{}),
+                ct:pal("restarting node ~p", [Node1]),
+                true = monitor_node(Node1, true),
+                ok = erpc:call(Node1, init, restart, []),
+                receive
+                    {nodedown, Node1} ->
+                        ok
+                after 10_000 ->
+                    ct:fail("node ~p didn't stop", [Node1])
+                end,
+                ct:pal("waiting for nodeup ~p", [Node1]),
+                wait_nodeup(Node1),
+                wait_gen_rpc_down(Node1Spec),
+                ct:pal("restarting apps on ~p", [Node1]),
+                Apps = maps:get(apps, Node1Spec),
+                ok = erpc:call(Node1, emqx_cth_suite, load_apps, [Apps]),
+                _ = erpc:call(Node1, emqx_cth_suite, start_apps, [Apps, Node1Spec]),
+                %% have to re-inject this so that we may stop the node succesfully at the
+                %% end....
+                ok = emqx_cth_cluster:set_node_opts(Node1, Node1Spec),
+                ok = snabbkaffe:forward_trace(Node1),
+                ct:pal("node ~p restarted", [Node1]),
+                ?tp(restarted_node, #{}),
+                ok
+            end),
+
+            ct:pal("starting 1"),
+            {ok, Client0} = emqtt:start_link([
+                {port, Port},
+                {clientid, ClientId},
+                {proto_ver, v5}
+            ]),
+            {ok, _} = emqtt:connect(Client0),
+            ct:pal("subscribing 1"),
+            {ok, _, [?RC_GRANTED_QOS_2]} = emqtt:subscribe(Client0, SubTopicFilter, qos2),
+            ct:pal("unsubscribing 1"),
+            process_flag(trap_exit, true),
+            catch emqtt:unsubscribe(Client0, SubTopicFilter),
+            receive
+                {'EXIT', {shutdown, _}} ->
+                    ok
+            after 0 -> ok
+            end,
+            process_flag(trap_exit, false),
+
+            {ok, _} = ?block_until(#{?snk_kind := restarted_node}, 15_000),
+            ct:pal("starting 2"),
+            {ok, Client1} = emqtt:start_link([
+                {port, Port},
+                {clientid, ClientId},
+                {proto_ver, v5}
+            ]),
+            {ok, _} = emqtt:connect(Client1),
+            ct:pal("subscribing 2"),
+            {ok, _, [?RC_GRANTED_QOS_2]} = emqtt:subscribe(Client1, SubTopicFilter, qos2),
+            ct:pal("unsubscribing 2"),
+            {{ok, _, [?RC_SUCCESS]}, {ok, _}} =
+                ?wait_async_action(
+                    emqtt:unsubscribe(Client1, SubTopicFilter),
+                    #{
+                        ?snk_kind := persistent_session_ds_iterator_delete,
+                        ?snk_span := {complete, _}
+                    },
+                    15_000
+                ),
+            SessionIterators = get_session_iterators(Node1, ClientId),
+
+            ok = emqtt:stop(Client1),
+
+            #{session_iterators => SessionIterators}
+        end,
+        fun(Res, Trace) ->
+            ct:pal("trace:\n  ~p", [Trace]),
+            #{session_iterators := SessionIterators} = Res,
+            %% No iterators remaining
+            ?assertEqual(#{}, SessionIterators),
+            ?assertEqual([], get_all_iterator_refs(Node1)),
+            ?assertEqual({ok, []}, get_all_iterator_ids(Node1)),
+            ok
+        end
+    ),
+    ok.

+ 61 - 2
apps/emqx/src/emqx_persistent_session_ds.erl

@@ -23,7 +23,8 @@
 -export([
     persist_message/1,
     open_session/1,
-    add_subscription/2
+    add_subscription/2,
+    del_subscription/3
 ]).
 
 -export([
@@ -32,7 +33,15 @@
 ]).
 
 %% RPC
--export([do_open_iterator/3]).
+-export([
+    ensure_iterator_closed_on_all_shards/1,
+    ensure_all_iterators_closed/1
+]).
+-export([
+    do_open_iterator/3,
+    do_ensure_iterator_closed/1,
+    do_ensure_all_iterators_closed/1
+]).
 
 %% FIXME
 -define(DS_SHARD, <<"local">>).
@@ -130,6 +139,56 @@ do_open_iterator(TopicFilter, StartMS, IteratorID) ->
     {ok, _It} = emqx_ds_storage_layer:ensure_iterator(?DS_SHARD, IteratorID, Replay),
     ok.
 
+-spec del_subscription(emqx_ds:iterator_id() | undefined, emqx_types:topic(), emqx_ds:session_id()) ->
+    ok | {skipped, disabled}.
+del_subscription(IteratorID, TopicFilterBin, DSSessionID) ->
+    ?WHEN_ENABLED(
+        begin
+            TopicFilter = emqx_topic:words(TopicFilterBin),
+            Ctx = #{iterator_id => IteratorID},
+            ?tp_span(
+                persistent_session_ds_close_iterators,
+                Ctx,
+                ok = ensure_iterator_closed_on_all_shards(IteratorID)
+            ),
+            ?tp_span(
+                persistent_session_ds_iterator_delete,
+                Ctx,
+                emqx_ds:session_del_iterator(DSSessionID, TopicFilter)
+            )
+        end
+    ).
+
+-spec ensure_iterator_closed_on_all_shards(emqx_ds:iterator_id()) -> ok.
+ensure_iterator_closed_on_all_shards(IteratorID) ->
+    %% Note: currently, shards map 1:1 to nodes, but this will change in the future.
+    Nodes = emqx:running_nodes(),
+    Results = emqx_persistent_session_ds_proto_v1:close_iterator(Nodes, IteratorID),
+    %% TODO: handle errors
+    true = lists:all(fun(Res) -> Res =:= {ok, ok} end, Results),
+    ok.
+
+%% RPC target.
+-spec do_ensure_iterator_closed(emqx_ds:iterator_id()) -> ok.
+do_ensure_iterator_closed(IteratorID) ->
+    ok = emqx_ds_storage_layer:discard_iterator(?DS_SHARD, IteratorID),
+    ok.
+
+-spec ensure_all_iterators_closed(emqx_ds:session_id()) -> ok.
+ensure_all_iterators_closed(DSSessionID) ->
+    %% Note: currently, shards map 1:1 to nodes, but this will change in the future.
+    Nodes = emqx:running_nodes(),
+    Results = emqx_persistent_session_ds_proto_v1:close_all_iterators(Nodes, DSSessionID),
+    %% TODO: handle errors
+    true = lists:all(fun(Res) -> Res =:= {ok, ok} end, Results),
+    ok.
+
+%% RPC target.
+-spec do_ensure_all_iterators_closed(emqx_ds:session_id()) -> ok.
+do_ensure_all_iterators_closed(DSSessionID) ->
+    ok = emqx_ds_storage_layer:discard_iterator_prefix(?DS_SHARD, DSSessionID),
+    ok.
+
 %%
 
 serialize_message(Msg) ->

+ 27 - 6
apps/emqx/src/emqx_session.erl

@@ -269,7 +269,9 @@ info(awaiting_rel_max, #session{max_awaiting_rel = Max}) ->
 info(await_rel_timeout, #session{await_rel_timeout = Timeout}) ->
     Timeout;
 info(created_at, #session{created_at = CreatedAt}) ->
-    CreatedAt.
+    CreatedAt;
+info(iterators, #session{iterators = Iterators}) ->
+    Iterators.
 
 %% @doc Get stats of the session.
 -spec stats(session()) -> emqx_types:stats().
@@ -318,8 +320,13 @@ is_subscriptions_full(#session{
 -spec add_persistent_subscription(emqx_types:topic(), emqx_types:clientid(), session()) ->
     session().
 add_persistent_subscription(TopicFilterBin, ClientId, Session) ->
-    _ = emqx_persistent_session_ds:add_subscription(TopicFilterBin, ClientId),
-    Session.
+    case emqx_persistent_session_ds:add_subscription(TopicFilterBin, ClientId) of
+        {ok, IteratorId, _IsNew} ->
+            Iterators = Session#session.iterators,
+            Session#session{iterators = Iterators#{TopicFilterBin => IteratorId}};
+        _ ->
+            Session
+    end.
 
 %%--------------------------------------------------------------------
 %% Client -> Broker: UNSUBSCRIBE
@@ -328,23 +335,37 @@ add_persistent_subscription(TopicFilterBin, ClientId, Session) ->
 -spec unsubscribe(emqx_types:clientinfo(), emqx_types:topic(), emqx_types:subopts(), session()) ->
     {ok, session()} | {error, emqx_types:reason_code()}.
 unsubscribe(
-    ClientInfo,
+    ClientInfo = #{clientid := ClientId},
     TopicFilter,
     UnSubOpts,
-    Session = #session{subscriptions = Subs}
+    Session0 = #session{subscriptions = Subs}
 ) ->
     case maps:find(TopicFilter, Subs) of
         {ok, SubOpts} ->
             ok = emqx_broker:unsubscribe(TopicFilter),
+            Session1 = remove_persistent_subscription(Session0, TopicFilter, ClientId),
             ok = emqx_hooks:run(
                 'session.unsubscribed',
                 [ClientInfo, TopicFilter, maps:merge(SubOpts, UnSubOpts)]
             ),
-            {ok, Session#session{subscriptions = maps:remove(TopicFilter, Subs)}};
+            {ok, Session1#session{subscriptions = maps:remove(TopicFilter, Subs)}};
         error ->
             {error, ?RC_NO_SUBSCRIPTION_EXISTED}
     end.
 
+-spec remove_persistent_subscription(session(), emqx_types:topic(), emqx_types:clientid()) ->
+    session().
+remove_persistent_subscription(Session, TopicFilterBin, ClientId) ->
+    Iterators = Session#session.iterators,
+    case maps:get(TopicFilterBin, Iterators, undefined) of
+        undefined ->
+            ok;
+        IteratorId ->
+            _ = emqx_persistent_session_ds:del_subscription(IteratorId, TopicFilterBin, ClientId),
+            ok
+    end,
+    Session#session{iterators = maps:remove(TopicFilterBin, Iterators)}.
+
 %%--------------------------------------------------------------------
 %% Client -> Broker: PUBLISH
 %%--------------------------------------------------------------------

+ 31 - 1
apps/emqx/src/proto/emqx_persistent_session_ds_proto_v1.erl

@@ -21,7 +21,9 @@
 -export([
     introduced_in/0,
 
-    open_iterator/4
+    open_iterator/4,
+    close_iterator/2,
+    close_all_iterators/2
 ]).
 
 -include_lib("emqx/include/bpapi.hrl").
@@ -47,3 +49,31 @@ open_iterator(Nodes, TopicFilter, StartMS, IteratorID) ->
         [TopicFilter, StartMS, IteratorID],
         ?TIMEOUT
     ).
+
+-spec close_iterator(
+    [node()],
+    emqx_ds:iterator_id()
+) ->
+    emqx_rpc:erpc_multicall(ok).
+close_iterator(Nodes, IteratorID) ->
+    erpc:multicall(
+        Nodes,
+        emqx_persistent_session_ds,
+        do_ensure_iterator_closed,
+        [IteratorID],
+        ?TIMEOUT
+    ).
+
+-spec close_all_iterators(
+    [node()],
+    emqx_ds:session_id()
+) ->
+    emqx_rpc:erpc_multicall(ok).
+close_all_iterators(Nodes, DSSessionID) ->
+    erpc:multicall(
+        Nodes,
+        emqx_persistent_session_ds,
+        do_ensure_all_iterators_closed,
+        [DSSessionID],
+        ?TIMEOUT
+    ).

+ 25 - 9
apps/emqx_durable_storage/src/emqx_ds.erl

@@ -30,6 +30,7 @@
     session_drop/1,
     session_suspend/1,
     session_add_iterator/2,
+    session_get_iterator_id/2,
     session_del_iterator/2,
     session_stats/0
 ]).
@@ -57,7 +58,9 @@
 %% Type declarations
 %%================================================================================
 
--type session_id() :: emqx_types:clientid().
+%% Currently, this is the clientid.  We avoid `emqx_types:clientid()' because that can be
+%% an atom, in theory (?).
+-type session_id() :: binary().
 
 -type iterator() :: term().
 
@@ -156,6 +159,7 @@ session_drop(ClientID) ->
     {atomic, ok} = mria:transaction(
         ?DS_SHARD,
         fun() ->
+            %% TODO: ensure all iterators from this clientid are closed?
             mnesia:delete({?SESSION_TAB, ClientID})
         end
     ),
@@ -201,14 +205,26 @@ session_add_iterator(DSSessionId, TopicFilter) ->
         end),
     Res.
 
-%% @doc Called when a client unsubscribes from a topic. Returns `true'
-%% if the session contained the subscription or `false' if it wasn't
-%% subscribed.
--spec session_del_iterator(session_id(), emqx_topic:words()) ->
-    {ok, boolean()} | {error, session_not_found}.
-session_del_iterator(_SessionId, _TopicFilter) ->
-    %% TODO
-    {ok, false}.
+-spec session_get_iterator_id(session_id(), emqx_topic:words()) ->
+    {ok, iterator_id()} | {error, not_found}.
+session_get_iterator_id(DSSessionId, TopicFilter) ->
+    IteratorRefId = {DSSessionId, TopicFilter},
+    case mnesia:dirty_read(?ITERATOR_REF_TAB, IteratorRefId) of
+        [] ->
+            {error, not_found};
+        [#iterator_ref{it_id = IteratorId}] ->
+            {ok, IteratorId}
+    end.
+
+%% @doc Called when a client unsubscribes from a topic.
+-spec session_del_iterator(session_id(), emqx_topic:words()) -> ok.
+session_del_iterator(DSSessionId, TopicFilter) ->
+    IteratorRefId = {DSSessionId, TopicFilter},
+    {atomic, ok} =
+        mria:transaction(?DS_SHARD, fun() ->
+            mnesia:delete(?ITERATOR_REF_TAB, IteratorRefId, write)
+        end),
+    ok.
 
 -spec session_stats() -> #{}.
 session_stats() ->