Просмотр исходного кода

fix(authn api): eliminate possible atom leak

Ilya Averyanov 4 лет назад
Родитель
Сommit
796553b5ea

+ 15 - 9
apps/emqx/src/emqx_authentication.erl

@@ -25,6 +25,8 @@
 -include("emqx.hrl").
 -include("emqx.hrl").
 -include("logger.hrl").
 -include("logger.hrl").
 
 
+-include_lib("stdlib/include/ms_transform.hrl").
+
 %% The authentication entrypoint.
 %% The authentication entrypoint.
 -export([ authenticate/2
 -export([ authenticate/2
         ]).
         ]).
@@ -45,6 +47,7 @@
         , delete_chain/1
         , delete_chain/1
         , lookup_chain/1
         , lookup_chain/1
         , list_chains/0
         , list_chains/0
+        , list_chain_names/0
         , create_authenticator/2
         , create_authenticator/2
         , delete_authenticator/2
         , delete_authenticator/2
         , update_authenticator/3
         , update_authenticator/3
@@ -312,13 +315,24 @@ delete_chain(Name) ->
 
 
 -spec lookup_chain(chain_name()) -> {ok, chain()} | {error, term()}.
 -spec lookup_chain(chain_name()) -> {ok, chain()} | {error, term()}.
 lookup_chain(Name) ->
 lookup_chain(Name) ->
-    call({lookup_chain, Name}).
+    case ets:lookup(?CHAINS_TAB, Name) of
+        [] ->
+            {error, {not_found, {chain, Name}}};
+        [Chain] ->
+            {ok, serialize_chain(Chain)}
+    end.
 
 
 -spec list_chains() -> {ok, [chain()]}.
 -spec list_chains() -> {ok, [chain()]}.
 list_chains() ->
 list_chains() ->
     Chains = ets:tab2list(?CHAINS_TAB),
     Chains = ets:tab2list(?CHAINS_TAB),
     {ok, [serialize_chain(Chain) || Chain <- Chains]}.
     {ok, [serialize_chain(Chain) || Chain <- Chains]}.
 
 
