Parcourir la source

feat(websocket): support for checking subprotocols (#4099)

tigercl il y a 5 ans
Parent
commit
5878950dc3
4 fichiers modifiés avec 69 ajouts et 22 suppressions
  1. 18 6
      etc/emqx.conf
  2. 18 7
      priv/emqx.schema
  3. 30 8
      src/emqx_ws_connection.erl
  4. 3 1
      test/emqx_ws_connection_SUITE.erl

+ 18 - 6
etc/emqx.conf

@@ -1554,10 +1554,16 @@ listener.ws.external.zone = external
 ## Value: ACL Rule
 listener.ws.external.access.1 = allow all
 
-## Verify if the protocol header is valid. Turn off for WeChat MiniApp.
+## If set to true, the server fails if the client does not have a Sec-WebSocket-Protocol to send.
+## Set to false for WeChat MiniApp.
 ##
-## Value: on | off
-listener.ws.external.verify_protocol_header = on
+## Value: true | false
+## listener.ws.external.fail_if_no_subprotocol = on
+
+## Supported subprotocols
+##
+## Default: mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5
+## listener.ws.external.supported_protocols = mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5
 
 ## Enable the Proxy Protocol V1/2 if the EMQ cluster is deployed behind
 ## HAProxy or Nginx.
@@ -1769,10 +1775,16 @@ listener.wss.external.zone = external
 ## Value: ACL Rule
 listener.wss.external.access.1 = allow all
 
-## See: listener.ws.external.verify_protocol_header
+## If set to true, the server fails if the client does not have a Sec-WebSocket-Protocol to send.
+## Set to false for WeChat MiniApp.
 ##
-## Value: on | off
-listener.wss.external.verify_protocol_header = on
+## Value: true | false
+## listener.wss.external.fail_if_no_subprotocol = true
+
+## Supported subprotocols
+##
+## Default: mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5
+## listener.ws.external.supported_protocols = mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5
 
 ## Enable the Proxy Protocol V1/2 support.
 ##

+ 18 - 7
priv/emqx.schema

@@ -1472,9 +1472,14 @@ end}.
   {datatype, string}
 ]}.
 
-{mapping, "listener.ws.$name.verify_protocol_header", "emqx.listeners", [
-  {default, on},
-  {datatype, flag}
+{mapping, "listener.ws.$name.fail_if_no_subprotocol", "emqx.listeners", [
+  {default, true},
+  {datatype, {enum, [true, false]}}
+]}.
+
+{mapping, "listener.ws.$name.supported_subprotocols", "emqx.listeners", [
+  {default, "mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5"},
+  {datatype, string}
 ]}.
 
 {mapping, "listener.ws.$name.proxy_protocol", "emqx.listeners", [
@@ -1638,9 +1643,14 @@ end}.
   {datatype, string}
 ]}.
 
-{mapping, "listener.wss.$name.verify_protocol_header", "emqx.listeners", [
-  {default, on},
-  {datatype, flag}
+{mapping, "listener.wss.$name.fail_if_no_subprotocol", "emqx.listeners", [
+  {default, true},
+  {datatype, {enum, [true, false]}}
+]}.
+
+{mapping, "listener.wss.$name.supported_subprotocols", "emqx.listeners", [
+  {default, "mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5"},
+  {datatype, string}
 ]}.
 
 {mapping, "listener.wss.$name.access.$id", "emqx.listeners", [
@@ -1892,7 +1902,8 @@ end}.
                           {rate_limit, RateLimit(cuttlefish:conf_get(Prefix ++ ".rate_limit", Conf, undefined))},
                           {proxy_protocol, cuttlefish:conf_get(Prefix ++ ".proxy_protocol", Conf, undefined)},
                           {proxy_protocol_timeout, cuttlefish:conf_get(Prefix ++ ".proxy_protocol_timeout", Conf, undefined)},
-                          {verify_protocol_header, cuttlefish:conf_get(Prefix ++ ".verify_protocol_header", Conf, undefined)},
+                          {fail_if_no_subprotocol, cuttlefish:conf_get(Prefix ++ ".fail_if_no_subprotocol", Conf, undefined)},
+                          {supported_subprotocols, string:tokens(cuttlefish:conf_get(Prefix ++ ".supported_subprotocols", Conf, ""), ", ")},
                           {peer_cert_as_username, cuttlefish:conf_get(Prefix ++ ".peer_cert_as_username", Conf, undefined)},
                           {compress, cuttlefish:conf_get(Prefix ++ ".compress", Conf, undefined)},
                           {idle_timeout, cuttlefish:conf_get(Prefix ++ ".idle_timeout", Conf, undefined)},

+ 30 - 8
src/emqx_ws_connection.erl

@@ -192,16 +192,38 @@ init(Req, Opts) ->
     end.
 
 parse_sec_websocket_protocol(Req, Opts, WsOpts) ->
+    FailIfNoSubprotocol = proplists:get_value(fail_if_no_subprotocol, Opts),
     case cowboy_req:parse_header(<<"sec-websocket-protocol">>, Req) of
         undefined ->
-            %% TODO: why not reply 500???
-            {cowboy_websocket, Req, [Req, Opts], WsOpts};
-        [<<"mqtt", Vsn/binary>>] ->
-            Resp = cowboy_req:set_resp_header(
-                <<"sec-websocket-protocol">>, <<"mqtt", Vsn/binary>>, Req),
-            {cowboy_websocket, Resp, [Req, Opts], WsOpts};
-        _ ->
-            {ok, cowboy_req:reply(400, Req), WsOpts}
+            case FailIfNoSubprotocol of
+                true ->
+                    {ok, cowboy_req:reply(400, Req), WsOpts};
+                false ->
+                    {cowboy_websocket, Req, [Req, Opts], WsOpts}
+            end;
+        Subprotocols ->
+            SupportedSubprotocols = proplists:get_value(supported_subprotocols, Opts),
+            NSupportedSubprotocols = [list_to_binary(Subprotocol)
+                                      || Subprotocol <- SupportedSubprotocols],
+            case pick_subprotocol(Subprotocols, NSupportedSubprotocols) of
+                {ok, Subprotocol} ->
+                    Resp = cowboy_req:set_resp_header(<<"sec-websocket-protocol">>,
+                                                      Subprotocol,
+                                                      Req),
+                    {cowboy_websocket, Resp, [Req, Opts], WsOpts};
+                {error, no_supported_subprotocol} ->
+                    {ok, cowboy_req:reply(400, Req), WsOpts}
+            end
+    end.
+
+pick_subprotocol([], _SupportedSubprotocols) ->
+    {error, no_supported_subprotocol};
+pick_subprotocol([Subprotocol | Rest], SupportedSubprotocols) ->
+    case lists:member(Subprotocol, SupportedSubprotocols) of
+        true ->
+            {ok, Subprotocol};
+        false ->
+            pick_subprotocol(Rest, SupportedSubprotocols)
     end.
 
 parse_header_fun_origin(Req, Opts) ->

+ 3 - 1
test/emqx_ws_connection_SUITE.erl

@@ -146,7 +146,9 @@ t_call(_) ->
     ?assertEqual(Info, ?ws_conn:call(WsPid, info)).
 
 t_init(_) ->
-    Opts = [{idle_timeout, 300000}],
+    Opts = [{idle_timeout, 300000},
+            {fail_if_no_subprotocol, false},
+            {supported_subprotocols, ["mqtt"]}],
     WsOpts = #{compress       => false,
                deflate_opts   => #{},
                max_frame_size => infinity,