Przeglądaj źródła

feat(shared-sub): add round_robin_per_group strategy

add round robin per group strategy that balances load in a more predictable fashion when using no replication
Benjamin Krenn 3 lat temu
rodzic
commit
dec892e867

+ 2 - 0
apps/emqx/i18n/emqx_schema_i18n.conf

@@ -1115,6 +1115,7 @@ special characters are allowed.
             en: """Dispatch strategy for shared subscription.
 - `random`: dispatch the message to a random selected subscriber
 - `round_robin`: select the subscribers in a round-robin manner
+- `round_robin_per_group`: select the subscribers in round-robin fashion within each shared subscriber group
 - `sticky`: always use the last selected subscriber to dispatch,
 until the subscriber disconnects.
 - `hash`: select the subscribers by the hash of `clientIds`
@@ -1124,6 +1125,7 @@ subscriber was not found, send to a random subscriber cluster-wide
             cn: """共享订阅的分发策略名称。
 - `random`: 随机选择一个组内成员;
 - `round_robin`: 循环选择下一个成员;
+- `round_robin_per_group`: 在共享组内循环选择下一个成员;
 - `sticky`: 使用上一次选中的成员;
 - `hash`: 根据 ClientID 哈希映射到一个成员;
 - `local`: 随机分发到节点本地成成员,如果本地成员不存在,则随机分发

+ 18 - 2
apps/emqx/src/emqx_schema.erl

