Browse Source

Merge pull request #5760 from zmstone/minor-refactors

refactor(authn): minor refactors
Zaiming (Stone) Shi 4 years ago
parent
commit
1af8148e30

+ 68 - 40
apps/emqx/src/emqx_authentication.erl

@@ -40,8 +40,10 @@
         , stop/0
         ]).
 
--export([ add_provider/2
-        , remove_provider/1
+-export([ register_provider/2
+        , register_providers/1
+        , deregister_provider/1
+        , deregister_providers/1
         , create_chain/1
         , delete_chain/1
         , lookup_chain/1
@@ -266,8 +268,9 @@ do_post_config_update({move_authenticator, ChainName, AuthenticatorID, Position}
     move_authenticator(ChainName, AuthenticatorID, Position).
 
 check_config(Config) ->
-    #{authentication := CheckedConfig} = hocon_schema:check_plain(emqx_authentication,
-        #{<<"authentication">> => Config}, #{nullable => true, atom_key => true}),
+    #{authentication := CheckedConfig} =
+        hocon_schema:check_plain(?MODULE, #{<<"authentication">> => Config},
+                                 #{nullable => true, atom_key => true}),
     CheckedConfig.
 
 %%------------------------------------------------------------------------------
@@ -331,27 +334,41 @@ stop() ->
 
 -spec get_refs() -> {ok, Refs} when Refs :: [{authn_type(), module()}].
 get_refs() ->
-    gen_server:call(?MODULE, get_refs).
-
--spec add_provider(authn_type(), module()) -> ok.
-add_provider(AuthNType, Provider) ->
-    gen_server:call(?MODULE, {add_provider, AuthNType, Provider}).
-
--spec remove_provider(authn_type()) -> ok.
-remove_provider(AuthNType) ->
-    gen_server:call(?MODULE, {remove_provider, AuthNType}).
+    call(get_refs).
+
+%% @doc Register authentication providers.
+%% A provider is a tuple of `AuthNType' the module which implements
+%% the authenticator callbacks.
+%% For example, ``[{{'password-based', redis}, emqx_authn_redis}]''
+%% NOTE: Later registered provider may override earlier registered if they
+%% happen to clash the same `AuthNType'.
+-spec register_providers([{authn_type(), module()}]) -> ok.
+register_providers(Providers) ->
+    call({register_providers, Providers}).
+
+-spec register_provider(authn_type(), module()) -> ok.
+register_provider(AuthNType, Provider) ->
+    register_providers([{AuthNType, Provider}]).
+
+-spec deregister_providers([authn_type()]) -> ok.
+deregister_providers(AuthNTypes) when is_list(AuthNTypes) ->
+    call({deregister_providers, AuthNTypes}).
+
+-spec deregister_provider(authn_type()) -> ok.
+deregister_provider(AuthNType) ->
+    deregister_providers([AuthNType]).
 
 -spec create_chain(chain_name()) -> {ok, chain()} | {error, term()}.
 create_chain(Name) ->
-    gen_server:call(?MODULE, {create_chain, Name}).
+    call({create_chain, Name}).
 
 -spec delete_chain(chain_name()) -> ok | {error, term()}.
 delete_chain(Name) ->
-    gen_server:call(?MODULE, {delete_chain, Name}).
+    call({delete_chain, Name}).
 
 -spec lookup_chain(chain_name()) -> {ok, chain()} | {error, term()}.
 lookup_chain(Name) ->
-    gen_server:call(?MODULE, {lookup_chain, Name}).
+    call({lookup_chain, Name}).
 
 -spec list_chains() -> {ok, [chain()]}.
 list_chains() ->
@@ -360,15 +377,15 @@ list_chains() ->
 
 -spec create_authenticator(chain_name(), config()) -> {ok, authenticator()} | {error, term()}.
 create_authenticator(ChainName, Config) ->
-    gen_server:call(?MODULE, {create_authenticator, ChainName, Config}).
+    call({create_authenticator, ChainName, Config}).
 
 -spec delete_authenticator(chain_name(), authenticator_id()) -> ok | {error, term()}.
 delete_authenticator(ChainName, AuthenticatorID) ->
-    gen_server:call(?MODULE, {delete_authenticator, ChainName, AuthenticatorID}).
+    call({delete_authenticator, ChainName, AuthenticatorID}).
 
 -spec update_authenticator(chain_name(), authenticator_id(), config()) -> {ok, authenticator()} | {error, term()}.
 update_authenticator(ChainName, AuthenticatorID, Config) ->
-    gen_server:call(?MODULE, {update_authenticator, ChainName, AuthenticatorID, Config}).
+    call({update_authenticator, ChainName, AuthenticatorID, Config}).
 
 -spec lookup_authenticator(chain_name(), authenticator_id()) -> {ok, authenticator()} | {error, term()}.
 lookup_authenticator(ChainName, AuthenticatorID) ->
@@ -395,32 +412,32 @@ list_authenticators(ChainName) ->
 
 -spec move_authenticator(chain_name(), authenticator_id(), position()) -> ok | {error, term()}.
 move_authenticator(ChainName, AuthenticatorID, Position) ->
-    gen_server:call(?MODULE, {move_authenticator, ChainName, AuthenticatorID, Position}).
+    call({move_authenticator, ChainName, AuthenticatorID, Position}).
 
 -spec import_users(chain_name(), authenticator_id(), binary()) -> ok | {error, term()}.
 import_users(ChainName, AuthenticatorID, Filename) ->
-    gen_server:call(?MODULE, {import_users, ChainName, AuthenticatorID, Filename}).
+    call({import_users, ChainName, AuthenticatorID, Filename}).
 
 -spec add_user(chain_name(), authenticator_id(), user_info()) -> {ok, user_info()} | {error, term()}.
 add_user(ChainName, AuthenticatorID, UserInfo) ->
-    gen_server:call(?MODULE, {add_user, ChainName, AuthenticatorID, UserInfo}).
+    call({add_user, ChainName, AuthenticatorID, UserInfo}).
 
 -spec delete_user(chain_name(), authenticator_id(), binary()) -> ok | {error, term()}.
 delete_user(ChainName, AuthenticatorID, UserID) ->
-    gen_server:call(?MODULE, {delete_user, ChainName, AuthenticatorID, UserID}).
+    call({delete_user, ChainName, AuthenticatorID, UserID}).
 
 -spec update_user(chain_name(), authenticator_id(), binary(), map()) -> {ok, user_info()} | {error, term()}.
 update_user(ChainName, AuthenticatorID, UserID, NewUserInfo) ->
-    gen_server:call(?MODULE, {update_user, ChainName, AuthenticatorID, UserID, NewUserInfo}).
+    call({update_user, ChainName, AuthenticatorID, UserID, NewUserInfo}).
 
 -spec lookup_user(chain_name(), authenticator_id(), binary()) -> {ok, user_info()} | {error, term()}.
 lookup_user(ChainName, AuthenticatorID, UserID) ->
-    gen_server:call(?MODULE, {lookup_user, ChainName, AuthenticatorID, UserID}).
+    call({lookup_user, ChainName, AuthenticatorID, UserID}).
 
 %% TODO: Support pagination
 -spec list_users(chain_name(), authenticator_id()) -> {ok, [user_info()]} | {error, term()}.
 list_users(ChainName, AuthenticatorID) ->
-    gen_server:call(?MODULE, {list_users, ChainName, AuthenticatorID}).
+    call({list_users, ChainName, AuthenticatorID}).
 
 -spec generate_id(config()) -> authenticator_id().
 generate_id(#{mechanism := Mechanism0, backend := Backend0}) ->
@@ -446,11 +463,20 @@ init(_Opts) ->
     ok = emqx_config_handler:add_handler([listeners, '?', '?', authentication], ?MODULE),
     {ok, #{hooked => false, providers => #{}}}.
 
-handle_call({add_provider, AuthNType, Provider}, _From, #{providers := Providers} = State) ->
-    reply(ok, State#{providers := Providers#{AuthNType => Provider}});
+handle_call({register_providers, Providers}, _From,
+            #{providers := Reg0} = State) ->
+    case lists:filter(fun({T, _}) -> maps:is_key(T, Reg0) end, Providers) of
+        [] ->
+            Reg = lists:foldl(fun({AuthNType, Module}, Pin) ->
+                                      Pin#{AuthNType => Module}
+                              end, Reg0, Providers),
+            reply(ok, State#{providers := Reg});
+        Clashes ->
+            reply({error, {authentication_type_clash, Clashes}}, State)
+    end;
 
-handle_call({remove_provider, AuthNType}, _From, #{providers := Providers} = State) ->
-    reply(ok, State#{providers := maps:remove(AuthNType, Providers)});
+handle_call({deregister_providers, AuthNTypes}, _From, #{providers := Providers} = State) ->
+    reply(ok, State#{providers := maps:without(AuthNTypes, Providers)});
 
 handle_call(get_refs, _From, #{providers := Providers} = State) ->
     Refs = lists:foldl(fun({_, Provider}, Acc) ->
@@ -476,7 +502,7 @@ handle_call({delete_chain, Name}, _From, State) ->
         [#chain{authenticators = Authenticators}] ->
             _ = [do_delete_authenticator(Authenticator) || Authenticator <- Authenticators],
             true = ets:delete(?CHAINS_TAB, Name),
-            reply(ok, may_unhook(State))
+            reply(ok, maybe_unhook(State))
     end;
 
 handle_call({lookup_chain, Name}, _From, State) ->
@@ -506,7 +532,7 @@ handle_call({create_authenticator, ChainName, Config}, _From, #{providers := Pro
             end
         end,
     Reply = update_chain(ChainName, UpdateFun),
-    reply(Reply, may_hook(State));
+    reply(Reply, maybe_hook(State));
 
 handle_call({delete_authenticator, ChainName, AuthenticatorID}, _From, State) ->
     UpdateFun = 
@@ -521,7 +547,7 @@ handle_call({delete_authenticator, ChainName, AuthenticatorID}, _From, State) ->
             end
         end,
     Reply = update_chain(ChainName, UpdateFun),
-    reply(Reply, may_unhook(State));
+    reply(Reply, maybe_unhook(State));
 
 handle_call({update_authenticator, ChainName, AuthenticatorID, Config}, _From, State) ->
     UpdateFun =
@@ -726,30 +752,30 @@ global_chain(stomp) ->
 global_chain(_) ->
     'unknown:global'.
 
-may_hook(#{hooked := false} = State) ->
+maybe_hook(#{hooked := false} = State) ->
     case lists:any(fun(#chain{authenticators = []}) -> false;
                       (_) -> true
                    end, ets:tab2list(?CHAINS_TAB)) of
         true ->
-            _ = emqx:hook('client.authenticate', {emqx_authentication, authenticate, []}),
+            _ = emqx:hook('client.authenticate', {?MODULE, authenticate, []}),
             State#{hooked => true};
         false ->
             State
     end;
-may_hook(State) ->
+maybe_hook(State) ->
     State.
 
-may_unhook(#{hooked := true} = State) ->
+maybe_unhook(#{hooked := true} = State) ->
     case lists:all(fun(#chain{authenticators = []}) -> true;
                       (_) -> false
                    end, ets:tab2list(?CHAINS_TAB)) of
         true ->
-            _ = emqx:unhook('client.authenticate', {emqx_authentication, authenticate, []}),
+            _ = emqx:unhook('client.authenticate', {?MODULE, authenticate, []}),
             State#{hooked => false};
         false ->
             State
     end;
-may_unhook(State) ->
+maybe_unhook(State) ->
     State.
 
 do_create_authenticator(ChainName, AuthenticatorID, #{enable := Enable} = Config, Providers) ->
@@ -773,7 +799,7 @@ do_create_authenticator(ChainName, AuthenticatorID, #{enable := Enable} = Config
 do_delete_authenticator(#authenticator{provider = Provider, state = State}) ->
     _ = Provider:destroy(State),
     ok.
-    
+
 replace_authenticator(ID, Authenticator, Authenticators) ->
     lists:keyreplace(ID, #authenticator.id, Authenticators, Authenticator).
 
@@ -875,3 +901,5 @@ to_list(L) when is_list(L) ->
 
 to_bin(B) when is_binary(B) -> B;
 to_bin(L) when is_list(L) -> list_to_binary(L).
+
+call(Call) -> gen_server:call(?MODULE, Call, infinity).

+ 51 - 37
apps/emqx/test/emqx_authentication_SUITE.erl

@@ -36,6 +36,7 @@
         ]).
 
 -define(AUTHN, emqx_authentication).
+-define(config(KEY), (fun() -> {KEY, _V_} = lists:keyfind(KEY, 1, Config), _V_ end)()).
 
 %%------------------------------------------------------------------------------
 %% Hocon Schema
@@ -92,20 +93,22 @@ end_per_suite(_) ->
     emqx_ct_helpers:stop_apps([]),
     ok.
 
-init_per_testcase(_, Config) ->
+init_per_testcase(Case, Config) ->
     meck:new(emqx, [non_strict, passthrough, no_history, no_link]),
     meck:expect(emqx, get_config, fun([node, data_dir]) ->
                                           {data_dir, Data} = lists:keyfind(data_dir, 1, Config),
                                           Data;
                                      (C) -> meck:passthrough([C])
                                   end),
-    Config.
+    ?MODULE:Case({'init', Config}).
 
-end_per_testcase(_, _Config) ->
+end_per_testcase(Case, Config) ->
+    _ = ?MODULE:Case({'end', Config}),
     meck:unload(emqx),
     ok.
 
-t_chain(_) ->
+t_chain({_, Config}) -> Config;
+t_chain(Config) when is_list(Config) ->
     % CRUD of authentication chain
     ChainName = 'test',
     ?assertMatch({ok, []}, ?AUTHN:list_chains()),
@@ -117,7 +120,10 @@ t_chain(_) ->
     ?assertMatch({error, {not_found, {chain, ChainName}}}, ?AUTHN:lookup_chain(ChainName)),
     ok.
 
-t_authenticator(_) ->
+t_authenticator({'init', Config}) ->
+    [{"auth1", {'password-based', 'built-in-database'}},
+     {"auth2", {'password-based', mysql}} | Config];
+t_authenticator(Config) when is_list(Config) ->
     ChainName = 'test',
     AuthenticatorConfig1 = #{mechanism => 'password-based',
                              backend => 'built-in-database',
@@ -129,8 +135,8 @@ t_authenticator(_) ->
     % Create an authenticator when the provider does not exist
     ?assertEqual({error, no_available_provider}, ?AUTHN:create_authenticator(ChainName, AuthenticatorConfig1)),
 
-    AuthNType1 = {'password-based', 'built-in-database'},
-    ?AUTHN:add_provider(AuthNType1, ?MODULE),
+    AuthNType1 = ?config("auth1"),
+    register_provider(AuthNType1, ?MODULE),
     ID1 = <<"password-based:built-in-database">>,
 
     % CRUD of authencaticator
@@ -144,8 +150,8 @@ t_authenticator(_) ->
     ?assertMatch({ok, []}, ?AUTHN:list_authenticators(ChainName)),
 
     % Multiple authenticators exist at the same time
-    AuthNType2 = {'password-based', mysql},
-    ?AUTHN:add_provider(AuthNType2, ?MODULE),
+    AuthNType2 = ?config("auth2"),
+    register_provider(AuthNType2, ?MODULE),
     ID2 = <<"password-based:mysql">>,
     AuthenticatorConfig2 = #{mechanism => 'password-based',
                              backend => mysql,
@@ -160,15 +166,18 @@ t_authenticator(_) ->
     ?assertEqual(ok, ?AUTHN:move_authenticator(ChainName, ID2, bottom)),
     ?assertMatch({ok, [#{id := ID1}, #{id := ID2}]}, ?AUTHN:list_authenticators(ChainName)),
     ?assertEqual(ok, ?AUTHN:move_authenticator(ChainName, ID2, {before, ID1})),
-    ?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ChainName)),
-
-    ?AUTHN:delete_chain(ChainName),
-    ?AUTHN:remove_provider(AuthNType1),
-    ?AUTHN:remove_provider(AuthNType2),
+    ?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ChainName));
+t_authenticator({'end', Config}) ->
+    ?AUTHN:delete_chain(test),
+    ?AUTHN:deregister_providers([?config("auth1"), ?config("auth2")]),
     ok.
 
-t_authenticate(_) ->
-    ListenerID = 'tcp:default',
+t_authenticate({init, Config}) ->
+    [{listener_id, 'tcp:default'},
+     {authn_type, {'password-based', 'built-in-database'}} | Config];
+t_authenticate(Config) when is_list(Config) ->
+    ListenerID = ?config(listener_id),
+    AuthNType = ?config(authn_type),
     ClientInfo = #{zone => default,
                    listener => ListenerID,
                    protocol => mqtt,
@@ -176,8 +185,7 @@ t_authenticate(_) ->
 			       password => <<"any">>},
     ?assertEqual({ok, #{is_superuser => false}}, emqx_access_control:authenticate(ClientInfo)),
 
-    AuthNType = {'password-based', 'built-in-database'},
-    ?AUTHN:add_provider(AuthNType, ?MODULE),
+    register_provider(AuthNType, ?MODULE),
 
     AuthenticatorConfig = #{mechanism => 'password-based',
                             backend => 'built-in-database',
@@ -185,21 +193,24 @@ t_authenticate(_) ->
     ?AUTHN:create_chain(ListenerID),
     ?assertMatch({ok, _}, ?AUTHN:create_authenticator(ListenerID, AuthenticatorConfig)),
     ?assertEqual({ok, #{is_superuser => true}}, emqx_access_control:authenticate(ClientInfo)),
-    ?assertEqual({error, bad_username_or_password}, emqx_access_control:authenticate(ClientInfo#{username => <<"bad">>})),
-
-    ?AUTHN:delete_chain(ListenerID),
-    ?AUTHN:remove_provider(AuthNType),
+    ?assertEqual({error, bad_username_or_password}, emqx_access_control:authenticate(ClientInfo#{username => <<"bad">>}));
+t_authenticate({'end', Config}) ->
+    ?AUTHN:delete_chain(?config(listener_id)),
+    ?AUTHN:deregister_provider(?config(authn_type)),
     ok.
 
-t_update_config(_) ->
-    emqx_config_handler:add_handler([authentication], emqx_authentication),
-
+t_update_config({init, Config}) ->
+    Global = 'mqtt:global',
     AuthNType1 = {'password-based', 'built-in-database'},
     AuthNType2 = {'password-based', mysql},
-    ?AUTHN:add_provider(AuthNType1, ?MODULE),
-    ?AUTHN:add_provider(AuthNType2, ?MODULE),
-
-    Global = 'mqtt:global',
+    [{global, Global},
+     {"auth1", AuthNType1},
+     {"auth2", AuthNType2} | Config];
+t_update_config(Config) when is_list(Config) ->
+    emqx_config_handler:add_handler([authentication], emqx_authentication),
+    ok = register_provider(?config("auth1"), ?MODULE),
+    ok = register_provider(?config("auth2"), ?MODULE),
+    Global = ?config(global),
     AuthenticatorConfig1 = #{mechanism => 'password-based',
                              backend => 'built-in-database',
                              enable => true},
@@ -208,7 +219,7 @@ t_update_config(_) ->
                              enable => true},
     ID1 = <<"password-based:built-in-database">>,
     ID2 = <<"password-based:mysql">>,
-    
+
     ?assertMatch({ok, []}, ?AUTHN:list_chains()),
     ?assertMatch({ok, _}, update_config([authentication], {create_authenticator, Global, AuthenticatorConfig1})),
     ?assertMatch({ok, #{id := ID1, state := #{mark := 1}}}, ?AUTHN:lookup_authenticator(Global, ID1)),
@@ -240,14 +251,14 @@ t_update_config(_) ->
     ?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ListenerID)),
 
     ?assertMatch({ok, _}, update_config(ConfKeyPath, {delete_authenticator, ListenerID, ID1})),
-    ?assertEqual({error, {not_found, {authenticator, ID1}}}, ?AUTHN:lookup_authenticator(ListenerID, ID1)),
-
-    ?AUTHN:delete_chain(Global),
-    ?AUTHN:remove_provider(AuthNType1),
-    ?AUTHN:remove_provider(AuthNType2),
+    ?assertEqual({error, {not_found, {authenticator, ID1}}}, ?AUTHN:lookup_authenticator(ListenerID, ID1));
+t_update_config({'end', Config}) ->
+    ?AUTHN:delete_chain(?config(global)),
+    ?AUTHN:deregister_providers([?config("auth1"), ?config("auth2")]),
     ok.
 
-t_convert_cert_options(_) ->
+t_convert_cert_options({_, Config}) -> Config;
+t_convert_cert_options(Config) when is_list(Config) ->
     Certs = certs([ {<<"keyfile">>, "key.pem"}
                   , {<<"certfile">>, "cert.pem"}
                   , {<<"cacertfile">>, "cacert.pem"}
@@ -284,4 +295,7 @@ certs(Certs) ->
 
 diff_cert(CertFile, CertPem2) ->
     {ok, CertPem1} = file:read_file(CertFile),
-    ?AUTHN:diff_cert(CertPem1, CertPem2).
+    ?AUTHN:diff_cert(CertPem1, CertPem2).
+
+register_provider(Type, Module) ->
+    ok = ?AUTHN:register_providers([{Type, Module}]).

+ 5 - 0
apps/emqx_authn/include/emqx_authn.hrl

@@ -14,6 +14,9 @@
 %% limitations under the License.
 %%--------------------------------------------------------------------
 
+-ifndef(EMQX_AUTHN_HRL).
+-define(EMQX_AUTHN_HRL, true).
+
 -define(APP, emqx_authn).
 
 -define(AUTHN, emqx_authentication).
@@ -23,3 +26,5 @@
 -define(RE_PLACEHOLDER, "\\$\\{[a-z0-9\\-]+\\}").
 
 -define(AUTH_SHARD, emqx_authn_shard).
+
+-endif.

+ 7 - 11
apps/emqx_authn/src/emqx_authn_app.erl

@@ -32,30 +32,26 @@
 start(_StartType, _StartArgs) ->
     ok = ekka_rlog:wait_for_shards([?AUTH_SHARD], infinity),
     {ok, Sup} = emqx_authn_sup:start_link(),
-    ok = add_providers(),
+    ok = ?AUTHN:register_providers(providers()),
     ok = initialize(),
     {ok, Sup}.
 
 stop(_State) ->
-    ok = remove_providers(),
+    ok = ?AUTHN:deregister_providers(provider_types()),
     ok.
 
 %%------------------------------------------------------------------------------
 %% Internal functions
 %%------------------------------------------------------------------------------
 
-add_providers() ->
-    _ = [?AUTHN:add_provider(AuthNType, Provider) || {AuthNType, Provider} <- providers()], ok.
-
-remove_providers() ->
-    _ = [?AUTHN:remove_provider(AuthNType) || {AuthNType, _} <- providers()], ok.
-
 initialize() ->
     ?AUTHN:initialize_authentication(?GLOBAL, emqx:get_raw_config([authentication], [])),
     lists:foreach(fun({ListenerID, ListenerConfig}) ->
                       ?AUTHN:initialize_authentication(ListenerID, maps:get(authentication, ListenerConfig, []))
-                  end, emqx_listeners:list()),
-    ok.
+                  end, emqx_listeners:list()).
+
+provider_types() ->
+    lists:map(fun({Type, _Module}) -> Type end, providers()).
 
 providers() ->
     [ {{'password-based', 'built-in-database'}, emqx_authn_mnesia}
@@ -66,4 +62,4 @@ providers() ->
     , {{'password-based', 'http-server'}, emqx_authn_http}
     , {jwt, emqx_authn_jwt}
     , {{scram, 'built-in-database'}, emqx_enhanced_authn_scram_mnesia}
-    ].
+    ].