فهرست منبع

Add acking mechamism for shared dispatch (#1872)

* Add acking mechamism for shared dispatch

For QoS0 messages, no acking
For QoS1/2 messages, 'ACK' at any of events below:
 - ACK when QoS is downgraded to 0
 - Message is sent to connection process
'NACK' at any of events below:
 - Message queue is full and the receiving session starts to drop old messages
 - The receiving session crash
Upon 'NACK', messages are dispatched to the 'next' subscriber in the group,
depending on the shared subscription dispatch strategy.
spring2maz 7 سال پیش
والد
کامیت
a2c658ba19
10فایلهای تغییر یافته به همراه410 افزوده شده و 127 حذف شده
  1. 9 0
      etc/emqx.conf
  2. 6 0
      priv/emqx.schema
  3. 7 9
      src/emqx_packet.erl
  4. 64 31
      src/emqx_session.erl
  5. 152 22
      src/emqx_shared_sub.erl
  6. 2 2
      src/emqx_sm.erl
  7. 1 1
      test/emqx_broker_SUITE.erl
  8. 33 22
      test/emqx_mock_client.erl
  9. 1 1
      test/emqx_session_SUITE.erl
  10. 135 39
      test/emqx_shared_sub_SUITE.erl

+ 9 - 0
etc/emqx.conf

@@ -1904,6 +1904,15 @@ broker.session_locking_strategy = quorum
 ## - hash
 broker.shared_subscription_strategy = random
 
+## Enable/disable shared dispatch acknowledgement for QoS1 and QoS2 messages
+## This should allow messages to be dispatched to a different subscriber in
+## the group in case the picked (based on shared_subscription_strategy) one # is offline
+##
+## Value: Enum
+## - true
+## - false
+broker.shared_dispatch_ack_enabled = false
+
 ## Enable batch clean for deleted routes.
 ##
 ## Value: Flag

+ 6 - 0
priv/emqx.schema

@@ -1719,6 +1719,12 @@ end}.
     ]}}
 ]}.
 