+-spec list_chain_names() -> {ok, [atom()]}.
+list_chain_names() ->
+    Select = ets:fun2ms(fun(#chain{name = Name}) -> Name end),
+    ChainNames = ets:select(?CHAINS_TAB, Select),
+    {ok, ChainNames}.
+
 -spec create_authenticator(chain_name(), config()) -> {ok, authenticator()} | {error, term()}.
 -spec create_authenticator(chain_name(), config()) -> {ok, authenticator()} | {error, term()}.
 create_authenticator(ChainName, Config) ->
 create_authenticator(ChainName, Config) ->
     call({create_authenticator, ChainName, Config}).
     call({create_authenticator, ChainName, Config}).
@@ -432,14 +446,6 @@ handle_call({delete_chain, Name}, _From, State) ->
             reply(ok, maybe_unhook(State))
             reply(ok, maybe_unhook(State))
     end;
     end;
 
 
-handle_call({lookup_chain, Name}, _From, State) ->
-    case ets:lookup(?CHAINS_TAB, Name) of
-        [] ->
-            reply({error, {not_found, {chain, Name}}}, State);
-        [Chain] ->
-            reply({ok, serialize_chain(Chain)}, State)
-    end;
-
 handle_call({create_authenticator, ChainName, Config}, _From, #{providers := Providers} = State) ->
 handle_call({create_authenticator, ChainName, Config}, _From, #{providers := Providers} = State) ->
     UpdateFun =
     UpdateFun =
         fun(#chain{authenticators = Authenticators} = Chain) ->
         fun(#chain{authenticators = Authenticators} = Chain) ->

+ 2 - 0
apps/emqx/test/emqx_authentication_SUITE.erl

@@ -108,10 +108,12 @@ t_chain(Config) when is_list(Config) ->
     % CRUD of authentication chain
     % CRUD of authentication chain
     ChainName = 'test',
     ChainName = 'test',
     ?assertMatch({ok, []}, ?AUTHN:list_chains()),
     ?assertMatch({ok, []}, ?AUTHN:list_chains()),
+    ?assertMatch({ok, []}, ?AUTHN:list_chain_names()),
     ?assertMatch({ok, #{name := ChainName, authenticators := []}}, ?AUTHN:create_chain(ChainName)),
     ?assertMatch({ok, #{name := ChainName, authenticators := []}}, ?AUTHN:create_chain(ChainName)),
     ?assertEqual({error, {already_exists, {chain, ChainName}}}, ?AUTHN:create_chain(ChainName)),
     ?assertEqual({error, {already_exists, {chain, ChainName}}}, ?AUTHN:create_chain(ChainName)),
     ?assertMatch({ok, #{name := ChainName, authenticators := []}}, ?AUTHN:lookup_chain(ChainName)),
     ?assertMatch({ok, #{name := ChainName, authenticators := []}}, ?AUTHN:lookup_chain(ChainName)),
     ?assertMatch({ok, [#{name := ChainName}]}, ?AUTHN:list_chains()),
     ?assertMatch({ok, [#{name := ChainName}]}, ?AUTHN:list_chains()),
+    ?assertEqual({ok, [ChainName]}, ?AUTHN:list_chain_names()),
     ?assertEqual(ok, ?AUTHN:delete_chain(ChainName)),
     ?assertEqual(ok, ?AUTHN:delete_chain(ChainName)),
     ?assertMatch({error, {not_found, {chain, ChainName}}}, ?AUTHN:lookup_chain(ChainName)),
     ?assertMatch({error, {not_found, {chain, ChainName}}}, ?AUTHN:lookup_chain(ChainName)),
     ok.
     ok.

+ 71 - 53
apps/emqx_authn/src/emqx_authn_api.erl

@@ -498,37 +498,37 @@ authenticator(delete, #{bindings := #{id := AuthenticatorID}}) ->
 
 
 listener_authenticators(post, #{bindings := #{listener_id := ListenerID}, body := Config}) ->
 listener_authenticators(post, #{bindings := #{listener_id := ListenerID}, body := Config}) ->
     with_listener(ListenerID,
     with_listener(ListenerID,
-                  fun(Type, Name) ->
+                  fun(Type, Name, ChainName) ->
                         create_authenticator([listeners, Type, Name, authentication],
                         create_authenticator([listeners, Type, Name, authentication],
-                                          ListenerID,
+                                          ChainName,
                                           Config)
                                           Config)
                   end);
                   end);
 
 
 listener_authenticators(get, #{bindings := #{listener_id := ListenerID}}) ->
 listener_authenticators(get, #{bindings := #{listener_id := ListenerID}}) ->
     with_listener(ListenerID,
     with_listener(ListenerID,
-                  fun(Type, Name) ->
+                  fun(Type, Name, _) ->
                         list_authenticators([listeners, Type, Name, authentication])
                         list_authenticators([listeners, Type, Name, authentication])
                   end).
                   end).
 
 
 listener_authenticator(get, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}}) ->
 listener_authenticator(get, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}}) ->
     with_listener(ListenerID,
     with_listener(ListenerID,
-                  fun(Type, Name) ->
+                  fun(Type, Name, _) ->
                         list_authenticator([listeners, Type, Name, authentication],
                         list_authenticator([listeners, Type, Name, authentication],
                                        AuthenticatorID)
                                        AuthenticatorID)
                   end);
                   end);
 listener_authenticator(put, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}, body := Config}) ->
 listener_authenticator(put, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}, body := Config}) ->
     with_listener(ListenerID,
     with_listener(ListenerID,
-                  fun(Type, Name) ->
+                  fun(Type, Name, ChainName) ->
                         update_authenticator([listeners, Type, Name, authentication],
                         update_authenticator([listeners, Type, Name, authentication],
-                                             ListenerID,
+                                             ChainName,
                                              AuthenticatorID,
                                              AuthenticatorID,
                                              Config)
                                              Config)
                   end);
                   end);
 listener_authenticator(delete, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}}) ->
 listener_authenticator(delete, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}}) ->
     with_listener(ListenerID,
     with_listener(ListenerID,
-                  fun(Type, Name) ->
+                  fun(Type, Name, ChainName) ->
                         delete_authenticator([listeners, Type, Name, authentication],
                         delete_authenticator([listeners, Type, Name, authentication],
-                                             ListenerID,
+                                             ChainName,
                                              AuthenticatorID)
                                              AuthenticatorID)
                   end).
                   end).
 
 
@@ -539,9 +539,9 @@ authenticator_move(post, #{bindings := #{id := _}, body := _}) ->
 
 
 listener_authenticator_move(post, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}, body := #{<<"position">> := Position}}) ->
 listener_authenticator_move(post, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}, body := #{<<"position">> := Position}}) ->
     with_listener(ListenerID,
     with_listener(ListenerID,
-                  fun(Type, Name) ->
+                  fun(Type, Name, ChainName) ->
                         move_authenitcator([listeners, Type, Name, authentication],
                         move_authenitcator([listeners, Type, Name, authentication],
-                                           ListenerID,
+                                           ChainName,
                                            AuthenticatorID,
                                            AuthenticatorID,
                                            Position)
                                            Position)
                   end);
                   end);
@@ -557,11 +557,13 @@ authenticator_import_users(post, #{bindings := #{id := _}, body := _}) ->
     serialize_error({missing_parameter, filename}).
     serialize_error({missing_parameter, filename}).
 
 
 listener_authenticator_import_users(post, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}, body := #{<<"filename">> := Filename}}) ->
 listener_authenticator_import_users(post, #{bindings := #{listener_id := ListenerID, id := AuthenticatorID}, body := #{<<"filename">> := Filename}}) ->
-    ChainName = to_atom(ListenerID),
-    case ?AUTHN:import_users(ChainName, AuthenticatorID, Filename) of
-        ok -> {204};
-        {error, Reason} -> serialize_error(Reason)
-    end;
+    with_chain(ListenerID,
+                    fun(ChainName) ->
+                        case ?AUTHN:import_users(ChainName, AuthenticatorID, Filename) of
+                            ok -> {204};
+                            {error, Reason} -> serialize_error(Reason)
+                        end
+                    end);
 listener_authenticator_import_users(post, #{bindings := #{listener_id := _, id := _}, body := _}) ->
 listener_authenticator_import_users(post, #{bindings := #{listener_id := _, id := _}, body := _}) ->
     serialize_error({missing_parameter, filename}).
     serialize_error({missing_parameter, filename}).
 
 
@@ -580,23 +582,38 @@ authenticator_user(delete, #{bindings := #{id := AuthenticatorID, user_id := Use
 
 
 listener_authenticator_users(post, #{bindings := #{listener_id := ListenerID,
 listener_authenticator_users(post, #{bindings := #{listener_id := ListenerID,
                              id := AuthenticatorID}, body := UserInfo}) ->
                              id := AuthenticatorID}, body := UserInfo}) ->
-    add_user(ListenerID, AuthenticatorID, UserInfo);
+    with_chain(ListenerID,
+                    fun(ChainName) ->
+                        add_user(ChainName, AuthenticatorID, UserInfo)
+                    end);
 listener_authenticator_users(get, #{bindings := #{listener_id := ListenerID,
 listener_authenticator_users(get, #{bindings := #{listener_id := ListenerID,
                             id := AuthenticatorID}, query_string := PageParams}) ->
                             id := AuthenticatorID}, query_string := PageParams}) ->
-    list_users(ListenerID, AuthenticatorID, PageParams).
+    with_chain(ListenerID,
+                    fun(ChainName) ->
+                        list_users(ChainName, AuthenticatorID, PageParams)
+                    end).
 
 
 listener_authenticator_user(put, #{bindings := #{listener_id := ListenerID,
 listener_authenticator_user(put, #{bindings := #{listener_id := ListenerID,
                             id := AuthenticatorID,
                             id := AuthenticatorID,
                             user_id := UserID}, body := UserInfo}) ->
                             user_id := UserID}, body := UserInfo}) ->
-    update_user(ListenerID, AuthenticatorID, UserID, UserInfo);
+    with_chain(ListenerID,
+                    fun(ChainName) ->
+                        update_user(ChainName, AuthenticatorID, UserID, UserInfo)
+                    end);
 listener_authenticator_user(get, #{bindings := #{listener_id := ListenerID,
 listener_authenticator_user(get, #{bindings := #{listener_id := ListenerID,
                             id := AuthenticatorID,
                             id := AuthenticatorID,
                             user_id := UserID}}) ->
                             user_id := UserID}}) ->
-    find_user(ListenerID, AuthenticatorID, UserID);
+    with_chain(ListenerID,
+                    fun(ChainName) ->
+                        find_user(ChainName, AuthenticatorID, UserID)
+                    end);
 listener_authenticator_user(delete, #{bindings := #{listener_id := ListenerID,
 listener_authenticator_user(delete, #{bindings := #{listener_id := ListenerID,
                                id := AuthenticatorID,
                                id := AuthenticatorID,
                                user_id := UserID}}) ->
                                user_id := UserID}}) ->
