Просмотр исходного кода

Merge pull request #7796 from zhongwencool/fix-ws-max-connection-not-work

fix: websocket's max_connection not work
zhongwencool 3 лет назад
Родитель
Сommit
737abc5700
2 измененных файлов с 131 добавлено и 97 удалено
  1. 2 2
      apps/emqx/src/emqx_schema.erl
  2. 129 95
      apps/emqx/src/emqx_ws_connection.erl

+ 2 - 2
apps/emqx/src/emqx_schema.erl

@@ -1515,7 +1515,7 @@ base_listener() ->
             )},
         {"acceptors",
             sc(
-                integer(),
+                pos_integer(),
                 #{
                     default => 16,
                     desc => ?DESC(base_listener_acceptors)
@@ -1523,7 +1523,7 @@ base_listener() ->
             )},
         {"max_connections",
             sc(
-                hoconsc:union([infinity, integer()]),
+                hoconsc:union([infinity, pos_integer()]),
                 #{
                     default => infinity,
                     desc => ?DESC(base_listener_max_connections)

+ 129 - 95
apps/emqx/src/emqx_ws_connection.erl

@@ -272,78 +272,65 @@ check_origin_header(Req, #{listener := {Type, Listener}} = Opts) ->
         false -> ok
     end.
 
-websocket_init([
-    Req,
-    #{zone := Zone, limiter := LimiterCfg, listener := {Type, Listener}} = Opts
-]) ->
-    {Peername, Peercert} =
-        case
-            emqx_config:get_listener_conf(Type, Listener, [proxy_protocol]) andalso
-                maps:get(proxy_header, Req)
-        of
-            #{src_address := SrcAddr, src_port := SrcPort, ssl := SSL} ->
-                SourceName = {SrcAddr, SrcPort},
-                %% Notice: Only CN is available in Proxy Protocol V2 additional info
-                SourceSSL =
-                    case maps:get(cn, SSL, undefined) of
-                        undeined -> nossl;
-                        CN -> [{pp2_ssl_cn, CN}]
-                    end,
-                {SourceName, SourceSSL};
-            #{src_address := SrcAddr, src_port := SrcPort} ->
-                SourceName = {SrcAddr, SrcPort},
-                {SourceName, nossl};
-            _ ->
-                {get_peer(Req, Opts), cowboy_req:cert(Req)}
-        end,
-    Sockname = cowboy_req:sock(Req),
-    WsCookie =
-        try
-            cowboy_req:parse_cookies(Req)
-        catch
-            error:badarg ->
-                ?SLOG(error, #{msg => "bad_cookie"}),
-                undefined;
-            Error:Reason ->
-                ?SLOG(error, #{
-                    msg => "failed_to_parse_cookie",
-                    exception => Error,
-                    reason => Reason
-                }),
-                undefined
-        end,
-    ConnInfo = #{
-        socktype => ws,
-        peername => Peername,
-        sockname => Sockname,
-        peercert => Peercert,
-        ws_cookie => WsCookie,
-        conn_mod => ?MODULE
-    },
-    Limiter = emqx_limiter_container:get_limiter_by_names(
-        [?LIMITER_BYTES_IN, ?LIMITER_MESSAGE_IN], LimiterCfg
-    ),
-    MQTTPiggyback = get_ws_opts(Type, Listener, mqtt_piggyback),
-    FrameOpts = #{
-        strict_mode => emqx_config:get_zone_conf(Zone, [mqtt, strict_mode]),
-        max_size => emqx_config:get_zone_conf(Zone, [mqtt, max_packet_size])
-    },
-    ParseState = emqx_frame:initial_parse_state(FrameOpts),
-    Serialize = emqx_frame:serialize_opts(),
-    Channel = emqx_channel:init(ConnInfo, Opts),
-    GcState =
-        case emqx_config:get_zone_conf(Zone, [force_gc]) of
-            #{enable := false} -> undefined;
-            GcPolicy -> emqx_gc:init(GcPolicy)
-        end,
-    StatsTimer =
-        case emqx_config:get_zone_conf(Zone, [stats, enable]) of
-            true -> undefined;
-            false -> disabled
-        end,
-    %% MQTT Idle Timeout
-    IdleTimeout = emqx_channel:get_mqtt_conf(Zone, idle_timeout),
-    IdleTimer = start_timer(IdleTimeout, idle_timeout),
+websocket_init([Req, Opts]) ->
+    #{zone := Zone, limiter := LimiterCfg, listener := {Type, Listener}} = Opts,
+    case check_max_connection(Type, Listener) of
+        allow ->
+            {Peername, PeerCert} = get_peer_info(Type, Listener, Req, Opts),
+            Sockname = cowboy_req:sock(Req),
+            WsCookie = get_ws_cookie(Req),
+            ConnInfo = #{
+                socktype => ws,
+                peername => Peername,
+                sockname => Sockname,
+                peercert => PeerCert,
+                ws_cookie => WsCookie,
+                conn_mod => ?MODULE
+            },
+            Limiter = emqx_limiter_container:get_limiter_by_names(
+                [?LIMITER_BYTES_IN, ?LIMITER_MESSAGE_IN], LimiterCfg
+            ),
+            MQTTPiggyback = get_ws_opts(Type, Listener, mqtt_piggyback),
+            FrameOpts = #{
+                strict_mode => emqx_config:get_zone_conf(Zone, [mqtt, strict_mode]),
+                max_size => emqx_config:get_zone_conf(Zone, [mqtt, max_packet_size])
+            },
+            ParseState = emqx_frame:initial_parse_state(FrameOpts),
+            Serialize = emqx_frame:serialize_opts(),
+            Channel = emqx_channel:init(ConnInfo, Opts),
+            GcState = get_force_gc(Zone),
+            StatsTimer = get_stats_enable(Zone),
+            %% MQTT Idle Timeout
+            IdleTimeout = emqx_channel:get_mqtt_conf(Zone, idle_timeout),
+            IdleTimer = start_timer(IdleTimeout, idle_timeout),
+            tune_heap_size(Channel),
+            emqx_logger:set_metadata_peername(esockd:format(Peername)),
+            {ok,
+                #state{
+                    peername = Peername,
+                    sockname = Sockname,
+                    sockstate = running,
+                    mqtt_piggyback = MQTTPiggyback,
+                    limiter = Limiter,
+                    parse_state = ParseState,
+                    serialize = Serialize,
+                    channel = Channel,
+                    gc_state = GcState,
+                    postponed = [],
+                    stats_timer = StatsTimer,
+                    idle_timeout = IdleTimeout,
+                    idle_timer = IdleTimer,
+                    zone = Zone,
+                    listener = {Type, Listener},
+                    limiter_timer = undefined,
+                    limiter_cache = queue:new()
+                },
+                hibernate};
+        {denny, Reason} ->
+            {stop, Reason}
+    end.
+
+tune_heap_size(Channel) ->
     case
         emqx_config:get_zone_conf(
             emqx_channel:info(zone, Channel),
@@ -352,29 +339,56 @@ websocket_init([
     of
         #{enable := false} -> ok;
         ShutdownPolicy -> emqx_misc:tune_heap_size(ShutdownPolicy)
-    end,
-    emqx_logger:set_metadata_peername(esockd:format(Peername)),
-    {ok,
-        #state{
-            peername = Peername,
-            sockname = Sockname,
-            sockstate = running,
-            mqtt_piggyback = MQTTPiggyback,
-            limiter = Limiter,
-            parse_state = ParseState,
-            serialize = Serialize,
-            channel = Channel,
-            gc_state = GcState,
-            postponed = [],
-            stats_timer = StatsTimer,
-            idle_timeout = IdleTimeout,
-            idle_timer = IdleTimer,
-            zone = Zone,
-            listener = {Type, Listener},
-            limiter_timer = undefined,
-            limiter_cache = queue:new()
-        },
-        hibernate}.
+    end.
+
+get_stats_enable(Zone) ->
+    case emqx_config:get_zone_conf(Zone, [stats, enable]) of
+        true -> undefined;
+        false -> disabled
+    end.
+
+get_force_gc(Zone) ->
+    case emqx_config:get_zone_conf(Zone, [force_gc]) of
+        #{enable := false} -> undefined;
+        GcPolicy -> emqx_gc:init(GcPolicy)
+    end.
+
+get_ws_cookie(Req) ->
+    try
+        cowboy_req:parse_cookies(Req)
+    catch
+        error:badarg ->
+            ?SLOG(error, #{msg => "bad_cookie"}),
+            undefined;
+        Error:Reason ->
+            ?SLOG(error, #{
+                msg => "failed_to_parse_cookie",
+                exception => Error,
+                reason => Reason
+            }),
+            undefined
+    end.
+
+get_peer_info(Type, Listener, Req, Opts) ->
+    case
+        emqx_config:get_listener_conf(Type, Listener, [proxy_protocol]) andalso
+            maps:get(proxy_header, Req)
+    of
+        #{src_address := SrcAddr, src_port := SrcPort, ssl := SSL} ->
+            SourceName = {SrcAddr, SrcPort},
+            %% Notice: Only CN is available in Proxy Protocol V2 additional info
+            SourceSSL =
+                case maps:get(cn, SSL, undefined) of
+                    undeined -> nossl;
+                    CN -> [{pp2_ssl_cn, CN}]
+                end,
+            {SourceName, SourceSSL};
+        #{src_address := SrcAddr, src_port := SrcPort} ->
+            SourceName = {SrcAddr, SrcPort},
+            {SourceName, nossl};
+        _ ->
+            {get_peer(Req, Opts), cowboy_req:cert(Req)}
+    end.
 
 websocket_handle({binary, Data}, State) when is_list(Data) ->
     websocket_handle({binary, iolist_to_binary(Data)}, State);
@@ -1000,6 +1014,26 @@ get_peer(Req, #{listener := {Type, Listener}}) ->
         _:_ -> {Addr, PeerPort}
     end.
 
+check_max_connection(Type, Listener) ->
+    case emqx_config:get_listener_conf(Type, Listener, [max_connections]) of
+        infinity ->
+            allow;
+        Max ->
+            MatchSpec = [{{'_', emqx_ws_connection}, [], [true]}],
+            Curr = ets:select_count(emqx_channel_conn, MatchSpec),
+            case Curr >= Max of
+                false ->
+                    allow;
+                true ->
+                    Reason = #{
+                        max => Max,
+                        current => Curr,
+                        msg => "websocket_max_connections_limited"
+                    },
+                    ?SLOG(warning, Reason),
+                    {denny, Reason}
+            end
+    end.
 %%--------------------------------------------------------------------
 %% For CT tests
 %%--------------------------------------------------------------------