Explorar o código

fix(psk): fix bugs and add test case

zhouzb %!s(int64=4) %!d(string=hai) anos
pai
achega
74c9a38e9f

+ 1 - 1
apps/emqx/src/emqx_schema.erl

@@ -1300,7 +1300,7 @@ parse_user_lookup_fun(StrConf) ->
     [ModStr, FunStr] = string:tokens(StrConf, ":"),
     Mod = list_to_atom(ModStr),
     Fun = list_to_atom(FunStr),
-    {fun Mod:Fun/3, <<>>}.
+    {fun Mod:Fun/3, undefined}.
 
 validate_ciphers(Ciphers) ->
     All = ssl:cipher_suites(all, 'tlsv1.3', openssl) ++

+ 2 - 2
apps/emqx/src/emqx_tls_psk.erl

@@ -25,8 +25,8 @@
 -type psk_user_state() :: term().
 
 -spec lookup(psk, psk_identity(), psk_user_state()) -> {ok, SharedSecret :: binary()} | error.
-lookup(psk, PSKIdentity, UserState) ->
-    try emqx_hooks:run_fold('tls_handshake.psk_lookup', [PSKIdentity, UserState], normal) of
+lookup(psk, PSKIdentity, _UserState) ->
+    try emqx_hooks:run_fold('tls_handshake.psk_lookup', [PSKIdentity], normal) of
         {ok, SharedSecret} when is_binary(SharedSecret) ->
             {ok, SharedSecret};
         normal ->

+ 1 - 0
apps/emqx_machine/src/emqx_machine.erl

@@ -149,6 +149,7 @@ reboot_apps() ->
     , emqx_rule_actions
     , emqx_authn
     , emqx_authz
+    , emqx_psk
     ].
 
 sorted_reboot_apps() ->

+ 1 - 0
apps/emqx_machine/src/emqx_machine_schema.erl

@@ -53,6 +53,7 @@
         , emqx_prometheus_schema
         , emqx_rule_engine_schema
         , emqx_exhook_schema
+        , emqx_psk_schema
         ]).
 
 namespace() -> undefined.

+ 2 - 2
apps/emqx_psk/etc/emqx_psk.conf

@@ -4,14 +4,14 @@
 
 psk {
     ## Whether to enable the PSK feature.
-    enable = true
+    enable = false
 
     ## If init file is specified, emqx will import PSKs from the file 
     ## into the built-in database at startup for use by the runtime.
     ##
     ## The file has to be structured line-by-line, each line must be in
     ## the format: <PSKIdentity>:<SharedSecret>
-    ## init_file = {{ platform_data_dir }}/init.psk
+    ## init_file = "{{ platform_data_dir }}/init.psk"
 
     ## Specifies the separator for PSKIdentity and SharedSecret in the init file.
     ## The default is colon (:)

+ 7 - 4
apps/emqx_psk/src/emqx_psk.erl

@@ -157,6 +157,9 @@ get_config(chunk_size) ->
 import_psks(SrcFile) ->
     case file:open(SrcFile, [read, raw, binary, read_ahead]) of
         {error, Reason} ->
+            ?SLOG(error, #{msg => "failed_to_open_psk_file",
+                                   file => SrcFile,
+                                   reason => Reason}),
             {error, Reason};
         {ok, Io} ->
             try import_psks(Io, get_config(separator), get_config(chunk_size)) of
@@ -167,10 +170,10 @@ import_psks(SrcFile) ->
                                    reason => Reason}),
                     {error, Reason}
             catch
-                Class:Reason:Stacktrace ->
+                Exception:Reason:Stacktrace ->
                     ?SLOG(error, #{msg => "failed_to_import_psk_file",
                                    file => SrcFile,
-                                   class => Class,
+                                   exception => Exception,
                                    reason => Reason,
                                    stacktrace => Stacktrace}),
                     {error, Reason}
@@ -182,10 +185,10 @@ import_psks(SrcFile) ->
 import_psks(Io, Delimiter, ChunkSize) ->
     case get_psks(Io, Delimiter, ChunkSize) of
         {ok, Entries} ->
-            _ = trans(fun insert_psks/1, Entries),
+            _ = trans(fun insert_psks/1, [Entries]),
             import_psks(Io, Delimiter, ChunkSize);
         {eof, Entries} ->
-            _ = trans(fun insert_psks/1, Entries),
+            _ = trans(fun insert_psks/1, [Entries]),
             ok;
         {error, Reaosn} ->
             {error, Reaosn}