@@ -1160,7 +1160,15 @@ fields("broker") ->
             )},
         {"shared_subscription_strategy",
             sc(
-                hoconsc:enum([random, round_robin, sticky, local, hash_topic, hash_clientid]),
+                hoconsc:enum([
+                    random,
+                    round_robin,
+                    round_robin_per_group,
+                    sticky,
+                    local,
+                    hash_topic,
+                    hash_clientid
+                ]),
                 #{
                     default => round_robin,
                     desc => ?DESC(broker_shared_subscription_strategy)
@@ -1200,7 +1208,15 @@ fields("shared_subscription_group") ->
     [
         {"strategy",
             sc(
-                hoconsc:enum([random, round_robin, sticky, local, hash_topic, hash_clientid]),
+                hoconsc:enum([
+                    random,
+                    round_robin,
+                    round_robin_per_group,
+                    sticky,
+                    local,
+                    hash_topic,
+                    hash_clientid
+                ]),
                 #{
                     default => random,
                     desc => ?DESC(shared_subscription_strategy_enum)

+ 37 - 6
apps/emqx/src/emqx_shared_sub.erl

@@ -72,6 +72,7 @@
 -type strategy() ::
     random
     | round_robin
+    | round_robin_per_group
     | sticky
     | local
     %% same as hash_clientid, backward compatible
@@ -81,6 +82,7 @@
 
 -define(SERVER, ?MODULE).
 -define(TAB, emqx_shared_subscription).
+-define(SHARED_SUBS_ROUND_ROBIN_COUNTER, emqx_shared_subscriber_round_robin_counter).
 -define(SHARED_SUBS, emqx_shared_subscriber).
 -define(ALIVE_SUBS, emqx_alive_shared_subscribers).
 -define(SHARED_SUB_QOS1_DISPATCH_TIMEOUT_SECONDS, 5).
@@ -315,7 +317,14 @@ do_pick_subscriber(Group, Topic, round_robin, _ClientId, _SourceTopic, Count) ->
             N -> (N + 1) rem Count
         end,
     _ = erlang:put({shared_sub_round_robin, Group, Topic}, Rem),
-    Rem + 1.
+    Rem + 1;
+do_pick_subscriber(Group, Topic, round_robin_per_group, _ClientId, _SourceTopic, Count) ->
+    %% reset the counter to 1 if counter > subscriber count to avoid the counter to grow larger
+    %% than the current subscriber count.
+    %% if no counter for the given group topic exists - due to a configuration change - create a new one starting at 0
+    ets:update_counter(?SHARED_SUBS_ROUND_ROBIN_COUNTER, {Group, Topic}, {2, 1, Count, 1}, {
+        {Group, Topic}, 0
+    }).
 
 subscribers(Group, Topic) ->
     ets:select(?TAB, [{{emqx_shared_subscription, Group, Topic, '$1'}, [], ['$1']}]).
@@ -330,6 +339,7 @@ init([]) ->
     {atomic, PMon} = mria:transaction(?SHARED_SUB_SHARD, fun init_monitors/0),
     ok = emqx_tables:new(?SHARED_SUBS, [protected, bag]),
     ok = emqx_tables:new(?ALIVE_SUBS, [protected, set, {read_concurrency, true}]),
+    ok = emqx_tables:new(?SHARED_SUBS_ROUND_ROBIN_COUNTER, [public, set, {write_concurrency, true}]),
     {ok, update_stats(#state{pmon = PMon})}.
 
 init_monitors() ->
@@ -348,12 +358,14 @@ handle_call({subscribe, Group, Topic, SubPid}, _From, State = #state{pmon = PMon
         false -> ok = emqx_router:do_add_route(Topic, {Group, node()})
     end,
     ok = maybe_insert_alive_tab(SubPid),
+    ok = maybe_insert_round_robin_count({Group, Topic}),
     true = ets:insert(?SHARED_SUBS, {{Group, Topic}, SubPid}),
     {reply, ok, update_stats(State#state{pmon = emqx_pmon:monitor(SubPid, PMon)})};
 handle_call({unsubscribe, Group, Topic, SubPid}, _From, State) ->
     mria:dirty_delete_object(?TAB, record(Group, Topic, SubPid)),
     true = ets:delete_object(?SHARED_SUBS, {{Group, Topic}, SubPid}),
     delete_route_if_needed({Group, Topic}),
+    maybe_delete_round_robin_count({Group, Topic}),
     {reply, ok, State};
 handle_call(Req, _From, State) ->
     ?SLOG(error, #{msg => "unexpected_call", req => Req}),
@@ -395,6 +407,25 @@ code_change(_OldVsn, State, _Extra) ->
 %% Internal functions
 %%--------------------------------------------------------------------
 
+maybe_insert_round_robin_count({Group, _Topic} = GroupTopic) ->
+    strategy(Group) =:= round_robin_per_group andalso
+        ets:insert(?SHARED_SUBS_ROUND_ROBIN_COUNTER, {GroupTopic, 0}),
+    ok.
+
+maybe_delete_round_robin_count({Group, _Topic} = GroupTopic) ->
+    strategy(Group) =:= round_robin_per_group andalso
+        if_no_more_subscribers(GroupTopic, fun() ->
+            ets:delete(?SHARED_SUBS_ROUND_ROBIN_COUNTER, GroupTopic)
+        end),
+    ok.
+
+if_no_more_subscribers(GroupTopic, Fn) ->
+    case ets:member(?SHARED_SUBS, GroupTopic) of
+        true -> ok;
+        false -> Fn()
+    end,
+    ok.
+
 %% keep track of alive remote pids
 maybe_insert_alive_tab(Pid) when ?IS_LOCAL_PID(Pid) -> ok;
 maybe_insert_alive_tab(Pid) when is_pid(Pid) ->
@@ -407,6 +438,7 @@ cleanup_down(SubPid) ->
         fun(Record = #emqx_shared_subscription{topic = Topic, group = Group}) ->
             ok = mria:dirty_delete_object(?TAB, Record),
             true = ets:delete_object(?SHARED_SUBS, {{Group, Topic}, SubPid}),
+            maybe_delete_round_robin_count({Group, Topic}),
             delete_route_if_needed({Group, Topic})
         end,
         mnesia:dirty_match_object(#emqx_shared_subscription{_ = '_', subpid = SubPid})
@@ -430,8 +462,7 @@ is_alive_sub(Pid) when ?IS_LOCAL_PID(Pid) ->
 is_alive_sub(Pid) ->
     [] =/= ets:lookup(?ALIVE_SUBS, Pid).
 
-delete_route_if_needed({Group, Topic}) ->
-    case ets:member(?SHARED_SUBS, {Group, Topic}) of
-        true -> ok;
-        false -> ok = emqx_router:do_delete_route(Topic, {Group, node()})
-    end.
+delete_route_if_needed({Group, Topic} = GroupTopic) ->
+    if_no_more_subscribers(GroupTopic, fun() ->
+        ok = emqx_router:do_delete_route(Topic, {Group, node()})
+    end).

+ 283 - 3
apps/emqx/test/emqx_shared_sub_SUITE.erl

@@ -195,6 +195,266 @@ t_round_robin(_) ->
     ok = ensure_config(round_robin, true),
     test_two_messages(round_robin).
 
+t_round_robin_per_group(_) ->
+    ok = ensure_config(round_robin_per_group, true),
+    test_two_messages(round_robin_per_group).
+
+%% this would fail if executed with the standard round_robin strategy
+t_round_robin_per_group_even_distribution_one_group(_) ->
+    ok = ensure_config(round_robin_per_group, true),
+    Topic = <<"foo/bar">>,
+    Group = <<"group1">>,
+    {ok, ConnPid1} = emqtt:start_link([{clientid, <<"C0">>}]),
+    {ok, ConnPid2} = emqtt:start_link([{clientid, <<"C1">>}]),
+    {ok, _} = emqtt:connect(ConnPid1),
+    {ok, _} = emqtt:connect(ConnPid2),
+
+    emqtt:subscribe(ConnPid1, {<<"$share/", Group/binary, "/", Topic/binary>>, 0}),
+    emqtt:subscribe(ConnPid2, {<<"$share/", Group/binary, "/", Topic/binary>>, 0}),
+
+    %% publisher with persistent connection
+    {ok, PublisherPid} = emqtt:start_link(),
+    {ok, _} = emqtt:connect(PublisherPid),
+
+    lists:foreach(
+        fun(I) ->
+            Message = erlang:integer_to_binary(I),
+            emqtt:publish(PublisherPid, Topic, Message)
+        end,
+        lists:seq(0, 9)
+    ),
+
+    AllReceivedMessages = lists:map(
+        fun(#{client_pid := SubscriberPid, payload := Payload}) -> {SubscriberPid, Payload} end,
+        lists:reverse(recv_msgs(10))
+    ),
+    MessagesReceivedSubscriber1 = lists:filter(
+        fun({P, _Payload}) -> P == ConnPid1 end, AllReceivedMessages
+    ),
+    MessagesReceivedSubscriber2 = lists:filter(
+        fun({P, _Payload}) -> P == ConnPid2 end, AllReceivedMessages
+    ),
+
+    emqtt:stop(ConnPid1),
+    emqtt:stop(ConnPid2),
+    emqtt:stop(PublisherPid),
+
+    %% ensure each subscriber received 5 messages in alternating fashion
+    %% one receives all even and the other all uneven payloads
+    ?assertEqual(
+        [
+            {ConnPid1, <<"0">>},
+            {ConnPid1, <<"2">>},
+            {ConnPid1, <<"4">>},
+            {ConnPid1, <<"6">>},
+            {ConnPid1, <<"8">>}
+        ],
+        MessagesReceivedSubscriber1
+    ),
+
+    ?assertEqual(
+        [
+            {ConnPid2, <<"1">>},
+            {ConnPid2, <<"3">>},
+            {ConnPid2, <<"5">>},
+            {ConnPid2, <<"7">>},
+            {ConnPid2, <<"9">>}
+        ],
+        MessagesReceivedSubscriber2
+    ),
+    ok.
+
+t_round_robin_per_group_even_distribution_two_groups(_) ->
+    ok = ensure_config(round_robin_per_group, true),
+    Topic = <<"foo/bar">>,
+    {ok, ConnPid1} = emqtt:start_link([{clientid, <<"C0">>}]),
+    {ok, ConnPid2} = emqtt:start_link([{clientid, <<"C1">>}]),
+    {ok, ConnPid3} = emqtt:start_link([{clientid, <<"C2">>}]),
+    {ok, ConnPid4} = emqtt:start_link([{clientid, <<"C3">>}]),
+    ConnPids = [ConnPid1, ConnPid2, ConnPid3, ConnPid4],
+    lists:foreach(fun(P) -> emqtt:connect(P) end, ConnPids),
+
+    %% group1 subscribers
+    emqtt:subscribe(ConnPid1, {<<"$share/group1/", Topic/binary>>, 0}),
+    emqtt:subscribe(ConnPid2, {<<"$share/group1/", Topic/binary>>, 0}),
+    %% group2 subscribers
+    emqtt:subscribe(ConnPid3, {<<"$share/group2/", Topic/binary>>, 0}),
+    emqtt:subscribe(ConnPid4, {<<"$share/group2/", Topic/binary>>, 0}),
+
+    publish_fire_and_forget(10, Topic),
+
+    AllReceivedMessages = lists:map(
+        fun(#{client_pid := SubscriberPid, payload := Payload}) -> {SubscriberPid, Payload} end,
+        lists:reverse(recv_msgs(20))
+    ),
+    MessagesReceivedSubscriber1 = lists:filter(
+        fun({P, _Payload}) -> P == ConnPid1 end, AllReceivedMessages
+    ),
+    MessagesReceivedSubscriber2 = lists:filter(
+        fun({P, _Payload}) -> P == ConnPid2 end, AllReceivedMessages
+    ),
+    MessagesReceivedSubscriber3 = lists:filter(
+        fun({P, _Payload}) -> P == ConnPid3 end, AllReceivedMessages
+    ),
+    MessagesReceivedSubscriber4 = lists:filter(
+        fun({P, _Payload}) -> P == ConnPid4 end, AllReceivedMessages
+    ),
+
+    lists:foreach(fun(P) -> emqtt:stop(P) end, ConnPids),
+
+    %% ensure each subscriber received 5 messages in alternating fashion in each group
+    %% subscriber 1 and 3 should receive all even messages
+    %% subscriber 2 and 4 should receive all uneven messages
+    ?assertEqual(
+        [
+            {ConnPid3, <<"0">>},
+            {ConnPid3, <<"2">>},
+            {ConnPid3, <<"4">>},
+            {ConnPid3, <<"6">>},
+            {ConnPid3, <<"8">>}
+        ],
+        MessagesReceivedSubscriber3
+    ),
+
+    ?assertEqual(
+        [
+            {ConnPid2, <<"1">>},
+            {ConnPid2, <<"3">>},
+            {ConnPid2, <<"5">>},
+            {ConnPid2, <<"7">>},
+            {ConnPid2, <<"9">>}
+        ],
+        MessagesReceivedSubscriber2
+    ),
+
+    ?assertEqual(
+        [
+            {ConnPid4, <<"1">>},
+            {ConnPid4, <<"3">>},
+            {ConnPid4, <<"5">>},
+            {ConnPid4, <<"7">>},
+            {ConnPid4, <<"9">>}
+        ],
+        MessagesReceivedSubscriber4
+    ),
+
+    ?assertEqual(
+        [
+            {ConnPid1, <<"0">>},
+            {ConnPid1, <<"2">>},
+            {ConnPid1, <<"4">>},
+            {ConnPid1, <<"6">>},
+            {ConnPid1, <<"8">>}
+        ],
+        MessagesReceivedSubscriber1
+    ),
+    ok.
+
+t_round_robin_per_group_two_nodes_publish_to_same_node(_) ->
+    ensure_config(round_robin_per_group),
+    Node = start_slave('rr_p_g_t_n', 31337),
+    ensure_node_config(Node, round_robin_per_group),
+
+    %% connect two subscribers on each node
+    Topic = <<"foo/bar">>,
+    {ok, Subscriber0} = emqtt:start_link([{clientid, <<"C0">>}]),
+    {ok, Subscriber1} = emqtt:start_link([{clientid, <<"C1">>}]),
+    {ok, Subscriber2} = emqtt:start_link([{clientid, <<"C2">>}, {port, 31337}]),
+    {ok, Subscriber3} = emqtt:start_link([{clientid, <<"C3">>}, {port, 31337}]),
+    SubscriberPids = [Subscriber0, Subscriber1, Subscriber2, Subscriber3],
+    lists:foreach(fun(P) -> emqtt:connect(P) end, SubscriberPids),
+
+    %% node 1 subscribers
+    emqtt:subscribe(Subscriber0, {<<"$share/group1/", Topic/binary>>, 0}),
+    emqtt:subscribe(Subscriber1, {<<"$share/group1/", Topic/binary>>, 0}),
+    %% node 2 subscribers
+    emqtt:subscribe(Subscriber2, {<<"$share/group1/", Topic/binary>>, 0}),
+    emqtt:subscribe(Subscriber3, {<<"$share/group1/", Topic/binary>>, 0}),
+
+    publish_fire_and_forget(10, Topic),
+
+    AllMessages = recv_msgs(10),
+    MessagesBySubscriber = lists:foldl(
+        fun(#{client_pid := Subscriber, payload := Payload}, Acc) ->
+            maps:update_with(Subscriber, fun(T) -> [Payload | T] end, [Payload], Acc)
+        end,
+        maps:new(),
+        AllMessages
+    ),
+    lists:foreach(fun(Pid) -> emqtt:stop(Pid) end, SubscriberPids),
+    stop_slave(Node),
+
+    ?assertEqual(
+        #{
+            Subscriber0 => [<<"0">>, <<"4">>, <<"8">>],
+            Subscriber1 => [<<"1">>, <<"5">>, <<"9">>],
+            Subscriber2 => [<<"2">>, <<"6">>],
+            Subscriber3 => [<<"3">>, <<"7">>]
+        },
+        MessagesBySubscriber
+    ).
+
+t_round_robin_per_group_two_nodes_alternating_publish(_) ->
+    ensure_config(round_robin_per_group),
+    Node = start_slave('rr_p_g_t_n_2', 41338),
+    ensure_node_config(Node, round_robin_per_group),
+
+    %% connect two subscribers on each node
+    Topic = <<"foo/bar">>,
+    {ok, Subscriber0} = emqtt:start_link([{clientid, <<"C0">>}]),
+    {ok, Subscriber1} = emqtt:start_link([{clientid, <<"C1">>}]),
+    {ok, Subscriber2} = emqtt:start_link([{clientid, <<"C2">>}, {port, 41338}]),
+    {ok, Subscriber3} = emqtt:start_link([{clientid, <<"C3">>}, {port, 41338}]),
+    SubscriberPids = [Subscriber0, Subscriber1, Subscriber2, Subscriber3],
+    lists:foreach(fun(P) -> emqtt:connect(P) end, SubscriberPids),
+
+    %% node 1 subscribers
+    emqtt:subscribe(Subscriber0, {<<"$share/group1/", Topic/binary>>, 0}),
+    emqtt:subscribe(Subscriber1, {<<"$share/group1/", Topic/binary>>, 0}),
+    %% node 2 subscribers
+    emqtt:subscribe(Subscriber2, {<<"$share/group1/", Topic/binary>>, 0}),
+    emqtt:subscribe(Subscriber3, {<<"$share/group1/", Topic/binary>>, 0}),
+
+    %% alternate publish messages between the nodes
+    lists:foreach(
+        fun(I) ->
+            Message = erlang:integer_to_binary(I),
+            {ok, PublisherPid} =
+                case I rem 2 of
+                    0 -> emqtt:start_link();
+                    1 -> emqtt:start_link([{port, 41338}])
+                end,
+            {ok, _} = emqtt:connect(PublisherPid),
+            emqtt:publish(PublisherPid, Topic, Message),
+            emqtt:stop(PublisherPid),
+            ct:sleep(50)
+        end,
+        lists:seq(0, 9)
+    ),
+
+    AllMessages = recv_msgs(10),
+    MessagesBySubscriber = lists:foldl(
+        fun(#{client_pid := Subscriber, payload := Payload}, Acc) ->
+            maps:update_with(Subscriber, fun(T) -> [Payload | T] end, [Payload], Acc)
+        end,
+        maps:new(),
+        AllMessages
+    ),
+    lists:foreach(fun(Pid) -> emqtt:stop(Pid) end, SubscriberPids),
+    stop_slave(Node),
+
+    %% this result show that when clustered round_robin_per_group behaves like the normal round_robin
+    %% strategy meaning that subscribers receive two consecutive messages which is not ideal
+    ?assertEqual(
+        #{
+            Subscriber0 => [<<"0">>, <<"1">>, <<"8">>, <<"9">>],
+            Subscriber1 => [<<"2">>, <<"3">>],
+            Subscriber2 => [<<"4">>, <<"5">>],
+            Subscriber3 => [<<"6">>, <<"7">>]
+        },
+        MessagesBySubscriber
+    ).
+
 t_sticky(_) ->
     ok = ensure_config(sticky, true),
     test_two_messages(sticky).
@@ -292,7 +552,7 @@ test_two_messages(Strategy, Group) ->
     emqtt:subscribe(ConnPid2, {<<"$share/", Group/binary, "/", Topic/binary>>, 0}),
 
     Message1 = emqx_message:make(ClientId1, 0, Topic, <<"hello1">>),
-    Message2 = emqx_message:make(ClientId1, 0, Topic, <<"hello2">>),
+    Message2 = emqx_message:make(ClientId2, 0, Topic, <<"hello2">>),
     ct:sleep(100),
 
     emqx:publish(Message1),
@@ -307,6 +567,7 @@ test_two_messages(Strategy, Group) ->
     case Strategy of
         sticky -> ?assertEqual(UsedSubPid1, UsedSubPid2);
         round_robin -> ?assertNotEqual(UsedSubPid1, UsedSubPid2);
+        round_robin_per_group -> ?assertNotEqual(UsedSubPid1, UsedSubPid2);
         hash -> ?assertEqual(UsedSubPid1, UsedSubPid2);
         _ -> ok
     end,
@@ -348,7 +609,8 @@ t_per_group_config(_) ->
     ok = ensure_group_config(#{
         <<"local_group">> => local,
         <<"round_robin_group">> => round_robin,
-        <<"sticky_group">> => sticky
+        <<"sticky_group">> => sticky,
+        <<"round_robin_per_group_group">> => round_robin_per_group
     }),
     %% Each test is repeated 4 times because random strategy may technically pass the test
     %% so we run 8 tests to make random pass in only 1/256 runs
@@ -360,7 +622,9 @@ t_per_group_config(_) ->
     test_two_messages(sticky, <<"sticky_group">>),
     test_two_messages(sticky, <<"sticky_group">>),
     test_two_messages(round_robin, <<"round_robin_group">>),
-    test_two_messages(round_robin, <<"round_robin_group">>).
+    test_two_messages(round_robin, <<"round_robin_group">>),
+    test_two_messages(round_robin_per_group, <<"round_robin_per_group_group">>),
+    test_two_messages(round_robin_per_group, <<"round_robin_per_group_group">>).
 
 t_local(_) ->
     GroupConfig = #{
@@ -482,6 +746,9 @@ ensure_config(Strategy, AckEnabled) ->
     emqx_config:put([broker, shared_dispatch_ack_enabled], AckEnabled),
     ok.
 
+ensure_node_config(Node, Strategy) ->
+    rpc:call(Node, emqx_config, force_put, [[broker, shared_subscription_strategy], Strategy]).
+
 ensure_group_config(Group2Strategy) ->
     lists:foreach(
         fun({Group, Strategy}) ->
@@ -505,6 +772,19 @@ ensure_group_config(Node, Group2Strategy) ->
         maps:to_list(Group2Strategy)
     ).
 
+publish_fire_and_forget(Count, Topic) when Count > 1 ->
+    lists:foreach(
+        fun(I) ->
+            Message = erlang:integer_to_binary(I),
+            {ok, PublisherPid} = emqtt:start_link(),
+            {ok, _} = emqtt:connect(PublisherPid),
+            emqtt:publish(PublisherPid, Topic, Message),
+            emqtt:stop(PublisherPid),
+            ct:sleep(50)
+        end,
+        lists:seq(0, Count - 1)
+    ).
+
 subscribed(Group, Topic, Pid) ->
     lists:member(Pid, emqx_shared_sub:subscribers(Group, Topic)).