-    delete_user(ListenerID, AuthenticatorID, UserID).
+    with_chain(ListenerID,
+                    fun(ChainName) ->
+                        delete_user(ChainName, AuthenticatorID, UserID)
+                    end).
 
 
 %%------------------------------------------------------------------------------
 %%------------------------------------------------------------------------------
 %% Internal functions
 %% Internal functions
@@ -604,27 +621,41 @@ listener_authenticator_user(delete, #{bindings := #{listener_id := ListenerID,
 
 
 with_listener(ListenerID, Fun) ->
 with_listener(ListenerID, Fun) ->
     case find_listener(ListenerID) of
     case find_listener(ListenerID) of
-        {ok, {Type, Name}} ->
-           Fun(Type, Name);
+        {ok, {BType, BName}} ->
+            Type = binary_to_existing_atom(BType),
+            Name = binary_to_existing_atom(BName),
+            ChainName = binary_to_atom(ListenerID),
+            Fun(Type, Name, ChainName);
         {error, Reason} ->
         {error, Reason} ->
             serialize_error(Reason)
             serialize_error(Reason)
     end.
     end.
 
 
 find_listener(ListenerID) ->
 find_listener(ListenerID) ->
-    case emqx_listeners:parse_listener_id(ListenerID) of
-        {error, _} ->
-            {error, {not_found, {listener, ListenerID}}};
-        {Type, Name} ->
-            case emqx_config:find([listeners, Type, Name]) of
-                {not_found, _, _} ->
-                    {error, {not_found, {listener, ListenerID}}};
+    case binary:split(ListenerID, <<":">>) of
+        [BType, BName] ->
+            case emqx_config:find([listeners, BType, BName]) of
                 {ok, _} ->
                 {ok, _} ->
-                    {ok, {Type, Name}}
-            end
+                    {ok, {BType, BName}};
+                {not_found, _, _} ->
+                    {error, {not_found, {listener, ListenerID}}}
+            end;
+        _ ->
+            {error, {not_found, {listener, ListenerID}}}
+    end.
+
+with_chain(ListenerID, Fun) ->
+    {ok, ChainNames} = ?AUTHN:list_chain_names(),
+    ListenerChainName =
+        [ Name || Name <- ChainNames, atom_to_binary(Name) =:= ListenerID ],
+    case ListenerChainName of
+        [ChainName] ->
+            Fun(ChainName);
+        _ ->
+            serialize_error({not_found, {chain, ListenerID}})
     end.
     end.
 
 
 create_authenticator(ConfKeyPath, ChainName, Config) ->
 create_authenticator(ConfKeyPath, ChainName, Config) ->
-    case update_config(ConfKeyPath, {create_authenticator, to_atom(ChainName), Config}) of
+    case update_config(ConfKeyPath, {create_authenticator, ChainName, Config}) of
         {ok, #{post_config_update := #{?AUTHN := #{id := ID}},
         {ok, #{post_config_update := #{?AUTHN := #{id := ID}},
             raw_config := AuthenticatorsConfig}} ->
             raw_config := AuthenticatorsConfig}} ->
             {ok, AuthenticatorConfig} = find_config(ID, AuthenticatorsConfig),
             {ok, AuthenticatorConfig} = find_config(ID, AuthenticatorsConfig),
