Jelajahi Sumber

Cache publishes before receiving the REGACK (#4695)

* refactor(emqx_sn): return new state from send_message

* fix(emqx_sn): send publish only after regack received
Shawn 4 tahun lalu
induk
melakukan
cb31d66bf2
2 mengubah file dengan 127 tambahan dan 104 penghapusan
  1. 106 96
      apps/emqx_sn/src/emqx_sn_gateway.erl
  2. 21 8
      apps/emqx_sn/test/emqx_sn_protocol_SUITE.erl

+ 106 - 96
apps/emqx_sn/src/emqx_sn_gateway.erl

@@ -67,6 +67,8 @@
 
 -type(maybe(T) :: T | undefined).
 
+-type(pending_msgs() :: #{integer() => [#mqtt_sn_message{}]}).
+
 -record(will_msg, {retain = false  :: boolean(),
                    qos    = ?QOS_0 :: emqx_mqtt_types:qos(),
                    topic           :: maybe(binary()),
@@ -92,7 +94,8 @@
                 stats_timer          :: maybe(reference()),
                 idle_timeout         :: integer(),
                 enable_qos3 = false  :: boolean(),
-                has_pending_pingresp = false :: boolean()
+                has_pending_pingresp = false :: boolean(),
+                pending_topic_ids = #{} :: pending_msgs()
                }).
 
 -define(INFO_KEYS, [socktype, peername, sockname, sockstate]). %, active_n]).
