Selaa lähdekoodia

feat(acl): make mqtt over websocket work with the new config

Shawn 4 vuotta sitten
vanhempi
commit
630b54f6ee

+ 9 - 2
apps/emqx/etc/emqx.conf

@@ -2218,8 +2218,8 @@ example_common_websocket_options {
   ##
   ## @doc listeners.<name>.websocket.compress
   ## ValueType: Boolean
-  ## Default: true
-  websocket.compress: true
+  ## Default: false
+  websocket.compress: false
 
   ## The idle timeout for external WebSocket connections.
   ##
@@ -2244,6 +2244,13 @@ example_common_websocket_options {
   ## Default: true
   websocket.fail_if_no_subprotocol: true
 
+  ## Supported subprotocols
+  ##
+  ## @doc listeners.<name>.websocket.supported_subprotocols
+  ## ValueType: String
+  ## Default: mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5
+  websocket.supported_subprotocols: "mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5"
+
   ## Enable origin check in header for websocket connection
   ##
   ## @doc listeners.<name>.websocket.check_origin_enable

+ 10 - 18
apps/emqx/src/emqx_connection.erl

@@ -243,7 +243,7 @@ init(Parent, Transport, RawSocket, Options) ->
             exit_on_sock_error(Reason)
     end.
 
-init_state(Transport, Socket, Options) ->
+init_state(Transport, Socket, #{zone := Zone, listener := Listener} = Opts) ->
     {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
     {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
     Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]),
@@ -253,8 +253,6 @@ init_state(Transport, Socket, Options) ->
                  peercert => Peercert,
                  conn_mod => ?MODULE
                 },
-    Zone = maps:get(zone, Options),
-    Listener = maps:get(listener, Options),
     Limiter = emqx_limiter:init(Zone, undefined, undefined, []),
     FrameOpts = #{
         strict_mode => emqx_config:get_listener_conf(Zone, Listener, [mqtt, strict_mode]),
@@ -262,7 +260,7 @@ init_state(Transport, Socket, Options) ->
     },
     ParseState = emqx_frame:initial_parse_state(FrameOpts),
     Serialize = emqx_frame:serialize_opts(),
-    Channel = emqx_channel:init(ConnInfo, Options),
+    Channel = emqx_channel:init(ConnInfo, Opts),
     GcState = case emqx_config:get_listener_conf(Zone, Listener, [force_gc]) of
         #{enable := false} -> undefined;
         GcPolicy -> emqx_gc:init(GcPolicy)
@@ -295,11 +293,9 @@ run_loop(Parent, State = #state{transport = Transport,
                                 peername  = Peername,
                                 channel   = Channel}) ->
     emqx_logger:set_metadata_peername(esockd:format(Peername)),
-    case emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
-            emqx_channel:info(listener, Channel), [force_shutdown]) of
-        #{enable := false} -> ok;
-        ShutdownPolicy -> emqx_misc:tune_heap_size(ShutdownPolicy)
-    end,
+    ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
+            emqx_channel:info(listener, Channel), [force_shutdown]),
+    emqx_misc:tune_heap_size(ShutdownPolicy),
     case activate_socket(State) of
         {ok, NState} -> hibernate(Parent, NState);
         {error, Reason} ->
