Bladeren bron

Merge pull request #11041 from HJianBo/fix-superuser-not-working

Fix superuser not working
JianBo He 2 jaren geleden
bovenliggende
commit
d9cf9c2cb6

+ 9 - 4
apps/emqx_gateway/src/emqx_gateway_ctx.erl

@@ -69,8 +69,9 @@
 authenticate(_Ctx, ClientInfo0) ->
 authenticate(_Ctx, ClientInfo0) ->
     ClientInfo = ClientInfo0#{zone => default},
     ClientInfo = ClientInfo0#{zone => default},
     case emqx_access_control:authenticate(ClientInfo) of
     case emqx_access_control:authenticate(ClientInfo) of
-        {ok, _} ->
-            {ok, mountpoint(ClientInfo)};
+        {ok, AuthResult} ->
+            ClientInfo1 = merge_auth_result(ClientInfo, AuthResult),
+            {ok, eval_mountpoint(ClientInfo1)};
         {error, Reason} ->
         {error, Reason} ->
             {error, Reason}
             {error, Reason}
     end.
     end.
@@ -174,8 +175,12 @@ metrics_inc(_Ctx = #{gwname := GwName}, Name, Oct) ->
 %% Internal funcs
 %% Internal funcs
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
-mountpoint(ClientInfo = #{mountpoint := undefined}) ->
+eval_mountpoint(ClientInfo = #{mountpoint := undefined}) ->
     ClientInfo;
     ClientInfo;
-mountpoint(ClientInfo = #{mountpoint := MountPoint}) ->
+eval_mountpoint(ClientInfo = #{mountpoint := MountPoint}) ->
     MountPoint1 = emqx_mountpoint:replvar(MountPoint, ClientInfo),
     MountPoint1 = emqx_mountpoint:replvar(MountPoint, ClientInfo),
     ClientInfo#{mountpoint := MountPoint1}.
     ClientInfo#{mountpoint := MountPoint1}.
+
+merge_auth_result(ClientInfo, AuthResult) when is_map(ClientInfo) andalso is_map(AuthResult) ->
+    IsSuperuser = maps:get(is_superuser, AuthResult, false),
+    maps:merge(ClientInfo, AuthResult#{is_superuser => IsSuperuser}).

+ 15 - 7
apps/emqx_gateway/test/emqx_gateway_ctx_SUITE.erl

@@ -36,8 +36,10 @@ init_per_suite(Conf) ->
         fun
         fun
             (#{clientid := bad_client}) ->
             (#{clientid := bad_client}) ->
                 {error, bad_username_or_password};
                 {error, bad_username_or_password};
-            (ClientInfo) ->
-                {ok, ClientInfo}
+            (#{clientid := admin}) ->
+                {ok, #{is_superuser => true}};
+            (_) ->
+                {ok, #{}}
         end
         end
     ),
     ),
     Conf.
     Conf.
@@ -56,15 +58,15 @@ t_authenticate(_) ->
         mountpoint => undefined,
         mountpoint => undefined,
         clientid => <<"user1">>
         clientid => <<"user1">>
     },
     },
-    NInfo1 = zone(Info1),
-    ?assertEqual({ok, NInfo1}, emqx_gateway_ctx:authenticate(Ctx, Info1)),
+    NInfo1 = default_result(Info1),
+    ?assertMatch({ok, NInfo1}, emqx_gateway_ctx:authenticate(Ctx, Info1)),
 
 
     Info2 = #{
     Info2 = #{
         mountpoint => <<"mqttsn/${clientid}/">>,
         mountpoint => <<"mqttsn/${clientid}/">>,
         clientid => <<"user1">>
         clientid => <<"user1">>
     },
     },
-    NInfo2 = zone(Info2#{mountpoint => <<"mqttsn/user1/">>}),
-    ?assertEqual({ok, NInfo2}, emqx_gateway_ctx:authenticate(Ctx, Info2)),
+    NInfo2 = default_result(Info2#{mountpoint => <<"mqttsn/user1/">>}),
+    ?assertMatch({ok, NInfo2}, emqx_gateway_ctx:authenticate(Ctx, Info2)),
 
 
     Info3 = #{
     Info3 = #{
         mountpoint => <<"mqttsn/${clientid}/">>,
         mountpoint => <<"mqttsn/${clientid}/">>,
@@ -72,6 +74,12 @@ t_authenticate(_) ->
     },
     },
     {error, bad_username_or_password} =
     {error, bad_username_or_password} =
         emqx_gateway_ctx:authenticate(Ctx, Info3),
         emqx_gateway_ctx:authenticate(Ctx, Info3),
+
+    Info4 = #{
+        mountpoint => undefined,
+        clientid => admin
+    },
+    ?assertMatch({ok, #{is_superuser := true}}, emqx_gateway_ctx:authenticate(Ctx, Info4)),
     ok.
     ok.
 
 
-zone(Info) -> Info#{zone => default}.
+default_result(Info) -> Info#{zone => default, is_superuser => false}.

+ 24 - 18
apps/emqx_gateway_stomp/src/emqx_stomp_channel.erl

@@ -448,7 +448,9 @@ handle_in(
     Topic = header(<<"destination">>, Headers),
     Topic = header(<<"destination">>, Headers),
     case emqx_gateway_ctx:authorize(Ctx, ClientInfo, publish, Topic) of
     case emqx_gateway_ctx:authorize(Ctx, ClientInfo, publish, Topic) of
         deny ->
         deny ->
-            handle_out(error, {receipt_id(Headers), "Authorization Deny"}, Channel);
+            ErrMsg = io_lib:format("Insufficient permissions for ~s", [Topic]),
+            ErrorFrame = error_frame(receipt_id(Headers), ErrMsg),
+            shutdown(acl_denied, ErrorFrame, Channel);
         allow ->
         allow ->
             case header(<<"transaction">>, Headers) of
             case header(<<"transaction">>, Headers) of
                 undefined ->
                 undefined ->
@@ -494,20 +496,25 @@ handle_in(
             ),
             ),
             case do_subscribe(NTopicFilters, NChannel) of
             case do_subscribe(NTopicFilters, NChannel) of
                 [] ->
                 [] ->
-                    ErrMsg = "Permission denied",
-                    handle_out(error, {receipt_id(Headers), ErrMsg}, Channel);
+                    ErrMsg = io_lib:format(
+                        "The client.subscribe hook blocked the ~s subscription request",
+                        [TopicFilter]
+                    ),
+                    ErrorFrame = error_frame(receipt_id(Headers), ErrMsg),
+                    shutdown(normal, ErrorFrame, Channel);
                 [{MountedTopic, SubOpts} | _] ->
                 [{MountedTopic, SubOpts} | _] ->
                     NSubs = [{SubId, MountedTopic, Ack, SubOpts} | Subs],
                     NSubs = [{SubId, MountedTopic, Ack, SubOpts} | Subs],
                     NChannel1 = NChannel#channel{subscriptions = NSubs},
                     NChannel1 = NChannel#channel{subscriptions = NSubs},
                     handle_out_and_update(receipt, receipt_id(Headers), NChannel1)
                     handle_out_and_update(receipt, receipt_id(Headers), NChannel1)
             end;
             end;
-        {error, ErrMsg, NChannel} ->
-            ?SLOG(error, #{
-                msg => "failed_top_subscribe_topic",
-                topic => Topic,
-                reason => ErrMsg
-            }),
-            handle_out(error, {receipt_id(Headers), ErrMsg}, NChannel)
+        {error, subscription_id_inused, NChannel} ->
+            ErrMsg = io_lib:format("Subscription id ~w is in used", [SubId]),
+            ErrorFrame = error_frame(receipt_id(Headers), ErrMsg),
+            shutdown(subscription_id_inused, ErrorFrame, NChannel);
+        {error, acl_denied, NChannel} ->
+            ErrMsg = io_lib:format("Insufficient permissions for ~s", [Topic]),
+            ErrorFrame = error_frame(receipt_id(Headers), ErrMsg),
+            shutdown(acl_denied, ErrorFrame, NChannel)
     end;
     end;
 handle_in(
 handle_in(
     ?PACKET(?CMD_UNSUBSCRIBE, Headers),
     ?PACKET(?CMD_UNSUBSCRIBE, Headers),
@@ -691,7 +698,7 @@ check_subscribed_status(
         {SubId, MountedTopic, _Ack, _} ->
         {SubId, MountedTopic, _Ack, _} ->
             ok;
             ok;
         {SubId, _OtherTopic, _Ack, _} ->
         {SubId, _OtherTopic, _Ack, _} ->
-            {error, "Conflict subscribe id"};
+            {error, subscription_id_inused};
         false ->
         false ->
             ok
             ok
     end.
     end.
@@ -704,7 +711,7 @@ check_sub_acl(
     }
     }
 ) ->
 ) ->
     case emqx_gateway_ctx:authorize(Ctx, ClientInfo, subscribe, ParsedTopic) of
     case emqx_gateway_ctx:authorize(Ctx, ClientInfo, subscribe, ParsedTopic) of
-        deny -> {error, "ACL Deny"};
+        deny -> {error, acl_denied};
         allow -> ok
         allow -> ok
     end.
     end.
 
 
@@ -987,7 +994,7 @@ handle_deliver(
     Delivers,
     Delivers,
     Channel = #channel{
     Channel = #channel{
         ctx = Ctx,
         ctx = Ctx,
-        clientinfo = ClientInfo,
+        clientinfo = ClientInfo = #{mountpoint := Mountpoint},
         subscriptions = Subs
         subscriptions = Subs
     }
     }
 ) ->
 ) ->
@@ -998,22 +1005,21 @@ handle_deliver(
         fun({_, _, Message}, Acc) ->
         fun({_, _, Message}, Acc) ->
             Topic0 = emqx_message:topic(Message),
             Topic0 = emqx_message:topic(Message),
             case lists:keyfind(Topic0, 2, Subs) of
             case lists:keyfind(Topic0, 2, Subs) of
-                {Id, Topic, Ack, _SubOpts} ->
-                    %% XXX: refactor later
+                {Id, _Topic, Ack, _SubOpts} ->
+                    Message1 = emqx_mountpoint:unmount(Mountpoint, Message),
                     metrics_inc('messages.delivered', Channel),
                     metrics_inc('messages.delivered', Channel),
                     NMessage = run_hooks_without_metrics(
                     NMessage = run_hooks_without_metrics(
                         Ctx,
                         Ctx,
                         'message.delivered',
                         'message.delivered',
                         [ClientInfo],
                         [ClientInfo],
-                        Message
+                        Message1
                     ),
                     ),
-                    Topic = emqx_message:topic(NMessage),
                     Headers = emqx_message:get_headers(NMessage),
                     Headers = emqx_message:get_headers(NMessage),
                     Payload = emqx_message:payload(NMessage),
                     Payload = emqx_message:payload(NMessage),
                     Headers0 = [
                     Headers0 = [
                         {<<"subscription">>, Id},
                         {<<"subscription">>, Id},
                         {<<"message-id">>, next_msgid()},
                         {<<"message-id">>, next_msgid()},
-                        {<<"destination">>, Topic},
+                        {<<"destination">>, emqx_message:topic(NMessage)},
                         {<<"content-type">>, <<"text/plain">>}
                         {<<"content-type">>, <<"text/plain">>}
                     ],
                     ],
                     Headers1 =
                     Headers1 =

+ 2 - 0
apps/emqx_gateway_stomp/src/emqx_stomp_frame.erl

@@ -185,6 +185,8 @@ parse(headers, Bin, State) ->
     parse(hdname, Bin, State);
     parse(hdname, Bin, State);
 parse(hdname, <<?LF, _Rest/binary>>, _State) ->
 parse(hdname, <<?LF, _Rest/binary>>, _State) ->
     error(unexpected_linefeed);
     error(unexpected_linefeed);
+parse(hdname, <<?COLON, $\s, Rest/binary>>, State = #parser_state{acc = Acc}) ->
+    parse(hdvalue, Rest, State#parser_state{hdname = Acc, acc = <<>>});
 parse(hdname, <<?COLON, Rest/binary>>, State = #parser_state{acc = Acc}) ->
 parse(hdname, <<?COLON, Rest/binary>>, State = #parser_state{acc = Acc}) ->
     parse(hdvalue, Rest, State#parser_state{hdname = Acc, acc = <<>>});
     parse(hdvalue, Rest, State#parser_state{hdname = Acc, acc = <<>>});
 parse(hdname, <<Ch:8, Rest/binary>>, State) ->
 parse(hdname, <<Ch:8, Rest/binary>>, State) ->

+ 263 - 120
apps/emqx_gateway_stomp/test/emqx_stomp_SUITE.erl

@@ -60,11 +60,11 @@ all() -> emqx_common_test_helpers:all(?MODULE).
 init_per_suite(Cfg) ->
 init_per_suite(Cfg) ->
     application:load(emqx_gateway_stomp),
     application:load(emqx_gateway_stomp),
     ok = emqx_common_test_helpers:load_config(emqx_gateway_schema, ?CONF_DEFAULT),
     ok = emqx_common_test_helpers:load_config(emqx_gateway_schema, ?CONF_DEFAULT),
-    emqx_mgmt_api_test_util:init_suite([emqx_authn, emqx_gateway]),
+    emqx_mgmt_api_test_util:init_suite([emqx_conf, emqx_authn, emqx_gateway]),
     Cfg.
     Cfg.
 
 
 end_per_suite(_Cfg) ->
 end_per_suite(_Cfg) ->
-    emqx_mgmt_api_test_util:end_suite([emqx_gateway, emqx_authn]),
+    emqx_mgmt_api_test_util:end_suite([emqx_gateway, emqx_authn, emqx_conf]),
     ok.
     ok.
 
 
 default_config() ->
 default_config() ->
@@ -73,73 +73,40 @@ default_config() ->
 stomp_ver() ->
 stomp_ver() ->
     ?STOMP_VER.
     ?STOMP_VER.
 
 
+restart_stomp_with_mountpoint(Mountpoint) ->
+    Conf = emqx:get_raw_config([gateway, stomp]),
+    emqx_gateway_conf:update_gateway(
+        stomp,
+        Conf#{<<"mountpoint">> => Mountpoint}
+    ).
+
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Test Cases
 %% Test Cases
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 t_connect(_) ->
 t_connect(_) ->
-    %% Connect should be succeed
-    with_connection(fun(Sock) ->
-        gen_tcp:send(
-            Sock,
-            serialize(
-                <<"CONNECT">>,
-                [
-                    {<<"accept-version">>, ?STOMP_VER},
-                    {<<"host">>, <<"127.0.0.1:61613">>},
-                    {<<"login">>, <<"guest">>},
-                    {<<"passcode">>, <<"guest">>},
-                    {<<"heart-beat">>, <<"1000,2000">>}
-                ]
-            )
-        ),
-        {ok, Data} = gen_tcp:recv(Sock, 0),
-        {ok,
-            Frame = #stomp_frame{
-                command = <<"CONNECTED">>,
-                headers = _,
-                body = _
-            },
-            _, _} = parse(Data),
-        <<"2000,1000">> = proplists:get_value(<<"heart-beat">>, Frame#stomp_frame.headers),
-
-        gen_tcp:send(
-            Sock,
-            serialize(
-                <<"DISCONNECT">>,
-                [{<<"receipt">>, <<"12345">>}]
-            )
+    %% Successful connect
+    ConnectSucced = fun(Sock) ->
+        ok = send_connection_frame(Sock, <<"guest">>, <<"guest">>, <<"1000,2000">>),
+        {ok, Frame} = recv_a_frame(Sock),
+        ?assertMatch(<<"CONNECTED">>, Frame#stomp_frame.command),
+        ?assertEqual(
+            <<"2000,1000">>, proplists:get_value(<<"heart-beat">>, Frame#stomp_frame.headers)
         ),
         ),
 
 
-        {ok, Data1} = gen_tcp:recv(Sock, 0),
-        {ok,
-            #stomp_frame{
+        ok = send_disconnect_frame(Sock, <<"12345">>),
+        ?assertMatch(
+            {ok, #stomp_frame{
                 command = <<"RECEIPT">>,
                 command = <<"RECEIPT">>,
-                headers = [{<<"receipt-id">>, <<"12345">>}],
-                body = _
-            },
-            _, _} = parse(Data1)
-    end),
-
-    %% Connect will be failed, because of bad login or passcode
-    %% FIXME: Waiting for authentication works
-    %with_connection(
-    %    fun(Sock) ->
-    %        gen_tcp:send(Sock, serialize(<<"CONNECT">>,
-    %                                     [{<<"accept-version">>, ?STOMP_VER},
-    %                                      {<<"host">>, <<"127.0.0.1:61613">>},
-    %                                      {<<"login">>, <<"admin">>},
-    %                                      {<<"passcode">>, <<"admin">>},
-    %                                      {<<"heart-beat">>, <<"1000,2000">>}])),
-    %          {ok, Data} = gen_tcp:recv(Sock, 0),
-    %          {ok, Frame, _, _} = parse(Data),
-    %          #stomp_frame{command = <<"ERROR">>,
-    %                       headers = _,
-    %                       body    = <<"Login or passcode error!">>} = Frame
-    %      end),
+                headers = [{<<"receipt-id">>, <<"12345">>}]
+            }},
+            recv_a_frame(Sock)
+        )
+    end,
+    with_connection(ConnectSucced),
 
 
     %% Connect will be failed, because of bad version
     %% Connect will be failed, because of bad version
-    with_connection(fun(Sock) ->
+    ProtocolError = fun(Sock) ->
         gen_tcp:send(
         gen_tcp:send(
             Sock,
             Sock,
             serialize(
             serialize(
@@ -160,7 +127,8 @@ t_connect(_) ->
             headers = _,
             headers = _,
             body = <<"Login Failed: Supported protocol versions < 1.2">>
             body = <<"Login Failed: Supported protocol versions < 1.2">>
         } = Frame
         } = Frame
-    end).
+    end,
+    with_connection(ProtocolError).
 
 
 t_heartbeat(_) ->
 t_heartbeat(_) ->
     %% Test heart beat
     %% Test heart beat
@@ -755,8 +723,7 @@ t_frame_error_too_many_headers(_) ->
     ),
     ),
     Assert =
     Assert =
         fun(Sock) ->
         fun(Sock) ->
-            {ok, Data} = gen_tcp:recv(Sock, 0),
-            {ok, ErrorFrame, _, _} = parse(Data),
+            {ok, ErrorFrame} = recv_a_frame(Sock),
             ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
             ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
             ?assertMatch(
             ?assertMatch(
                 match, re:run(ErrorFrame#stomp_frame.body, "too_many_headers", [{capture, none}])
                 match, re:run(ErrorFrame#stomp_frame.body, "too_many_headers", [{capture, none}])
@@ -777,8 +744,7 @@ t_frame_error_too_long_header(_) ->
     ),
     ),
     Assert =
     Assert =
         fun(Sock) ->
         fun(Sock) ->
-            {ok, Data} = gen_tcp:recv(Sock, 0),
-            {ok, ErrorFrame, _, _} = parse(Data),
+            {ok, ErrorFrame} = recv_a_frame(Sock),
             ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
             ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
             ?assertMatch(
             ?assertMatch(
                 match, re:run(ErrorFrame#stomp_frame.body, "too_long_header", [{capture, none}])
                 match, re:run(ErrorFrame#stomp_frame.body, "too_long_header", [{capture, none}])
@@ -796,8 +762,7 @@ t_frame_error_too_long_body(_) ->
     ),
     ),
     Assert =
     Assert =
         fun(Sock) ->
         fun(Sock) ->
-            {ok, Data} = gen_tcp:recv(Sock, 0),
-            {ok, ErrorFrame, _, _} = parse(Data),
+            {ok, ErrorFrame} = recv_a_frame(Sock),
             ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
             ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
             ?assertMatch(
             ?assertMatch(
                 match, re:run(ErrorFrame#stomp_frame.body, "too_long_body", [{capture, none}])
                 match, re:run(ErrorFrame#stomp_frame.body, "too_long_body", [{capture, none}])
@@ -808,54 +773,16 @@ t_frame_error_too_long_body(_) ->
 
 
 test_frame_error(Frame, AssertFun) ->
 test_frame_error(Frame, AssertFun) ->
     with_connection(fun(Sock) ->
     with_connection(fun(Sock) ->
-        gen_tcp:send(
-            Sock,
-            serialize(
-                <<"CONNECT">>,
-                [
-                    {<<"accept-version">>, ?STOMP_VER},
-                    {<<"host">>, <<"127.0.0.1:61613">>},
-                    {<<"login">>, <<"guest">>},
-                    {<<"passcode">>, <<"guest">>},
-                    {<<"heart-beat">>, <<"0,0">>}
-                ]
-            )
-        ),
-        {ok, Data} = gen_tcp:recv(Sock, 0),
-        {ok,
-            #stomp_frame{
-                command = <<"CONNECTED">>,
-                headers = _,
-                body = _
-            },
-            _, _} = parse(Data),
+        send_connection_frame(Sock, <<"guest">>, <<"guest">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"CONNECTED">>}}, recv_a_frame(Sock)),
         gen_tcp:send(Sock, Frame),
         gen_tcp:send(Sock, Frame),
         AssertFun(Sock)
         AssertFun(Sock)
     end).
     end).
 
 
 t_rest_clienit_info(_) ->
 t_rest_clienit_info(_) ->
     with_connection(fun(Sock) ->
     with_connection(fun(Sock) ->
-        gen_tcp:send(
-            Sock,
-            serialize(
-                <<"CONNECT">>,
-                [
-                    {<<"accept-version">>, ?STOMP_VER},
-                    {<<"host">>, <<"127.0.0.1:61613">>},
-                    {<<"login">>, <<"guest">>},
-                    {<<"passcode">>, <<"guest">>},
-                    {<<"heart-beat">>, <<"0,0">>}
-                ]
-            )
-        ),
-        {ok, Data} = gen_tcp:recv(Sock, 0),
-        {ok,
-            #stomp_frame{
-                command = <<"CONNECTED">>,
-                headers = _,
-                body = _
-            },
-            _, _} = parse(Data),
+        send_connection_frame(Sock, <<"guest">>, <<"guest">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"CONNECTED">>}}, recv_a_frame(Sock)),
 
 
         %% client lists
         %% client lists
         {200, Clients} = request(get, "/gateways/stomp/clients"),
         {200, Clients} = request(get, "/gateways/stomp/clients"),
@@ -909,18 +836,8 @@ t_rest_clienit_info(_) ->
 
 
         %% sub & unsub
         %% sub & unsub
         {200, []} = request(get, ClientPath ++ "/subscriptions"),
         {200, []} = request(get, ClientPath ++ "/subscriptions"),
-        gen_tcp:send(
-            Sock,
-            serialize(
-                <<"SUBSCRIBE">>,
-                [
-                    {<<"id">>, 0},
-                    {<<"destination">>, <<"/queue/foo">>},
-                    {<<"ack">>, <<"client">>}
-                ]
-            )
-        ),
-        timer:sleep(100),
+        ok = send_subscribe_frame(Sock, 0, <<"/queue/foo">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"RECEIPT">>}}, recv_a_frame(Sock)),
 
 
         {200, Subs} = request(get, ClientPath ++ "/subscriptions"),
         {200, Subs} = request(get, ClientPath ++ "/subscriptions"),
         ?assertEqual(1, length(Subs)),
         ?assertEqual(1, length(Subs)),
@@ -956,6 +873,141 @@ t_rest_clienit_info(_) ->
         ?assertEqual(0, length(maps:get(data, Clients2)))
         ?assertEqual(0, length(maps:get(data, Clients2)))
     end).
     end).
 
 
+t_authn_superuser(_) ->
+    %% mock authn
+    meck:new(emqx_access_control, [passthrough]),
+    meck:expect(
+        emqx_access_control,
+        authenticate,
+        fun
+            (#{username := <<"admin">>}) ->
+                {ok, #{is_superuser => true}};
+            (#{username := <<"bad_user">>}) ->
+                {error, not_authorized};
+            (_) ->
+                {ok, #{is_superuser => false}}
+        end
+    ),
+    %% mock authz
+    meck:expect(
+        emqx_access_control,
+        authorize,
+        fun
+            (_ClientInfo = #{is_superuser := true}, _PubSub, _Topic) ->
+                allow;
+            (_ClientInfo, _PubSub, _Topic) ->
+                deny
+        end
+    ),
+
+    LoginFailure = fun(Sock) ->
+        ok = send_connection_frame(Sock, <<"bad_user">>, <<"public">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"ERROR">>}}, recv_a_frame(Sock)),
+        ?assertMatch({error, closed}, recv_a_frame(Sock))
+    end,
+
+    PublishFailure = fun(Sock) ->
+        ok = send_connection_frame(Sock, <<"user1">>, <<"public">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"CONNECTED">>}}, recv_a_frame(Sock)),
+        ok = send_message_frame(Sock, <<"t/a">>, <<"hello">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"ERROR">>}}, recv_a_frame(Sock)),
+        ?assertMatch({error, closed}, recv_a_frame(Sock))
+    end,
+
+    SubscribeFailed = fun(Sock) ->
+        ok = send_connection_frame(Sock, <<"user1">>, <<"public">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"CONNECTED">>}}, recv_a_frame(Sock)),
+        ok = send_subscribe_frame(Sock, 0, <<"t/a">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"ERROR">>}}, recv_a_frame(Sock)),
+        ?assertMatch({error, closed}, recv_a_frame(Sock))
+    end,
+
+    LoginAsSuperUser = fun(Sock) ->
+        ok = send_connection_frame(Sock, <<"admin">>, <<"public">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"CONNECTED">>}}, recv_a_frame(Sock)),
+        ok = send_subscribe_frame(Sock, 0, <<"t/a">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"RECEIPT">>}}, recv_a_frame(Sock)),
+        ok = send_message_frame(Sock, <<"t/a">>, <<"hello">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"RECEIPT">>}}, recv_a_frame(Sock)),
+        ?assertMatch(
+            {ok, #stomp_frame{
+                command = <<"MESSAGE">>,
+                body = <<"hello">>
+            }},
+            recv_a_frame(Sock)
+        ),
+        ok = send_disconnect_frame(Sock)
+    end,
+
+    with_connection(LoginFailure),
+    with_connection(PublishFailure),
+    with_connection(SubscribeFailed),
+    with_connection(LoginAsSuperUser),
+    meck:unload(emqx_access_control).
+
+t_mountpoint(_) ->
+    restart_stomp_with_mountpoint(<<"stomp/">>),
+
+    PubSub = fun(Sock) ->
+        ok = send_connection_frame(Sock, <<"user1">>, <<"public">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"CONNECTED">>}}, recv_a_frame(Sock)),
+        ok = send_subscribe_frame(Sock, 0, <<"t/a">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"RECEIPT">>}}, recv_a_frame(Sock)),
+        ok = send_message_frame(Sock, <<"t/a">>, <<"hello">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"RECEIPT">>}}, recv_a_frame(Sock)),
+
+        {ok, #stomp_frame{
+            command = <<"MESSAGE">>,
+            headers = Headers,
+            body = <<"hello">>
+        }} = recv_a_frame(Sock),
+        ?assertEqual(<<"t/a">>, proplists:get_value(<<"destination">>, Headers)),
+
+        ok = send_disconnect_frame(Sock)
+    end,
+
+    PubToMqtt = fun(Sock) ->
+        ok = send_connection_frame(Sock, <<"user1">>, <<"public">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"CONNECTED">>}}, recv_a_frame(Sock)),
+
+        ok = emqx:subscribe(<<"stomp/t/a">>),
+        ok = send_message_frame(Sock, <<"t/a">>, <<"hello">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"RECEIPT">>}}, recv_a_frame(Sock)),
+
+        receive
+            {deliver, Topic, Msg} ->
+                ?assertEqual(<<"stomp/t/a">>, Topic),
+                ?assertEqual(<<"hello">>, emqx_message:payload(Msg))
+        after 100 ->
+            ?assert(false, "waiting message timeout")
+        end,
+        ok = send_disconnect_frame(Sock)
+    end,
+
+    ReceiveMsgFromMqtt = fun(Sock) ->
+        ok = send_connection_frame(Sock, <<"user1">>, <<"public">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"CONNECTED">>}}, recv_a_frame(Sock)),
+        ok = send_subscribe_frame(Sock, 0, <<"t/a">>),
+        ?assertMatch({ok, #stomp_frame{command = <<"RECEIPT">>}}, recv_a_frame(Sock)),
+
+        Msg = emqx_message:make(<<"stomp/t/a">>, <<"hello">>),
+        emqx:publish(Msg),
+
+        {ok, #stomp_frame{
+            command = <<"MESSAGE">>,
+            headers = Headers,
+            body = <<"hello">>
+        }} = recv_a_frame(Sock),
+        ?assertEqual(<<"t/a">>, proplists:get_value(<<"destination">>, Headers)),
+
+        ok = send_disconnect_frame(Sock)
+    end,
+
+    with_connection(PubSub),
+    with_connection(PubToMqtt),
+    with_connection(ReceiveMsgFromMqtt),
+    restart_stomp_with_mountpoint(<<>>).
+
 %% TODO: Mountpoint, AuthChain, Authorization + Mountpoint, ClientInfoOverride,
 %% TODO: Mountpoint, AuthChain, Authorization + Mountpoint, ClientInfoOverride,
 %%       Listeners, Metrics, Stats, ClientInfo
 %%       Listeners, Metrics, Stats, ClientInfo
 %%
 %%
@@ -963,6 +1015,9 @@ t_rest_clienit_info(_) ->
 %%
 %%
 %% TODO: RateLimit, OOM,
 %% TODO: RateLimit, OOM,
 
 
+%%--------------------------------------------------------------------
+%% helpers
+
 with_connection(DoFun) ->
 with_connection(DoFun) ->
     {ok, Sock} = gen_tcp:connect(
     {ok, Sock} = gen_tcp:connect(
         {127, 0, 0, 1},
         {127, 0, 0, 1},
@@ -973,6 +1028,8 @@ with_connection(DoFun) ->
     try
     try
         DoFun(Sock)
         DoFun(Sock)
     after
     after
+        erase(parser),
+        erase(rest),
         gen_tcp:close(Sock)
         gen_tcp:close(Sock)
     end.
     end.
 
 
@@ -982,6 +1039,46 @@ serialize(Command, Headers) ->
 serialize(Command, Headers, Body) ->
 serialize(Command, Headers, Body) ->
     emqx_stomp_frame:serialize_pkt(emqx_stomp_frame:make(Command, Headers, Body), #{}).
     emqx_stomp_frame:serialize_pkt(emqx_stomp_frame:make(Command, Headers, Body), #{}).
 
 
+recv_a_frame(Sock) ->
+    Parser =
+        case get(parser) of
+            undefined ->
+                ProtoEnv = #{
+                    max_headers => 1024,
+                    max_header_length => 10240,
+                    max_body_length => 81920
+                },
+                emqx_stomp_frame:initial_parse_state(ProtoEnv);
+            P ->
+                P
+        end,
+    LastRest =
+        case get(rest) of
+            undefined -> <<>>;
+            R -> R
+        end,
+    case emqx_stomp_frame:parse(LastRest, Parser) of
+        {more, NParser} ->
+            case gen_tcp:recv(Sock, 0, 5000) of
+                {ok, Data} ->
+                    put(parser, NParser),
+                    put(rest, <<LastRest/binary, Data/binary>>),
+                    recv_a_frame(Sock);
+                {error, _} = Err1 ->
+                    erase(parser),
+                    erase(rest),
+                    Err1
+            end;
+        {ok, Frame, Rest, NParser} ->
+            put(parser, NParser),
+            put(rest, Rest),
+            {ok, Frame};
+        {error, _} = Err ->
+            erase(parser),
+            erase(rest),
+            Err
+    end.
+
 parse(Data) ->
 parse(Data) ->
     ProtoEnv = #{
     ProtoEnv = #{
         max_headers => 1024,
         max_headers => 1024,
@@ -996,6 +1093,52 @@ get_field(command, #stomp_frame{command = Command}) ->
 get_field(body, #stomp_frame{body = Body}) ->
 get_field(body, #stomp_frame{body = Body}) ->
     Body.
     Body.
 
 
+send_connection_frame(Sock, Username, Password) ->
+    send_connection_frame(Sock, Username, Password, <<"0,0">>).
+
+send_connection_frame(Sock, Username, Password, Heartbeat) ->
+    Headers =
+        case Username == undefined of
+            true -> [];
+            false -> [{<<"login">>, Username}]
+        end ++
+            case Password == undefined of
+                true -> [];
+                false -> [{<<"passcode">>, Password}]
+            end,
+    Headers1 = [
+        {<<"accept-version">>, ?STOMP_VER},
+        {<<"host">>, <<"127.0.0.1:61613">>},
+        {<<"heart-beat">>, Heartbeat}
+        | Headers
+    ],
+    ok = gen_tcp:send(Sock, serialize(<<"CONNECT">>, Headers1)).
+
+send_subscribe_frame(Sock, Id, Topic) ->
+    Headers =
+        [
+            {<<"id">>, Id},
+            {<<"receipt">>, Id},
+            {<<"destination">>, Topic},
+            {<<"ack">>, <<"auto">>}
+        ],
+    ok = gen_tcp:send(Sock, serialize(<<"SUBSCRIBE">>, Headers)).
+
+send_message_frame(Sock, Topic, Payload) ->
+    Headers =
+        [
+            {<<"destination">>, Topic},
+            {<<"receipt">>, <<"rp-", Topic/binary>>}
+        ],
+    ok = gen_tcp:send(Sock, serialize(<<"SEND">>, Headers, Payload)).
+
+send_disconnect_frame(Sock) ->
+    ok = gen_tcp:send(Sock, serialize(<<"DISCONNECT">>, [])).
+
+send_disconnect_frame(Sock, ReceiptId) ->
+    Headers = [{<<"receipt">>, ReceiptId}],
+    ok = gen_tcp:send(Sock, serialize(<<"DISCONNECT">>, Headers)).
+
 clients() ->
 clients() ->
     {200, Clients} = request(get, "/gateways/stomp/clients"),
     {200, Clients} = request(get, "/gateways/stomp/clients"),
     maps:get(data, Clients).
     maps:get(data, Clients).

+ 5 - 0
changes/ce/fix-11018.en.md

@@ -0,0 +1,5 @@
+Fixed multiple issues with the Stomp gateway, including:
+- Fixed an issue where `is_superuser` was not working correctly.
+- Fixed an issue where the mountpoint was not being removed in message delivery.
+- After a message or subscription request fails, the Stomp client should be disconnected
+  immediately after replying with an ERROR message.