Просмотр исходного кода

Merge pull request #13350 from HJianBo/add-peersni-to-client-attr

feat: support to extract the client peersni field to clientinfo
JianBo He 1 год назад
Родитель
Сommit
9f44c50025

+ 14 - 8
apps/emqx/src/emqx_channel.erl

@@ -269,7 +269,7 @@ init(
         },
         Zone
     ),
-    {NClientInfo, NConnInfo} = take_ws_cookie(ClientInfo, ConnInfo),
+    {NClientInfo, NConnInfo} = take_conn_info_fields([ws_cookie, peersni], ClientInfo, ConnInfo),
     #channel{
         conninfo = NConnInfo,
         clientinfo = NClientInfo,
@@ -309,13 +309,19 @@ set_peercert_infos(Peercert, ClientInfo, Zone) ->
     ClientId = PeercetAs(peer_cert_as_clientid),
     ClientInfo#{username => Username, clientid => ClientId, dn => DN, cn => CN}.
 
-take_ws_cookie(ClientInfo, ConnInfo) ->
-    case maps:take(ws_cookie, ConnInfo) of
-        {WsCookie, NConnInfo} ->
-            {ClientInfo#{ws_cookie => WsCookie}, NConnInfo};
-        _ ->
-            {ClientInfo, ConnInfo}
-    end.
+take_conn_info_fields(Fields, ClientInfo, ConnInfo) ->
+    lists:foldl(
+        fun(Field, {ClientInfo0, ConnInfo0}) ->
+            case maps:take(Field, ConnInfo0) of
+                {Value, NConnInfo} ->
+                    {ClientInfo0#{Field => Value}, NConnInfo};
+                _ ->
+                    {ClientInfo0, ConnInfo0}
+            end
+        end,
+        {ClientInfo, ConnInfo},
+        Fields
+    ).
 
 %%--------------------------------------------------------------------
 %% Handle incoming packet

+ 2 - 0
apps/emqx/src/emqx_connection.erl

@@ -305,11 +305,13 @@ init_state(
     {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
     {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
     Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]),
+    PeerSNI = Transport:ensure_ok_or_exit(peersni, [Socket]),
     ConnInfo = #{
         socktype => Transport:type(Socket),
         peername => Peername,
         sockname => Sockname,
         peercert => Peercert,
+        peersni => PeerSNI,
         conn_mod => ?MODULE
     },
 

+ 6 - 1
apps/emqx/src/emqx_quic_stream.erl

@@ -39,7 +39,8 @@
     getopts/2,
     peername/1,
     sockname/1,
-    peercert/1
+    peercert/1,
+    peersni/1
 ]).
 -include_lib("quicer/include/quicer.hrl").
 -include_lib("emqx/include/emqx_quic.hrl").
@@ -106,6 +107,10 @@ peercert(_S) ->
     %% @todo but unsupported by msquic
     nossl.
 
+peersni(_S) ->
+    %% @todo
+    undefined.
+
 getstat({quic, Conn, _Stream, _Info}, Stats) ->
     case quicer:getstat(Conn, Stats) of
         {error, _} -> {error, closed};

+ 10 - 6
apps/emqx/src/emqx_ws_connection.erl

@@ -280,7 +280,7 @@ websocket_init([Req, Opts]) ->
     #{zone := Zone, limiter := LimiterCfg, listener := {Type, Listener} = ListenerCfg} = Opts,
     case check_max_connection(Type, Listener) of
         allow ->
-            {Peername, PeerCert} = get_peer_info(Type, Listener, Req, Opts),
+            {Peername, PeerCert, PeerSNI} = get_peer_info(Type, Listener, Req, Opts),
             Sockname = cowboy_req:sock(Req),
             WsCookie = get_ws_cookie(Req),
             ConnInfo = #{
@@ -288,6 +288,7 @@ websocket_init([Req, Opts]) ->
                 peername => Peername,
                 sockname => Sockname,
                 peercert => PeerCert,
+                peersni => PeerSNI,
                 ws_cookie => WsCookie,
                 conn_mod => ?MODULE
             },
@@ -376,11 +377,12 @@ get_ws_cookie(Req) ->
     end.
 
 get_peer_info(Type, Listener, Req, Opts) ->
+    Host = maps:get(host, Req, undefined),
     case
         emqx_config:get_listener_conf(Type, Listener, [proxy_protocol]) andalso
             maps:get(proxy_header, Req)
     of
-        #{src_address := SrcAddr, src_port := SrcPort, ssl := SSL} ->
+        #{src_address := SrcAddr, src_port := SrcPort, ssl := SSL} = ProxyInfo ->
             SourceName = {SrcAddr, SrcPort},
             %% Notice: CN is only available in Proxy Protocol V2 additional info.
             %% `CN` is unsupported in Proxy Protocol V1