+ 1 - 1
apps/emqx_psk/src/emqx_psk_schema.erl

@@ -24,7 +24,7 @@
         , fields/1
         ]).
 
-roots() -> [].
+roots() -> ["psk"].
 
 fields("psk") ->
     [ {enable,     fun enable/1}

+ 2 - 0
apps/emqx_psk/test/data/init.psk

@@ -0,0 +1,2 @@
+myclient1:8c701116e9127c57a99d5563709af3deaca75563e2c4dd0865701ae839fb6d79
+myclient2:d1e617d3b963757bfc21dad3fea169716c3a2f053f23decaea5cdfaabd04bfc4

+ 85 - 0
apps/emqx_psk/test/emqx_psk_SUITE.erl

@@ -0,0 +1,85 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%% http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_psk_SUITE).
+
+-compile(nowarn_export_all).
+-compile(export_all).
+
+-include_lib("common_test/include/ct.hrl").
+-include_lib("eunit/include/eunit.hrl").
+
+all() ->
+    emqx_ct:all(?MODULE).
+
+init_per_suite(Config) ->
+    meck:new(emqx_config, [non_strict, passthrough, no_history, no_link]),
+    meck:expect(emqx_config, get, fun([psk, enable]) -> true;
+                                     ([psk, chunk_size]) -> 50;
+                                     (KeyPath) -> meck:passthrough([KeyPath])
+                                  end),
+    meck:expect(emqx_config, get, fun([psk, init_file], _) ->
+                                         filename:join([code:lib_dir(emqx_psk, test), "data/init.psk"]);
+                                     ([psk, separator], _) -> <<":">>;
+                                     (KeyPath, Default) -> meck:passthrough([KeyPath, Default])
+                                  end),
+    emqx_ct_helpers:start_apps([emqx_psk]),
+    Config.
+
+end_per_suite(_) ->
+    meck:unload(emqx_config),
+    emqx_ct_helpers:stop_apps([emqx_psk]),
+    ok.
+
+t_psk_lookup(_) ->
+
+    PSKIdentity1 = <<"myclient1">>,
+    SharedSecret1 = <<"8c701116e9127c57a99d5563709af3deaca75563e2c4dd0865701ae839fb6d79">>,
+    ?assertEqual({stop, {ok, SharedSecret1}}, emqx_psk:on_psk_lookup(PSKIdentity1, any)),
+
+    PSKIdentity2 = <<"myclient2">>,
+    SharedSecret2 = <<"d1e617d3b963757bfc21dad3fea169716c3a2f053f23decaea5cdfaabd04bfc4">>,
+    ?assertEqual({stop, {ok, SharedSecret2}}, emqx_psk:on_psk_lookup(PSKIdentity2, any)),
+
+    ?assertEqual(ignore, emqx_psk:on_psk_lookup(<<"myclient3">>, any)),
+
+    ClientLookup = fun(psk, undefined, _) -> {ok, SharedSecret1};
+                      (psk, _, _) -> error
+                   end,
+
+    ClientTLSOpts = #{ versions => ['tlsv1.2']
+                     , ciphers => ["PSK-AES256-CBC-SHA"]
+                     , psk_identity => "myclient1"
+                     , verify => verify_none
+                     , user_lookup_fun => {ClientLookup, undefined}
+                     },
+
+    ServerTLSOpts = #{ versions => ['tlsv1.2']
+                     , ciphers => ["PSK-AES256-CBC-SHA"]
+                     , verify => verify_none
+                     , reuseaddr => true
+                     , user_lookup_fun => {fun emqx_tls_psk:lookup/3, undefined}
+                     },
+    emqx_config:put([listeners, ssl ,default, ssl], ServerTLSOpts),
+    emqx_listeners:restart_listener('ssl:default'),
+
+    {ok, Socket} = ssl:connect("127.0.0.1", 8883, maps:to_list(ClientTLSOpts)),
+    ssl:close(Socket),
+
+    ClientTLSOpts1 = ClientTLSOpts#{psk_identity => "myclient2"},
+    ?assertMatch({error, _}, ssl:connect("127.0.0.1", 8883, maps:to_list(ClientTLSOpts1))),
+
+    ok.
+

+ 1 - 0
rebar.config.erl

@@ -279,6 +279,7 @@ relx_apps(ReleaseType) ->
     , emqx_retainer
     , emqx_statsd
     , emqx_prometheus
+    , emqx_psk
     ]
     ++ [quicer || is_quicer_supported()]
     ++ [emqx_license || is_enterprise()]