Просмотр исходного кода

Merge pull request #8243 from HJianBo/gw-fixes

fix(gw): enhance the authn resources managing logic
JianBo He 3 лет назад
Родитель
Сommit
34fe5e67e7

+ 1 - 1
apps/emqx_gateway/src/emqx_gateway.app.src

@@ -4,7 +4,7 @@
     {vsn, "0.1.0"},
     {registered, []},
     {mod, {emqx_gateway_app, []}},
-    {applications, [kernel, stdlib, grpc, emqx]},
+    {applications, [kernel, stdlib, grpc, emqx, emqx_authn]},
     {env, []},
     {modules, []},
     {licenses, ["Apache 2.0"]},

+ 1 - 0
apps/emqx_gateway/src/emqx_gateway_ctx.erl

@@ -29,6 +29,7 @@
     #{
         %% Gateway Name
         gwname := gateway_name(),
+        %% FIXME: use process name instead of pid()
         %% The ConnectionManager PID
         cm := pid()
     }.

+ 3 - 0
apps/emqx_gateway/src/emqx_gateway_http.erl

@@ -556,6 +556,9 @@ with_gateway(GwName0, Fun) ->
             return_http_error(404, "Resource not found. path: " ++ Path);
         error:{badmatch, {error, einval}} ->
             return_http_error(400, "Invalid bind address");
+        error:{badauth, Reason} ->
+            Reason1 = emqx_gateway_utils:stringfy(Reason),
+            return_http_error(400, ["Bad authentication config: ", Reason1]);
         Class:Reason:Stk ->
             ?SLOG(error, #{
                 msg => "uncaught_exception",

+ 108 - 115
apps/emqx_gateway/src/emqx_gateway_insta_sup.erl

@@ -45,7 +45,6 @@
     name :: gateway_name(),
     config :: emqx_config:config(),
     ctx :: emqx_gateway_ctx:context(),
-    authns :: [{emqx_authentication:chain_name(), map()}],
     status :: stopped | running,
     child_pids :: [pid()],
     gw_state :: emqx_gateway_impl:state() | undefined,
@@ -101,13 +100,14 @@ init([Gateway, Ctx, _GwDscrptr]) ->
     State = #state{
         ctx = Ctx,
         name = GwName,
-        authns = [],
         config = Config,
         child_pids = [],
         status = stopped,
         created_at = erlang:system_time(millisecond)
     },
-    case maps:get(enable, Config, true) of
+    Enable = maps:get(enable, Config, true),
+    ok = ensure_authn_running(State, Enable),
+    case Enable of
         false ->
             ?SLOG(info, #{
                 msg => "skip_to_start_gateway_due_to_disabled",
@@ -115,11 +115,11 @@ init([Gateway, Ctx, _GwDscrptr]) ->
             }),
             {ok, State};
         true ->
-            case cb_gateway_load(ensure_authn_created(State)) of
+            case cb_gateway_load(State) of
                 {error, Reason} ->
                     {stop, Reason};
-                {ok, NState1} ->
-                    {ok, NState1}
+                {ok, NState} ->
+                    {ok, NState}
             end
     end.
 
