emqx_message_validation_tests.erl 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. %%--------------------------------------------------------------------
  2. %% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
  3. %%--------------------------------------------------------------------
  4. -module(emqx_message_validation_tests).
  5. -include_lib("eunit/include/eunit.hrl").
  6. -define(VALIDATIONS_PATH, "message_validation.validations").
  7. %%------------------------------------------------------------------------------
  8. %% Helper fns
  9. %%------------------------------------------------------------------------------
  10. parse_and_check(InnerConfigs) ->
  11. RootBin = <<"message_validation">>,
  12. InnerBin = <<"validations">>,
  13. RawConf = #{RootBin => #{InnerBin => InnerConfigs}},
  14. #{RootBin := #{InnerBin := Checked}} = hocon_tconf:check_plain(
  15. emqx_message_validation_schema,
  16. RawConf,
  17. #{
  18. required => false,
  19. atom_key => false,
  20. make_serializable => false
  21. }
  22. ),
  23. Checked.
  24. validation(Name, Checks) ->
  25. validation(Name, Checks, _Overrides = #{}).
  26. validation(Name, Checks, Overrides) ->
  27. Default = #{
  28. <<"tags">> => [<<"some">>, <<"tags">>],
  29. <<"description">> => <<"my validation">>,
  30. <<"enable">> => true,
  31. <<"name">> => Name,
  32. <<"topics">> => <<"t/+">>,
  33. <<"strategy">> => <<"all_pass">>,
  34. <<"failure_action">> => <<"drop">>,
  35. <<"log_failure">> => #{<<"level">> => <<"warning">>},
  36. <<"checks">> => Checks
  37. },
  38. emqx_utils_maps:deep_merge(Default, Overrides).
  39. sql_check() ->
  40. sql_check(<<"select * where true">>).
  41. sql_check(SQL) ->
  42. #{
  43. <<"type">> => <<"sql">>,
  44. <<"sql">> => SQL
  45. }.
  46. eval_sql(Message, SQL) ->
  47. {ok, Check} = emqx_message_validation:parse_sql_check(SQL),
  48. Validation = #{log_failure => #{level => warning}, name => <<"validation">>},
  49. emqx_message_validation:evaluate_sql_check(Check, Validation, Message).
  50. message() ->
  51. message(_Opts = #{}).
  52. message(Opts) ->
  53. Defaults = #{
  54. id => emqx_guid:gen(),
  55. qos => 0,
  56. from => emqx_guid:to_hexstr(emqx_guid:gen()),
  57. flags => #{retain => false},
  58. headers => #{
  59. proto_ver => v5,
  60. properties => #{'User-Property' => [{<<"a">>, <<"b">>}]}
  61. },
  62. topic => <<"t/t">>,
  63. payload => emqx_utils_json:encode(#{value => 10}),
  64. timestamp => 1710272561615,
  65. extra => []
  66. },
  67. emqx_message:from_map(emqx_utils_maps:deep_merge(Defaults, Opts)).
  68. %%------------------------------------------------------------------------------
  69. %% Test cases
  70. %%------------------------------------------------------------------------------
  71. schema_test_() ->
  72. [
  73. {"topics is always a list 1",
  74. ?_assertMatch(
  75. [#{<<"topics">> := [<<"t/1">>]}],
  76. parse_and_check([
  77. validation(
  78. <<"foo">>,
  79. [sql_check()],
  80. #{<<"topics">> => <<"t/1">>}
  81. )
  82. ])
  83. )},
  84. {"topics is always a list 2",
  85. ?_assertMatch(
  86. [#{<<"topics">> := [<<"t/1">>]}],
  87. parse_and_check([
  88. validation(
  89. <<"foo">>,
  90. [sql_check()],
  91. #{<<"topics">> => [<<"t/1">>]}
  92. )
  93. ])
  94. )},
  95. {"foreach expression is not allowed",
  96. ?_assertThrow(
  97. {_Schema, [
  98. #{
  99. reason := foreach_not_allowed,
  100. kind := validation_error
  101. }
  102. ]},
  103. parse_and_check([
  104. validation(
  105. <<"foo">>,
  106. [sql_check(<<"foreach foo as f where true">>)]
  107. )
  108. ])
  109. )},
  110. {"from clause is not allowed",
  111. ?_assertThrow(
  112. {_Schema, [
  113. #{
  114. reason := non_empty_from_clause,
  115. kind := validation_error
  116. }
  117. ]},
  118. parse_and_check([
  119. validation(
  120. <<"foo">>,
  121. [sql_check(<<"select * from t">>)]
  122. )
  123. ])
  124. )},
  125. {"names are unique",
  126. ?_assertThrow(
  127. {_Schema, [
  128. #{
  129. reason := <<"duplicated name:", _/binary>>,
  130. path := ?VALIDATIONS_PATH,
  131. kind := validation_error
  132. }
  133. ]},
  134. parse_and_check([
  135. validation(<<"foo">>, [sql_check()]),
  136. validation(<<"foo">>, [sql_check()])
  137. ])
  138. )},
  139. {"checks must be non-empty",
  140. ?_assertThrow(
  141. {_Schema, [
  142. #{
  143. reason := "at least one check must be defined",
  144. kind := validation_error
  145. }
  146. ]},
  147. parse_and_check([
  148. validation(
  149. <<"foo">>,
  150. []
  151. )
  152. ])
  153. )},
  154. {"bogus check type",
  155. ?_assertThrow(
  156. {_Schema, [
  157. #{
  158. expected := <<"sql", _/binary>>,
  159. kind := validation_error,
  160. field_name := type
  161. }
  162. ]},
  163. parse_and_check([validation(<<"foo">>, [#{<<"type">> => <<"foo">>}])])
  164. )}
  165. ].
  166. invalid_names_test_() ->
  167. [
  168. {InvalidName,
  169. ?_assertThrow(
  170. {_Schema, [
  171. #{
  172. reason := <<"must conform to regex:", _/binary>>,
  173. kind := validation_error,
  174. path := "message_validation.validations.1.name"
  175. }
  176. ]},
  177. parse_and_check([validation(InvalidName, [sql_check()])])
  178. )}
  179. || InvalidName <- [
  180. <<"">>,
  181. <<"_name">>,
  182. <<"name$">>,
  183. <<"name!">>,
  184. <<"some name">>,
  185. <<"nãme"/utf8>>,
  186. <<"test_哈哈"/utf8>>
  187. ]
  188. ].
  189. check_test_() ->
  190. [
  191. {"denied by payload 1",
  192. ?_assertNot(eval_sql(message(), <<"select * where payload.value > 15">>))},
  193. {"denied by payload 2",
  194. ?_assertNot(eval_sql(message(), <<"select payload.value as x where x > 15">>))},
  195. {"allowed by payload 1",
  196. ?_assert(eval_sql(message(), <<"select * where payload.value > 5">>))},
  197. {"allowed by payload 2",
  198. ?_assert(eval_sql(message(), <<"select payload.value as x where x > 5">>))},
  199. {"always passes 1", ?_assert(eval_sql(message(), <<"select * where true">>))},
  200. {"always passes 2", ?_assert(eval_sql(message(), <<"select * where 1 = 1">>))},
  201. {"never passes 1", ?_assertNot(eval_sql(message(), <<"select * where false">>))},
  202. {"never passes 2", ?_assertNot(eval_sql(message(), <<"select * where 1 = 2">>))},
  203. {"never passes 3", ?_assertNot(eval_sql(message(), <<"select * where true and false">>))}
  204. ].