Просмотр исходного кода

fix(mysql_bridge): forbid update statements with batch operations

Fixes https://emqx.atlassian.net/browse/EMQX-11605
Thales Macedo Garitezi 2 лет назад
Родитель
Сommit
2c61b2bfbb

+ 56 - 9
apps/emqx_bridge_mysql/src/emqx_bridge_mysql_connector.erl

@@ -35,13 +35,18 @@ on_add_channel(
 ) ->
     ChannelConfig1 = emqx_utils_maps:unindent(parameters, ChannelConfig0),
     QueryTemplates = emqx_mysql:parse_prepare_sql(ChannelId, ChannelConfig1),
-    ChannelConfig2 = maps:merge(ChannelConfig1, QueryTemplates),
-    ChannelConfig = set_prepares(ChannelConfig2, ConnectorState),
-    State = State0#{
-        channels => maps:put(ChannelId, ChannelConfig, Channels),
-        connector_state => ConnectorState
-    },
-    {ok, State}.
+    case validate_sql_type(ChannelId, ChannelConfig1, QueryTemplates) of
+        ok ->
+            ChannelConfig2 = maps:merge(ChannelConfig1, QueryTemplates),
+            ChannelConfig = set_prepares(ChannelConfig2, ConnectorState),
+            State = State0#{
+                channels => maps:put(ChannelId, ChannelConfig, Channels),
+                connector_state => ConnectorState
+            },
+            {ok, State};
+        {error, Error} ->
+            {error, Error}
+    end.
 
 on_get_channel_status(_InstanceId, ChannelId, #{channels := Channels}) ->
     case maps:get(ChannelId, Channels) of
@@ -116,11 +121,13 @@ on_batch_query(InstanceId, BatchRequest, _State = #{connector_state := Connector
 
 on_remove_channel(
     _InstanceId, #{channels := Channels, connector_state := ConnectorState} = State, ChannelId
-) ->
+) when is_map_key(ChannelId, Channels) ->
     ChannelConfig = maps:get(ChannelId, Channels),
     emqx_mysql:unprepare_sql(maps:merge(ChannelConfig, ConnectorState)),
     NewState = State#{channels => maps:remove(ChannelId, Channels)},
-    {ok, NewState}.
+    {ok, NewState};
+on_remove_channel(_InstanceId, State, _ChannelId) ->
+    {ok, State}.
 
 -spec on_start(binary(), hocon:config()) ->
     {ok, #{connector_state := emqx_mysql:state(), channels := map()}} | {error, _}.
@@ -148,3 +155,43 @@ set_prepares(ChannelConfig, ConnectorState) ->
     #{prepares := Prepares} =
         emqx_mysql:init_prepare(maps:merge(ConnectorState, ChannelConfig)),
     ChannelConfig#{prepares => Prepares}.
+
+validate_sql_type(ChannelId, ChannelConfig, #{query_templates := QueryTemplates}) ->
+    Batch =
+        case emqx_utils_maps:deep_get([resource_opts, batch_size], ChannelConfig) of
+            N when N > 1 -> batch;
+            _ -> single
+        end,
+    BatchKey = {ChannelId, batch},
+    SingleKey = {ChannelId, prepstmt},
+    case {QueryTemplates, Batch} of
+        {#{BatchKey := _}, batch} ->
+            ok;
+        {#{SingleKey := _}, single} ->
+            ok;
+        {_, batch} ->
+            %% try to provide helpful info
+            SQL = maps:get(sql, ChannelConfig),
+            Type = emqx_utils_sql:get_statement_type(SQL),
+            ErrorContext0 = #{
+                reason => failed_to_prepare_statement,
+                statement_type => Type,
+                operation_type => Batch
+            },
+            ErrorContext = emqx_utils_maps:put_if(
+                ErrorContext0,
+                hint,
+                <<"UPDATE statements are not supported for batch operations">>,
+                Type =:= update
+            ),
+            {error, ErrorContext};
+        _ ->
+            SQL = maps:get(sql, ChannelConfig),
+            Type = emqx_utils_sql:get_statement_type(SQL),
+            ErrorContext = #{
+                reason => failed_to_prepare_statement,
+                statement_type => Type,
+                operation_type => Batch
+            },
+            {error, ErrorContext}
+    end.

+ 104 - 2
apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl

@@ -31,6 +31,8 @@
 
 -define(WORKER_POOL_SIZE, 4).
 
+-define(ACTION_TYPE, mysql).
+
 -import(emqx_common_test_helpers, [on_exit/1]).
 
 %%------------------------------------------------------------------------------
@@ -45,7 +47,14 @@ all() ->
 
 groups() ->
     TCs = emqx_common_test_helpers:all(?MODULE),
-    NonBatchCases = [t_write_timeout, t_uninitialized_prepared_statement],
+    NonBatchCases = [
+        t_write_timeout,
+        t_uninitialized_prepared_statement,
+        t_non_batch_update_is_allowed
+    ],
+    OnlyBatchCases = [
+        t_batch_update_is_forbidden
+    ],
     BatchingGroups = [
         {group, with_batch},
         {group, without_batch}
@@ -57,7 +66,7 @@ groups() ->
         {async, BatchingGroups},
         {sync, BatchingGroups},
         {with_batch, TCs -- NonBatchCases},
-        {without_batch, TCs}
+        {without_batch, TCs -- OnlyBatchCases}
     ].
 
 init_per_group(tcp, Config) ->
@@ -103,6 +112,8 @@ end_per_group(_Group, _Config) ->
     ok.
 
 init_per_suite(Config) ->
+    emqx_common_test_helpers:clear_screen(),
+
     Config.
 
 end_per_suite(_Config) ->
@@ -151,6 +162,9 @@ common_init(Config0) ->
                     {mysql_config, MysqlConfig},
                     {mysql_bridge_type, BridgeType},
                     {mysql_name, Name},
+                    {bridge_type, BridgeType},
+                    {bridge_name, Name},
+                    {bridge_config, MysqlConfig},
                     {proxy_host, ProxyHost},
                     {proxy_port, ProxyPort}
                     | Config0
@@ -874,3 +888,91 @@ t_nested_payload_template(Config) ->
         connect_and_get_payload(Config)
     ),
     ok.
+
+t_batch_update_is_forbidden(Config) ->
+    ?check_trace(
+        begin
+            Overrides = #{
+                <<"sql">> =>
+                    <<
+                        "UPDATE mqtt_test "
+                        "SET arrived = FROM_UNIXTIME(${timestamp}/1000) "
+                        "WHERE payload = ${payload.value}"
+                    >>
+            },
+            ProbeRes = emqx_bridge_testlib:probe_bridge_api(Config, Overrides),
+            ?assertMatch({error, {{_, 400, _}, _, _Body}}, ProbeRes),
+            {error, {{_, 400, _}, _, ProbeBodyRaw}} = ProbeRes,
+            ?assertEqual(
+                match,
+                re:run(
+                    ProbeBodyRaw,
+                    <<"UPDATE statements are not supported for batch operations">>,
+                    [global, {capture, none}]
+                )
+            ),
+            CreateRes = emqx_bridge_testlib:create_bridge_api(Config, Overrides),
+            ?assertMatch(
+                {ok, {{_, 201, _}, _, #{<<"status">> := <<"disconnected">>}}},
+                CreateRes
+            ),
+            {ok, {{_, 201, _}, _, #{<<"status_reason">> := Reason}}} = CreateRes,
+            ?assertEqual(
+                match,
+                re:run(
+                    Reason,
+                    <<"UPDATE statements are not supported for batch operations">>,
+                    [global, {capture, none}]
+                )
+            ),
+            ok
+        end,
+        []
+    ),
+    ok.
+
+t_non_batch_update_is_allowed(Config) ->
+    ?check_trace(
+        begin
+            BridgeName = ?config(bridge_name, Config),
+            Overrides = #{
+                <<"resource_opts">> => #{<<"metrics_flush_interval">> => <<"500ms">>},
+                <<"sql">> =>
+                    <<
+                        "UPDATE mqtt_test "
+                        "SET arrived = FROM_UNIXTIME(${timestamp}/1000) "
+                        "WHERE payload = ${payload.value}"
+                    >>
+            },
+            ProbeRes = emqx_bridge_testlib:probe_bridge_api(Config, Overrides),
+            ?assertMatch({ok, {{_, 204, _}, _, _Body}}, ProbeRes),
+            ?assertMatch(
+                {ok, {{_, 201, _}, _, #{<<"status">> := <<"connected">>}}},
+                emqx_bridge_testlib:create_bridge_api(Config, Overrides)
+            ),
+            {ok, #{
+                <<"id">> := RuleId,
+                <<"from">> := [Topic]
+            }} = create_rule_and_action_http(Config),
+            Payload = emqx_utils_json:encode(#{value => <<"aaaa">>}),
+            Message = emqx_message:make(Topic, Payload),
+            {_, {ok, _}} =
+                ?wait_async_action(
+                    emqx:publish(Message),
+                    #{?snk_kind := mysql_connector_query_return},
+                    10_000
+                ),
+            ActionId = emqx_bridge_v2:id(?ACTION_TYPE, BridgeName),
+            ?assertEqual(1, emqx_resource_metrics:matched_get(ActionId)),
+            ?retry(
+                _Sleep0 = 200,
+                _Attempts0 = 10,
+                ?assertEqual(1, emqx_resource_metrics:success_get(ActionId))
+            ),
+
+            ?assertEqual(1, emqx_metrics_worker:get(rule_metrics, RuleId, 'actions.success')),
+            ok
+        end,
+        []
+    ),
+    ok.

+ 2 - 2
apps/emqx_mysql/src/emqx_mysql.erl

@@ -436,11 +436,11 @@ parse_batch_sql(Key, Query, Acc) ->
             end;
         select ->
             Acc;
-        Otherwise ->
+        Type ->
             ?SLOG(error, #{
                 msg => "invalid sql statement type",
                 sql => Query,
-                type => Otherwise
+                type => Type
             }),
             Acc
     end.

+ 2 - 1
apps/emqx_utils/src/emqx_utils_sql.erl

@@ -28,7 +28,7 @@
 
 -export_type([value/0]).
 
--type statement_type() :: select | insert | delete.
+-type statement_type() :: select | insert | delete | update.
 -type value() :: null | binary() | number() | boolean() | [value()].
 
 -dialyzer({no_improper_lists, [escape_mysql/4, escape_prepend/4]}).
@@ -38,6 +38,7 @@ get_statement_type(Query) ->
     KnownTypes = #{
         <<"select">> => select,
         <<"insert">> => insert,
+        <<"update">> => update,
         <<"delete">> => delete
     },
     case re:run(Query, <<"^\\s*([a-zA-Z]+)">>, [{capture, all_but_first, binary}]) of