Просмотр исходного кода

fix(variform): add basic tests

zmstone 1 год назад
Родитель
Сommit
bf12efac6d

+ 38 - 11
apps/emqx_utils/src/emqx_variform.erl

@@ -1,5 +1,5 @@
 %%--------------------------------------------------------------------
-%% Copyright (c) 2020-2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
 %%
 %% Licensed under the Apache License, Version 2.0 (the "License");
 %% you may not use this file except in compliance with the License.
@@ -22,7 +22,12 @@
 %% or used to choose the first non-empty value from a list of variables.
 -module(emqx_variform).
 
--export([inject_allowed_modules/1]).
+-export([
+    inject_allowed_module/1,
+    inject_allowed_modules/1,
+    erase_allowed_module/1,
+    erase_allowed_modules/1
+]).
 -export([render/2, render/3]).
 
 %% @doc Render a variform expression with bindings.
@@ -48,6 +53,8 @@
 render(Expression, Bindings) ->
     render(Expression, Bindings, #{}).
 
+render(Expression, Bindings, Opts) when is_binary(Expression) ->
+    render(unicode:characters_to_list(Expression), Bindings, Opts);
 render(Expression, Bindings, Opts) ->
     case emqx_variform_scan:string(Expression) of
         {ok, Tokens, _Line} ->
@@ -66,7 +73,7 @@ render(Expression, Bindings, Opts) ->
 
 eval_as_string(Expr, Bindings, _Opts) ->
     try
-        {ok, iolist_to_binary(eval(Expr, Bindings))}
+        {ok, str(eval(Expr, Bindings))}
     catch
         throw:Reason ->
             {error, Reason};
@@ -97,7 +104,7 @@ call(emqx_variform_str, concat, Args) ->
 call(emqx_variform_str, coalesce, Args) ->
     str(emqx_variform_str:coalesce(Args));
 call(Mod, Fun, Args) ->
-    str(erlang:apply(Mod, Fun, Args)).
+    erlang:apply(Mod, Fun, Args).
 
 resolve_func_name(FuncNameStr) ->
     case string:tokens(FuncNameStr, ".") of
@@ -107,7 +114,10 @@ resolve_func_name(FuncNameStr) ->
                     list_to_existing_atom(Mod0)
                 catch
                     error:badarg ->
-                        throw(#{unknown_module => Mod0})
+                        throw(#{
+                            reason => unknown_variform_module,
+                            module => Mod0
+                        })
                 end,
             ok = assert_module_allowed(Mod),
             Fun =
@@ -115,7 +125,10 @@ resolve_func_name(FuncNameStr) ->
                     list_to_existing_atom(Fun0)
                 catch
                     error:badarg ->
-                        throw(#{unknown_function => Fun0})
+                        throw(#{
+                            reason => unknown_variform_function,
+                            function => Fun0
+                        })
                 end,
             {Mod, Fun};
         [Fun] ->
@@ -125,11 +138,13 @@ resolve_func_name(FuncNameStr) ->
                 catch
                     error:badarg ->
                         throw(#{
-                            reason => "unknown_variform_function",
+                            reason => unknown_variform_function,
                             function => Fun
                         })
                 end,
-            {emqx_variform_str, FuncName}
+            {emqx_variform_str, FuncName};
+        _ ->
+            throw(#{reason => invalid_function_reference, function => FuncNameStr})
     end.
 
 resolve_var_value(VarName, Bindings) ->
@@ -145,13 +160,14 @@ assert_func_exported(emqx_variform_str, concat, _Arity) ->
 assert_func_exported(emqx_variform_str, coalesce, _Arity) ->
     ok;
 assert_func_exported(Mod, Fun, Arity) ->
+    %% ensure beam loaded
     _ = Mod:module_info(md5),
     case erlang:function_exported(Mod, Fun, Arity) of
         true ->
             ok;
         false ->
             throw(#{
-                reason => "unknown_variform_function",
+                reason => unknown_variform_function,
                 module => Mod,
                 function => Fun,
                 arity => Arity
@@ -167,16 +183,27 @@ assert_module_allowed(Mod) ->
             ok;
         false ->
             throw(#{
-                reason => "unallowed_veriform_module",
+                reason => unallowed_veriform_module,
                 module => Mod
             })
     end.
 
-inject_allowed_modules(Modules) ->
+inject_allowed_module(Module) when is_atom(Module) ->
+    inject_allowed_modules([Module]).
+
+inject_allowed_modules(Modules) when is_list(Modules) ->
     Allowed0 = get_allowed_modules(),
     Allowed = lists:usort(Allowed0 ++ Modules),
     persistent_term:put({emqx_variform, allowed_modules}, Allowed).
 
+erase_allowed_module(Module) when is_atom(Module) ->
+    erase_allowed_modules([Module]).
+
+erase_allowed_modules(Modules) when is_list(Modules) ->
+    Allowed0 = get_allowed_modules(),
+    Allowed = Allowed0 -- Modules,
+    persistent_term:put({emqx_variform, allowed_modules}, Allowed).
+
 get_allowed_modules() ->
     persistent_term:get({emqx_variform, allowed_modules}, []).
 

+ 16 - 1
apps/emqx_utils/src/emqx_variform_str.erl

@@ -52,7 +52,8 @@
     find/3,
     join_to_string/1,
     join_to_string/2,
-    unescape/1
+    unescape/1,
+    nth/2
 ]).
 
 -define(IS_EMPTY(X), (X =:= <<>> orelse X =:= "" orelse X =:= undefined)).