@@ -392,12 +394,14 @@ get_peer_info(Type, Listener, Req, Opts) ->
                     undefined -> undefined;
                     CN -> [{pp2_ssl_cn, CN}]
                 end,
-            {SourceName, SourceSSL};
-        #{src_address := SrcAddr, src_port := SrcPort} ->
+            PeerSNI = maps:get(authority, ProxyInfo, Host),
+            {SourceName, SourceSSL, PeerSNI};
+        #{src_address := SrcAddr, src_port := SrcPort} = ProxyInfo ->
+            PeerSNI = maps:get(authority, ProxyInfo, Host),
             SourceName = {SrcAddr, SrcPort},
-            {SourceName, nossl};
+            {SourceName, nossl, PeerSNI};
         _ ->
-            {get_peer(Req, Opts), cowboy_req:cert(Req)}
+            {get_peer(Req, Opts), cowboy_req:cert(Req), Host}
     end.
 
 websocket_handle({binary, Data}, State) when is_list(Data) ->

+ 2 - 1
apps/emqx/test/emqx_connection_SUITE.erl

@@ -84,7 +84,8 @@ init_per_testcase(TestCase, Config) when
         fun
             (peername, [sock]) -> {ok, {{127, 0, 0, 1}, 3456}};
             (sockname, [sock]) -> {ok, {{127, 0, 0, 1}, 1883}};
-            (peercert, [sock]) -> undefined
+            (peercert, [sock]) -> undefined;
+            (peersni, [sock]) -> undefined
         end
     ),
     ok = meck:expect(emqx_transport, setopts, fun(_Sock, _Opts) -> ok end),

+ 145 - 0
apps/emqx/test/emqx_cth_listener.erl