@@ -649,7 +680,7 @@ list_authenticator(ConfKeyPath, AuthenticatorID) ->
     end.
     end.
 
 
 update_authenticator(ConfKeyPath, ChainName, AuthenticatorID, Config) ->
 update_authenticator(ConfKeyPath, ChainName, AuthenticatorID, Config) ->
-    case update_config(ConfKeyPath, {update_authenticator, to_atom(ChainName), AuthenticatorID, Config}) of
+    case update_config(ConfKeyPath, {update_authenticator, ChainName, AuthenticatorID, Config}) of
         {ok, #{post_config_update := #{?AUTHN := #{id := ID}},
         {ok, #{post_config_update := #{?AUTHN := #{id := ID}},
                raw_config := AuthenticatorsConfig}} ->
                raw_config := AuthenticatorsConfig}} ->
             {ok, AuthenticatorConfig} = find_config(ID, AuthenticatorsConfig),
             {ok, AuthenticatorConfig} = find_config(ID, AuthenticatorsConfig),
@@ -658,8 +689,7 @@ update_authenticator(ConfKeyPath, ChainName, AuthenticatorID, Config) ->
             serialize_error(Reason)
             serialize_error(Reason)
     end.
     end.
 
 
-delete_authenticator(ConfKeyPath, ChainName0, AuthenticatorID) ->
-    ChainName = to_atom(ChainName0),
+delete_authenticator(ConfKeyPath, ChainName, AuthenticatorID) ->
     case update_config(ConfKeyPath, {delete_authenticator, ChainName, AuthenticatorID}) of
     case update_config(ConfKeyPath, {delete_authenticator, ChainName, AuthenticatorID}) of
         {ok, _} ->
         {ok, _} ->
             {204};
             {204};