@@ -130,7 +130,8 @@ handle_call(disable, _From, State = #state{status = Status}) ->
         running ->
             case cb_gateway_unload(State) of
                 {ok, NState} ->
-                    {reply, ok, disable_authns(NState)};
+                    ok = disable_authns(State),
+                    {reply, ok, NState};
                 {error, Reason} ->
                     {reply, {error, Reason}, State}
             end;
@@ -140,11 +141,16 @@ handle_call(disable, _From, State = #state{status = Status}) ->
 handle_call(enable, _From, State = #state{status = Status}) ->
     case Status of
         stopped ->
-            case cb_gateway_load(ensure_authn_running(State)) of
+            case ensure_authn_running(State) of
+                ok ->
+                    case cb_gateway_load(State) of
+                        {error, Reason} ->
+                            {reply, {error, Reason}, State};
+                        {ok, NState1} ->
+                            {reply, ok, NState1}
+                    end;
                 {error, Reason} ->
-                    {reply, {error, Reason}, State};
-                {ok, NState1} ->
-                    {reply, ok, NState1}
+                    {reply, {error, Reason}, State}
             end;
         _ ->
             {reply, {error, already_started}, State}
@@ -210,7 +216,7 @@ handle_info(Info, State) ->
 
 terminate(_Reason, State = #state{child_pids = Pids}) ->
     Pids /= [] andalso (_ = cb_gateway_unload(State)),
-    _ = do_deinit_authn(State#state.authns),
+    _ = remove_all_authns(State),
     ok.
 
 code_change(_OldVsn, State, _Extra) ->
@@ -236,65 +242,73 @@ detailed_gateway_info(State) ->
 %%--------------------------------------------------------------------
 %% Authn resources managing funcs
 
-%% ensure authentication chain, authenticator created and keep its status
-%% as expected
-ensure_authn_created(State = #state{ctx = Ctx, name = GwName, config = Config}) ->
-    Authns = init_authn(GwName, Config),
-    AuthnNames = lists:map(fun({ChainName, _}) -> ChainName end, Authns),
-    State#state{authns = Authns, ctx = maps:put(auth, AuthnNames, Ctx)}.
+pipeline(_, []) ->
+    ok;
+pipeline(Fun, [Args | More]) ->
+    case Fun(Args) of
+        ok ->
+            pipeline(Fun, More);
+        {error, Reason} ->
+            {error, Reason}
+    end.
 
-%% temporarily disable authenticators after gateway disabled
-disable_authns(State = #state{ctx = Ctx, authns = Authns}) ->
-    lists:foreach(
+%% ensure authentication chain, authenticator created and keep its configured
+%% status
+ensure_authn_running(#state{name = GwName, config = Config}) ->
+    pipeline(
         fun({ChainName, AuthConf}) ->
-            TempConf = maps:put(enable, false, AuthConf),
-            do_update_authenticator(ChainName, TempConf)
+            ensure_authenticator_created(ChainName, AuthConf)
         end,
-        Authns
-    ),
-    State#state{ctx = maps:remove(auth, Ctx)}.
+        authns(GwName, Config)
+    ).
 
-%% keep authenticators running as expected
-ensure_authn_running(State = #state{ctx = Ctx, authns = Authns}) ->
-    AuthnNames = lists:map(
+%% ensure authentication chain, authenticator created and keep its status
+%% as given
+ensure_authn_running(#state{name = GwName, config = Config}, Enable) ->
+    pipeline(
         fun({ChainName, AuthConf}) ->
-            ok = do_update_authenticator(ChainName, AuthConf),
-            ChainName
+            ensure_authenticator_created(ChainName, AuthConf#{enable => Enable})
         end,
-        Authns
-    ),
-    State#state{ctx = maps:put(auth, AuthnNames, Ctx)}.
+        authns(GwName, Config)
+    ).
 
-do_update_authenticator({ChainName, Confs}) ->
-    do_update_authenticator(ChainName, Confs).
+%% temporarily disable authenticators after gateway disabled
+disable_authns(State) ->
+    ensure_authn_running(State, false).
 
-do_update_authenticator(ChainName, Confs) ->
-    {ok, [#{id := AuthenticatorId}]} = emqx_authentication:list_authenticators(ChainName),
-    {ok, _} = emqx_authentication:update_authenticator(ChainName, AuthenticatorId, Confs),
-    ok.
+%% remove all authns if gateway unloaded
+remove_all_authns(#state{name = GwName, config = Config}) ->
+    lists:foreach(
+        fun({ChainName, _}) ->
+            case emqx_authentication:delete_chain(ChainName) of
+                ok ->
+                    ok;
+                {error, {not_found, _}} ->
+                    ok;
+                {error, Reason} ->
+                    ?SLOG(error, #{
+                        msg => "failed_to_clean_authn_chain",
+                        chain_name => ChainName,
+                        reason => Reason
+                    })
+            end
+        end,
+        authns(GwName, Config)
+    ).
 
-%% There are two layer authentication configs
-%%       stomp.authn
-%%           /                   \
-%%   listeners.tcp.default.authn  *.ssl.default.authn
-%%
-init_authn(GwName, Config) ->
-    Authns = authns(GwName, Config),
-    try
-        ok = do_init_authn(Authns),
-        Authns
-    catch
-        throw:Reason = {badauth, _} ->
-            do_deinit_authn(Authns),
-            throw(Reason)
+ensure_authenticator_created(ChainName, Confs) ->
+    case emqx_authentication:list_authenticators(ChainName) of
+        {ok, [#{id := AuthenticatorId}]} ->
+            case emqx_authentication:update_authenticator(ChainName, AuthenticatorId, Confs) of
+                {ok, _} -> ok;
+                {error, Reason} -> {error, {badauth, Reason}}
+            end;
+        {ok, []} ->
+            do_create_authenticator(ChainName, Confs);
+        {error, {not_found, {chain, _}}} ->
+            do_create_authenticator(ChainName, Confs)
     end.
 
-do_init_authn([]) ->
-    ok;
-do_init_authn([{ChainName, AuthConf} | More]) when is_map(AuthConf) ->
-    ok = do_create_authn_chain(ChainName, AuthConf),
-    do_init_authn(More).
-
 authns(GwName, Config) ->
     Listeners = maps:to_list(maps:get(listeners, Config, #{})),
     Authns0 =
@@ -319,7 +333,7 @@ authns(GwName, Config) ->
 authn_conf(Conf) ->
     maps:get(authentication, Conf, undefined).
 
-do_create_authn_chain(ChainName, AuthConf) ->
+do_create_authenticator(ChainName, AuthConf) ->
     case emqx_authentication:create_authenticator(ChainName, AuthConf) of
         {ok, _} ->
             ok;
@@ -330,28 +344,9 @@ do_create_authn_chain(ChainName, AuthConf) ->
                 reason => Reason,
                 config => AuthConf
             }),
-            throw({badauth, Reason})
+            {error, {badauth, Reason}}
     end.
 
-do_deinit_authn(Authns) ->
-    lists:foreach(
-        fun({ChainName, _}) ->
-            case emqx_authentication:delete_chain(ChainName) of
-                ok ->
-                    ok;
-                {error, {not_found, _}} ->
-                    ok;
-                {error, Reason} ->
-                    ?SLOG(error, #{
-                        msg => "failed_to_clean_authn_chain",
-                        chain_name => ChainName,
-                        reason => Reason
-                    })
-            end
-        end,
-        Authns
-    ).
-
 do_update_one_by_one(
     NCfg,
     State = #state{
@@ -365,53 +360,53 @@ do_update_one_by_one(
     OAuthns = authns(GwName, OCfg),
     NAuthns = authns(GwName, NCfg),
 
+    ok = remove_deleted_authns(NAuthns, OAuthns),
+
     case {Status, NEnable} of
         {stopped, true} ->
-            NState = State#state{config = NCfg},
-            cb_gateway_load(ensure_authn_running(NState));
+            case ensure_authn_running(State#state{config = NCfg}) of
+                ok ->
+                    cb_gateway_load(State#state{config = NCfg});
+                {error, Reason} ->
+                    {error, Reason}
+            end;
         {stopped, false} ->
-            {ok, State#state{config = NCfg}};
+            case disable_authns(State#state{config = NCfg}) of
+                ok ->
+                    {ok, State#state{config = NCfg}};
+                {error, Reason} ->
+                    {error, Reason}
+            end;
         {running, true} ->
-            {Added, Updated, Deleted} = diff_auths(NAuthns, OAuthns),
-            _ = do_deinit_authn(Deleted),
-            _ = do_init_authn(Added),
-            _ = lists:foreach(fun do_update_authenticator/1, Updated),
-            NState = State#state{authns = NAuthns},
-            %% TODO: minimum impact update ???
-            cb_gateway_update(NCfg, NState);
+            %% FIXME: minimum impact update
+            case ensure_authn_running(State#state{config = NCfg}) of
+                ok ->
+                    cb_gateway_update(NCfg, State);
+                {error, Reason} ->
+                    {error, Reason}
+            end;
         {running, false} ->
             case cb_gateway_unload(State) of
-                {ok, NState} -> {ok, disable_authns(NState#state{config = NCfg})};
-                {error, Reason} -> {error, Reason}
+                {ok, NState} ->
+                    ok = disable_authns(State#state{config = NCfg}),
+                    {ok, NState#state{config = NCfg}};
+                {error, Reason} ->
+                    {error, Reason}
             end;
         _ ->
             throw(nomatch)
     end.
 
-diff_auths(NAuthns, OAuthns) ->
+remove_deleted_authns(NAuthns, OAuthns) ->
     NNames = proplists:get_keys(NAuthns),
     ONames = proplists:get_keys(OAuthns),
-    AddedNames = NNames -- ONames,
     DeletedNames = ONames -- NNames,
-    BothNames = NNames -- AddedNames,
-    UpdatedNames = lists:foldl(
-        fun(Name, Acc) ->
-            case
-                proplists:get_value(Name, NAuthns) ==
-                    proplists:get_value(Name, OAuthns)
-            of
-                true -> Acc;
-                false -> [Name | Acc]
-            end
+    lists:foreach(
+        fun(ChainName) ->
+            _ = emqx_authentication:delete_chain(ChainName)
         end,
-        [],
-        BothNames
-    ),
-    {
-        lists:filter(fun({Name, _}) -> lists:member(Name, AddedNames) end, NAuthns),
-        lists:filter(fun({Name, _}) -> lists:member(Name, UpdatedNames) end, NAuthns),
-        lists:filter(fun({Name, _}) -> lists:member(Name, DeletedNames) end, OAuthns)
-    }.
+        DeletedNames
+    ).
 
 cb_gateway_unload(
     State = #state{
@@ -461,7 +456,6 @@ cb_gateway_load(
             {ok, ChildPidOrSpecs, GwState} ->
                 ChildPids = start_child_process(ChildPidOrSpecs),
                 {ok, State#state{
-                    ctx = Ctx,
                     status = running,
                     child_pids = ChildPids,
                     gw_state = GwState,
@@ -475,7 +469,6 @@ cb_gateway_load(
                 msg => "load_gateway_crashed",
                 gateway_name => GwName,
                 gateway => Gateway,
-                ctx => Ctx,
                 reason => {Class, Reason1},
                 stacktrace => Stk
             }),

+ 2 - 2
apps/emqx_gateway/test/emqx_coap_SUITE.erl

@@ -57,14 +57,14 @@ all() -> emqx_common_test_helpers:all(?MODULE).
 
 init_per_suite(Config) ->
     ok = emqx_common_test_helpers:load_config(emqx_gateway_schema, ?CONF_DEFAULT),
-    emqx_mgmt_api_test_util:init_suite([emqx_gateway]),
+    emqx_mgmt_api_test_util:init_suite([emqx_authn, emqx_gateway]),
     ok = meck:new(emqx_access_control, [passthrough, no_history, no_link]),
     Config.
 
 end_per_suite(_) ->
     meck:unload(emqx_access_control),
     {ok, _} = emqx:remove_config([<<"gateway">>, <<"coap">>]),
-    emqx_mgmt_api_test_util:end_suite([emqx_gateway]).
+    emqx_mgmt_api_test_util:end_suite([emqx_gateway, emqx_authn]).
 
 init_per_testcase(t_connection_with_authn_failed, Config) ->
     ok = meck:expect(

+ 2 - 2
apps/emqx_gateway/test/emqx_coap_api_SUITE.erl

@@ -57,12 +57,12 @@ all() ->
 
 init_per_suite(Config) ->
     ok = emqx_common_test_helpers:load_config(emqx_gateway_schema, ?CONF_DEFAULT),
-    emqx_mgmt_api_test_util:init_suite([emqx_gateway]),
+    emqx_mgmt_api_test_util:init_suite([emqx_authn, emqx_gateway]),
     Config.
 
 end_per_suite(Config) ->
     {ok, _} = emqx:remove_config([<<"gateway">>, <<"coap">>]),
-    emqx_mgmt_api_test_util:end_suite([emqx_gateway]),
+    emqx_mgmt_api_test_util:end_suite([emqx_gateway, emqx_authn]),
     Config.
 
 %%--------------------------------------------------------------------

+ 2 - 2
apps/emqx_gateway/test/emqx_exproto_SUITE.erl

@@ -73,12 +73,12 @@ metrics() ->
 init_per_group(GrpName, Cfg) ->
     put(grpname, GrpName),
     Svrs = emqx_exproto_echo_svr:start(),
-    emqx_common_test_helpers:start_apps([emqx_gateway], fun set_special_cfg/1),
+    emqx_common_test_helpers:start_apps([emqx_authn, emqx_gateway], fun set_special_cfg/1),
     [{servers, Svrs}, {listener_type, GrpName} | Cfg].
 
 end_per_group(_, Cfg) ->
     emqx_config:erase(gateway),
-    emqx_common_test_helpers:stop_apps([emqx_gateway]),
+    emqx_common_test_helpers:stop_apps([emqx_gateway, emqx_authn]),
     emqx_exproto_echo_svr:stop(proplists:get_value(servers, Cfg)).
 
 set_special_cfg(emqx_gateway) ->

+ 2 - 2
apps/emqx_gateway/test/emqx_gateway_authz_SUITE.erl

@@ -69,7 +69,7 @@ init_per_suite(Config) ->
     init_gateway_conf(),
     meck:new(emqx_authz_file, [non_strict, passthrough, no_history, no_link]),
     meck:expect(emqx_authz_file, create, fun(S) -> S end),
-    emqx_mgmt_api_test_util:init_suite([emqx_conf, emqx_authz, emqx_gateway]),
+    emqx_mgmt_api_test_util:init_suite([emqx_conf, emqx_authz, emqx_authn, emqx_gateway]),
     application:ensure_all_started(cowboy),
     emqx_gateway_auth_ct:start(),
     Config.
@@ -79,7 +79,7 @@ end_per_suite(Config) ->
     emqx_gateway_auth_ct:stop(),
     ok = emqx_authz_test_lib:restore_authorizers(),
     emqx_config:erase(gateway),
-    emqx_mgmt_api_test_util:end_suite([cowboy, emqx_authz, emqx_gateway]),
+    emqx_mgmt_api_test_util:end_suite([cowboy, emqx_authz, emqx_authn, emqx_gateway]),
     Config.
 
 init_per_testcase(_Case, Config) ->

+ 2 - 2
apps/emqx_gateway/test/emqx_gateway_registry_SUITE.erl

@@ -38,11 +38,11 @@ all() -> emqx_common_test_helpers:all(?MODULE).
 
 init_per_suite(Cfg) ->
     ok = emqx_common_test_helpers:load_config(emqx_gateway_schema, ?CONF_DEFAULT),
-    emqx_common_test_helpers:start_apps([emqx_gateway]),
+    emqx_common_test_helpers:start_apps([emqx_authn, emqx_gateway]),
     Cfg.
 
 end_per_suite(_Cfg) ->
-    emqx_common_test_helpers:stop_apps([emqx_gateway]),
+    emqx_common_test_helpers:stop_apps([emqx_gateway, emqx_authn]),
     ok.
 
 %%--------------------------------------------------------------------

+ 2 - 2
apps/emqx_gateway/test/emqx_lwm2m_SUITE.erl

@@ -155,13 +155,13 @@ groups() ->
 init_per_suite(Config) ->
     %% load application first for minirest api searching
     application:load(emqx_gateway),
-    emqx_mgmt_api_test_util:init_suite([emqx_conf]),
+    emqx_mgmt_api_test_util:init_suite([emqx_conf, emqx_authn]),
     Config.
 
 end_per_suite(Config) ->
     timer:sleep(300),
     {ok, _} = emqx_conf:remove([<<"gateway">>, <<"lwm2m">>], #{}),
-    emqx_mgmt_api_test_util:end_suite([emqx_conf]),
+    emqx_mgmt_api_test_util:end_suite([emqx_conf, emqx_authn]),
     Config.
 
 init_per_testcase(_AllTestCase, Config) ->

+ 2 - 2
apps/emqx_gateway/test/emqx_lwm2m_api_SUITE.erl

@@ -83,13 +83,13 @@ all() ->
 init_per_suite(Config) ->
     ok = emqx_common_test_helpers:load_config(emqx_gateway_schema, ?CONF_DEFAULT),
     application:load(emqx_gateway),
-    emqx_mgmt_api_test_util:init_suite([emqx_conf]),
+    emqx_mgmt_api_test_util:init_suite([emqx_conf, emqx_authn]),
     Config.
 
 end_per_suite(Config) ->
     timer:sleep(300),
     {ok, _} = emqx_conf:remove([<<"gateway">>, <<"lwm2m">>], #{}),
-    emqx_mgmt_api_test_util:end_suite([emqx_conf]),
+    emqx_mgmt_api_test_util:end_suite([emqx_authn, emqx_conf]),
     Config.
 
 init_per_testcase(_AllTestCase, Config) ->

+ 2 - 2
apps/emqx_gateway/test/emqx_sn_protocol_SUITE.erl

@@ -98,12 +98,12 @@ all() ->
 
 init_per_suite(Config) ->
     ok = emqx_common_test_helpers:load_config(emqx_gateway_schema, ?CONF_DEFAULT),
-    emqx_mgmt_api_test_util:init_suite([emqx_conf, emqx_gateway]),
+    emqx_mgmt_api_test_util:init_suite([emqx_conf, emqx_authn, emqx_gateway]),
     Config.
 
 end_per_suite(_) ->
     {ok, _} = emqx:remove_config([gateway, mqttsn]),
-    emqx_mgmt_api_test_util:end_suite([emqx_gateway, emqx_conf]).
+    emqx_mgmt_api_test_util:end_suite([emqx_gateway, emqx_auhtn, emqx_conf]).
 
 restart_mqttsn_with_subs_resume_on() ->
     Conf = emqx:get_raw_config([gateway, mqttsn]),

+ 2 - 2
apps/emqx_gateway/test/emqx_stomp_SUITE.erl

@@ -54,11 +54,11 @@ all() -> emqx_common_test_helpers:all(?MODULE).
 
 init_per_suite(Cfg) ->
     ok = emqx_common_test_helpers:load_config(emqx_gateway_schema, ?CONF_DEFAULT),
-    emqx_mgmt_api_test_util:init_suite([emqx_gateway]),
+    emqx_mgmt_api_test_util:init_suite([emqx_authn, emqx_gateway]),
     Cfg.
 
 end_per_suite(_Cfg) ->
-    emqx_mgmt_api_test_util:end_suite([emqx_gateway]),
+    emqx_mgmt_api_test_util:end_suite([emqx_gateway, emqx_authn]),
     ok.
 
 default_config() ->