@@ -0,0 +1,145 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_cth_listener).
+
+-include_lib("esockd/include/esockd.hrl").
+
+-export([
+    reload_listener_with_ppv2/1,
+    reload_listener_with_ppv2/2,
+    reload_listener_without_ppv2/1
+]).
+
+-export([meck_recv_ppv2/1, clear_meck_recv_ppv2/1]).
+
+-define(DEFAULT_OPTS, #{
+    host => "127.0.0.1",
+    proto_ver => v5,
+    connect_timeout => 5,
+    ssl => false
+}).
+
+%%--------------------------------------------------------------------
+%% APIs
+%%--------------------------------------------------------------------
+
+reload_listener_with_ppv2(Path = [listeners, _Type, _Name]) ->
+    reload_listener_with_ppv2(Path, <<>>).
+
+reload_listener_with_ppv2(Path = [listeners, Type, Name], DefaultSni) when
+    Type == tcp; Type == ws
+->
+    Cfg = emqx_config:get(Path),
+    ok = emqx_config:put(Path, Cfg#{proxy_protocol => true}),
+    ok = emqx_listeners:restart_listener(
+        emqx_listeners:listener_id(Type, Name)
+    ),
+    ok = meck_recv_ppv2(Type),
+    client_conn_fn(Type, maps:get(bind, Cfg), DefaultSni).
+
+client_conn_fn(tcp, Bind, Sni) ->
+    client_conn_fn_gen(connect, ?DEFAULT_OPTS#{port => bind2port(Bind), sni => Sni});
+client_conn_fn(ws, Bind, Sni) ->
+    client_conn_fn_gen(ws_connect, ?DEFAULT_OPTS#{port => bind2port(Bind), sni => Sni}).
+
+bind2port({_, Port}) -> Port;
+bind2port(Port) when is_integer(Port) -> Port.
+
+client_conn_fn_gen(Connect, Opts0) ->
+    fun(ClientId, Opts1) ->
+        Opts2 = maps:merge(Opts0, Opts1#{clientid => ClientId}),
+        Sni = maps:get(sni, Opts2, undefined),
+        NOpts = prepare_sni_for_meck(Sni, Opts2),
+        {ok, C} = emqtt:start_link(NOpts),
+        case emqtt:Connect(C) of
+            {ok, _} -> {ok, C};
+            {error, _} = Err -> Err
+        end
+    end.
+
+prepare_sni_for_meck(ClientSni, Opts) when is_binary(ClientSni) ->
+    ServerSni =
+        case ClientSni of
+            disable -> undefined;
+            _ -> ClientSni
+        end,
+    persistent_term:put(current_client_sni, ServerSni),
+    case maps:get(ssl, Opts, false) of
+        false ->
+            Opts;
+        true ->
+            SslOpts = maps:get(ssl_opts, Opts, #{}),
+            Opts#{ssl_opts => [{server_name_indication, ClientSni} | SslOpts]}
+    end.
+
+reload_listener_without_ppv2(Path = [listeners, Type, Name]) when
+    Type == tcp; Type == ws
+->
+    Cfg = emqx_config:get(Path),
+    ok = emqx_config:put(Path, Cfg#{proxy_protocol => false}),
+    ok = emqx_listeners:restart_listener(
+        emqx_listeners:listener_id(Type, Name)
+    ),
+    ok = clear_meck_recv_ppv2(Type).
+
+meck_recv_ppv2(tcp) ->
+    ok = meck:new(esockd_proxy_protocol, [passthrough, no_history, no_link]),
+    ok = meck:expect(
+        esockd_proxy_protocol,
+        recv,
+        fun(_Transport, Socket, _Timeout) ->
+            SNI = persistent_term:get(current_client_sni, undefined),
+            {ok, {SrcAddr, SrcPort}} = esockd_transport:peername(Socket),
+            {ok, {DstAddr, DstPort}} = esockd_transport:sockname(Socket),
+            {ok, #proxy_socket{
+                inet = inet4,
+                socket = Socket,
+                src_addr = SrcAddr,
+                dst_addr = DstAddr,
+                src_port = SrcPort,
+                dst_port = DstPort,
+                pp2_additional_info = [{pp2_authority, SNI}]
+            }}
+        end
+    );
+meck_recv_ppv2(ws) ->
+    ok = meck:new(ranch_tcp, [passthrough, no_history, no_link]),
+    ok = meck:expect(
+        ranch_tcp,
+        recv_proxy_header,
+        fun(Socket, _Timeout) ->
+            SNI = persistent_term:get(current_client_sni, undefined),
+            {ok, {SrcAddr, SrcPort}} = esockd_transport:peername(Socket),
+            {ok, {DstAddr, DstPort}} = esockd_transport:sockname(Socket),
+            {ok, #{
+                authority => SNI,
+                command => proxy,
+                dest_address => DstAddr,
+                dest_port => DstPort,
+                src_address => SrcAddr,
+                src_port => SrcPort,
+                transport_family => ipv4,
+                transport_protocol => stream,
+                version => 2
+            }}
+        end
+    ).
+
+clear_meck_recv_ppv2(tcp) ->
+    ok = meck:unload(esockd_proxy_protocol);
+clear_meck_recv_ppv2(ws) ->
+    ok = meck:unload(ranch_tcp).

+ 169 - 0
apps/emqx/test/emqx_peersni_SUITE.erl

@@ -0,0 +1,169 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_peersni_SUITE).
+
+-compile(export_all).
+-compile(nowarn_export_all).
+
+-include_lib("emqx/include/emqx.hrl").
+-include_lib("emqx/include/emqx_mqtt.hrl").
+-include_lib("eunit/include/eunit.hrl").
+-include_lib("common_test/include/ct.hrl").
+
+-include_lib("esockd/include/esockd.hrl").
+
+-define(SERVER_NAME, <<"localhost">>).
+
+%%--------------------------------------------------------------------
+%% setups
+%%--------------------------------------------------------------------
+
+all() ->
+    [
+        {group, tcp_ppv2},
+        {group, ws_ppv2},
+        {group, ssl},
+        {group, wss}
+    ].
+
+groups() ->
+    TCs = emqx_common_test_helpers:all(?MODULE),
+    [
+        {tcp_ppv2, [], TCs},
+        {ws_ppv2, [], TCs},
+        {ssl, [], TCs},
+        {wss, [], TCs}
+    ].
+
+init_per_suite(Config) ->
+    Apps = emqx_cth_suite:start(
+        [{emqx, #{}}],
+        #{work_dir => emqx_cth_suite:work_dir(Config)}
+    ),
+    [{apps, Apps} | Config].
+
+end_per_suite(Config) ->
+    emqx_cth_suite:stop(proplists:get_value(apps, Config)).
+
+init_per_group(tcp_ppv2, Config) ->
+    ClientFn = emqx_cth_listener:reload_listener_with_ppv2(
+        [listeners, tcp, default],
+        ?SERVER_NAME
+    ),
+    [{client_fn, ClientFn} | Config];
+init_per_group(ws_ppv2, Config) ->
+    ClientFn = emqx_cth_listener:reload_listener_with_ppv2(
+        [listeners, ws, default],
+        ?SERVER_NAME
+    ),
+    [{client_fn, ClientFn} | Config];
+init_per_group(ssl, Config) ->
+    ClientFn = fun(ClientId, Opts) ->
+        Opts1 = Opts#{
+            host => ?SERVER_NAME,
+            port => 8883,
+            ssl => true,
+            ssl_opts => [
+                {verify, verify_none},
+                {server_name_indication, binary_to_list(?SERVER_NAME)}
+            ]
+        },
+        {ok, C} = emqtt:start_link(Opts1#{clientid => ClientId}),
+        case emqtt:connect(C) of
+            {ok, _} -> {ok, C};
+            {error, _} = Err -> Err
+        end
+    end,
+    [{client_fn, ClientFn} | Config];
+init_per_group(wss, Config) ->
+    ClientFn = fun(ClientId, Opts) ->
+        Opts1 = Opts#{
+            host => ?SERVER_NAME,
+            port => 8084,
+            ws_transport_options => [
+                {transport, tls},
+                {protocols, [http]},
+                {transport_opts, [
+                    {verify, verify_none},
+                    {server_name_indication, binary_to_list(?SERVER_NAME)},
+                    {customize_hostname_check, []}
+                ]}
+            ]
+        },
+        {ok, C} = emqtt:start_link(Opts1#{clientid => ClientId}),
+        case emqtt:ws_connect(C) of
+            {ok, _} -> {ok, C};
+            {error, _} = Err -> Err
+        end
+    end,
+    [{client_fn, ClientFn} | Config];
+init_per_group(_, Config) ->
+    Config.
+
+end_per_group(tcp_ppv2, _Config) ->
+    emqx_cth_listener:reload_listener_without_ppv2([listeners, tcp, default]);
+end_per_group(ws_ppv2, _Config) ->
+    emqx_cth_listener:reload_listener_without_ppv2([listeners, ws, default]);
+end_per_group(_, _Config) ->
+    ok.
+
+init_per_testcase(TestCase, Config) ->
+    case erlang:function_exported(?MODULE, TestCase, 2) of
+        true -> ?MODULE:TestCase(init, Config);
+        _ -> Config
+    end.
+
+end_per_testcase(TestCase, Config) ->
+    case erlang:function_exported(?MODULE, TestCase, 2) of
+        true -> ?MODULE:TestCase('end', Config);
+        false -> ok
+    end,
+    Config.
+
+%%--------------------------------------------------------------------
+%% cases
+%%--------------------------------------------------------------------
+
+t_peersni_saved_into_conninfo(Config) ->
+    process_flag(trap_exit, true),
+
+    ClientId = <<"test-clientid1">>,
+    ClientFn = proplists:get_value(client_fn, Config),
+
+    {ok, Client} = ClientFn(ClientId, _Opts = #{}),
+    ?assertMatch(#{clientinfo := #{peersni := ?SERVER_NAME}}, emqx_cm:get_chan_info(ClientId)),
+
+    ok = emqtt:disconnect(Client).
+
+t_parse_peersni_to_client_attr(Config) ->
+    process_flag(trap_exit, true),
+
+    %% set the peersni to the client attribute
+    {ok, Variform} = emqx_variform:compile("nth(1, tokens(peersni, 'h'))"),
+    emqx_config:put([mqtt, client_attrs_init], [
+        #{expression => Variform, set_as_attr => mnts}
+    ]),
+
+    ClientId = <<"test-clientid2">>,
+    ClientFn = proplists:get_value(client_fn, Config),
+    {ok, Client} = ClientFn(ClientId, _Opts = #{}),
+
+    ?assertMatch(
+        #{clientinfo := #{client_attrs := #{mnts := <<"local">>}}}, emqx_cm:get_chan_info(ClientId)
+    ),
+
+    ok = emqtt:disconnect(Client).

+ 2 - 2
apps/emqx/test/emqx_ws_connection_SUITE.erl

@@ -169,7 +169,7 @@ t_header(_) ->
     set_ws_opts(proxy_address_header, <<"x-forwarded-for">>),
     set_ws_opts(proxy_port_header, <<"x-forwarded-port">>),
     {ok, St, _} = ?ws_conn:websocket_init([
-        req,
+        #{},
         #{
             zone => default,
             limiter => limiter_cfg(),
@@ -573,7 +573,7 @@ t_shutdown(_) ->
 st() -> st(#{}).
 st(InitFields) when is_map(InitFields) ->
     {ok, St, _} = ?ws_conn:websocket_init([
-        req,
+        #{},
         #{
             zone => default,
             listener => {ws, default},

+ 1 - 0
changes/ce/feat-13350.md

@@ -0,0 +1 @@
+Support for getting the Server Name of the client connected and storing it in the client information as `peersni`.