Просмотр исходного кода

refactor(ssl): use 'available' for defaults

prior to this change, for TLS versions, 'default's are 'available's
there is no need for such an alias.

now we call available_versions with a specific tag:
tls, dtls, or all
Zaiming (Stone) Shi 3 лет назад
Родитель
Сommit
7851a3aefd

+ 1 - 1
apps/emqx/src/emqx_listeners.erl

@@ -583,7 +583,7 @@ enable_authn(Opts) ->
     maps:get(enable_authn, Opts, true).
 
 ssl_opts(Opts) ->
-    emqx_tls_lib:to_server_opts(maps:get(ssl_options, Opts, #{})).
+    emqx_tls_lib:to_server_opts(tls, maps:get(ssl_options, Opts, #{})).
 
 tcp_opts(Opts) ->
     maps:to_list(

+ 8 - 9
apps/emqx/src/emqx_schema.erl

@@ -1843,6 +1843,8 @@ filter(Opts) ->
 common_ssl_opts_schema(Defaults) ->
     D = fun(Field) -> maps:get(to_atom(Field), Defaults, undefined) end,
     Df = fun(Field, Default) -> maps:get(to_atom(Field), Defaults, Default) end,
+    Collection = maps:get(versions, Defaults, tls_all_available),
+    AvailableVersions = default_tls_vsns(Collection),
     [
         {"cacertfile",
             sc(
@@ -1909,9 +1911,9 @@ common_ssl_opts_schema(Defaults) ->
             sc(
                 hoconsc:array(typerefl:atom()),
                 #{
-                    default => default_tls_vsns(maps:get(versions, Defaults, tls_all_available)),
+                    default => AvailableVersions,
                     desc => ?DESC(common_ssl_opts_schema_versions),
-                    validator => fun validate_tls_versions/1
+                    validator => fun(Inputs) -> validate_tls_versions(AvailableVersions, Inputs) end
                 }
             )},
         {"ciphers", ciphers_schema(D("ciphers"))},
@@ -2022,9 +2024,9 @@ client_ssl_opts_schema(Defaults) ->
         ].
 
 default_tls_vsns(dtls_all_available) ->
-    proplists:get_value(available_dtls, ssl:versions());
+    emqx_tls_lib:available_versions(dtls);
 default_tls_vsns(tls_all_available) ->
-    emqx_tls_lib:default_versions().
+    emqx_tls_lib:available_versions(tls).
 
 -spec ciphers_schema(quic | dtls_all_available | tls_all_available | undefined) ->
     hocon_schema:field_schema().
@@ -2248,13 +2250,10 @@ validate_ciphers(Ciphers) ->
         Bad -> {error, {bad_ciphers, Bad}}
     end.
 
-validate_tls_versions(Versions) ->
-    AvailableVersions =
-        proplists:get_value(available, ssl:versions()) ++
-            proplists:get_value(available_dtls, ssl:versions()),
+validate_tls_versions(AvailableVersions, Versions) ->
     case lists:filter(fun(V) -> not lists:member(V, AvailableVersions) end, Versions) of
         [] -> ok;
-        Vs -> {error, {unsupported_ssl_versions, Vs}}
+        Vs -> {error, {unsupported_tls_versions, Vs}}
     end.
 
 validations() ->

+ 44 - 44
apps/emqx/src/emqx_tls_lib.erl

@@ -18,8 +18,8 @@
 
 %% version & cipher suites
 -export([
-    default_versions/0,
-    integral_versions/1,
+    available_versions/1,
+    integral_versions/2,
     default_ciphers/0,
     selected_ciphers/1,
     integral_ciphers/2,
@@ -37,8 +37,9 @@
 ]).
 
 -export([
-    to_server_opts/1,
-    to_client_opts/1
+    to_server_opts/2,
+    to_client_opts/1,
+    to_client_opts/2
 ]).
 
 -include("logger.hrl").
@@ -111,27 +112,23 @@
     "RSA-PSK-AES128-CBC-SHA"
 ]).
 
-%% @doc Returns the default supported tls versions.
--spec default_versions() -> [atom()].
-default_versions() -> available_versions().
-
 %% @doc Validate a given list of desired tls versions.
 %% raise an error exception if non of them are available.
 %% The input list can be a string/binary of comma separated versions.
--spec integral_versions(undefined | string() | binary() | [ssl:tls_version()]) ->
+-spec integral_versions(tls | dtls, undefined | string() | binary() | [ssl:tls_version()]) ->
     [ssl:tls_version()].
-integral_versions(undefined) ->
-    integral_versions(default_versions());
-integral_versions([]) ->
-    integral_versions(default_versions());
-integral_versions(<<>>) ->
-    integral_versions(default_versions());
-integral_versions(Desired) when ?IS_STRING(Desired) ->
-    integral_versions(iolist_to_binary(Desired));
-integral_versions(Desired) when is_binary(Desired) ->
-    integral_versions(parse_versions(Desired));
-integral_versions(Desired) ->
-    Available = available_versions(),
+integral_versions(Type, undefined) ->
+    available_versions(Type);
+integral_versions(Type, []) ->
+    available_versions(Type);
+integral_versions(Type, <<>>) ->
+    available_versions(Type);
+integral_versions(Type, Desired) when ?IS_STRING(Desired) ->
+    integral_versions(Type, iolist_to_binary(Desired));
+integral_versions(Type, Desired) when is_binary(Desired) ->
+    integral_versions(Type, parse_versions(Desired));
+integral_versions(Type, Desired) ->
+    Available = available_versions(Type),
     case lists:filter(fun(V) -> lists:member(V, Available) end, Desired) of
         [] ->
             erlang:error(#{
@@ -153,11 +150,11 @@ all_ciphers_set_cached() ->
             Set
     end.
 
-%% @doc Return a list of all supported ciphers.
+%% @hidden Return a list of all supported ciphers.
 all_ciphers() ->
-    all_ciphers(default_versions()).
+    all_ciphers(available_versions(all)).
 
-%% @doc Return a list of (openssl string format) cipher suites.
+%% @hidden Return a list of (openssl string format) cipher suites.
 -spec all_ciphers([ssl:tls_version()]) -> [string()].
 all_ciphers(['tlsv1.3']) ->
     %% When it's only tlsv1.3 wanted, use 'exclusive' here
@@ -172,7 +169,7 @@ all_ciphers(Versions) ->
 
 %% @doc All Pre-selected TLS ciphers.
 default_ciphers() ->
-    selected_ciphers(available_versions()).
+    selected_ciphers(available_versions(all)).
 
 %% @doc Pre-selected TLS ciphers for given versions..
 selected_ciphers(Vsns) ->
@@ -218,22 +215,17 @@ ensure_tls13_cipher(true, Ciphers) ->
 ensure_tls13_cipher(false, Ciphers) ->
     Ciphers.
 
-%% default ssl versions based on available versions.
--spec available_versions() -> [atom()].
-available_versions() ->
-    OtpRelease = list_to_integer(erlang:system_info(otp_release)),
-    default_versions(OtpRelease).
-
-%% tlsv1.3 is available from OTP-22 but we do not want to use until 23.
-default_versions(OtpRelease) when OtpRelease >= 23 ->
-    availables();
-default_versions(_) ->
-    lists:delete('tlsv1.3', availables()).
-
-availables() ->
+%% @doc Returns the default available tls/dtls versions.
+available_versions(Type) ->
     All = ssl:versions(),
-    proplists:get_value(available, All) ++
-        proplists:get_value(available_dtls, All).
+    available_versions(Type, All).
+
+available_versions(tls, All) ->
+    proplists:get_value(available, All);
+available_versions(dtls, All) ->
+    proplists:get_value(available_dtls, All);
+available_versions(all, All) ->
+    available_versions(tls, All) ++ available_versions(dtls, All).
 
 %% Deduplicate a list without re-ordering the elements.
 dedup([]) ->
@@ -494,17 +486,25 @@ do_drop_invalid_certs([Key | Keys], SSL) ->
 
 %% @doc Convert hocon-checked ssl server options (map()) to
 %% proplist accepted by ssl library.
-to_server_opts(Opts) ->
-    Versions = integral_versions(maps:get(versions, Opts, undefined)),
+-spec to_server_opts(tls | dtls, map()) -> [{atom(), term()}].
+to_server_opts(Type, Opts) ->
+    Versions = integral_versions(Type, maps:get(versions, Opts, undefined)),
     Ciphers = integral_ciphers(Versions, maps:get(ciphers, Opts, undefined)),
     maps:to_list(Opts#{
         ciphers => Ciphers,
         versions => Versions
     }).
 
-%% @doc Convert hocon-checked ssl client options (map()) to
+%% @doc Convert hocon-checked tls client options (map()) to
 %% proplist accepted by ssl library.
+-spec to_client_opts(map()) -> [{atom(), term()}].
 to_client_opts(Opts) ->
+    to_client_opts(tls, Opts).
+
+%% @doc Convert hocon-checked tls or dtls client options (map()) to
+%% proplist accepted by ssl library.
+-spec to_client_opts(tls | dtls, map()) -> [{atom(), term()}].
+to_client_opts(Type, Opts) ->
     GetD = fun(Key, Default) -> fuzzy_map_get(Key, Opts, Default) end,
     Get = fun(Key) -> GetD(Key, undefined) end,
     case GetD(enable, false) of
@@ -514,7 +514,7 @@ to_client_opts(Opts) ->
             CAFile = ensure_str(Get(cacertfile)),
             Verify = GetD(verify, verify_none),
             SNI = ensure_sni(Get(server_name_indication)),
-            Versions = integral_versions(Get(versions)),
+            Versions = integral_versions(Type, Get(versions)),
             Ciphers = integral_ciphers(Versions, Get(ciphers)),
             filter([
                 {keyfile, KeyFile},

+ 1 - 1
apps/emqx/test/emqx_schema_tests.erl

@@ -134,7 +134,7 @@ ciphers_schema_test() ->
 
 bad_tls_version_test() ->
     Sc = emqx_schema:server_ssl_opts_schema(#{}, false),
-    Reason = {unsupported_ssl_versions, [foo]},
+    Reason = {unsupported_tls_versions, [foo]},
     ?assertThrow(
         {_Sc, [#{kind := validation_error, reason := Reason}]},
         validate(Sc, #{<<"versions">> => [<<"foo">>]})

+ 28 - 18
apps/emqx/test/emqx_tls_lib_tests.erl

@@ -51,24 +51,34 @@ test_cipher_format(Input) ->
     ?assertEqual([?TLS_13_CIPHER, ?TLS_12_CIPHER], Ciphers).
 
 tls_versions_test() ->
-    ?assert(lists:member('tlsv1.3', emqx_tls_lib:default_versions())).
-
-tls_version_unknown_test() ->
-    ?assertEqual(
-        emqx_tls_lib:default_versions(),
-        emqx_tls_lib:integral_versions([])
-    ),
-    ?assertEqual(
-        emqx_tls_lib:default_versions(),
-        emqx_tls_lib:integral_versions(<<>>)
-    ),
-    ?assertEqual(
-        emqx_tls_lib:default_versions(),
-        emqx_tls_lib:integral_versions("foo")
-    ),
-    ?assertError(
-        #{reason := no_available_tls_version},
-        emqx_tls_lib:integral_versions([foo])
+    ?assert(lists:member('tlsv1.3', emqx_tls_lib:available_versions(tls))).
+
+tls_version_unknown_test_() ->
+    lists:flatmap(
+        fun(Type) ->
+            [
+                ?_assertEqual(
+                    emqx_tls_lib:available_versions(Type),
+                    emqx_tls_lib:integral_versions(Type, [])
+                ),
+                ?_assertEqual(
+                    emqx_tls_lib:available_versions(Type),
+                    emqx_tls_lib:integral_versions(Type, <<>>)
+                ),
+                ?_assertEqual(
+                    emqx_tls_lib:available_versions(Type),
+                    %% unknown version dropped
+                    emqx_tls_lib:integral_versions(Type, "foo")
+                ),
+                fun() ->
+                    ?assertError(
+                        #{reason := no_available_tls_version},
+                        emqx_tls_lib:integral_versions(Type, [foo])
+                    )
+                end
+            ]
+        end,
+        [tls, dtls]
     ).
 
 cipher_suites_no_duplication_test() ->

+ 6 - 1
apps/emqx_gateway/src/emqx_gateway_utils.erl

@@ -455,7 +455,12 @@ esockd_access_rules(StrRules) ->
     [Access(R) || R <- StrRules].
 
 ssl_opts(Name, Opts) ->
-    emqx_tls_lib:to_server_opts(maps:get(Name, Opts, #{})).
+    Type =
+        case Name of
+            ssl -> tls;
+            dtls -> dtls
+        end,
+    emqx_tls_lib:to_server_opts(Type, maps:get(Name, Opts, #{})).
 
 sock_opts(Name, Opts) ->
     maps:to_list(