@@ -224,6 +225,20 @@ unescape(Bin) when is_binary(Bin) ->
             throw({invalid_unicode_character, Error})
     end.
 
+nth(N, List) when (is_list(N) orelse is_binary(N)) andalso is_list(List) ->
+    try binary_to_integer(iolist_to_binary(N)) of
+        N1 ->
+            nth(N1, List)
+    catch
+        _:_ ->
+            throw(#{reason => invalid_argument, func => nth, index => N})
+    end;
+nth(N, List) when is_integer(N) andalso is_list(List) ->
+    case length(List) of
+        L when L < N -> <<>>;
+        _ -> lists:nth(N, List)
+    end.
+
 unescape_string(Input) -> unescape_string(Input, []).
 
 unescape_string([], Acc) ->

+ 129 - 0
apps/emqx_utils/test/emqx_variform_tests.erl

@@ -0,0 +1,129 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_variform_tests).
+
+-compile(export_all).
+-compile(nowarn_export_all).
+
+-include_lib("eunit/include/eunit.hrl").
+
+-define(SYNTAX_ERROR, {error, "syntax error before:" ++ _}).
+
+redner_test_() ->
+    [
+        {"direct var reference", fun() -> ?assertEqual({ok, <<"1">>}, render("a", #{a => 1})) end},
+        {"concat strings", fun() ->
+            ?assertEqual({ok, <<"a,b">>}, render("concat('a',',','b')", #{}))
+        end},
+        {"concat empty string", fun() -> ?assertEqual({ok, <<"">>}, render("concat('')", #{})) end},
+        {"tokens 1st", fun() ->
+            ?assertEqual({ok, <<"a">>}, render("nth(1,tokens(var, ','))", #{var => <<"a,b">>}))
+        end},
+        {"unknown var as empty str", fun() ->
+            ?assertEqual({ok, <<>>}, render("var", #{}))
+        end},
+        {"out of range nth index", fun() ->
+            ?assertEqual({ok, <<>>}, render("nth(2, tokens(var, ','))", #{var => <<"a">>}))
+        end},
+        {"not a index number for nth", fun() ->
+            ?assertMatch(
+                {error, #{reason := invalid_argument, func := nth, index := <<"notnum">>}},
+                render("nth('notnum', tokens(var, ','))", #{var => <<"a">>})
+            )
+        end}
+    ].
+
+unknown_func_test_() ->
+    [
+        {"unknown function", fun() ->
+            ?assertMatch(
+                {error, #{reason := unknown_variform_function}},
+                render("nonexistingatom__(a)", #{})
+            )
+        end},
+        {"unknown module", fun() ->
+            ?assertMatch(
+                {error, #{reason := unknown_variform_module}},
+                render("nonexistingatom__.nonexistingatom__(a)", #{})
+            )
+        end},
+        {"unknown function in a known module", fun() ->
+            ?assertMatch(
+                {error, #{reason := unknown_variform_function}},
+                render("emqx_variform_str.nonexistingatom__(a)", #{})
+            )
+        end},
+        {"invalid func reference", fun() ->
+            ?assertMatch(
+                {error, #{reason := invalid_function_reference, function := "a.b.c"}},
+                render("a.b.c(var)", #{})
+            )
+        end}
+    ].
+
+concat(L) -> iolist_to_binary(L).
+
+inject_allowed_module_test() ->
+    try
+        emqx_variform:inject_allowed_module(?MODULE),
+        ?assertEqual({ok, <<"ab">>}, render(atom_to_list(?MODULE) ++ ".concat(['a','b'])", #{})),
+        ?assertMatch(
+            {error, #{
+                reason := unknown_variform_function,
+                module := ?MODULE,
+                function := concat,
+                arity := 2
+            }},
+            render(atom_to_list(?MODULE) ++ ".concat('a','b')", #{})
+        ),
+        ?assertMatch(
+            {error, #{reason := unallowed_veriform_module, module := emqx}},
+            render("emqx.concat('a','b')", #{})
+        )
+    after
+        emqx_variform:erase_allowed_module(?MODULE)
+    end.
+
+coalesce_test_() ->
+    [
+        {"coalesce first", fun() ->
+            ?assertEqual({ok, <<"a">>}, render("coalesce('a','b')", #{}))
+        end},
+        {"coalesce second", fun() ->
+            ?assertEqual({ok, <<"b">>}, render("coalesce('', 'b')", #{}))
+        end},
+        {"coalesce first var", fun() ->
+            ?assertEqual({ok, <<"a">>}, render("coalesce(a,b)", #{a => <<"a">>, b => <<"b">>}))
+        end},
+        {"coalesce second var", fun() ->
+            ?assertEqual({ok, <<"b">>}, render("coalesce(a,b)", #{b => <<"b">>}))
+        end},
+        {"coalesce empty", fun() -> ?assertEqual({ok, <<>>}, render("coalesce(a,b)", #{})) end}
+    ].
+
+syntax_error_test_() ->
+    [
+        {"empty expression", fun() -> ?assertMatch(?SYNTAX_ERROR, render("", #{})) end},
+        {"const string single quote", fun() -> ?assertMatch(?SYNTAX_ERROR, render("'a'", #{})) end},
+        {"const string double quote", fun() ->
+            ?assertMatch(?SYNTAX_ERROR, render(<<"\"a\"">>, #{}))
+        end},
+        {"no arity", fun() -> ?assertMatch(?SYNTAX_ERROR, render("concat()", #{})) end}
+    ].
+
+render(Expression, Bindings) ->
+    emqx_variform:render(Expression, Bindings).