+%% @doc Enable or disable shared dispatch acknowledgement for QoS1 and QoS2 messages
+{mapping, "broker.shared_dispatch_ack_enabled", "emqx.shared_dispatch_ack_enabled",
+ [ {default, false},
+   {datatype, {enum, [true, false]}}
+ ]}.
+
 {mapping, "broker.route_batch_clean", "emqx.route_batch_clean", [
   {default, on},
   {datatype, flag}

+ 7 - 9
src/emqx_packet.erl

@@ -127,15 +127,13 @@ from_message(PacketId, #message{qos = QoS, flags = Flags, headers = Headers,
                  variable = Publish, payload = Payload}.
 
 publish_props(Headers) ->
-    maps:filter(fun('Payload-Format-Indicator', _) -> true;
-                   ('Response-Topic',           _) -> true;
-                   ('Correlation-Data',         _) -> true;
-                   ('User-Property',            _) -> true;
-                   ('Subscription-Identifier',  _) -> true;
-                   ('Content-Type',             _) -> true;
-                   ('Message-Expiry-Interval',  _) -> true;
-                   (_Key, _Val) -> false
-                end , Headers).
+    maps:with(['Payload-Format-Indicator',
+               'Response-Topic',
+               'Correlation-Data',
+               'User-Property',
+               'Subscription-Identifier',
+               'Content-Type',
+               'Message-Expiry-Interval'], Headers).
 
 %% @doc Message from Packet
 -spec(to_message(emqx_types:credentials(), emqx_mqtt_types:packet())

+ 64 - 31
src/emqx_session.erl

@@ -257,19 +257,21 @@ subscribe(SPid, PacketId, Properties, TopicFilters) ->
     SubReq = {PacketId, Properties, TopicFilters},
     gen_server:cast(SPid, {subscribe, self(), SubReq}).
 
+%% @doc Called by connection processes when publishing messages
 -spec(publish(spid(), emqx_mqtt_types:packet_id(), emqx_types:message())
       -> {ok, emqx_types:deliver_results()}).
 publish(_SPid, _PacketId, Msg = #message{qos = ?QOS_0}) ->
-    %% Publish QoS0 message to broker directly
+    %% Publish QoS0 message directly
     emqx_broker:publish(Msg);
-
 publish(_SPid, _PacketId, Msg = #message{qos = ?QOS_1}) ->
-    %% Publish QoS1 message to broker directly
+    %% Publish QoS1 message directly
     emqx_broker:publish(Msg);
-
-publish(SPid, PacketId, Msg = #message{qos = ?QOS_2}) ->
-    %% Publish QoS2 message to session
-    gen_server:call(SPid, {publish, PacketId, Msg}, infinity).
+publish(SPid, PacketId, Msg = #message{qos = ?QOS_2, timestamp = Ts}) ->
+    %% Register QoS2 message packet ID (and timestamp) to session, then publish
+    case gen_server:call(SPid, {register_publish_packet_id, PacketId, Ts}, infinity) of
+        ok -> emqx_broker:publish(Msg);
+        {error, Reason} -> {error, Reason}
+    end.
 
 -spec(puback(spid(), emqx_mqtt_types:packet_id()) -> ok).
 puback(SPid, PacketId) ->
@@ -405,8 +407,9 @@ handle_call({discard, ByPid}, _From, State = #state{client_id = ClientId, conn_p
     ConnPid ! {shutdown, discard, {ClientId, ByPid}},
     {stop, {shutdown, discard}, ok, State};
 
-%% PUBLISH:
-handle_call({publish, PacketId, Msg = #message{qos = ?QOS_2, timestamp = Ts}}, _From,
+%% PUBLISH: This is only to register packetId to session state.
+%% The actual message dispatching should be done by the caller (e.g. connection) process.
+handle_call({register_publish_packet_id, PacketId, Ts}, _From,
             State = #state{awaiting_rel = AwaitingRel}) ->
     reply(case is_awaiting_full(State) of
               false ->
@@ -415,7 +418,7 @@ handle_call({publish, PacketId, Msg = #message{qos = ?QOS_2, timestamp = Ts}}, _
                           {{error, ?RC_PACKET_IDENTIFIER_IN_USE}, State};
                       false ->
                           State1 = State#state{awaiting_rel = maps:put(PacketId, Ts, AwaitingRel)},
-                          {emqx_broker:publish(Msg), ensure_await_rel_timer(State1)}
+                          {ok, ensure_await_rel_timer(State1)}
                   end;
               true ->
                   emqx_metrics:inc('messages/qos2/dropped'),
@@ -575,22 +578,15 @@ handle_info({dispatch, Topic, Msgs}, State) when is_list(Msgs) ->
                           end, State, Msgs)};
 
 %% Dispatch message
-handle_info({dispatch, Topic, Msg = #message{headers = Headers}},
-            State = #state{subscriptions = SubMap,
-                           topic_alias_maximum = TopicAliasMaximum}) when is_record(Msg, message) ->
-    TopicAlias = maps:get('Topic-Alias', Headers, undefined),
-    if
-        TopicAlias =:= undefined orelse TopicAlias =< TopicAliasMaximum ->
-            noreply(case maps:find(Topic, SubMap) of
-                        {ok, #{nl := Nl, qos := QoS, rap := Rap, subid := SubId}} ->
-                            run_dispatch_steps([{nl, Nl}, {qos, QoS}, {rap, Rap}, {subid, SubId}], Msg, State);
-                        {ok, #{nl := Nl, qos := QoS, rap := Rap}} ->
-                            run_dispatch_steps([{nl, Nl}, {qos, QoS}, {rap, Rap}], Msg, State);
-                        error ->
-                            dispatch(emqx_message:unset_flag(dup, Msg), State)
-                    end);
+handle_info({dispatch, Topic, Msg = #message{}}, State) ->
+    case emqx_shared_sub:is_ack_required(Msg) andalso not has_connection(State) of
         true ->
-            noreply(State)
+            %% Require ack, but we do not have connection
+            %% negative ack the message so it can try the next subscriber in the group
+            ok = emqx_shared_sub:nack_no_connection(Msg),
+            noreply(State);
+        false ->
+            handle_dispatch(Topic, Msg, State)
     end;
 
 
@@ -644,7 +640,6 @@ handle_info({'EXIT', Pid, Reason}, State = #state{conn_pid = ConnPid}) ->
     ?LOG(error, "Unexpected EXIT: conn_pid=~p, exit_pid=~p, reason=~p",
          [ConnPid, Pid, Reason], State),
     {noreply, State};
-
 handle_info(Info, State) ->
     emqx_logger:error("[Session] unexpected info: ~p", [Info]),
     {noreply, State}.
@@ -667,6 +662,27 @@ code_change(_OldVsn, State, _Extra) ->
 %% Internal functions
 %%------------------------------------------------------------------------------
 
+has_connection(#state{conn_pid = Pid}) -> is_pid(Pid) andalso is_process_alive(Pid).
+
+handle_dispatch(Topic, Msg = #message{headers = Headers},
+                State = #state{subscriptions = SubMap,
+                               topic_alias_maximum = TopicAliasMaximum
+                              }) ->
+    TopicAlias = maps:get('Topic-Alias', Headers, undefined),
+    if
+        TopicAlias =:= undefined orelse TopicAlias =< TopicAliasMaximum ->
+            noreply(case maps:find(Topic, SubMap) of
+                        {ok, #{nl := Nl, qos := QoS, rap := Rap, subid := SubId}} ->
+                            run_dispatch_steps([{nl, Nl}, {qos, QoS}, {rap, Rap}, {subid, SubId}], Msg, State);
+                        {ok, #{nl := Nl, qos := QoS, rap := Rap}} ->
+                            run_dispatch_steps([{nl, Nl}, {qos, QoS}, {rap, Rap}], Msg, State);
+                        error ->
+                            dispatch(emqx_message:unset_flag(dup, Msg), State)
+                    end);
+        true ->
+            noreply(State)
+    end.
+
 suback(_From, undefined, _ReasonCodes) ->
     ignore;
 suback(From, PacketId, ReasonCodes) ->
@@ -784,7 +800,12 @@ run_dispatch_steps([{nl, 1}|_Steps], #message{from = ClientId}, State = #state{c
     State;
 run_dispatch_steps([{nl, _}|Steps], Msg, State) ->
     run_dispatch_steps(Steps, Msg, State);
-run_dispatch_steps([{qos, SubQoS}|Steps], Msg = #message{qos = PubQoS}, State = #state{upgrade_qos = false}) ->
+run_dispatch_steps([{qos, SubQoS}|Steps], Msg0 = #message{qos = PubQoS}, State = #state{upgrade_qos = false}) ->
+    %% Ack immediately if a shared dispatch QoS is downgraded to 0
+    Msg = case SubQoS =:= ?QOS_0 of
+              true -> emqx_shared_sub:maybe_ack(Msg0);
+              false -> Msg0
+          end,
     run_dispatch_steps(Steps, Msg#message{qos = min(SubQoS, PubQoS)}, State);
 run_dispatch_steps([{qos, SubQoS}|Steps], Msg = #message{qos = PubQoS}, State = #state{upgrade_qos = true}) ->
     run_dispatch_steps(Steps, Msg#message{qos = max(SubQoS, PubQoS)}, State);
@@ -813,14 +834,16 @@ dispatch(Msg = #message{qos = QoS} = Msg,
          State = #state{next_pkt_id = PacketId, inflight = Inflight})
   when QoS =:= ?QOS_1 orelse QoS =:= ?QOS_2 ->
     case emqx_inflight:is_full(Inflight) of
-        true -> enqueue_msg(Msg, State);
+        true ->
+            enqueue_msg(Msg, State);
         false ->
             deliver(PacketId, Msg, State),
             await(PacketId, Msg, inc_stats(deliver, Msg, next_pkt_id(State)))
     end.
 
 enqueue_msg(Msg, State = #state{mqueue = Q}) ->
-    {_Dropped, NewQ} = emqx_mqueue:in(Msg, Q),
+    {Dropped, NewQ} = emqx_mqueue:in(Msg, Q),
+    Dropped =/= undefined andalso emqx_shared_sub:maybe_nack_dropped(Dropped),
     inc_stats(enqueue, Msg, State#state{mqueue = NewQ}).
 
 %%------------------------------------------------------------------------------
@@ -835,9 +858,19 @@ redeliver({PacketId, Msg = #message{qos = QoS}}, State) ->
 redeliver({pubrel, PacketId}, #state{conn_pid = ConnPid}) ->
     ConnPid ! {deliver, {pubrel, PacketId}}.
 
-deliver(PacketId, Msg, #state{conn_pid = ConnPid, binding = local}) ->
+deliver(PacketId, Msg, State) ->
+    %% Ack QoS1/QoS2 messages when message is delivered to connection.
+    %% NOTE: NOT to wait for PUBACK because:
+    %% The sender is monitoring this session process,
+    %% if the message is delivered to client but connection or session crashes,
+    %% sender will try to dispatch the message to the next shared subscriber.
+    %% This violates spec as QoS2 messages are not allowed to be sent to more
+    %% than one member in the group.
+    do_deliver(PacketId, emqx_shared_sub:maybe_ack(Msg), State).
+
+do_deliver(PacketId, Msg, #state{conn_pid = ConnPid, binding = local}) ->
     ConnPid ! {deliver, {publish, PacketId, Msg}};
-deliver(PacketId, Msg, #state{conn_pid = ConnPid, binding = remote}) ->
+do_deliver(PacketId, Msg, #state{conn_pid = ConnPid, binding = remote}) ->
     emqx_rpc:cast(node(ConnPid), erlang, send, [ConnPid, {deliver, {publish, PacketId, Msg}}]).
 
 %%------------------------------------------------------------------------------

+ 152 - 22
src/emqx_shared_sub.erl

@@ -27,7 +27,10 @@
 -export([start_link/0]).
 
 -export([subscribe/3, unsubscribe/3]).
--export([dispatch/3]).
+-export([dispatch/3, maybe_ack/1, maybe_nack_dropped/1, nack_no_connection/1, is_ack_required/1]).
+
+%% for testing
+-export([subscribers/2]).
 
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2,
@@ -36,10 +39,17 @@
 -define(SERVER, ?MODULE).
 -define(TAB, emqx_shared_subscription).
 -define(ALIVE_SUBS, emqx_alive_shared_subscribers).
+-define(SHARED_SUB_QOS1_DISPATCH_TIMEOUT_SECONDS, 5).
+-define(ack, shared_sub_ack).
+-define(nack(Reason), {shared_sub_nack, Reason}).
+-define(IS_LOCAL_PID(Pid), (is_pid(Pid) andalso node(Pid) =:= node())).
+-define(no_ack, no_ack).
 
 -record(state, {pmon}).
 -record(emqx_shared_subscription, {group, topic, subpid}).
 
+-include("emqx_mqtt.hrl").
+
 %%------------------------------------------------------------------------------
 %% Mnesia bootstrap
 %%------------------------------------------------------------------------------
@@ -62,10 +72,6 @@ mnesia(copy) ->
 start_link() ->
     gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).
 
--spec(strategy() -> random | round_robin | sticky | hash).
-strategy() ->
-    emqx_config:get_env(shared_subscription_strategy, round_robin).
-
 subscribe(undefined, _Topic, _SubPid) ->
     ok;
 subscribe(Group, Topic, SubPid) when is_pid(SubPid) ->
@@ -80,33 +86,147 @@ unsubscribe(Group, Topic, SubPid) when is_pid(SubPid) ->
 record(Group, Topic, SubPid) ->
     #emqx_shared_subscription{group = Group, topic = Topic, subpid = SubPid}.
 
-dispatch(Group, Topic, Delivery = #delivery{message = Msg, results = Results}) ->
+dispatch(Group, Topic, Delivery) ->
+    dispatch(Group, Topic, Delivery, _FailedSubs = []).
+
+dispatch(Group, Topic, Delivery = #delivery{message = Msg, results = Results}, FailedSubs) ->
     #message{from = ClientId} = Msg,
-    case pick(strategy(), ClientId, Group, Topic) of
-        false  -> Delivery;
-        SubPid -> SubPid ! {dispatch, Topic, Msg},
-                  Delivery#delivery{results = [{dispatch, {Group, Topic}, 1} | Results]}
+    case pick(strategy(), ClientId, Group, Topic, FailedSubs) of
+        false ->
+            Delivery;
+        SubPid ->
+            case do_dispatch(SubPid, Topic, Msg) of
+                ok ->
+                    Delivery#delivery{results = [{dispatch, {Group, Topic}, 1} | Results]};
+                {error, _Reason} ->
+                    %% failed to dispatch to this sub, try next
+                    %% 'Reason' is discarded so far, meaning for QoS1/2 messages
+                    %% if all subscribers are off line, the dispatch would faile
+                    %% even if there are sessions not expired yet.
+                    %% If required, we can make use of the 'no_connection' reason to perform
+                    %% retry without requiring acks, so the messages can be delivered
+                    %% to sessions of offline clients
+                    dispatch(Group, Topic, Delivery, [SubPid | FailedSubs])
+            end
+    end.
+
+-spec(strategy() -> random | round_robin | sticky | hash).
+strategy() ->
+    emqx_config:get_env(shared_subscription_strategy, round_robin).
+
+-spec(ack_enabled() -> boolean()).
+ack_enabled() ->
+    emqx_config:get_env(shared_dispatch_ack_enabled, false).
+
+do_dispatch(SubPid, Topic, Msg) when SubPid =:= self() ->
+    %% Deadlock otherwise
+    _ = erlang:send(SubPid, {dispatch, Topic, Msg}),
+    ok;
+do_dispatch(SubPid, Topic, Msg) ->
+    dispatch_per_qos(SubPid, Topic, Msg).
+
+%% return either 'ok' (when everything is fine) or 'error'
+dispatch_per_qos(SubPid, Topic, #message{qos = ?QOS_0} = Msg) ->
+    %% For QoS 0 message, send it as regular dispatch
+    _ = erlang:send(SubPid, {dispatch, Topic, Msg}),
+    ok;
+dispatch_per_qos(SubPid, Topic, Msg) ->
+    case ack_enabled() of
+        true ->
+            dispatch_with_ack(SubPid, Topic, Msg);
+        false ->
+            _ = erlang:send(SubPid, {dispatch, Topic, Msg}),
+            ok
+    end.
+
+dispatch_with_ack(SubPid, Topic, Msg) ->
+    %% For QoS 1/2 message, expect an ack
+    Ref = erlang:monitor(process, SubPid),
+    Sender = self(),
+    _ = erlang:send(SubPid, {dispatch, Topic, with_ack_ref(Msg, {Sender, Ref})}),
+    Timeout = case Msg#message.qos of
+                  ?QOS_1 -> timer:seconds(?SHARED_SUB_QOS1_DISPATCH_TIMEOUT_SECONDS);
+                  ?QOS_2 -> infinity
+              end,
+    try
+        receive
+            {Ref, ?ack} ->
+                ok;
+            {Ref, ?nack(Reason)} ->
+                %% the receive session may nack this message when its queue is full
+                {error, Reason};
+            {'DOWN', Ref, process, SubPid, Reason} ->
+                {error, Reason}
+        after
+            Timeout ->
+                {error, timeout}
+        end
+    after
+        _ = erlang:demonitor(Ref, [flush])
+    end.
+
+with_ack_ref(Msg, SenderRef) ->
+    emqx_message:set_headers(#{shared_dispatch_ack => SenderRef}, Msg).
+
+without_ack_ref(Msg) ->
+    emqx_message:set_headers(#{shared_dispatch_ack => ?no_ack}, Msg).
+
+get_ack_ref(Msg) ->
+    emqx_message:get_header(shared_dispatch_ack, Msg, ?no_ack).
+
+-spec(is_ack_required(emqx_types:message()) -> boolean()).
+is_ack_required(Msg) -> ?no_ack =/= get_ack_ref(Msg).
+
+%% @doc Negative ack dropped message due to message queue being full.
+-spec(maybe_nack_dropped(emqx_types:message()) -> ok).
+maybe_nack_dropped(Msg) ->
+    case get_ack_ref(Msg) of
+        ?no_ack -> ok;
+        {Sender, Ref} -> nack(Sender, Ref, drpped)
     end.
 
-pick(sticky, ClientId, Group, Topic) ->
+%% @doc Negative ack message due to connection down.
+%% Assuming this function is always called when ack is required
+%% i.e is_ack_required returned true.
+-spec(nack_no_connection(emqx_types:message()) -> ok).
+nack_no_connection(Msg) ->
+    {Sender, Ref} = get_ack_ref(Msg),
+    nack(Sender, Ref, no_connection).
+
+-spec(nack(pid(), reference(), dropped | no_connection) -> ok).
+nack(Sender, Ref, Reason) ->
+    erlang:send(Sender, {Ref, ?nack(Reason)}),
+    ok.
+
+-spec(maybe_ack(emqx_types:message()) -> emqx_types:message()).
+maybe_ack(Msg) ->
+    case get_ack_ref(Msg) of
+        ?no_ack ->
+            Msg;
+        {Sender, Ref} ->
+            erlang:send(Sender, {Ref, ?ack}),
+            without_ack_ref(Msg)
+    end.
+
+pick(sticky, ClientId, Group, Topic, FailedSubs) ->
     Sub0 = erlang:get({shared_sub_sticky, Group, Topic}),
-    case is_sub_alive(Sub0) of
+    case is_active_sub(Sub0, FailedSubs) of
         true ->
             %% the old subscriber is still alive
             %% keep using it for sticky strategy
             Sub0;
         false ->
             %% randomly pick one for the first message
-            Sub = do_pick(random, ClientId, Group, Topic),
+            Sub = do_pick(random, ClientId, Group, Topic, FailedSubs),
             %% stick to whatever pick result
             erlang:put({shared_sub_sticky, Group, Topic}, Sub),
             Sub
     end;
-pick(Strategy, ClientId, Group, Topic) ->
-    do_pick(Strategy, ClientId, Group, Topic).
+pick(Strategy, ClientId, Group, Topic, FailedSubs) ->
+    do_pick(Strategy, ClientId, Group, Topic, FailedSubs).
 
-do_pick(Strategy, ClientId, Group, Topic) ->
-    case subscribers(Group, Topic) of
+do_pick(Strategy, ClientId, Group, Topic, FailedSubs) ->
+    case subscribers(Group, Topic) -- FailedSubs of
         [] -> false;
         [Sub] -> Sub;
         All -> pick_subscriber(Group, Topic, Strategy, ClientId, All)
@@ -153,7 +273,7 @@ handle_call(Req, _From, State) ->
 
 handle_cast({monitor, SubPid}, State= #state{pmon = PMon}) ->
     NewPmon = emqx_pmon:monitor(SubPid, PMon),
-    ets:insert(?ALIVE_SUBS, {SubPid}),
+    ok = maybe_insert_alive_tab(SubPid),
     {noreply, update_stats(State#state{pmon = NewPmon})};
 handle_cast(Msg, State) ->
     emqx_logger:error("[SharedSub] unexpected cast: ~p", [Msg]),
@@ -189,8 +309,12 @@ code_change(_OldVsn, State, _Extra) ->
 %% Internal functions
 %%--------------------------------------------------------------------
 
+%% keep track of alive remote pids
+maybe_insert_alive_tab(Pid) when ?IS_LOCAL_PID(Pid) -> ok;
+maybe_insert_alive_tab(Pid) when is_pid(Pid) -> ets:insert(?ALIVE_SUBS, {Pid}), ok.
+
 cleanup_down(SubPid) ->
-    ets:delete(?ALIVE_SUBS, SubPid),
+    ?IS_LOCAL_PID(SubPid) orelse ets:delete(?ALIVE_SUBS, SubPid),
     lists:foreach(
         fun(Record) ->
             mnesia:dirty_delete_object(?TAB, Record)
@@ -199,7 +323,13 @@ cleanup_down(SubPid) ->
 update_stats(State) ->
     emqx_stats:setstat('subscriptions/shared/count', 'subscriptions/shared/max', ets:info(?TAB, size)), State.
 
-%% erlang:is_process_alive/1 is expensive
-%% and does not work with remote pids
-is_sub_alive(Sub) -> [] =/= ets:lookup(?ALIVE_SUBS, Sub).
+%% Return 'true' if the subscriber process is alive AND not in the failed list
+is_active_sub(Pid, FailedSubs) ->
+    is_alive_sub(Pid) andalso not lists:member(Pid, FailedSubs).
+
+%% erlang:is_process_alive/1 does not work with remote pid.
+is_alive_sub(Pid) when ?IS_LOCAL_PID(Pid) ->
+    erlang:is_process_alive(Pid);
+is_alive_sub(Pid) ->
+    [] =/= ets:lookup(?ALIVE_SUBS, Pid).
 

+ 2 - 2
src/emqx_sm.erl

@@ -59,8 +59,8 @@ open_session(SessAttrs = #{clean_start := true, client_id := ClientId, conn_pid
                  end,
     emqx_sm_locker:trans(ClientId, CleanStart);
 
-open_session(SessAttrs = #{clean_start          := false,
-                           client_id            := ClientId}) ->
+open_session(SessAttrs = #{clean_start := false,
+                           client_id   := ClientId}) ->
     ResumeStart = fun(_) ->
                       case resume_session(ClientId, SessAttrs) of
                           {ok, SPid} ->

+ 1 - 1
test/emqx_broker_SUITE.erl

@@ -164,7 +164,7 @@ start_session(_) ->
     emqx_session:publish(SessPid, 3, Message2),
     emqx_session:unsubscribe(SessPid, [{<<"topic/session">>, []}]),
     %% emqx_mock_client:stop(ClientPid).
-    emqx_mock_client:close_session(ClientPid, SessPid).
+    emqx_mock_client:close_session(ClientPid).
 
 %%--------------------------------------------------------------------
 %% Broker Group

+ 33 - 22
test/emqx_mock_client.erl

@@ -16,55 +16,54 @@
 
 -behaviour(gen_server).
 
--export([start_link/1, open_session/3, close_session/2, stop/1, get_last_message/1]).
+-export([start_link/1, open_session/3, open_session/4,
+         close_session/1, stop/1, get_last_message/1]).
 
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
          terminate/2, code_change/3]).
 
--record(state, {clean_start, client_id, client_pid, last_msg}).
+-record(state, {clean_start, client_id, client_pid, last_msg, session_pid}).
 
 start_link(ClientId) ->
     gen_server:start_link(?MODULE, [ClientId], []).
 
 open_session(ClientPid, ClientId, Zone) ->
-    gen_server:call(ClientPid, {start_session, ClientPid, ClientId, Zone}).
+    open_session(ClientPid, ClientId, Zone, _Attrs = #{}).
 
-close_session(ClientPid, SessPid) ->
-    gen_server:call(ClientPid, {stop_session, SessPid}).
+open_session(ClientPid, ClientId, Zone, Attrs0) ->
+    Attrs1 = default_session_attributes(Zone, ClientId, ClientPid),
+    Attrs = maps:merge(Attrs1, Attrs0),
+    gen_server:call(ClientPid, {start_session, ClientPid, ClientId, Attrs}).
+
+%% close session and terminate the client itself
+close_session(ClientPid) ->
+    gen_server:call(ClientPid, stop_session, infinity).
 
 stop(CPid) ->
-    gen_server:call(CPid, stop).
+    gen_server:call(CPid, stop, infinity).
 
 get_last_message(Pid) ->
-    gen_server:call(Pid, get_last_message).
+    gen_server:call(Pid, get_last_message, infinity).
 
 init([ClientId]) ->
+    erlang:process_flag(trap_exit, true),
     {ok, #state{clean_start = true,
                 client_id = ClientId,
                 last_msg = undefined
                }
     }.
 
-handle_call({start_session, ClientPid, ClientId, Zone}, _From, State) ->
-    Attrs = #{ zone                 => Zone,
-               client_id            => ClientId,
-               conn_pid             => ClientPid,
-               clean_start          => true,
-               username             => undefined,
-               expiry_interval      => 0,
-               max_inflight         => 0,
-               topic_alias_maximum  => 0,
-               will_msg             => undefined
-             },
+handle_call({start_session, ClientPid, ClientId, Attrs}, _From, State) ->
     {ok, SessPid} = emqx_sm:open_session(Attrs),
     {reply, {ok, SessPid},
      State#state{clean_start = true,
                  client_id = ClientId,
-                 client_pid = ClientPid
+                 client_pid = ClientPid,
+                 session_pid = SessPid
                 }};
-handle_call({stop_session, SessPid}, _From, State) ->
-    emqx_sm:close_session(SessPid),
-    {stop, normal, ok, State};
+handle_call(stop_session, _From, #state{session_pid = Pid} = State) ->
+    is_pid(Pid) andalso is_process_alive(Pid) andalso emqx_sm:close_session(Pid),
+    {stop, normal, ok, State#state{session_pid = undefined}};
 handle_call(get_last_message, _From, #state{last_msg = Msg} = State) ->
     {reply, Msg, State};
 handle_call(stop, _From, State) ->
@@ -86,3 +85,15 @@ terminate(_Reason, _State) ->
 code_change(_OldVsn, State, _Extra) ->
     {ok, State}.
 
+default_session_attributes(Zone, ClientId, ClientPid) ->
+    #{zone                => Zone,
+      client_id           => ClientId,
+      conn_pid            => ClientPid,
+      clean_start         => true,
+      username            => undefined,
+      expiry_interval     => 0,
+      max_inflight        => 0,
+      topic_alias_maximum => 0,
+      will_msg            => undefined
+     }.
+

+ 1 - 1
test/emqx_session_SUITE.erl

@@ -76,4 +76,4 @@ t_session_all(_) ->
     emqx_session:unsubscribe(SPid, [<<"topic">>]),
     timer:sleep(200),
     [] = emqx:subscriptions({SPid, <<"clientId">>}),
-    emqx_mock_client:close_session(ConnPid, SPid).
+    emqx_mock_client:close_session(ConnPid).

+ 135 - 39
test/emqx_shared_sub_SUITE.erl

@@ -16,7 +16,14 @@
 -module(emqx_shared_sub_SUITE).
 
 -export([all/0, init_per_suite/1, end_per_suite/1]).
--export([t_random_basic/1, t_random/1, t_round_robin/1, t_sticky/1, t_hash/1, t_not_so_sticky/1]).
+-export([t_random_basic/1,
+         t_random/1,
+         t_round_robin/1,
+         t_sticky/1,
+         t_hash/1,
+         t_not_so_sticky/1,
+         t_no_connection_nack/1
+        ]).
 
 -include("emqx.hrl").
 -include_lib("eunit/include/eunit.hrl").
@@ -24,7 +31,14 @@
 
 -define(wait(For, Timeout), wait_for(?FUNCTION_NAME, ?LINE, fun() -> For end, Timeout)).
 
-all() -> [t_random_basic, t_random, t_round_robin, t_sticky, t_hash, t_not_so_sticky].
+all() -> [t_random_basic,
+          t_random,
+          t_round_robin,
+          t_sticky,
+          t_hash,
+          t_not_so_sticky,
+          t_no_connection_nack
+         ].
 
 init_per_suite(Config) ->
     emqx_ct_broker_helpers:run_setup_steps(),
@@ -34,26 +48,91 @@ end_per_suite(_Config) ->
     emqx_ct_broker_helpers:run_teardown_steps().
 
 t_random_basic(_) ->
-    application:set_env(?APPLICATION, shared_subscription_strategy, random),
+    ok = ensure_config(random),
     ClientId = <<"ClientId">>,
     {ok, ConnPid} = emqx_mock_client:start_link(ClientId),
     {ok, SPid} = emqx_mock_client:open_session(ConnPid, ClientId, internal),
     Message1 = emqx_message:make(<<"ClientId">>, 2, <<"foo">>, <<"hello">>),
     emqx_session:subscribe(SPid, [{<<"foo">>, #{qos => 2, share => <<"group1">>}}]),
     %% wait for the subscription to show up
-    ?wait(ets:lookup(emqx_alive_shared_subscribers, SPid) =:= [{SPid}], 1000),
-    emqx_session:publish(SPid, 1, Message1),
+    ?wait(subscribed(<<"group1">>, <<"foo">>, SPid), 1000),
+    PacketId = 1,
+    emqx_session:publish(SPid, PacketId, Message1),
     ?wait(case emqx_mock_client:get_last_message(ConnPid) of
               {publish, 1, _} -> true;
               Other -> Other
           end, 1000),
-    emqx_session:puback(SPid, 2),
-    emqx_session:puback(SPid, 3, reasoncode),
-    emqx_session:pubrec(SPid, 4),
-    emqx_session:pubrec(SPid, 5, reasoncode),
-    emqx_session:pubrel(SPid, 6, reasoncode),
-    emqx_session:pubcomp(SPid, 7, reasoncode),
-    emqx_mock_client:close_session(ConnPid, SPid),
+    emqx_session:pubrec(SPid, PacketId, reasoncode),
+    emqx_session:pubcomp(SPid, PacketId, reasoncode),
+    emqx_mock_client:close_session(ConnPid),
+    ok.
+
+%% Start two subscribers share subscribe to "$share/g1/foo/bar"
+%% Set 'sticky' dispatch strategy, send 1st message to find
+%% out which member it picked, then close its connection
+%% send the second message, the message should be 'nack'ed
+%% by the sticky session and delivered to the 2nd session.
+t_no_connection_nack(_) ->
+    ok = ensure_config(sticky),
+    Publisher = <<"publisher">>,
+    Subscriber1 = <<"Subscriber1">>,
+    Subscriber2 = <<"Subscriber2">>,
+    QoS = 1,
+    Group = <<"g1">>,
+    Topic = <<"foo/bar">>,
+    {ok, PubConnPid} = emqx_mock_client:start_link(Publisher),
+    {ok, SubConnPid1} = emqx_mock_client:start_link(Subscriber1),
+    {ok, SubConnPid2} = emqx_mock_client:start_link(Subscriber2),
+    %% allow session to persist after connection shutdown
+    Attrs = #{expiry_interval => timer:seconds(30)},
+    {ok, P_Pid} = emqx_mock_client:open_session(PubConnPid, Publisher, internal, Attrs),
+    {ok, SPid1} = emqx_mock_client:open_session(SubConnPid1, Subscriber1, internal, Attrs),
+    {ok, SPid2} = emqx_mock_client:open_session(SubConnPid2, Subscriber2, internal, Attrs),
+    emqx_session:subscribe(SPid1, [{Topic, #{qos => QoS, share => Group}}]),
+    emqx_session:subscribe(SPid2, [{Topic, #{qos => QoS, share => Group}}]),
+    %% wait for the subscriptions to show up
+    ?wait(subscribed(Group, Topic, SPid1), 1000),
+    ?wait(subscribed(Group, Topic, SPid2), 1000),
+    MkPayload = fun(PacketId) -> iolist_to_binary(["hello-", integer_to_list(PacketId)]) end,
+    SendF = fun(PacketId) -> emqx_session:publish(P_Pid, PacketId, emqx_message:make(Publisher, QoS, Topic, MkPayload(PacketId))) end,
+    SendF(1),
+    Ref = make_ref(),
+    CasePid = self(),
+    Received =
+        fun(PacketId, ConnPid) ->
+                Payload = MkPayload(PacketId),
+                case emqx_mock_client:get_last_message(ConnPid) of
+                    {publish, _, #message{payload = Payload}} ->
+                        CasePid ! {Ref, PacketId, ConnPid},
+                        true;
+                    _Other ->
+                        false
+                end
+        end,
+    ?wait(Received(1, SubConnPid1) orelse Received(1, SubConnPid2), 1000),
+    %% This is the connection which was picked by broker to dispatch (sticky) for 1st message
+    ConnPid = receive {Ref, 1, Pid} -> Pid after 1000 -> error(timeout) end,
+    %% Now kill the connection, expect all following messages to be delivered to the other subscriber.
+    emqx_mock_client:stop(ConnPid),
+    %% sleep then make synced calls to session processes to ensure that
+    %% the connection pid's 'EXIT' message is propagated to the session process
+    %% also to be sure sessions are still alive
+    timer:sleep(5),
+    _ = emqx_session:info(SPid1),
+    _ = emqx_session:info(SPid2),
+    %% Now we know what is the other still alive connection
+    [TheOtherConnPid] = [SubConnPid1, SubConnPid2] -- [ConnPid],
+    %% Send some more messages
+    PacketIdList = lists:seq(2, 10),
+    lists:foreach(fun(Id) ->
+                          SendF(Id),
+                          ?wait(Received(Id, TheOtherConnPid), 1000)
+                  end, PacketIdList),
+    %% clean up
+    emqx_mock_client:close_session(PubConnPid),
+    emqx_sm:close_session(SPid1),
+    emqx_sm:close_session(SPid2),
+    emqx_mock_client:close_session(TheOtherConnPid),
     ok.
 
 t_random(_) ->
@@ -66,11 +145,11 @@ t_sticky(_) ->
     test_two_messages(sticky).
 
 t_hash(_) ->
-    test_two_messages(hash).
+    test_two_messages(hash, false).
 
 %% if the original subscriber dies, change to another one alive
 t_not_so_sticky(_) ->
-    application:set_env(?APPLICATION, shared_subscription_strategy, sticky),
+    ok = ensure_config(sticky),
     ClientId1 = <<"ClientId1">>,
     ClientId2 = <<"ClientId2">>,
     {ok, ConnPid1} = emqx_mock_client:start_link(ClientId1),
@@ -81,41 +160,45 @@ t_not_so_sticky(_) ->
     Message2 = emqx_message:make(ClientId1, 0, <<"foo/bar">>, <<"hello2">>),
     emqx_session:subscribe(SPid1, [{<<"foo/bar">>, #{qos => 0, share => <<"group1">>}}]),
     %% wait for the subscription to show up
-    ?wait(ets:lookup(emqx_alive_shared_subscribers, SPid1) =:= [{SPid1}], 1000),
+    ?wait(subscribed(<<"group1">>, <<"foo/bar">>, SPid1), 1000),
     emqx_session:publish(SPid1, 1, Message1),
     ?wait(case emqx_mock_client:get_last_message(ConnPid1) of
               {publish, _, #message{payload = <<"hello1">>}} -> true;
               Other -> Other
           end, 1000),
-    emqx_mock_client:close_session(ConnPid1, SPid1),
-    ?wait(ets:lookup(emqx_alive_shared_subscribers, SPid1) =:= [], 1000),
+    emqx_mock_client:close_session(ConnPid1),
+    ?wait(not subscribed(<<"group1">>, <<"foo/bar">>, SPid1), 1000),
     emqx_session:subscribe(SPid2, [{<<"foo/#">>, #{qos => 0, share => <<"group1">>}}]),
-    ?wait(ets:lookup(emqx_alive_shared_subscribers, SPid2) =:= [{SPid2}], 1000),
+    ?wait(subscribed(<<"group1">>, <<"foo/#">>, SPid2), 1000),
     emqx_session:publish(SPid2, 2, Message2),
     ?wait(case emqx_mock_client:get_last_message(ConnPid2) of
               {publish, _, #message{payload = <<"hello2">>}} -> true;
               Other -> Other
           end, 1000),
-    emqx_mock_client:close_session(ConnPid2, SPid2),
-    ?wait(ets:tab2list(emqx_alive_shared_subscribers) =:= [], 1000),
+    emqx_mock_client:close_session(ConnPid2),
+    ?wait(not subscribed(<<"group1">>, <<"foo/#">>, SPid2), 1000),
     ok.
 
 test_two_messages(Strategy) ->
-    application:set_env(?APPLICATION, shared_subscription_strategy, Strategy),
+    test_two_messages(Strategy, _WithAck = true).
+
+test_two_messages(Strategy, WithAck) ->
+    ok = ensure_config(Strategy, WithAck),
+    Topic = <<"foo/bar">>,
     ClientId1 = <<"ClientId1">>,
     ClientId2 = <<"ClientId2">>,
     {ok, ConnPid1} = emqx_mock_client:start_link(ClientId1),
     {ok, ConnPid2} = emqx_mock_client:start_link(ClientId2),
     {ok, SPid1} = emqx_mock_client:open_session(ConnPid1, ClientId1, internal),
     {ok, SPid2} = emqx_mock_client:open_session(ConnPid2, ClientId2, internal),
-    Message1 = emqx_message:make(ClientId1, 0, <<"foo/bar">>, <<"hello1">>),
-    Message2 = emqx_message:make(ClientId1, 0, <<"foo/bar">>, <<"hello2">>),
-    emqx_session:subscribe(SPid1, [{<<"foo/bar">>, #{qos => 0, share => <<"group1">>}}]),
-    emqx_session:subscribe(SPid2, [{<<"foo/bar">>, #{qos => 0, share => <<"group1">>}}]),
+    Message1 = emqx_message:make(ClientId1, 0, Topic, <<"hello1">>),
+    Message2 = emqx_message:make(ClientId1, 0, Topic, <<"hello2">>),
+    emqx_session:subscribe(SPid1, [{Topic, #{qos => 0, share => <<"group1">>}}]),
+    emqx_session:subscribe(SPid2, [{Topic, #{qos => 0, share => <<"group1">>}}]),
     %% wait for the subscription to show up
-    ?wait(ets:lookup(emqx_alive_shared_subscribers, SPid1) =:= [{SPid1}] andalso
-          ets:lookup(emqx_alive_shared_subscribers, SPid2) =:= [{SPid2}], 1000),
-    emqx_session:publish(SPid1, 1, Message1),
+    ?wait(subscribed(<<"group1">>, Topic, SPid1) andalso
+          subscribed(<<"group1">>, Topic, SPid2), 1000),
+    emqx_broker:publish(Message1),
     Me = self(),
     WaitF = fun(ExpectedPayload) ->
                     case last_message(ExpectedPayload, [ConnPid1, ConnPid2]) of
@@ -128,8 +211,7 @@ test_two_messages(Strategy) ->
             end,
     ?wait(WaitF(<<"hello1">>), 2000),
     UsedSubPid1 = receive {subscriber, P1} -> P1 end,
-    %% publish both messages with SPid1
-    emqx_session:publish(SPid1, 2, Message2),
+    emqx_broker:publish(Message2),
     ?wait(WaitF(<<"hello2">>), 2000),
     UsedSubPid2 = receive {subscriber, P2} -> P2 end,
     case Strategy of
@@ -138,8 +220,8 @@ test_two_messages(Strategy) ->
         hash -> ?assert(UsedSubPid1 =:= UsedSubPid2);
         _ -> ok
     end,
-    emqx_mock_client:close_session(ConnPid1, SPid1),
-    emqx_mock_client:close_session(ConnPid2, SPid2),
+    emqx_mock_client:close_session(ConnPid1),
+    emqx_mock_client:close_session(ConnPid2),
     ok.
 
 last_message(_ExpectedPayload, []) -> <<"not yet?">>;
@@ -153,6 +235,17 @@ last_message(ExpectedPayload, [Pid | Pids]) ->
 %% help functions
 %%------------------------------------------------------------------------------
 
+ensure_config(Strategy) ->
+    ensure_config(Strategy, _AckEnabled = true).
+
+ensure_config(Strategy, AckEnabled) ->
+    application:set_env(?APPLICATION, shared_subscription_strategy, Strategy),
+    application:set_env(?APPLICATION, shared_dispatch_ack_enabled, AckEnabled),
+    ok.
+
+subscribed(Group, Topic, Pid) ->
+    lists:member(Pid, emqx_shared_sub:subscribers(Group, Topic)).
+
 wait_for(Fn, Ln, F, Timeout) ->
     {Pid, Mref} = erlang:spawn_monitor(fun() -> wait_loop(F, catch_call(F)) end),
     wait_for_down(Fn, Ln, Timeout, Pid, Mref, false).
@@ -161,7 +254,9 @@ wait_for_down(Fn, Ln, Timeout, Pid, Mref, Kill) ->
     receive
         {'DOWN', Mref, process, Pid, normal} ->
             ok;
-        {'DOWN', Mref, process, Pid, {C, E, S}} ->
+        {'DOWN', Mref, process, Pid, {unexpected, Result}} ->
+            erlang:error({unexpected, Fn, Ln, Result});
+        {'DOWN', Mref, process, Pid, {crashed, {C, E, S}}} ->
             erlang:raise(C, {Fn, Ln, E}, S)
     after
         Timeout ->
@@ -176,23 +271,24 @@ wait_for_down(Fn, Ln, Timeout, Pid, Mref, Kill) ->
             end
     end.
 
-wait_loop(_F, true) -> exit(normal);
+wait_loop(_F, ok) -> exit(normal);
 wait_loop(F, LastRes) ->
-    Res = catch_call(F),
     receive
         stop -> erlang:exit(LastRes)
     after
-        100 -> wait_loop(F, Res)
+        100 ->
+            Res = catch_call(F),
+            wait_loop(F, Res)
     end.
 
 catch_call(F) ->
     try
         case F() of
-            true -> true;
-            Other -> erlang:error({unexpected, Other})
+            true -> ok;
+            Other -> {unexpected, Other}
         end
     catch
         C : E : S ->
-            {C, E, S}
+            {crashed, {C, E, S}}
     end.