@@ -180,8 +183,8 @@ init([{_, SockPid, Sock}, Peername, Options]) ->
 callback_mode() -> state_functions.
 
 idle(cast, {incoming, ?SN_SEARCHGW_MSG(_Radius)}, State = #state{gwid = GwId}) ->
-    send_message(?SN_GWINFO_MSG(GwId, <<>>), State),
-    {keep_state, State, State#state.idle_timeout};
+    State0 = send_message(?SN_GWINFO_MSG(GwId, <<>>), State),
+    {keep_state, State0, State0#state.idle_timeout};
 
 idle(cast, {incoming, ?SN_CONNECT_MSG(Flags, _ProtoId, Duration, ClientId)}, State) ->
     #mqtt_sn_flags{will = Will, clean_start = CleanStart} = Flags,
@@ -221,12 +224,10 @@ idle(cast, {incoming, PingReq = ?SN_PINGREQ_MSG(_ClientId)}, State) ->
     handle_ping(PingReq, State);
 
 idle(cast, {outgoing, Packet}, State) ->
-    ok = handle_outgoing(Packet, State),
-    {keep_state, State};
+    {keep_state, handle_outgoing(Packet, State)};
 
 idle(cast, {connack, ConnAck}, State) ->
-    ok = handle_outgoing(ConnAck, State),
-    {next_state, connected, State};
+    {next_state, connected, handle_outgoing(ConnAck, State)};
 
 idle(timeout, _Timeout, State) ->
     stop(idle_timeout, State);
@@ -245,8 +246,8 @@ wait_for_will_topic(cast, {incoming, ?SN_WILLTOPIC_EMPTY_MSG}, State = #state{co
 wait_for_will_topic(cast, {incoming, ?SN_WILLTOPIC_MSG(Flags, Topic)}, State) ->
     #mqtt_sn_flags{qos = QoS, retain = Retain} = Flags,
     WillMsg = #will_msg{retain = Retain, qos = QoS, topic = Topic},
-    send_message(?SN_WILLMSGREQ_MSG(), State),
-    {next_state, wait_for_will_msg, State#state{will_msg = WillMsg}};
+    State0 = send_message(?SN_WILLMSGREQ_MSG(), State),
+    {next_state, wait_for_will_msg, State0#state{will_msg = WillMsg}};
 
 wait_for_will_topic(cast, {incoming, ?SN_ADVERTISE_MSG(_GwId, _Radius)}, _State) ->
     % ignore
@@ -256,12 +257,10 @@ wait_for_will_topic(cast, {incoming, ?SN_CONNECT_MSG(Flags, _ProtoId, Duration,
     do_2nd_connect(Flags, Duration, ClientId, State);
 
 wait_for_will_topic(cast, {outgoing, Packet}, State) ->
-    ok = handle_outgoing(Packet, State),
-    {keep_state, State};
+    {keep_state, handle_outgoing(Packet, State)};
 
 wait_for_will_topic(cast, {connack, ConnAck}, State) ->
-    ok = handle_outgoing(ConnAck, State),
-    {next_state, connected, State};
+    {next_state, connected, handle_outgoing(ConnAck, State)};
 
 wait_for_will_topic(cast, Event, _State) ->
     ?LOG(error, "wait_for_will_topic UNEXPECTED Event: ~p", [Event]),
@@ -284,18 +283,17 @@ wait_for_will_msg(cast, {incoming, ?SN_CONNECT_MSG(Flags, _ProtoId, Duration, Cl
     do_2nd_connect(Flags, Duration, ClientId, State);
 
 wait_for_will_msg(cast, {outgoing, Packet}, State) ->
-    ok = handle_outgoing(Packet, State),
-    {keep_state, State};
+    {keep_state, handle_outgoing(Packet, State)};
 
 wait_for_will_msg(cast, {connack, ConnAck}, State) ->
-    ok = handle_outgoing(ConnAck, State),
-    {next_state, connected, State};
+    {next_state, connected, handle_outgoing(ConnAck, State)};
 
 wait_for_will_msg(EventType, EventContent, State) ->
     handle_event(EventType, EventContent, wait_for_will_msg, State).
 
 connected(cast, {incoming, ?SN_REGISTER_MSG(_TopicId, MsgId, TopicName)},
           State = #state{clientid = ClientId, registry = Registry}) ->
+    State0 =
     case emqx_sn_registry:register_topic(Registry, self(), TopicName) of
         TopicId when is_integer(TopicId) ->
             ?LOG(debug, "register ClientId=~p, TopicName=~p, TopicId=~p", [ClientId, TopicName, TopicId]),
@@ -307,7 +305,7 @@ connected(cast, {incoming, ?SN_REGISTER_MSG(_TopicId, MsgId, TopicName)},
             ?LOG(error, "wildcard topic can not be registered! ClientId=~p, TopicName=~p", [ClientId, TopicName]),
             send_message(?SN_REGACK_MSG(?SN_INVALID_TOPIC_ID, MsgId, ?SN_RC_NOT_SUPPORTED), State)
     end,
-    {keep_state, State};
+    {keep_state, State0};
 
 connected(cast, {incoming, ?SN_PUBLISH_MSG(Flags, TopicId, MsgId, Data)},
           State = #state{enable_qos3 = EnableQoS3}) ->
@@ -339,19 +337,19 @@ connected(cast, {incoming, ?SN_UNSUBSCRIBE_MSG(Flags, MsgId, TopicId)}, State) -
 connected(cast, {incoming, PingReq = ?SN_PINGREQ_MSG(_ClientId)}, State) ->
     handle_ping(PingReq, State);
 
-connected(cast, {incoming, ?SN_REGACK_MSG(_TopicId, _MsgId, ?SN_RC_ACCEPTED)}, State) ->
-    {keep_state, State};
+connected(cast, {incoming, ?SN_REGACK_MSG(TopicId, _MsgId, ?SN_RC_ACCEPTED)}, State) ->
+    {keep_state, replay_no_reg_pending_publishes(TopicId, State)};
 connected(cast, {incoming, ?SN_REGACK_MSG(TopicId, MsgId, ReturnCode)}, State) ->
     ?LOG(error, "client does not accept register TopicId=~p, MsgId=~p, ReturnCode=~p",
          [TopicId, MsgId, ReturnCode]),
     {keep_state, State};
 
 connected(cast, {incoming, ?SN_DISCONNECT_MSG(Duration)}, State) ->
-    ok = send_message(?SN_DISCONNECT_MSG(undefined), State),
+    State0 = send_message(?SN_DISCONNECT_MSG(undefined), State),
     case Duration of
         undefined ->
-            handle_incoming(?DISCONNECT_PACKET(), State);
-        _Other -> goto_asleep_state(Duration, State)
+            handle_incoming(?DISCONNECT_PACKET(), State0);
+        _Other -> goto_asleep_state(Duration, State0)
     end;
 
 connected(cast, {incoming, ?SN_WILLTOPICUPD_MSG(Flags, Topic)}, State = #state{will_msg = WillMsg}) ->
@@ -359,12 +357,12 @@ connected(cast, {incoming, ?SN_WILLTOPICUPD_MSG(Flags, Topic)}, State = #state{w
                    undefined -> undefined;
                    _         -> update_will_topic(WillMsg, Flags, Topic)
                end,
-    send_message(?SN_WILLTOPICRESP_MSG(0), State),
-    {keep_state, State#state{will_msg = WillMsg1}};
+    State0 = send_message(?SN_WILLTOPICRESP_MSG(0), State),
+    {keep_state, State0#state{will_msg = WillMsg1}};
 
 connected(cast, {incoming, ?SN_WILLMSGUPD_MSG(Payload)}, State = #state{will_msg = WillMsg}) ->
-    ok = send_message(?SN_WILLMSGRESP_MSG(0), State),
-    {keep_state, State#state{will_msg = update_will_msg(WillMsg, Payload)}};
+    State0 = send_message(?SN_WILLMSGRESP_MSG(0), State),
+    {keep_state, State0#state{will_msg = update_will_msg(WillMsg, Payload)}};
 
 connected(cast, {incoming, ?SN_ADVERTISE_MSG(_GwId, _Radius)}, State) ->
     % ignore
@@ -374,17 +372,14 @@ connected(cast, {incoming, ?SN_CONNECT_MSG(Flags, _ProtoId, Duration, ClientId)}
     do_2nd_connect(Flags, Duration, ClientId, State);
 
 connected(cast, {outgoing, Packet}, State) ->
-    ok = handle_outgoing(Packet, State),
-    {keep_state, State};
+    {keep_state, handle_outgoing(Packet, State)};
 
 %% XXX: It's so strange behavoir!!!
 connected(cast, {connack, ConnAck}, State) ->
-    ok = handle_outgoing(ConnAck, State),
-    {keep_state, State};
+    {keep_state, handle_outgoing(ConnAck, State)};
 
 connected(cast, {shutdown, Reason, Packet}, State) ->
-    ok = handle_outgoing(Packet, State),
-    stop(Reason, State);
+    stop(Reason, handle_outgoing(Packet, State));
 
 connected(cast, {shutdown, Reason}, State) ->
     stop(Reason, State);
@@ -397,12 +392,12 @@ connected(EventType, EventContent, State) ->
     handle_event(EventType, EventContent, connected, State).
 
 asleep(cast, {incoming, ?SN_DISCONNECT_MSG(Duration)}, State) ->
-    ok = send_message(?SN_DISCONNECT_MSG(undefined), State),
+    State0 = send_message(?SN_DISCONNECT_MSG(undefined), State),
     case Duration of
         undefined ->
-            handle_incoming(?DISCONNECT_PACKET(), State);
+            handle_incoming(?DISCONNECT_PACKET(), State0);
         _Other ->
-            goto_asleep_state(Duration, State)
+            goto_asleep_state(Duration, State0)
     end;
 
 asleep(cast, {incoming, ?SN_PINGREQ_MSG(undefined)}, State) ->
@@ -411,13 +406,13 @@ asleep(cast, {incoming, ?SN_PINGREQ_MSG(undefined)}, State) ->
 
 asleep(cast, {incoming, ?SN_PINGREQ_MSG(ClientIdPing)},
        State = #state{clientid = ClientId, channel = Channel}) ->
+    inc_ping_counter(),
     case ClientIdPing of
         ClientId ->
-            inc_ping_counter(),
             case emqx_session:replay(emqx_channel:get_session(Channel)) of
                 {ok, [], Session0} ->
-                    send_message(?SN_PINGRESP_MSG(), State),
-                    {keep_state, State#state{
+                    State0 = send_message(?SN_PINGRESP_MSG(), State),
+                    {keep_state, State0#state{
                         channel = emqx_channel:set_session(Session0, Channel)}};
                 {ok, Publishes, Session0} ->
                     {Packets, Channel1} = emqx_channel:do_deliver(Publishes,
@@ -449,14 +444,13 @@ asleep(cast, {incoming, ?SN_CONNECT_MSG(_Flags, _ProtoId, _Duration, _ClientId)}
     % keepalive timer may timeout in asleep state and delete itself, need to restart keepalive
     % TODO: Fixme later.
     %% self() ! {keepalive, start, Interval},
-    send_connack(State),
-    {next_state, connected, State};
+    {next_state, connected, send_connack(State)};
 
 asleep(EventType, EventContent, State) ->
     handle_event(EventType, EventContent, asleep, State).
 
-awake(cast, {incoming, ?SN_REGACK_MSG(_TopicId, _MsgId, ?SN_RC_ACCEPTED)}, State) ->
-    {keep_state, State};
+awake(cast, {incoming, ?SN_REGACK_MSG(TopicId, _MsgId, ?SN_RC_ACCEPTED)}, State) ->
+    {keep_state, replay_no_reg_pending_publishes(TopicId, State)};
 
 awake(cast, {incoming, ?SN_REGACK_MSG(TopicId, MsgId, ReturnCode)}, State) ->
     ?LOG(error, "client does not accept register TopicId=~p, MsgId=~p, ReturnCode=~p",
@@ -467,8 +461,7 @@ awake(cast, {incoming, PingReq = ?SN_PINGREQ_MSG(_ClientId)}, State) ->
     handle_ping(PingReq, State);
 
 awake(cast, {outgoing, Packet}, State) ->
-    ok = handle_outgoing(Packet, State),
-    {keep_state, State};
+    {keep_state, handle_outgoing(Packet, State)};
 
 awake(cast, {incoming, ?SN_PUBACK_MSG(TopicId, MsgId, ReturnCode)}, State) ->
     do_puback(TopicId, MsgId, ReturnCode, awake, State);
@@ -482,8 +475,8 @@ awake(cast, try_goto_asleep, State=#state{channel = Channel,
     Inflight = emqx_session:info(inflight, emqx_channel:get_session(Channel)),
     case emqx_inflight:size(Inflight) of
         0 when PingPending =:= true ->
-            send_message(?SN_PINGRESP_MSG(), State),
-            goto_asleep_state(State#state{has_pending_pingresp = false});
+            State0 = send_message(?SN_PINGRESP_MSG(), State),
+            goto_asleep_state(State0#state{has_pending_pingresp = false});
         0 when PingPending =:= false ->
             goto_asleep_state(State);
         _Size ->
@@ -499,13 +492,13 @@ handle_event({call, From}, Req, _StateName, State) ->
             gen_server:reply(From, Reply),
             {keep_state, NState};
         {stop, Reason, Reply, NState} ->
-            case NState#state.sockstate of
+            State0 = case NState#state.sockstate of
                 running ->
                     send_message(?SN_DISCONNECT_MSG(undefined), NState);
-                _ -> ok
+                _ -> NState
             end,
             gen_server:reply(From, Reply),
-            stop(Reason, NState)
+            stop(Reason, State0)
     end;
 
 handle_event(info, {datagram, SockPid, Data}, StateName,
@@ -526,9 +519,10 @@ handle_event(info, {datagram, SockPid, Data}, StateName,
     end;
 
 handle_event(info, {deliver, _Topic, Msg}, asleep,
-             State = #state{channel = Channel}) ->
+             State = #state{channel = Channel, pending_topic_ids = Pendings}) ->
     % section 6.14, Support of sleeping clients
-    ?LOG(debug, "enqueue downlink message in asleep state Msg=~0p", [Msg]),
+    ?LOG(debug, "enqueue downlink message in asleep state, msg: ~0p, pending_topic_ids: ~0p",
+         [Msg, Pendings]),
     Session = emqx_session:enqueue(Msg, emqx_channel:get_session(Channel)),
     {keep_state, State#state{channel = emqx_channel:set_session(Session, Channel)}};
 
@@ -537,8 +531,7 @@ handle_event(info, Deliver = {deliver, _Topic, _Msg}, _StateName,
     handle_return(emqx_channel:handle_deliver([Deliver], Channel), State);
 
 handle_event(info, {redeliver, {?PUBREL, MsgId}}, _StateName, State) ->
-    send_message(?SN_PUBREC_MSG(?SN_PUBREL, MsgId), State),
-    {keep_state, State};
+    {keep_state, send_message(?SN_PUBREC_MSG(?SN_PUBREL, MsgId), State)};
 
 %% FIXME: Is not unused in v4.x
 handle_event(info, {timeout, TRef, emit_stats}, _StateName,
@@ -634,8 +627,7 @@ handle_return({shutdown, Reason, NChannel}, State, _AddEvents) ->
     stop(Reason, State#state{channel = NChannel});
 handle_return({shutdown, Reason, OutPacket, NChannel}, State, _AddEvents) ->
     NState = State#state{channel = NChannel},
-    ok = handle_outgoing(OutPacket, NState),
-    stop(Reason, NState).
+    stop(Reason, handle_outgoing(OutPacket, NState)).
 
 outgoing_events(Actions) ->
     lists:map(fun outgoing_event/1, Actions).
@@ -702,9 +694,9 @@ call(Pid, Req, Timeout) ->
 %% Internal Functions
 %%--------------------------------------------------------------------
 handle_ping(_PingReq, State) ->
-    ok = send_message(?SN_PINGRESP_MSG(), State),
+    State0 = send_message(?SN_PINGRESP_MSG(), State),
     inc_ping_counter(),
-    {keep_state, State}.
+    {keep_state, State0}.
 
 inc_ping_counter() ->
     inc_counter(recv_msg, 1).
@@ -768,13 +760,13 @@ send_connack(State) ->
     send_message(?SN_CONNACK_MSG(?SN_RC_ACCEPTED), State).
 
 send_message(Msg = #mqtt_sn_message{type = Type},
-             #state{sockpid = SockPid, peername = Peername}) ->
+             State = #state{sockpid = SockPid, peername = Peername}) ->
     ?LOG(debug, "SEND ~s~n", [emqx_sn_frame:format(Msg)]),
     inc_outgoing_stats(Type),
     Data = emqx_sn_frame:serialize(Msg),
     ok = emqx_metrics:inc('bytes.sent', iolist_size(Data)),
     SockPid ! {datagram, Peername, Data},
-    ok.
+    State.
 
 goto_asleep_state(State) ->
     goto_asleep_state(undefined, State).
@@ -834,8 +826,8 @@ do_connect(ClientId, CleanStart, WillFlag, Duration, State) ->
                                    properties  = OnlyOneInflight
                                   },
     case WillFlag of
-        true -> send_message(?SN_WILLTOPICREQ_MSG(), State),
-                NState = State#state{connpkt  = ConnPkt,
+        true -> State0 = send_message(?SN_WILLTOPICREQ_MSG(), State),
+                NState = State0#state{connpkt  = ConnPkt,
                                      clientid = ClientId,
                                      keepalive_interval = Duration
                                     },
@@ -872,11 +864,11 @@ handle_subscribe(?SN_NORMAL_TOPIC, TopicName, QoS, MsgId,
                  State=#state{registry = Registry}) ->
     case emqx_sn_registry:register_topic(Registry, self(), TopicName) of
         {error, too_large} ->
-            ok = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS},
+            State0 = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS},
                                              ?SN_INVALID_TOPIC_ID,
                                              MsgId,
                                              ?SN_RC_INVALID_TOPIC_ID), State),
-            {keep_state, State};
+            {keep_state, State0};
         {error, wildcard_topic} ->
             proto_subscribe(TopicName, QoS, MsgId, ?SN_INVALID_TOPIC_ID, State);
         NewTopicId when is_integer(NewTopicId) ->
@@ -887,11 +879,11 @@ handle_subscribe(?SN_PREDEFINED_TOPIC, TopicId, QoS, MsgId,
                  State = #state{registry = Registry}) ->
     case emqx_sn_registry:lookup_topic(Registry, self(), TopicId) of
         undefined ->
-            ok = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS},
+            State0 = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS},
                                              TopicId,
                                              MsgId,
                                              ?SN_RC_INVALID_TOPIC_ID), State),
-            {next_state, connected, State};
+            {next_state, connected, State0};
         PredefinedTopic ->
             proto_subscribe(PredefinedTopic, QoS, MsgId, TopicId, State)
     end;
@@ -904,11 +896,11 @@ handle_subscribe(?SN_SHORT_TOPIC, TopicId, QoS, MsgId, State) ->
     proto_subscribe(TopicName, QoS, MsgId, ?SN_INVALID_TOPIC_ID, State);
 
 handle_subscribe(_, _TopicId, QoS, MsgId, State) ->
-    ok = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS},
+    State0 = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS},
                                      ?SN_INVALID_TOPIC_ID,
                                      MsgId,
                                      ?SN_RC_INVALID_TOPIC_ID), State),
-    {keep_state, State}.
+    {keep_state, State0}.
 
 handle_unsubscribe(?SN_NORMAL_TOPIC, TopicId, MsgId, State) ->
     proto_unsubscribe(TopicId, MsgId, State);
@@ -917,8 +909,7 @@ handle_unsubscribe(?SN_PREDEFINED_TOPIC, TopicId, MsgId,
                    State = #state{registry = Registry}) ->
     case emqx_sn_registry:lookup_topic(Registry, self(), TopicId) of
         undefined ->
-            ok = send_message(?SN_UNSUBACK_MSG(MsgId), State),
-            {keep_state, State};
+            {keep_state, send_message(?SN_UNSUBACK_MSG(MsgId), State)};
         PredefinedTopic ->
             proto_unsubscribe(PredefinedTopic, MsgId, State)
     end;
@@ -931,8 +922,7 @@ handle_unsubscribe(?SN_SHORT_TOPIC, TopicId, MsgId, State) ->
     proto_unsubscribe(TopicName, MsgId, State);
 
 handle_unsubscribe(_, _TopicId, MsgId, State) ->
-    send_message(?SN_UNSUBACK_MSG(MsgId), State),
-    {keep_state, State}.
+    {keep_state, send_message(?SN_UNSUBACK_MSG(MsgId), State)}.
 
 do_publish(?SN_NORMAL_TOPIC, TopicName, Data, Flags, MsgId, State) ->
     %% XXX: Handle normal topic id as predefined topic id, to be compatible with paho mqtt-sn library
@@ -944,25 +934,26 @@ do_publish(?SN_PREDEFINED_TOPIC, TopicId, Data, Flags, MsgId,
     NewQoS = get_corrected_qos(QoS),
     case emqx_sn_registry:lookup_topic(Registry, self(), TopicId) of
         undefined ->
-            (NewQoS =/= ?QOS_0) andalso send_message(?SN_PUBACK_MSG(TopicId, MsgId, ?SN_RC_INVALID_TOPIC_ID), State),
-            {keep_state, State};
+            {keep_state, maybe_send_puback(NewQoS, TopicId, MsgId, ?SN_RC_INVALID_TOPIC_ID,
+                State)};
         TopicName ->
             proto_publish(TopicName, Data, Dup, NewQoS, Retain, MsgId, TopicId, State)
     end;
+
 do_publish(?SN_SHORT_TOPIC, STopicName, Data, Flags, MsgId, State) ->
     #mqtt_sn_flags{qos = QoS, dup = Dup, retain = Retain} = Flags,
     NewQoS = get_corrected_qos(QoS),
     <<TopicId:16>> = STopicName ,
     case emqx_topic:wildcard(STopicName) of
         true ->
-            (NewQoS =/= ?QOS_0) andalso send_message(?SN_PUBACK_MSG(TopicId, MsgId, ?SN_RC_NOT_SUPPORTED), State),
-            {keep_state, State};
+            {keep_state, maybe_send_puback(NewQoS, TopicId, MsgId, ?SN_RC_NOT_SUPPORTED,
+                State)};
         false ->
             proto_publish(STopicName, Data, Dup, NewQoS, Retain, MsgId, TopicId, State)
     end;
 do_publish(_, TopicId, _Data, #mqtt_sn_flags{qos = QoS}, MsgId, State) ->
-    (QoS =/= ?QOS_0) andalso send_message(?SN_PUBACK_MSG(TopicId, MsgId, ?SN_RC_NOT_SUPPORTED), State),
-    {keep_state, State}.
+    {keep_state, maybe_send_puback(QoS, TopicId, MsgId, ?SN_RC_NOT_SUPPORTED,
+        State)}.
 
 do_publish_will(#state{will_msg = undefined}) ->
     ok;
@@ -986,12 +977,11 @@ do_puback(TopicId, MsgId, ReturnCode, StateName,
             handle_incoming(?PUBACK_PACKET(MsgId), StateName, State);
         ?SN_RC_INVALID_TOPIC_ID ->
             case emqx_sn_registry:lookup_topic(Registry, self(), TopicId) of
-                undefined -> ok;
+                undefined -> {keep_state, State};
                 TopicName ->
                     %%notice that this TopicName maybe normal or predefined,
                     %% involving the predefined topic name in register to enhance the gateway's robustness even inconsistent with MQTT-SN channels
-                    send_register(TopicName, TopicId, MsgId, State),
-                    {keep_state, State}
+                    {keep_state, send_register(TopicName, TopicId, MsgId, State)}
             end;
         _ ->
             ?LOG(error, "CAN NOT handle PUBACK ReturnCode=~p", [ReturnCode]),
@@ -1070,30 +1060,45 @@ channel_handle_in(Packet = ?PACKET(Type), #state{channel = Channel}) ->
     emqx_channel:handle_in(Packet, Channel).
 
 handle_outgoing(Packets, State) when is_list(Packets) ->
-    lists:foreach(fun(Packet) -> handle_outgoing(Packet, State) end, Packets);
+    lists:foldl(fun(Packet, State0) ->
+        handle_outgoing(Packet, State0)
+    end, State, Packets);
 
-handle_outgoing(PubPkt = ?PUBLISH_PACKET(QoS, TopicName, PacketId, Payload),
+handle_outgoing(PubPkt = ?PUBLISH_PACKET(_, TopicName, _, _),
                 State = #state{registry = Registry}) ->
-    #mqtt_packet{header = #mqtt_packet_header{dup = Dup, retain = Retain}} = PubPkt,
-    MsgId = message_id(PacketId),
-    ?LOG(debug, "Handle outgoing: ~0p", [PubPkt]),
-
-    (emqx_sn_registry:lookup_topic_id(Registry, self(), TopicName) == undefined)
-        andalso (byte_size(TopicName) =/= 2)
-            andalso register_and_notify_client(TopicName, Payload, Dup, QoS,
-                                               Retain, MsgId, State),
-
-    send_message(mqtt2sn(PubPkt, State), State);
+    ?LOG(debug, "Handle outgoing publish: ~0p", [PubPkt]),
+    TopicId = emqx_sn_registry:lookup_topic_id(Registry, self(), TopicName),
+    case (TopicId == undefined) andalso (byte_size(TopicName) =/= 2) of
+        true -> register_and_notify_client(PubPkt, State);
+        false -> send_message(mqtt2sn(PubPkt, State), State)
+    end;
 
 handle_outgoing(Packet, State) ->
     send_message(mqtt2sn(Packet, State), State).
 
-register_and_notify_client(TopicName, Payload, Dup, QoS, Retain, MsgId,
-                           State = #state{registry = Registry}) ->
+cache_no_reg_publish_message(Pendings, TopicId, PubPkt, State) ->
+    ?LOG(debug, "cache non-registered publish message for topic-id: ~p, msg: ~0p, pendings: ~0p",
+        [TopicId, PubPkt, Pendings]),
+    Msgs = maps:get(pending_topic_ids, Pendings, []),
+    Pendings#{TopicId => Msgs ++ [mqtt2sn(PubPkt, State)]}.
+
+replay_no_reg_pending_publishes(TopicId, #state{pending_topic_ids = Pendings} = State0) ->
+    ?LOG(debug, "replay non-registered publish message for topic-id: ~p, pendings: ~0p",
+        [TopicId, Pendings]),
+    State = lists:foldl(fun(Msg, State1) ->
+        send_message(Msg, State1)
+    end, State0, maps:get(TopicId, Pendings, [])),
+    State#state{pending_topic_ids = maps:remove(TopicId, Pendings)}.
+
+register_and_notify_client(?PUBLISH_PACKET(QoS, TopicName, PacketId, Payload) = PubPkt,
+        State = #state{registry = Registry, pending_topic_ids = Pendings}) ->
+    MsgId = message_id(PacketId),
+    #mqtt_packet{header = #mqtt_packet_header{dup = Dup, retain = Retain}} = PubPkt,
     TopicId = emqx_sn_registry:register_topic(Registry, self(), TopicName),
     ?LOG(debug, "Register TopicId=~p, TopicName=~p, Payload=~p, Dup=~p, QoS=~p, "
                 "Retain=~p, MsgId=~p", [TopicId, TopicName, Payload, Dup, QoS, Retain, MsgId]),
-    send_register(TopicName, TopicId, MsgId, State).
+    NewPendings = cache_no_reg_publish_message(Pendings, TopicId, PubPkt, State),
+    send_register(TopicName, TopicId, MsgId, State#state{pending_topic_ids = NewPendings}).
 
 message_id(undefined) ->
     rand:uniform(16#FFFF);
@@ -1126,3 +1131,8 @@ append(Replies, AddEvents) when is_list(Replies) ->
     Replies ++ AddEvents;
 append(Replies, AddEvents) ->
     [Replies] ++ AddEvents.
+
+maybe_send_puback(?QOS_0, _TopicId, _MsgId, _ReasonCode, State) ->
+    State;
+maybe_send_puback(_QoS, TopicId, MsgId, ReasonCode, State) ->
+    send_message(?SN_PUBACK_MSG(TopicId, MsgId, ReasonCode), State).

+ 21 - 8
apps/emqx_sn/test/emqx_sn_protocol_SUITE.erl

@@ -1102,10 +1102,12 @@ t_asleep_test04_to_awake_qos1_dl_msg(_) ->
 
     %% send downlink data in asleep state. This message should be send to device once it wake up
     Payload1 = <<55, 66, 77, 88, 99>>,
+    Payload2 = <<55, 66, 77, 88, 100>>,
 
     {ok, C} = emqtt:start_link(),
     {ok, _} = emqtt:connect(C),
     {ok, _} = emqtt:publish(C, <<"a/b/c">>, Payload1, QoS),
+    {ok, _} = emqtt:publish(C, <<"a/b/c">>, Payload2, QoS),
     timer:sleep(100),
     ok = emqtt:disconnect(C),
 
@@ -1114,21 +1116,32 @@ t_asleep_test04_to_awake_qos1_dl_msg(_) ->
     % goto awake state, receive downlink messages, and go back to asleep
     send_pingreq_msg(Socket, <<"test">>),
 
-    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-    %% get REGISTER first, since this topic has never been registered
-    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-    UdpData2 = receive_response(Socket),
-    {TopicIdNew, MsgId3} = check_register_msg_on_udp(<<"a/b/c">>, UdpData2),
+    %% 1. get REGISTER first, since this topic has never been registered
+    UdpData1 = receive_response(Socket),
+    {TopicIdNew, MsgId3} = check_register_msg_on_udp(<<"a/b/c">>, UdpData1),
+
+    %% 2. but before we reply the REGACK, the sn-gateway should not send any PUBLISH
+    ?assertError(_, receive_publish(Socket)),
+
     send_regack_msg(Socket, TopicIdNew, MsgId3),
 
-    UdpData = receive_response(Socket),
-    MsgId_udp = check_publish_msg_on_udp({Dup, QoS, Retain, WillBit, CleanSession, ?SN_NORMAL_TOPIC, TopicIdNew, Payload1}, UdpData),
-    send_puback_msg(Socket, TopicIdNew, MsgId_udp),
+    UdpData2 = receive_response(Socket),
+    MsgId_udp2 = check_publish_msg_on_udp({Dup, QoS, Retain, WillBit, CleanSession, ?SN_NORMAL_TOPIC, TopicIdNew, Payload1}, UdpData2),
+    send_puback_msg(Socket, TopicIdNew, MsgId_udp2),
+
+    UdpData3 = receive_response(Socket),
+    MsgId_udp3 = check_publish_msg_on_udp({Dup, QoS, Retain, WillBit, CleanSession, ?SN_NORMAL_TOPIC, TopicIdNew, Payload2}, UdpData3),
+    send_puback_msg(Socket, TopicIdNew, MsgId_udp3),
 
     ?assertEqual(<<2, ?SN_PINGRESP>>, receive_response(Socket)),
 
     gen_udp:close(Socket).
 
+receive_publish(Socket) ->
+    UdpData3 = receive_response(Socket, 1000),
+    <<HeaderUdp:5/binary, _:16, _/binary>> = UdpData3,
+    <<_:8, ?SN_PUBLISH, _/binary>> = HeaderUdp.
+
 t_asleep_test05_to_awake_qos1_dl_msg(_) ->
     QoS = 1,
     Duration = 5,