Просмотр исходного кода

refactor(auth_mnesia): Export transaction funs

ieQu1 3 лет назад
Родитель
Сommit
9449e3cb32

+ 67 - 65
apps/emqx_authn/src/enhanced_authn/emqx_enhanced_authn_scram_mnesia.erl

@@ -52,6 +52,14 @@
     group_match_spec/1
     group_match_spec/1
 ]).
 ]).
 
 
+%% Internal exports (RPC)
+-export([
+    do_destroy/1,
+    do_add_user/2,
+    do_delete_user/2,
+    do_update_user/3
+]).
+
 -define(TAB, ?MODULE).
 -define(TAB, ?MODULE).
 -define(AUTHN_QSCHEMA, [
 -define(AUTHN_QSCHEMA, [
     {<<"like_user_id">>, binary},
     {<<"like_user_id">>, binary},
@@ -170,83 +178,79 @@ authenticate(_Credential, _State) ->
     ignore.
     ignore.
 
 
 destroy(#{user_group := UserGroup}) ->
 destroy(#{user_group := UserGroup}) ->
+    trans(fun ?MODULE:do_destroy/1, [UserGroup]).
+
+do_destroy(UserGroup) ->
     MatchSpec = group_match_spec(UserGroup),
     MatchSpec = group_match_spec(UserGroup),
-    trans(
-        fun() ->
-            ok = lists:foreach(
-                fun(UserInfo) ->
-                    mnesia:delete_object(?TAB, UserInfo, write)
-                end,
-                mnesia:select(?TAB, MatchSpec, write)
-            )
-        end
+    ok = lists:foreach(
+        fun(UserInfo) ->
+            mnesia:delete_object(?TAB, UserInfo, write)
+        end,
+        mnesia:select(?TAB, MatchSpec, write)
     ).
     ).
 
 
-add_user(
+add_user(UserInfo, State) ->
+    trans(fun ?MODULE:do_add_user/2, [UserInfo, State]).
+
+do_add_user(
     #{
     #{
         user_id := UserID,
         user_id := UserID,
         password := Password
         password := Password
     } = UserInfo,
     } = UserInfo,
     #{user_group := UserGroup} = State
     #{user_group := UserGroup} = State
 ) ->
 ) ->
-    trans(
-        fun() ->
-            case mnesia:read(?TAB, {UserGroup, UserID}, write) of
-                [] ->
-                    IsSuperuser = maps:get(is_superuser, UserInfo, false),
-                    add_user(UserGroup, UserID, Password, IsSuperuser, State),
-                    {ok, #{user_id => UserID, is_superuser => IsSuperuser}};
-                [_] ->
-                    {error, already_exist}
-            end
-        end
-    ).
+    case mnesia:read(?TAB, {UserGroup, UserID}, write) of
+        [] ->
+            IsSuperuser = maps:get(is_superuser, UserInfo, false),
+            add_user(UserGroup, UserID, Password, IsSuperuser, State),
+            {ok, #{user_id => UserID, is_superuser => IsSuperuser}};
+        [_] ->
+            {error, already_exist}
+    end.
 
 
-delete_user(UserID, #{user_group := UserGroup}) ->
-    trans(
-        fun() ->
-            case mnesia:read(?TAB, {UserGroup, UserID}, write) of
-                [] ->
-                    {error, not_found};
-                [_] ->
-                    mnesia:delete(?TAB, {UserGroup, UserID}, write)
-            end
-        end
-    ).
+delete_user(UserID, State) ->
+    trans(fun ?MODULE:do_delete_user/2, [UserID, State]).
+
+do_delete_user(UserID, #{user_group := UserGroup}) ->
+    case mnesia:read(?TAB, {UserGroup, UserID}, write) of
+        [] ->
+            {error, not_found};
+        [_] ->
+            mnesia:delete(?TAB, {UserGroup, UserID}, write)
+    end.
+
+update_user(UserID, User, State) ->
+    trans(fun ?MODULE:do_update_user/3, [UserID, User, State]).
 
 
-update_user(
+do_update_user(
     UserID,
     UserID,
     User,
     User,
     #{user_group := UserGroup} = State
     #{user_group := UserGroup} = State
 ) ->
 ) ->
-    trans(
-        fun() ->
-            case mnesia:read(?TAB, {UserGroup, UserID}, write) of
-                [] ->
-                    {error, not_found};
-                [#user_info{is_superuser = IsSuperuser} = UserInfo] ->
-                    UserInfo1 = UserInfo#user_info{
-                        is_superuser = maps:get(is_superuser, User, IsSuperuser)
-                    },
-                    UserInfo2 =
-                        case maps:get(password, User, undefined) of
-                            undefined ->
-                                UserInfo1;
-                            Password ->
-                                {StoredKey, ServerKey, Salt} = esasl_scram:generate_authentication_info(
-                                    Password, State
-                                ),
-                                UserInfo1#user_info{
-                                    stored_key = StoredKey,
-                                    server_key = ServerKey,
-                                    salt = Salt
-                                }
-                        end,
-                    mnesia:write(?TAB, UserInfo2, write),
-                    {ok, format_user_info(UserInfo2)}
-            end
-        end
-    ).
+    case mnesia:read(?TAB, {UserGroup, UserID}, write) of
+        [] ->
+            {error, not_found};
+        [#user_info{is_superuser = IsSuperuser} = UserInfo] ->
+            UserInfo1 = UserInfo#user_info{
+                is_superuser = maps:get(is_superuser, User, IsSuperuser)
+            },
+            UserInfo2 =
+                case maps:get(password, User, undefined) of
+                    undefined ->
+                        UserInfo1;
+                    Password ->
+                        {StoredKey, ServerKey, Salt} = esasl_scram:generate_authentication_info(
+                            Password, State
+                        ),
+                        UserInfo1#user_info{
+                            stored_key = StoredKey,
+                            server_key = ServerKey,
+                            salt = Salt
+                        }
+                end,
+            mnesia:write(?TAB, UserInfo2, write),
+            {ok, format_user_info(UserInfo2)}
+    end.
 
 
 lookup_user(UserID, #{user_group := UserGroup}) ->
 lookup_user(UserID, #{user_group := UserGroup}) ->
     case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of
     case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of
@@ -386,12 +390,10 @@ retrieve(UserID, #{user_group := UserGroup}) ->
     end.
     end.
 
 
 %% TODO: Move to emqx_authn_utils.erl
 %% TODO: Move to emqx_authn_utils.erl
-trans(Fun) ->
-    trans(Fun, []).
-
 trans(Fun, Args) ->
 trans(Fun, Args) ->
     case mria:transaction(?AUTH_SHARD, Fun, Args) of
     case mria:transaction(?AUTH_SHARD, Fun, Args) of
         {atomic, Res} -> Res;
         {atomic, Res} -> Res;
+        {aborted, {function_clause, Stack}} -> erlang:raise(error, function_clause, Stack);
         {aborted, Reason} -> {error, Reason}
         {aborted, Reason} -> {error, Reason}
     end.
     end.
 
 

+ 70 - 67
apps/emqx_authn/src/simple_authn/emqx_authn_mnesia.erl

@@ -54,6 +54,16 @@
     group_match_spec/1
     group_match_spec/1
 ]).
 ]).
 
 
+%% Internal exports (RPC)
+-export([
+    do_destroy/1,
+    do_add_user/2,
+    do_delete_user/2,
+    do_update_user/3,
+    import/2,
+    import_csv/3
+]).
+
 -type user_group() :: binary().
 -type user_group() :: binary().
 -type user_id() :: binary().
 -type user_id() :: binary().
 
 
@@ -175,15 +185,14 @@ authenticate(
     end.
     end.
 
 
 destroy(#{user_group := UserGroup}) ->
 destroy(#{user_group := UserGroup}) ->
-    trans(
-        fun() ->
-            ok = lists:foreach(
-                fun(User) ->
-                    mnesia:delete_object(?TAB, User, write)
-                end,
-                mnesia:select(?TAB, group_match_spec(UserGroup), write)
-            )
-        end
+    trans(fun ?MODULE:do_destroy/1, [UserGroup]).
+
+do_destroy(UserGroup) ->
+    ok = lists:foreach(
+        fun(User) ->
+            mnesia:delete_object(?TAB, User, write)
+        end,
+        mnesia:select(?TAB, group_match_spec(UserGroup), write)
     ).
     ).
 
 
 import_users({Filename0, FileData}, State) ->
 import_users({Filename0, FileData}, State) ->
@@ -200,7 +209,10 @@ import_users({Filename0, FileData}, State) ->
             {error, {unsupported_file_format, Extension}}
             {error, {unsupported_file_format, Extension}}
     end.
     end.
 
 
-add_user(
+add_user(UserInfo, State) ->
+    trans(fun ?MODULE:do_add_user/2, [UserInfo, State]).
+
+do_add_user(
     #{
     #{
         user_id := UserID,
         user_id := UserID,
         password := Password
         password := Password
@@ -210,33 +222,31 @@ add_user(
         password_hash_algorithm := Algorithm
         password_hash_algorithm := Algorithm
     }
     }
 ) ->
 ) ->
-    trans(
-        fun() ->
-            case mnesia:read(?TAB, {UserGroup, UserID}, write) of
-                [] ->
-                    {PasswordHash, Salt} = emqx_authn_password_hashing:hash(Algorithm, Password),
-                    IsSuperuser = maps:get(is_superuser, UserInfo, false),
-                    insert_user(UserGroup, UserID, PasswordHash, Salt, IsSuperuser),
-                    {ok, #{user_id => UserID, is_superuser => IsSuperuser}};
-                [_] ->
-                    {error, already_exist}
-            end
-        end
-    ).
+    case mnesia:read(?TAB, {UserGroup, UserID}, write) of
+        [] ->
+            {PasswordHash, Salt} = emqx_authn_password_hashing:hash(Algorithm, Password),
+            IsSuperuser = maps:get(is_superuser, UserInfo, false),
+            insert_user(UserGroup, UserID, PasswordHash, Salt, IsSuperuser),
+            {ok, #{user_id => UserID, is_superuser => IsSuperuser}};
+        [_] ->
+            {error, already_exist}
+    end.
 
 
-delete_user(UserID, #{user_group := UserGroup}) ->
-    trans(
-        fun() ->
-            case mnesia:read(?TAB, {UserGroup, UserID}, write) of
-                [] ->
-                    {error, not_found};
-                [_] ->
-                    mnesia:delete(?TAB, {UserGroup, UserID}, write)
-            end
-        end
-    ).
+delete_user(UserID, State) ->
+    trans(fun ?MODULE:do_delete_user/2, [UserID, State]).
+
+do_delete_user(UserID, #{user_group := UserGroup}) ->
+    case mnesia:read(?TAB, {UserGroup, UserID}, write) of
+        [] ->
+            {error, not_found};
+        [_] ->
+            mnesia:delete(?TAB, {UserGroup, UserID}, write)
+    end.
+
+update_user(UserID, UserInfo, State) ->
+    trans(fun ?MODULE:do_update_user/3, [UserID, UserInfo, State]).
 
 
-update_user(
+do_update_user(
     UserID,
     UserID,
     UserInfo,
     UserInfo,
     #{
     #{
@@ -244,33 +254,29 @@ update_user(
         password_hash_algorithm := Algorithm
         password_hash_algorithm := Algorithm
     }
     }
 ) ->
 ) ->
-    trans(
-        fun() ->
-            case mnesia:read(?TAB, {UserGroup, UserID}, write) of
-                [] ->
-                    {error, not_found};
-                [
-                    #user_info{
-                        password_hash = PasswordHash,
-                        salt = Salt,
-                        is_superuser = IsSuperuser
-                    }
-                ] ->
-                    NSuperuser = maps:get(is_superuser, UserInfo, IsSuperuser),
-                    {NPasswordHash, NSalt} =
-                        case UserInfo of
-                            #{password := Password} ->
-                                emqx_authn_password_hashing:hash(
-                                    Algorithm, Password
-                                );
-                            #{} ->
-                                {PasswordHash, Salt}
-                        end,
-                    insert_user(UserGroup, UserID, NPasswordHash, NSalt, NSuperuser),
-                    {ok, #{user_id => UserID, is_superuser => NSuperuser}}
-            end
-        end
-    ).
+    case mnesia:read(?TAB, {UserGroup, UserID}, write) of
+        [] ->
+            {error, not_found};
+        [
+            #user_info{
+                password_hash = PasswordHash,
+                salt = Salt,
+                is_superuser = IsSuperuser
+            }
+        ] ->
+            NSuperuser = maps:get(is_superuser, UserInfo, IsSuperuser),
+            {NPasswordHash, NSalt} =
+                case UserInfo of
+                    #{password := Password} ->
+                        emqx_authn_password_hashing:hash(
+                            Algorithm, Password
+                        );
+                    #{} ->
+                        {PasswordHash, Salt}
+                end,
+            insert_user(UserGroup, UserID, NPasswordHash, NSalt, NSuperuser),
+            {ok, #{user_id => UserID, is_superuser => NSuperuser}}
+    end.
 
 
 lookup_user(UserID, #{user_group := UserGroup}) ->
 lookup_user(UserID, #{user_group := UserGroup}) ->
     case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of
     case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of
@@ -335,7 +341,7 @@ run_fuzzy_filter(
 import_users_from_json(Bin, #{user_group := UserGroup}) ->
 import_users_from_json(Bin, #{user_group := UserGroup}) ->
     case emqx_json:safe_decode(Bin, [return_maps]) of
     case emqx_json:safe_decode(Bin, [return_maps]) of
         {ok, List} ->
         {ok, List} ->
-            trans(fun import/2, [UserGroup, List]);
+            trans(fun ?MODULE:import/2, [UserGroup, List]);
         {error, Reason} ->
         {error, Reason} ->
             {error, Reason}
             {error, Reason}
     end.
     end.
@@ -344,7 +350,7 @@ import_users_from_json(Bin, #{user_group := UserGroup}) ->
 import_users_from_csv(CSV, #{user_group := UserGroup}) ->
 import_users_from_csv(CSV, #{user_group := UserGroup}) ->
     case get_csv_header(CSV) of
     case get_csv_header(CSV) of
         {ok, Seq, NewCSV} ->
         {ok, Seq, NewCSV} ->
-            trans(fun import_csv/3, [UserGroup, NewCSV, Seq]);
+            trans(fun ?MODULE:import_csv/3, [UserGroup, NewCSV, Seq]);
         {error, Reason} ->
         {error, Reason} ->
             {error, Reason}
             {error, Reason}
     end.
     end.
@@ -435,9 +441,6 @@ get_user_identity(#{clientid := ClientID}, clientid) ->
 get_user_identity(_, Type) ->
 get_user_identity(_, Type) ->
     {error, {bad_user_identity_type, Type}}.
     {error, {bad_user_identity_type, Type}}.
 
 
-trans(Fun) ->
-    trans(Fun, []).
-
 trans(Fun, Args) ->
 trans(Fun, Args) ->
     case mria:transaction(?AUTH_SHARD, Fun, Args) of
     case mria:transaction(?AUTH_SHARD, Fun, Args) of
         {atomic, Res} -> Res;
         {atomic, Res} -> Res;