@@ -793,15 +789,11 @@ check_oom(State = #state{channel = Channel}) ->
     ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
         emqx_channel:info(listener, Channel), [force_shutdown]),
     ?tp(debug, check_oom, #{policy => ShutdownPolicy}),
-    case ShutdownPolicy of
-        #{enable := false} -> ok;
-        ShutdownPolicy ->
-            case emqx_misc:check_oom(ShutdownPolicy) of
-                {shutdown, Reason} ->
-                    %% triggers terminate/2 callback immediately
-                    erlang:exit({shutdown, Reason});
-                _ -> ok
-            end
+    case emqx_misc:check_oom(ShutdownPolicy) of
+        {shutdown, Reason} ->
+            %% triggers terminate/2 callback immediately
+            erlang:exit({shutdown, Reason});
+        _ -> ok
     end,
     State.
 

+ 14 - 13
apps/emqx/src/emqx_listeners.erl

@@ -74,13 +74,13 @@ do_start_listener(ZoneName, ListenerName, #{type := tcp, bind := ListenOn} = Opt
 %% Start MQTT/WS listener
 do_start_listener(ZoneName, ListenerName, #{type := ws, bind := ListenOn} = Opts) ->
     Id = listener_id(ZoneName, ListenerName),
-    RanchOpts = ranch_opts(Opts),
+    RanchOpts = ranch_opts(ListenOn, Opts),
     WsOpts = ws_opts(ZoneName, ListenerName, Opts),
     case is_ssl(Opts) of
         false ->
-            cowboy:start_clear(Id, with_port(ListenOn, RanchOpts), WsOpts);
+            cowboy:start_clear(Id, RanchOpts, WsOpts);
         true ->
-            cowboy:start_tls(Id, with_port(ListenOn, RanchOpts), WsOpts)
+            cowboy:start_tls(Id, RanchOpts, WsOpts)
     end.
 
 esockd_opts(Opts0) ->
@@ -104,21 +104,22 @@ ws_opts(ZoneName, ListenerName, Opts) ->
     ProxyProto = maps:get(proxy_protocol, Opts, false),
     #{env => #{dispatch => Dispatch}, proxy_header => ProxyProto}.
 
-ranch_opts(Opts) ->
+ranch_opts(ListenOn, Opts) ->
     NumAcceptors = maps:get(acceptors, Opts, 4),
     MaxConnections = maps:get(max_connections, Opts, 1024),
+    SocketOpts = case is_ssl(Opts) of
+        true -> tcp_opts(Opts) ++ proplists:delete(handshake_timeout, ssl_opts(Opts));
+        false -> tcp_opts(Opts)
+    end,
     #{num_acceptors => NumAcceptors,
       max_connections => MaxConnections,
       handshake_timeout => maps:get(handshake_timeout, Opts, 15000),
-      socket_opts => case is_ssl(Opts) of
-          true -> tcp_opts(Opts) ++ proplists:delete(handshake_timeout, ssl_opts(Opts));
-          false -> tcp_opts(Opts)
-        end}.
-
-with_port(Port, Opts = #{socket_opts := SocketOption}) when is_integer(Port) ->
-    Opts#{socket_opts => [{port, Port}| SocketOption]};
-with_port({Addr, Port}, Opts = #{socket_opts := SocketOption}) ->
-    Opts#{socket_opts => [{ip, Addr}, {port, Port}| SocketOption]}.
+      socket_opts => ip_port(ListenOn) ++ SocketOpts}.
+
+ip_port(Port) when is_integer(Port) ->
+    [{port, Port}];
+ip_port({Addr, Port}) ->
+    [{ip, Addr}, {port, Port}].
 
 esockd_access_rules(StrRules) ->
     Access = fun(S) ->

+ 6 - 5
apps/emqx/src/emqx_map_lib.erl

@@ -43,16 +43,17 @@ deep_get(ConfKeyPath, Map, Default) ->
         {ok, Data} -> Data
     end.
 
--spec deep_find(config_key_path(), map()) -> {ok, term()} | {not_found, config_key(), term()}.
+-spec deep_find(config_key_path(), map()) ->
+    {ok, term()} | {not_found, config_key_path(), term()}.
 deep_find([], Map) ->
     {ok, Map};
-deep_find([Key | KeyPath], Map) when is_map(Map) ->
+deep_find([Key | KeyPath] = Path, Map) when is_map(Map) ->
     case maps:find(Key, Map) of
         {ok, SubMap} -> deep_find(KeyPath, SubMap);
-        error -> {not_found, Key, Map}
+        error -> {not_found, Path, Map}
     end;
-deep_find([Key | _KeyPath], Data) ->
-    {not_found, Key, Data}.
+deep_find(_KeyPath, Data) ->
+    {not_found, _KeyPath, Data}.
 
 -spec deep_put(config_key_path(), map(), term()) -> map().
 deep_put([], Map, Config) when is_map(Map) ->

+ 5 - 5
apps/emqx/src/emqx_schema.erl

@@ -364,11 +364,11 @@ fields("mqtt_ws_listener") ->
 fields("ws_opts") ->
     [ {"mqtt_path", t(string(), undefined, "/mqtt")}
     , {"mqtt_piggyback", t(union(single, multiple), undefined, multiple)}
-    , {"compress", t(boolean())}
+    , {"compress", t(boolean(), undefined, false)}
     , {"idle_timeout", t(duration(), undefined, "15s")}
     , {"max_frame_size", maybe_infinity(integer())}
     , {"fail_if_no_subprotocol", t(boolean(), undefined, true)}
-    , {"supported_subprotocols", t(string(), undefined,
+    , {"supported_subprotocols", t(comma_separated_list(), undefined,
         "mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5")}
     , {"check_origin_enable", t(boolean(), undefined, false)}
     , {"allow_origin_absence", t(boolean(), undefined, true)}
@@ -401,12 +401,12 @@ fields("ssl_opts") ->
 
 fields("deflate_opts") ->
     [ {"level", t(union([none, default, best_compression, best_speed]))}
-    , {"mem_level", t(range(1, 9))}
+    , {"mem_level", t(range(1, 9), undefined, 8)}
     , {"strategy", t(union([default, filtered, huffman_only, rle]))}
     , {"server_context_takeover", t(union(takeover, no_takeover))}
     , {"client_context_takeover", t(union(takeover, no_takeover))}
-    , {"server_max_window_bits", t(integer())}
-    , {"client_max_window_bits", t(integer())}
+    , {"server_max_window_bits", t(range(8, 15), undefined, 15)}
+    , {"client_max_window_bits", t(range(8, 15), undefined, 15)}
     ];
 
 fields("module") ->

+ 57 - 47
apps/emqx/src/emqx_ws_connection.erl

@@ -174,21 +174,13 @@ call(WsPid, Req, Timeout) when is_pid(WsPid) ->
 %% WebSocket callbacks
 %%--------------------------------------------------------------------
 
-init(Req, Opts) ->
+init(Req, #{zone := Zone, listener := Listener} = Opts) ->
     %% WS Transport Idle Timeout
-    IdleTimeout = proplists:get_value(idle_timeout, Opts, 7200000),
-    DeflateOptions = maps:from_list(proplists:get_value(deflate_options, Opts, [])),
-    MaxFrameSize = case proplists:get_value(max_frame_size, Opts, 0) of
-                       0 -> infinity;
-                       I -> I
-                   end,
-    Compress = proplists:get_bool(compress, Opts),
-    WsOpts = #{compress       => Compress,
-               deflate_opts   => DeflateOptions,
-               max_frame_size => MaxFrameSize,
-               idle_timeout   => IdleTimeout
+    WsOpts = #{compress => get_ws_opts(Zone, Listener, compress),
+               deflate_opts => get_ws_opts(Zone, Listener, deflate_opts),
+               max_frame_size => get_ws_opts(Zone, Listener, max_frame_size),
+               idle_timeout => get_ws_opts(Zone, Listener, idle_timeout)
               },
-
     case check_origin_header(Req, Opts) of
         {error, Message} ->
             ?LOG(error, "Invalid Origin Header ~p~n", [Message]),
@@ -196,18 +188,17 @@ init(Req, Opts) ->
         ok -> parse_sec_websocket_protocol(Req, Opts, WsOpts)
     end.
 
-parse_sec_websocket_protocol(Req, Opts, WsOpts) ->
-    FailIfNoSubprotocol = proplists:get_value(fail_if_no_subprotocol, Opts),
+parse_sec_websocket_protocol(Req, #{zone := Zone, listener := Listener} = Opts, WsOpts) ->
     case cowboy_req:parse_header(<<"sec-websocket-protocol">>, Req) of
         undefined ->
-            case FailIfNoSubprotocol of
+            case get_ws_opts(Zone, Listener, fail_if_no_subprotocol) of
                 true ->
                     {ok, cowboy_req:reply(400, Req), WsOpts};
                 false ->
                     {cowboy_websocket, Req, [Req, Opts], WsOpts}
             end;
         Subprotocols ->
-            SupportedSubprotocols = proplists:get_value(supported_subprotocols, Opts),
+            SupportedSubprotocols = get_ws_opts(Zone, Listener, supported_subprotocols),
             NSupportedSubprotocols = [list_to_binary(Subprotocol)
                                       || Subprotocol <- SupportedSubprotocols],
             case pick_subprotocol(Subprotocols, NSupportedSubprotocols) of
@@ -231,31 +222,30 @@ pick_subprotocol([Subprotocol | Rest], SupportedSubprotocols) ->
             pick_subprotocol(Rest, SupportedSubprotocols)
     end.
 
-parse_header_fun_origin(Req, Opts) ->
+parse_header_fun_origin(Req, #{zone := Zone, listener := Listener}) ->
     case cowboy_req:header(<<"origin">>, Req) of
         undefined ->
-                case proplists:get_bool(allow_origin_absence, Opts) of
+                case get_ws_opts(Zone, Listener, allow_origin_absence) of
                     true -> ok;
                     false -> {error, origin_header_cannot_be_absent}
                 end;
         Value ->
-            Origins = proplists:get_value(check_origins, Opts, []),
-            case lists:member(Value, Origins) of
+            case lists:member(Value, get_ws_opts(Zone, Listener, check_origins)) of
                 true -> ok;
                 false -> {origin_not_allowed, Value}
             end
     end.
 
-check_origin_header(Req, Opts) ->
-    case proplists:get_bool(check_origin_enable, Opts) of
+check_origin_header(Req, #{zone := Zone, listener := Listener} = Opts) ->
+    case get_ws_opts(Zone, Listener, check_origin_enable) of
         true -> parse_header_fun_origin(Req, Opts);
         false -> ok
     end.
 
-websocket_init([Req, Opts]) ->
+websocket_init([Req, #{zone := Zone, listener := Listener} = Opts]) ->
     {Peername, Peercert} =
-        case proplists:get_bool(proxy_protocol, Opts)
-        andalso maps:get(proxy_header, Req) of
+        case emqx_config:get_listener_conf(Zone, Listener, [proxy_protocol]) andalso
+             maps:get(proxy_header, Req) of
             #{src_address := SrcAddr, src_port := SrcPort, ssl := SSL} ->
                 SourceName = {SrcAddr, SrcPort},
                 %% Notice: Only CN is available in Proxy Protocol V2 additional info
@@ -266,7 +256,7 @@ websocket_init([Req, Opts]) ->
                 {SourceName, SourceSSL};
             #{src_address := SrcAddr, src_port := SrcPort} ->
                 SourceName = {SrcAddr, SrcPort},
-                {SourceName , nossl};
+                {SourceName, nossl};
             _ ->
                 {get_peer(Req, Opts), cowboy_req:cert(Req)}
         end,
@@ -288,22 +278,31 @@ websocket_init([Req, Opts]) ->
                  ws_cookie => WsCookie,
                  conn_mod  => ?MODULE
                 },
-    Zone = proplists:get_value(zone, Opts),
-    PubLimit = emqx_zone:publish_limit(Zone),
-    BytesIn = proplists:get_value(rate_limit, Opts),
-    RateLimit = emqx_zone:ratelimit(Zone),
-    Limiter = emqx_limiter:init(Zone, PubLimit, BytesIn, RateLimit),
-    MQTTPiggyback = proplists:get_value(mqtt_piggyback, Opts, multiple),
-    FrameOpts = emqx_zone:mqtt_frame_options(Zone),
+    Limiter = emqx_limiter:init(Zone, undefined, undefined, []),
+    MQTTPiggyback = get_ws_opts(Zone, Listener, mqtt_piggyback),
+    FrameOpts = #{
+        strict_mode => emqx_config:get_listener_conf(Zone, Listener, [mqtt, strict_mode]),
+        max_size => emqx_config:get_listener_conf(Zone, Listener, [mqtt, max_packet_size])
+    },
     ParseState = emqx_frame:initial_parse_state(FrameOpts),
     Serialize = emqx_frame:serialize_opts(),
     Channel = emqx_channel:init(ConnInfo, Opts),
-    GcState = emqx_zone:init_gc_state(Zone),
-    StatsTimer = emqx_zone:stats_timer(Zone),
+    GcState = case emqx_config:get_listener_conf(Zone, Listener, [force_gc]) of
+        #{enable := false} -> undefined;
+        GcPolicy -> emqx_gc:init(GcPolicy)
+    end,
+    StatsTimer = case emqx_config:get_listener_conf(Zone, Listener, [stats, enable]) of
+        true -> undefined;
+        false -> disabled
+    end,
     %% MQTT Idle Timeout
-    IdleTimeout = emqx_zone:idle_timeout(Zone),
+    IdleTimeout = emqx_channel:get_mqtt_conf(Zone, Listener, idle_timeout),
     IdleTimer = start_timer(IdleTimeout, idle_timeout),
-    emqx_misc:tune_heap_size(emqx_zone:oom_policy(Zone)),
+    case emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
+            emqx_channel:info(listener, Channel), [force_shutdown]) of
+        #{enable := false} -> ok;
+        ShutdownPolicy -> emqx_misc:tune_heap_size(ShutdownPolicy)
+    end,
     emqx_logger:set_metadata_peername(esockd:format(Peername)),
     {ok, #state{peername       = Peername,
                 sockname       = Sockname,
@@ -317,7 +316,9 @@ websocket_init([Req, Opts]) ->
                 postponed      = [],
                 stats_timer    = StatsTimer,
                 idle_timeout   = IdleTimeout,
-                idle_timer     = IdleTimer
+                idle_timer     = IdleTimer,
+                zone           = Zone,
+                listener       = Listener
                }, hibernate}.
 
 websocket_handle({binary, Data}, State) when is_list(Data) ->
@@ -517,11 +518,16 @@ run_gc(Stats, State = #state{gc_state = GcSt}) ->
     end.
 
 check_oom(State = #state{channel = Channel}) ->
-    OomPolicy = emqx_zone:oom_policy(emqx_channel:info(zone, Channel)),
-    case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of
-        Shutdown = {shutdown, _Reason} ->
-            postpone(Shutdown, State);
-        _Other -> State
+    ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
+        emqx_channel:info(listener, Channel), [force_shutdown]),
+    case ShutdownPolicy of
+        #{enable := false} -> ok;
+        #{enable := true} ->
+            case emqx_misc:check_oom(ShutdownPolicy) of
+                Shutdown = {shutdown, _Reason} ->
+                    postpone(Shutdown, State);
+                _Other -> State
+            end
     end.
 
 %%--------------------------------------------------------------------
@@ -741,9 +747,10 @@ classify([Event|More], Packets, Cmds, Events) ->
 
 trigger(Event) -> erlang:send(self(), Event).
 
-get_peer(Req, Opts) ->
+get_peer(Req, #{zone := Zone, listener := Listener}) ->
     {PeerAddr, PeerPort} = cowboy_req:peer(Req),
-    AddrHeader = cowboy_req:header(proplists:get_value(proxy_address_header, Opts), Req, <<>>),
+    AddrHeader = cowboy_req:header(
+        get_ws_opts(Zone, Listener, proxy_address_header), Req, <<>>),
     ClientAddr = case string:tokens(binary_to_list(AddrHeader), ", ") of
                      [] ->
                          undefined;
@@ -756,7 +763,8 @@ get_peer(Req, Opts) ->
                _ ->
                    PeerAddr
            end,
-    PortHeader = cowboy_req:header(proplists:get_value(proxy_port_header, Opts), Req, <<>>),
+    PortHeader = cowboy_req:header(
+        get_ws_opts(Zone, Listener, proxy_port_header), Req, <<>>),
     ClientPort = case string:tokens(binary_to_list(PortHeader), ", ") of
                      [] ->
                          undefined;
@@ -777,3 +785,5 @@ set_field(Name, Value, State) ->
     Pos = emqx_misc:index_of(Name, record_info(fields, state)),
     setelement(Pos+1, State, Value).
 
+get_ws_opts(Zone, Listener, Key) ->
+    emqx_config:get_listener_conf(Zone, Listener, [websocket, Key]).