@@ -667,8 +697,7 @@ delete_authenticator(ConfKeyPath, ChainName0, AuthenticatorID) ->
             serialize_error(Reason)
             serialize_error(Reason)
     end.
     end.
 
 
-move_authenitcator(ConfKeyPath, ChainName0, AuthenticatorID, Position) ->
-    ChainName = to_atom(ChainName0),
+move_authenitcator(ConfKeyPath, ChainName, AuthenticatorID, Position) ->
     case parse_position(Position) of
     case parse_position(Position) of
         {ok, NPosition} ->
         {ok, NPosition} ->
             case update_config(ConfKeyPath, {move_authenticator, ChainName, AuthenticatorID, NPosition}) of
             case update_config(ConfKeyPath, {move_authenticator, ChainName, AuthenticatorID, NPosition}) of
@@ -681,8 +710,7 @@ move_authenitcator(ConfKeyPath, ChainName0, AuthenticatorID, Position) ->
             serialize_error(Reason)
             serialize_error(Reason)
     end.
     end.
 
 
-add_user(ChainName0, AuthenticatorID, #{<<"user_id">> := UserID, <<"password">> := Password} = UserInfo) ->
-    ChainName = to_atom(ChainName0),
+add_user(ChainName, AuthenticatorID, #{<<"user_id">> := UserID, <<"password">> := Password} = UserInfo) ->
     IsSuperuser = maps:get(<<"is_superuser">>, UserInfo, false),
     IsSuperuser = maps:get(<<"is_superuser">>, UserInfo, false),
     case ?AUTHN:add_user(ChainName, AuthenticatorID, #{ user_id => UserID
     case ?AUTHN:add_user(ChainName, AuthenticatorID, #{ user_id => UserID
                                                       , password => Password
                                                       , password => Password
@@ -697,8 +725,7 @@ add_user(_, _, #{<<"user_id">> := _}) ->
 add_user(_, _, _) ->
 add_user(_, _, _) ->
     serialize_error({missing_parameter, user_id}).
     serialize_error({missing_parameter, user_id}).
 
 
-update_user(ChainName0, AuthenticatorID, UserID, UserInfo) ->
-    ChainName = to_atom(ChainName0),
+update_user(ChainName, AuthenticatorID, UserID, UserInfo) ->
     case maps:with([<<"password">>, <<"is_superuser">>], UserInfo) =:= #{} of
     case maps:with([<<"password">>, <<"is_superuser">>], UserInfo) =:= #{} of
         true ->
         true ->
             serialize_error({missing_parameter, password});
             serialize_error({missing_parameter, password});
@@ -711,8 +738,7 @@ update_user(ChainName0, AuthenticatorID, UserID, UserInfo) ->
             end
             end
     end.
     end.
 
 
-find_user(ChainName0, AuthenticatorID, UserID) ->
-    ChainName = to_atom(ChainName0),
+find_user(ChainName, AuthenticatorID, UserID) ->
     case ?AUTHN:lookup_user(ChainName, AuthenticatorID, UserID) of
     case ?AUTHN:lookup_user(ChainName, AuthenticatorID, UserID) of
         {ok, User} ->
         {ok, User} ->
             {200, User};
             {200, User};
@@ -720,8 +746,7 @@ find_user(ChainName0, AuthenticatorID, UserID) ->
             serialize_error({user_error, Reason})
             serialize_error({user_error, Reason})
     end.
     end.
 
 
-delete_user(ChainName0, AuthenticatorID, UserID) ->
-    ChainName = to_atom(ChainName0),
+delete_user(ChainName, AuthenticatorID, UserID) ->
     case ?AUTHN:delete_user(ChainName, AuthenticatorID, UserID) of
     case ?AUTHN:delete_user(ChainName, AuthenticatorID, UserID) of
         ok ->
         ok ->
             {204};
             {204};
@@ -729,8 +754,7 @@ delete_user(ChainName0, AuthenticatorID, UserID) ->
             serialize_error({user_error, Reason})
             serialize_error({user_error, Reason})
     end.
     end.
 
 
-list_users(ChainName0, AuthenticatorID, PageParams) ->
-    ChainName = to_atom(ChainName0),
+list_users(ChainName, AuthenticatorID, PageParams) ->
     case ?AUTHN:list_users(ChainName, AuthenticatorID, PageParams) of
     case ?AUTHN:list_users(ChainName, AuthenticatorID, PageParams) of
         {ok, Users} ->
         {ok, Users} ->
             {200, Users};
             {200, Users};
@@ -834,12 +858,6 @@ parse_position(_) ->
 ensure_list(M) when is_map(M) -> [M];
 ensure_list(M) when is_map(M) -> [M];
 ensure_list(L) when is_list(L) -> L.
 ensure_list(L) when is_list(L) -> L.
 
 
-% TODO: fix atom leak!
-to_atom(B) when is_binary(B) ->
-    binary_to_atom(B);
-to_atom(A) when is_atom(A) ->
-    A.
-
 binfmt(Fmt, Args) -> iolist_to_binary(io_lib:format(Fmt, Args)).
 binfmt(Fmt, Args) -> iolist_to_binary(io_lib:format(Fmt, Args)).
 
 
 authenticator_array_example() ->
 authenticator_array_example() ->

+ 1 - 1
apps/emqx_dashboard/test/emqx_swagger_requestBody_SUITE.erl

@@ -187,7 +187,7 @@ t_api_spec(_Config) ->
 
 
     Filter0 = filter(Spec0, Path),
     Filter0 = filter(Spec0, Path),
     ?assertMatch(
     ?assertMatch(
-        {ok, #{body := ActualBody}},
+        {ok, #{body := #{<<"timeout">> := <<"infinity">>}}},
         trans_requestBody(Path, Body, Filter0)),
         trans_requestBody(Path, Body, Filter0)),
 
 
     {Spec1, _} = emqx_dashboard_swagger:spec(?MODULE, #{check_schema => true, translate_body => true}),
     {Spec1, _} = emqx_dashboard_swagger:spec(?MODULE, #{check_schema => true, translate_body => true}),