Просмотр исходного кода

test(authentication): add test case and fix some errors

zhouzb 4 лет назад
Родитель
Сommit
3a96aa3db2

+ 0 - 0
apps/emqx_authentication/etc/emqx_authentication.conf


+ 1 - 0
apps/emqx_authentication/include/emqx_authentication.hrl

@@ -25,6 +25,7 @@
         { name
         , type %% service_type
         , provider
+        , params
         , state
         }).
 

+ 0 - 0
apps/emqx_authentication/priv/emqx_authentication.schema


+ 78 - 23
apps/emqx_authentication/src/emqx_authentication.erl

@@ -28,8 +28,11 @@
 
 -export([ create_chain/1
         , delete_chain/1
+        , lookup_chain/1
         , add_services_to_chain/2
         , delete_services_from_chain/2
+        , lookup_service/2
+        , list_services/1
         , move_service_to_the_front_of_chain/2
         , move_service_to_the_end_of_chain/2
         , move_service_to_the_nth_of_chain/3
@@ -79,18 +82,18 @@ mnesia(copy) ->
     ok = ekka_mnesia:copy_table(?SERVICE_TYPE_TAB, ram_copies).
 
 enable() ->
-    case emqx:hook('client.authenticate', fun emqx_authentication:authenticate/2) of
-        ok -> ok;
-        {error, already_exists} -> ok
-    end,
-    case emqx:hook('client.enhanced_authenticate', fun emqx_authentication:enhanced_authenticate/2) of
+    case emqx:hook('client.authenticate', fun emqx_authentication:authenticate/1) of
         ok -> ok;
         {error, already_exists} -> ok
     end.
+    % case emqx:hook('client.enhanced_authenticate', fun emqx_authentication:enhanced_authenticate/2) of
+    %     ok -> ok;
+    %     {error, already_exists} -> ok
+    % end.
 
 disable() ->
     emqx:unhook('client.authenticate', {}),
-    emqx:unhook('client.enhanced_authenticate', {}),
+    % emqx:unhook('client.enhanced_authenticate', {}),
     ok.
 
 authenticate(#{chain_id := ChainID} = ClientInfo) ->
@@ -145,7 +148,7 @@ create_chain(Params = #{chain_id := ChainID}) ->
                                                    services = Services,
                                                    created_at = erlang:system_time(millisecond)},
                                     mnesia:write(?CHAIN_TAB, Chain, write),
-                                    {ok, Chain};
+                                    {ok, ChainID};
                                 {error, Reason} ->
                                     {error, Reason}
                             end;
@@ -158,7 +161,7 @@ create_chain(Params = #{chain_id := ChainID}) ->
     end.
 
 delete_chain(ChainID) ->
-    mnesia:transaction(
+    trans(
         fun() ->
             case mnesia:read(?CHAIN_TAB, ChainID, write) of
                 [] ->
@@ -169,6 +172,14 @@ delete_chain(ChainID) ->
             end
         end).
 
+lookup_chain(ChainID) ->
+    case mnesia:dirty_read(?CHAIN_TAB, ChainID) of
+        [] ->
+            {error, {not_found, {chain, ChainID}}};
+        [Chain] ->
+            {ok, serialize_chain(Chain)}
+    end.
+
 add_services_to_chain(ChainID, ServiceParams) ->
     case validate_service_params(ServiceParams) of
         {ok, NServiceParams} ->
@@ -210,6 +221,27 @@ delete_services_from_chain(ChainID, ServiceNames) ->
             {error, Reason}
     end.
 
+lookup_service(ChainID, ServiceName) ->
+    case mnesia:dirty_read(?CHAIN_TAB, ChainID) of
+        [] ->
+            {error, {not_found, {chain, ChainID}}};
+        [#chain{services = Services}] ->
+            case lists:keytake(ServiceName, 1, Services) of
+                {value, Service, _} ->
+                    {ok, serialize_service(Service)};
+                false ->
+                    {error, {not_found, {service, ServiceName}}}
+            end
+    end.
+
+list_services(ChainID) ->
+    case mnesia:dirty_read(?CHAIN_TAB, ChainID) of
+        [] ->
+            {error, {not_found, {chain, ChainID}}};
+        [#chain{services = Services}] ->
+            {ok, [serialize_service(Service) || Service <- Services]}
+    end.
+
 move_service_to_the_front_of_chain(ChainID, ServiceName) ->
     UpdateFun = fun(Chain = #chain{services = Services}) ->
                     case move_service_to_the_front(ServiceName, Services) of
@@ -246,17 +278,6 @@ move_service_to_the_nth_of_chain(ChainID, ServiceName, N) ->
                  end,
     update_chain(ChainID, UpdateFun).
 
-update_chain(ChainID, UpdateFun) ->
-    trans(
-        fun() ->
-            case mnesia:read(?CHAIN_TAB, ChainID, write) of
-                [] ->
-                    {error, {not_found, {chain, ChainID}}};
-                [Chain] ->
-                    UpdateFun(Chain)
-            end
-        end).
-
 import_user_credentials(ChainID, ServiceName, Filename, FileFormat) ->
     call_service(ChainID, ServiceName, import_user_credentials, [Filename, FileFormat]).
 
@@ -323,6 +344,7 @@ validate_other_service_params([#{type := Type, params := Params} = ServiceParams
             NParams = emqx_rule_validator:validate_params(Params, ParamsSpec),
             validate_other_service_params(More,
                                           [ServiceParams#{params => NParams,
+                                                          original_params => Params,
                                                           provider => Provider} | Acc]);
         {error, not_found} ->
             {error, {not_found, {service_type, Type}}}
@@ -344,12 +366,17 @@ create_services(ChainID, ServiceParams) ->
 
 create_services(_ChainID, [], Acc) ->
     {ok, lists:reverse(Acc)};
-create_services(ChainID, [#{name := Name, type := Type, provider := Provider, params := Params} | More], Acc) ->
+create_services(ChainID, [#{name := Name,
+                            type := Type,
+                            provider := Provider,
+                            params := Params,
+                            original_params := OriginalParams} | More], Acc) ->
     case Provider:create(ChainID, Name, Params) of
         {ok, State} ->
             Service = #service{name = Name,
                                type = Type,
                                provider = Provider,
+                               params = OriginalParams,
                                state = State},
             create_services(ChainID, More, [{Name, Service} | Acc]);
         {error, Reason} ->
@@ -397,7 +424,7 @@ move_service_to_the_end(ServiceName, [Service | More], Passed) ->
     move_service_to_the_end(ServiceName, More, [Service | Passed]).
 
 move_service_to_nth(ServiceName, Services, N)
-  when length(Services) < N ->
+  when N =< length(Services) andalso N > 0 ->
     move_service_to_nth(ServiceName, Services, N, []);
 move_service_to_nth(_, _, _) ->
     {error, out_of_range}.
@@ -407,10 +434,23 @@ move_service_to_nth(ServiceName, [], _, _) ->
 move_service_to_nth(ServiceName, [{ServiceName, _} = Service | More], N, Passed)
   when N =< length(Passed) ->
     {L1, L2} = lists:split(N - 1, lists:reverse(Passed)),
-    {ok, L1 ++ [Service] + L2 + More};
+    {ok, L1 ++ [Service] ++ L2 ++ More};
 move_service_to_nth(ServiceName, [{ServiceName, _} = Service | More], N, Passed) ->
     {L1, L2} = lists:split(N - length(Passed) - 1, More),
-    {ok, lists:reverse(Passed) ++ L1 ++ [Service] ++ L2}.
+    {ok, lists:reverse(Passed) ++ L1 ++ [Service] ++ L2};
+move_service_to_nth(ServiceName, [Service | More], N, Passed) ->
+    move_service_to_nth(ServiceName, More, N, [Service | Passed]).
+
+update_chain(ChainID, UpdateFun) ->
+    trans(
+        fun() ->
+            case mnesia:read(?CHAIN_TAB, ChainID, write) of
+                [] ->
+                    {error, {not_found, {chain, ChainID}}};
+                [Chain] ->
+                    UpdateFun(Chain)
+            end
+        end).
 
 call_service(ChainID, ServiceName, Func, Args) ->
     case mnesia:dirty_read(?CHAIN_TAB, ChainID) of
@@ -431,6 +471,21 @@ call_service(ChainID, ServiceName, Func, Args) ->
             end
     end.
 
+serialize_chain(#chain{id = ID,
+                       services = Services,
+                       created_at = CreatedAt}) ->
+    #{id => ID,
+      services => [serialize_service(Service) || Service <- Services],
+      created_at => CreatedAt}.
+
+
+serialize_service({_, #service{name = Name,
+                               type = Type,
+                               params = Params}}) ->
+    #{name => Name,
+      type => Type,
+      params => Params}.
+
 trans(Fun) ->
     trans(Fun, []).
 

+ 7 - 3
apps/emqx_authentication/src/emqx_authentication_mnesia.erl

@@ -139,8 +139,9 @@ import_user_credentials(Filename, csv,
                           password_hash_algorithm := Algorithm}) ->
     case file:open(Filename, [read, binary]) of
         {ok, File} ->
-            import(UserGroup, File, Algorithm),
-            file:close(File);
+            Result = import(UserGroup, File, Algorithm),
+            file:close(File),
+            Result;
         {error, Reason} ->
             {error, Reason}
     end.
@@ -211,7 +212,10 @@ do_import(_UserGroup, [_ | _More], _Algorithm) ->
 do_import(UserGroup, File, Algorithm)  ->
     case file:read_line(File) of
         {ok, Line} ->
-            case binary:split(Line, <<",">>, [global]) of
+            case binary:split(Line, [<<",">>, <<"\n">>], [global]) of
+                [UserIdentity, Password, <<>>] ->
+                    add(UserGroup, UserIdentity, Password, Algorithm),
+                    do_import(UserGroup, File, Algorithm);
                 [UserIdentity, Password] ->
                     add(UserGroup, UserIdentity, Password, Algorithm),
                     do_import(UserGroup, File, Algorithm);

+ 2 - 0
apps/emqx_authentication/test/data/user-credentials.csv

@@ -0,0 +1,2 @@
+myuser3,mypassword3
+myuser4,mypassword4

+ 4 - 0
apps/emqx_authentication/test/data/user-credentials.json

@@ -0,0 +1,4 @@
+{
+    "myuser1": "mypassword1",
+    "myuser2": "mypassword2"
+}

+ 192 - 0
apps/emqx_authentication/test/emqx_authentication_SUITE.erl

@@ -0,0 +1,192 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_authentication_SUITE).
+
+-compile(export_all).
+-compile(nowarn_export_all).
+
+-include_lib("common_test/include/ct.hrl").
+-include_lib("eunit/include/eunit.hrl").
+
+-define(AUTH, emqx_authentication).
+
+all() ->
+    emqx_ct:all(?MODULE).
+
+init_per_suite(Config) ->
+    emqx_ct_helpers:start_apps([emqx_authentication]),
+    Config.
+
+end_per_suite(_) ->
+    emqx_ct_helpers:stop_apps([emqx_authentication]),
+    ok.
+
+t_chain(_) ->
+    ChainID = <<"mychain">>,
+    ChainParams = #{chain_id => ChainID,
+                    service_params => []},
+    ?assertEqual({ok, ChainID}, ?AUTH:create_chain(ChainParams)),
+    ?assertEqual({error, {already_exists, {chain, ChainID}}}, ?AUTH:create_chain(ChainParams)),
+    ?assertMatch({ok, #{id := ChainID, services := []}}, ?AUTH:lookup_chain(ChainID)),
+    ?assertEqual(ok, ?AUTH:delete_chain(ChainID)),
+    ?assertMatch({error, {not_found, {chain, ChainID}}}, ?AUTH:lookup_chain(ChainID)),
+    ok.
+
+t_service(_) ->
+    ChainID = <<"mychain">>,
+    ServiceName1 = <<"myservice1">>,
+    ServiceParams1 = #{name => ServiceName1,
+                       type => mnesia,
+                       params => #{
+                           user_identity_type => <<"username">>,
+                           password_hash_algorithm => <<"sha256">>}},
+    ChainParams = #{chain_id => ChainID,
+                    service_params => [ServiceParams1]},
+    ?assertEqual({ok, ChainID}, ?AUTH:create_chain(ChainParams)),
+    Service1 = ServiceParams1,
+    ?assertMatch({ok, #{id := ChainID, services := [Service1]}}, ?AUTH:lookup_chain(ChainID)),
+    ?assertEqual({ok, Service1}, ?AUTH:lookup_service(ChainID, ServiceName1)),
+    ?assertEqual({ok, [Service1]}, ?AUTH:list_services(ChainID)),
+    ?assertEqual({error, {already_exists, {service, ServiceName1}}}, ?AUTH:add_services_to_chain(ChainID, [ServiceParams1])),
+    ServiceName2 = <<"myservice2">>,
+    ServiceParams2 = ServiceParams1#{name => ServiceName2},
+    ?assertEqual(ok, ?AUTH:add_services_to_chain(ChainID, [ServiceParams2])),
+    Service2 = ServiceParams2,
+    ?assertMatch({ok, #{id := ChainID, services := [Service1, Service2]}}, ?AUTH:lookup_chain(ChainID)),
+    ?assertEqual({ok, Service2}, ?AUTH:lookup_service(ChainID, ServiceName2)),
+    ?assertEqual({ok, [Service1, Service2]}, ?AUTH:list_services(ChainID)),
+
+    ?assertEqual(ok, ?AUTH:move_service_to_the_front_of_chain(ChainID, ServiceName2)),
+    ?assertEqual({ok, [Service2, Service1]}, ?AUTH:list_services(ChainID)),
+    ?assertEqual(ok, ?AUTH:move_service_to_the_end_of_chain(ChainID, ServiceName2)),
+    ?assertEqual({ok, [Service1, Service2]}, ?AUTH:list_services(ChainID)),
+    ?assertEqual(ok, ?AUTH:move_service_to_the_nth_of_chain(ChainID, ServiceName2, 1)),
+    ?assertEqual({ok, [Service2, Service1]}, ?AUTH:list_services(ChainID)),
+    ?assertEqual({error, out_of_range}, ?AUTH:move_service_to_the_nth_of_chain(ChainID, ServiceName2, 3)),
+    ?assertEqual({error, out_of_range}, ?AUTH:move_service_to_the_nth_of_chain(ChainID, ServiceName2, 0)),
+    ?assertEqual(ok, ?AUTH:delete_services_from_chain(ChainID, [ServiceName1, ServiceName2])),
+    ?assertEqual({ok, []}, ?AUTH:list_services(ChainID)),
+    ?assertEqual(ok, ?AUTH:delete_chain(ChainID)),
+    ok.
+
+t_mnesia_service(_) ->
+    ChainID = <<"mychain">>,
+    ServiceName = <<"myservice">>,
+    ServiceParams = #{name => ServiceName,
+                      type => mnesia,
+                      params => #{
+                          user_identity_type => <<"username">>,
+                          password_hash_algorithm => <<"sha256">>}},
+    ChainParams = #{chain_id => ChainID,
+                    service_params => [ServiceParams]},
+    ?assertEqual({ok, ChainID}, ?AUTH:create_chain(ChainParams)),
+    UserCredential = #{user_identity => <<"myuser">>,
+                       password => <<"mypass">>},
+    ?assertEqual(ok, ?AUTH:add_user_credential(ChainID, ServiceName, UserCredential)),
+    ?assertMatch({ok, #{user_identity := <<"myuser">>, password_hash := _}},
+                 ?AUTH:lookup_user_credential(ChainID, ServiceName, <<"myuser">>)),
+    ClientInfo = #{chain_id => ChainID,
+			       username => <<"myuser">>,
+			       password => <<"mypass">>},
+    ?assertEqual(ok, ?AUTH:authenticate(ClientInfo)),
+    ClientInfo2 = ClientInfo#{username => <<"baduser">>},
+    ?assertEqual({error, user_credential_not_found}, ?AUTH:authenticate(ClientInfo2)),
+    ClientInfo3 = ClientInfo#{password => <<"badpass">>},
+    ?assertEqual({error, bad_password}, ?AUTH:authenticate(ClientInfo3)),
+    UserCredential2 = UserCredential#{password => <<"mypass2">>},
+    ?assertEqual(ok, ?AUTH:update_user_credential(ChainID, ServiceName, UserCredential2)),
+    ClientInfo4 = ClientInfo#{password => <<"mypass2">>},
+    ?assertEqual(ok, ?AUTH:authenticate(ClientInfo4)),
+    ?assertEqual(ok, ?AUTH:delete_user_credential(ChainID, ServiceName, <<"myuser">>)),
+    ?assertEqual({error, not_found}, ?AUTH:lookup_user_credential(ChainID, ServiceName, <<"myuser">>)),
+
+    ?assertEqual(ok, ?AUTH:add_user_credential(ChainID, ServiceName, UserCredential)),
+    ?assertMatch({ok, #{user_identity := <<"myuser">>}}, ?AUTH:lookup_user_credential(ChainID, ServiceName, <<"myuser">>)),
+    ?assertEqual(ok, ?AUTH:delete_services_from_chain(ChainID, [ServiceName])),
+    ?assertEqual(ok, ?AUTH:add_services_to_chain(ChainID, [ServiceParams])),
+    ?assertMatch({error, not_found}, ?AUTH:lookup_user_credential(ChainID, ServiceName, <<"myuser">>)),
+
+    ?assertEqual(ok, ?AUTH:delete_chain(ChainID)),
+    ?assertEqual([], ets:tab2list(mnesia_basic_auth)),
+    ok.
+
+t_import(_) ->
+    ChainID = <<"mychain">>,
+    ServiceName = <<"myservice">>,
+    ServiceParams = #{name => ServiceName,
+                      type => mnesia,
+                      params => #{
+                          user_identity_type => <<"username">>,
+                          password_hash_algorithm => <<"sha256">>}},
+    ChainParams = #{chain_id => ChainID,
+                    service_params => [ServiceParams]},
+    ?assertEqual({ok, ChainID}, ?AUTH:create_chain(ChainParams)),
+    Dir = code:lib_dir(emqx_authentication, test),
+    ?assertEqual(ok, ?AUTH:import_user_credentials(ChainID, ServiceName, filename:join([Dir, "data/user-credentials.json"]), json)),
+    ?assertEqual(ok, ?AUTH:import_user_credentials(ChainID, ServiceName, filename:join([Dir, "data/user-credentials.csv"]), csv)),
+    ?assertMatch({ok, #{user_identity := <<"myuser1">>}}, ?AUTH:lookup_user_credential(ChainID, ServiceName, <<"myuser1">>)),
+    ?assertMatch({ok, #{user_identity := <<"myuser3">>}}, ?AUTH:lookup_user_credential(ChainID, ServiceName, <<"myuser3">>)),
+    ClientInfo1 = #{chain_id => ChainID,
+			        username => <<"myuser1">>,
+			        password => <<"mypassword1">>},
+    ?assertEqual(ok, ?AUTH:authenticate(ClientInfo1)),
+    ClientInfo2 = ClientInfo1#{username => <<"myuser3">>,
+                               password => <<"mypassword3">>},
+    ?assertEqual(ok, ?AUTH:authenticate(ClientInfo2)),
+    ?assertEqual(ok, ?AUTH:delete_chain(ChainID)),
+    ok.
+
+t_multi_mnesia_service(_) ->
+    ChainID = <<"mychain">>,
+    ServiceName1 = <<"myservice1">>,
+    ServiceParams1 = #{name => ServiceName1,
+                       type => mnesia,
+                       params => #{
+                           user_identity_type => <<"username">>,
+                           password_hash_algorithm => <<"sha256">>}},
+    ServiceName2 = <<"myservice2">>,
+    ServiceParams2 = #{name => ServiceName2,
+                       type => mnesia,
+                       params => #{
+                           user_identity_type => <<"clientid">>,
+                           password_hash_algorithm => <<"sha256">>}},
+    ChainParams = #{chain_id => ChainID,
+                    service_params => [ServiceParams1, ServiceParams2]},
+    ?assertEqual({ok, ChainID}, ?AUTH:create_chain(ChainParams)),
+
+    ?assertEqual(ok, ?AUTH:add_user_credential(ChainID,
+                                               ServiceName1,
+                                               #{user_identity => <<"myuser">>,
+                                                 password => <<"mypass1">>})),
+    ?assertEqual(ok, ?AUTH:add_user_credential(ChainID,
+                                               ServiceName2,
+                                               #{user_identity => <<"myclient">>,
+                                                 password => <<"mypass2">>})),
+    ClientInfo1 = #{chain_id => ChainID,
+			        username => <<"myuser">>,
+                    clientid => <<"myclient">>,
+			        password => <<"mypass1">>},
+    ?assertEqual(ok, ?AUTH:authenticate(ClientInfo1)),
+    ?assertEqual(ok, ?AUTH:move_service_to_the_front_of_chain(ChainID, ServiceName2)),
+    ?assertEqual({error, bad_password}, ?AUTH:authenticate(ClientInfo1)),
+    ClientInfo2 = ClientInfo1#{password => <<"mypass2">>},
+    ?assertEqual(ok, ?AUTH:authenticate(ClientInfo2)),
+    ?assertEqual(ok, ?AUTH:delete_chain(ChainID)),
+    ok.
+
+
+