Prechádzať zdrojové kódy

Merge pull request #10095 from fix/EEC-782/mysql-prepstmt-exhaustion

fix(mysql): be explicit that batch queries are parameterless
Andrew Mayorov 2 rokov pred
rodič
commit
a530ccbe3d

+ 9 - 7
.ci/docker-compose-file/docker-compose-mysql-tcp.yaml

@@ -13,10 +13,12 @@ services:
     networks:
       - emqx_bridge
     command:
-      --bind-address "::"
-      --character-set-server=utf8mb4
-      --collation-server=utf8mb4_general_ci
-      --explicit_defaults_for_timestamp=true
-      --lower_case_table_names=1
-      --max_allowed_packet=128M
-      --skip-symbolic-links
+      - --bind-address=0.0.0.0
+      - --character-set-server=utf8mb4
+      - --collation-server=utf8mb4_general_ci
+      - --lower-case-table-names=1
+      - --max-allowed-packet=128M
+      # Severely limit maximum number of prepared statements the server must permit
+      # so that we hit potential resource exhaustion earlier in tests.
+      - --max-prepared-stmt-count=64
+      - --skip-symbolic-links

+ 5 - 3
.ci/docker-compose-file/docker-compose-mysql-tls.yaml

@@ -23,9 +23,11 @@ services:
       - --port=3306
       - --character-set-server=utf8mb4
       - --collation-server=utf8mb4_general_ci
-      - --explicit_defaults_for_timestamp=true
-      - --lower_case_table_names=1
-      - --max_allowed_packet=128M
+      - --lower-case-table-names=1
+      - --max-allowed-packet=128M
+      # Severely limit maximum number of prepared statements the server must permit
+      # so that we hit potential resource exhaustion earlier in tests.
+      - --max-prepared-stmt-count=64
       - --ssl-ca=/etc/certs/ca-cert.pem
       - --ssl-cert=/etc/certs/server-cert.pem
       - --ssl-key=/etc/certs/server-key.pem

+ 13 - 9
apps/emqx_connector/src/emqx_connector_mysql.erl

@@ -391,14 +391,18 @@ proc_sql_params(TypeOrKey, SQLOrData, Params, #{params_tokens := ParamsTokens})
     end.
 
 on_batch_insert(InstId, BatchReqs, InsertPart, Tokens, State) ->
-    SQL = emqx_plugin_libs_rule:proc_batch_sql(BatchReqs, InsertPart, Tokens),
-    on_sql_query(InstId, query, SQL, [], default_timeout, State).
+    ValuesPart = lists:join($,, [
+        emqx_placeholder:proc_param_str(Tokens, Msg, fun emqx_placeholder:quote_mysql/1)
+     || {_, Msg} <- BatchReqs
+    ]),
+    Query = [InsertPart, <<" values ">> | ValuesPart],
+    on_sql_query(InstId, query, Query, no_params, default_timeout, State).
 
 on_sql_query(
     InstId,
     SQLFunc,
     SQLOrKey,
-    Data,
+    Params,
     Timeout,
     #{poolname := PoolName} = State
 ) ->
@@ -409,9 +413,9 @@ on_sql_query(
         {ok, Conn} ->
             ?tp(
                 mysql_connector_send_query,
-                #{sql_func => SQLFunc, sql_or_key => SQLOrKey, data => Data}
+                #{sql_func => SQLFunc, sql_or_key => SQLOrKey, data => Params}
             ),
-            do_sql_query(SQLFunc, Conn, SQLOrKey, Data, Timeout, LogMeta);
+            do_sql_query(SQLFunc, Conn, SQLOrKey, Params, Timeout, LogMeta);
         {error, disconnected} ->
             ?SLOG(
                 error,
@@ -423,8 +427,8 @@ on_sql_query(
             {error, {recoverable_error, disconnected}}
     end.
 
-do_sql_query(SQLFunc, Conn, SQLOrKey, Data, Timeout, LogMeta) ->
-    try mysql:SQLFunc(Conn, SQLOrKey, Data, Timeout) of
+do_sql_query(SQLFunc, Conn, SQLOrKey, Params, Timeout, LogMeta) ->
+    try mysql:SQLFunc(Conn, SQLOrKey, Params, no_filtermap_fun, Timeout) of
         {error, disconnected} ->
             ?SLOG(
                 error,
@@ -466,7 +470,7 @@ do_sql_query(SQLFunc, Conn, SQLOrKey, Data, Timeout, LogMeta) ->
         error:badarg ->
             ?SLOG(
                 error,
-                LogMeta#{msg => "mysql_connector_invalid_params", params => Data}
+                LogMeta#{msg => "mysql_connector_invalid_params", params => Params}
             ),
-            {error, {unrecoverable_error, {invalid_params, Data}}}
+            {error, {unrecoverable_error, {invalid_params, Params}}}
     end.

+ 89 - 21
apps/emqx_plugin_libs/src/emqx_placeholder.erl

@@ -30,6 +30,7 @@
     proc_sql/2,
     proc_sql_param_str/2,
     proc_cql_param_str/2,
+    proc_param_str/3,
     preproc_tmpl_deep/1,
     preproc_tmpl_deep/2,
     proc_tmpl_deep/2,
@@ -39,6 +40,12 @@
     sql_data/1
 ]).
 
+-export([
+    quote_sql/1,
+    quote_cql/1,
+    quote_mysql/1
+]).
+
 -include_lib("emqx/include/emqx_placeholder.hrl").
 
 -define(EX_PLACE_HOLDER, "(\\$\\{[a-zA-Z0-9\\._]+\\})").
@@ -83,6 +90,8 @@
     | {tmpl, tmpl_token()}
     | {value, term()}.
 
+-dialyzer({no_improper_lists, [quote_mysql/1, escape_mysql/4, escape_prepend/4]}).
+
 %%------------------------------------------------------------------------------
 %% APIs
 %%------------------------------------------------------------------------------
@@ -162,12 +171,22 @@ proc_sql(Tokens, Data) ->
 
 -spec proc_sql_param_str(tmpl_token(), map()) -> binary().
 proc_sql_param_str(Tokens, Data) ->
+    % NOTE
+    % This is a bit misleading: currently, escaping logic in `quote_sql/1` likely
+    % won't work with pgsql since it does not support C-style escapes by default.
+    % https://www.postgresql.org/docs/14/sql-syntax-lexical.html#SQL-SYNTAX-CONSTANTS
     proc_param_str(Tokens, Data, fun quote_sql/1).
 
 -spec proc_cql_param_str(tmpl_token(), map()) -> binary().
 proc_cql_param_str(Tokens, Data) ->
     proc_param_str(Tokens, Data, fun quote_cql/1).
 
+-spec proc_param_str(tmpl_token(), map(), fun((_Value) -> iodata())) -> binary().
+proc_param_str(Tokens, Data, Quote) ->
+    iolist_to_binary(
+        proc_tmpl(Tokens, Data, #{return => rawlist, var_trans => Quote})
+    ).
+
 -spec preproc_tmpl_deep(term()) -> deep_template().
 preproc_tmpl_deep(Data) ->
     preproc_tmpl_deep(Data, #{process_keys => true}).
@@ -226,15 +245,29 @@ sql_data(Map) when is_map(Map) -> emqx_json:encode(Map).
 -spec bin(term()) -> binary().
 bin(Val) -> emqx_plugin_libs_rule:bin(Val).
 
+-spec quote_sql(_Value) -> iolist().
+quote_sql(Str) ->
+    quote_escape(Str, fun escape_sql/1).
+
+-spec quote_cql(_Value) -> iolist().
+quote_cql(Str) ->
+    quote_escape(Str, fun escape_cql/1).
+
+-spec quote_mysql(_Value) -> iolist().
+quote_mysql(Str) when is_binary(Str) ->
+    try
+        escape_mysql(Str)
+    catch
+        throw:invalid_utf8 ->
+            [<<"0x">> | binary:encode_hex(Str)]
+    end;
+quote_mysql(Str) ->
+    quote_escape(Str, fun escape_mysql/1).
+
 %%------------------------------------------------------------------------------
 %% Internal functions
 %%------------------------------------------------------------------------------
 
-proc_param_str(Tokens, Data, Quote) ->
-    iolist_to_binary(
-        proc_tmpl(Tokens, Data, #{return => rawlist, var_trans => Quote})
-    ).
-
 get_phld_var(Phld, Data) ->
     emqx_rule_maps:nested_get(Phld, Data).
 
@@ -312,21 +345,56 @@ unwrap(<<"\"${", Val/binary>>, _StripDoubleQuote = true) ->
 unwrap(<<"${", Val/binary>>, _StripDoubleQuote) ->
     binary:part(Val, {0, byte_size(Val) - 1}).
 
-quote_sql(Str) ->
-    quote(Str, <<"\\\\'">>).
-
-quote_cql(Str) ->
-    quote(Str, <<"''">>).
-
-quote(Str, ReplaceWith) when
-    is_list(Str);
-    is_binary(Str);
-    is_atom(Str);
-    is_map(Str)
-->
-    [$', escape_apo(bin(Str), ReplaceWith), $'];
-quote(Val, _) ->
+-spec quote_escape(_Value, fun((binary()) -> iodata())) -> iodata().
+quote_escape(Str, EscapeFun) when is_binary(Str) ->
+    EscapeFun(Str);
+quote_escape(Str, EscapeFun) when is_list(Str) ->
+    case unicode:characters_to_binary(Str) of
+        Bin when is_binary(Bin) ->
+            EscapeFun(Bin);
+        Otherwise ->
+            error(Otherwise)
+    end;
+quote_escape(Str, EscapeFun) when is_atom(Str) orelse is_map(Str) ->
+    EscapeFun(bin(Str));
+quote_escape(Val, _EscapeFun) ->
     bin(Val).
 
-escape_apo(Str, ReplaceWith) ->
-    re:replace(Str, <<"'">>, ReplaceWith, [{return, binary}, global]).
+-spec escape_sql(binary()) -> iolist().
+escape_sql(S) ->
+    ES = binary:replace(S, [<<"\\">>, <<"'">>], <<"\\">>, [global, {insert_replaced, 1}]),
+    [$', ES, $'].
+
+-spec escape_cql(binary()) -> iolist().
+escape_cql(S) ->
+    ES = binary:replace(S, <<"'">>, <<"'">>, [global, {insert_replaced, 1}]),
+    [$', ES, $'].
+
+-spec escape_mysql(binary()) -> iolist().
+escape_mysql(S0) ->
+    % https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
+    [$', escape_mysql(S0, 0, 0, S0), $'].
+
+%% NOTE
+%% This thing looks more complicated than needed because it's optimized for as few
+%% intermediate memory (re)allocations as possible.
+escape_mysql(<<$', Rest/binary>>, I, Run, Src) ->
+    escape_prepend(I, Run, Src, [<<"\\'">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
+escape_mysql(<<$\\, Rest/binary>>, I, Run, Src) ->
+    escape_prepend(I, Run, Src, [<<"\\\\">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
+escape_mysql(<<0, Rest/binary>>, I, Run, Src) ->
+    escape_prepend(I, Run, Src, [<<"\\0">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
+escape_mysql(<<_/utf8, Rest/binary>> = S, I, Run, Src) ->
+    CWidth = byte_size(S) - byte_size(Rest),
+    escape_mysql(Rest, I, Run + CWidth, Src);
+escape_mysql(<<>>, 0, _, Src) ->
+    Src;
+escape_mysql(<<>>, I, Run, Src) ->
+    binary:part(Src, I, Run);
+escape_mysql(_, _I, _Run, _Src) ->
+    throw(invalid_utf8).
+
+escape_prepend(_RunI, 0, _Src, Tail) ->
+    Tail;
+escape_prepend(I, Run, Src, Tail) ->
+    [binary:part(Src, I, Run) | Tail].

+ 1 - 1
apps/emqx_plugin_libs/src/emqx_plugin_libs.app.src

@@ -1,7 +1,7 @@
 %% -*- mode: erlang -*-
 {application, emqx_plugin_libs, [
     {description, "EMQX Plugin utility libs"},
-    {vsn, "4.3.6"},
+    {vsn, "4.3.7"},
     {modules, []},
     {applications, [kernel, stdlib]},
     {env, []}

+ 2 - 7
apps/emqx_plugin_libs/src/emqx_plugin_libs_rule.erl

@@ -68,11 +68,6 @@
 
 -compile({no_auto_import, [float/1]}).
 
--define(EX_PLACE_HOLDER, "(\\$\\{[a-zA-Z0-9\\._]+\\})").
-
-%% Space and CRLF
--define(EX_WITHE_CHARS, "\\s").
-
 -type uri_string() :: iodata().
 
 -type tmpl_token() :: list({var, binary()} | {str, binary()}).
@@ -172,8 +167,8 @@ detect_sql_type(SQL) ->
 ) -> InsertSQL :: binary().
 proc_batch_sql(BatchReqs, InsertPart, Tokens) ->
     ValuesPart = erlang:iolist_to_binary(
-        lists:join(", ", [
-            emqx_plugin_libs_rule:proc_sql_param_str(Tokens, Msg)
+        lists:join($,, [
+            proc_sql_param_str(Tokens, Msg)
          || {_, Msg} <- BatchReqs
         ])
     ),

+ 13 - 5
apps/emqx_plugin_libs/test/emqx_placeholder_SUITE.erl

@@ -105,19 +105,27 @@ t_preproc_sql3(_) ->
         emqx_placeholder:proc_sql_param_str(ParamsTokens, Selected)
     ).
 
-t_preproc_sql4(_) ->
+t_preproc_mysql1(_) ->
     %% with apostrophes
     %% https://github.com/emqx/emqx/issues/4135
     Selected = #{
         a => <<"1''2">>,
         b => 1,
         c => 1.0,
-        d => #{d1 => <<"someone's phone">>}
+        d => #{d1 => <<"someone's phone">>},
+        e => <<$\\, 0, "💩"/utf8>>,
+        f => <<"non-utf8", 16#DCC900:24>>,
+        g => "utf8's cool 🐸"
     },
-    ParamsTokens = emqx_placeholder:preproc_tmpl(<<"a:${a},b:${b},c:${c},d:${d}">>),
+    ParamsTokens = emqx_placeholder:preproc_tmpl(
+        <<"a:${a},b:${b},c:${c},d:${d},e:${e},f:${f},g:${g}">>
+    ),
     ?assertEqual(
-        <<"a:'1\\'\\'2',b:1,c:1.0,d:'{\"d1\":\"someone\\'s phone\"}'">>,
-        emqx_placeholder:proc_sql_param_str(ParamsTokens, Selected)
+        <<
+            "a:'1\\'\\'2',b:1,c:1.0,d:'{\"d1\":\"someone\\'s phone\"}',"
+            "e:'\\\\\\0💩',f:0x6E6F6E2D75746638DCC900,g:'utf8\\'s cool 🐸'"/utf8
+        >>,
+        emqx_placeholder:proc_param_str(ParamsTokens, Selected, fun emqx_placeholder:quote_mysql/1)
     ).
 
 t_preproc_sql5(_) ->

+ 3 - 0
changes/ee/fix-10095.en.md

@@ -0,0 +1,3 @@
+Stop MySQL client from bombarding server repeatedly with unnecessary `PREPARE` queries on every batch, trashing the server and exhausting its internal limits. This was happening when the MySQL bridge was in the batch mode.
+
+Ensure safer and more careful escaping of strings and binaries in batch insert queries when the MySQL bridge is in the batch mode.

+ 1 - 0
changes/ee/fix-10095.zh.md

@@ -0,0 +1 @@
+优化 MySQL 桥接在批量模式下能更高效的使用预处理语句 ,减少了对 MySQL 服务器的查询压力, 并确保对 SQL 语句进行更安全和谨慎的转义。

+ 63 - 13
lib-ee/emqx_ee_bridge/test/emqx_ee_bridge_mysql_SUITE.erl

@@ -28,6 +28,9 @@
 -define(MYSQL_DATABASE, "mqtt").
 -define(MYSQL_USERNAME, "root").
 -define(MYSQL_PASSWORD, "public").
+-define(MYSQL_POOL_SIZE, 4).
+
+-define(WORKER_POOL_SIZE, 4).
 
 %%------------------------------------------------------------------------------
 %% CT boilerplate
@@ -168,11 +171,13 @@ mysql_config(BridgeType, Config) ->
             "  database = ~p\n"
             "  username = ~p\n"
             "  password = ~p\n"
+            "  pool_size = ~b\n"
             "  sql = ~p\n"
             "  resource_opts = {\n"
             "    request_timeout = 500ms\n"
             "    batch_size = ~b\n"
             "    query_mode = ~s\n"
+            "    worker_pool_size = ~b\n"
             "  }\n"
             "  ssl = {\n"
             "    enable = ~w\n"
@@ -185,9 +190,11 @@ mysql_config(BridgeType, Config) ->
                 ?MYSQL_DATABASE,
                 ?MYSQL_USERNAME,
                 ?MYSQL_PASSWORD,
+                ?MYSQL_POOL_SIZE,
                 ?SQL_BRIDGE,
                 BatchSize,
                 QueryMode,
+                ?WORKER_POOL_SIZE,
                 TlsEnabled
             ]
         ),
@@ -265,27 +272,26 @@ connect_direct_mysql(Config) ->
     {ok, Pid} = mysql:start_link(Opts ++ SslOpts),
     Pid.
 
+query_direct_mysql(Config, Query) ->
+    Pid = connect_direct_mysql(Config),
+    try
+        mysql:query(Pid, Query)
+    after
+        mysql:stop(Pid)
+    end.
+
 % These funs connect and then stop the mysql connection
 connect_and_create_table(Config) ->
-    DirectPid = connect_direct_mysql(Config),
-    ok = mysql:query(DirectPid, ?SQL_CREATE_TABLE),
-    mysql:stop(DirectPid).
+    query_direct_mysql(Config, ?SQL_CREATE_TABLE).
 
 connect_and_drop_table(Config) ->
-    DirectPid = connect_direct_mysql(Config),
-    ok = mysql:query(DirectPid, ?SQL_DROP_TABLE),
-    mysql:stop(DirectPid).
+    query_direct_mysql(Config, ?SQL_DROP_TABLE).
 
 connect_and_clear_table(Config) ->
-    DirectPid = connect_direct_mysql(Config),
-    ok = mysql:query(DirectPid, ?SQL_DELETE),
-    mysql:stop(DirectPid).
+    query_direct_mysql(Config, ?SQL_DELETE).
 
 connect_and_get_payload(Config) ->
-    DirectPid = connect_direct_mysql(Config),
-    Result = mysql:query(DirectPid, ?SQL_SELECT),
-    mysql:stop(DirectPid),
-    Result.
+    query_direct_mysql(Config, ?SQL_SELECT).
 
 %%------------------------------------------------------------------------------
 %% Testcases
@@ -505,6 +511,50 @@ t_bad_sql_parameter(Config) ->
     end,
     ok.
 
+t_nasty_sql_string(Config) ->
+    ?assertMatch({ok, _}, create_bridge(Config)),
+    Payload = list_to_binary(lists:seq(0, 255)),
+    Message = #{payload => Payload, timestamp => erlang:system_time(millisecond)},
+    Result = send_message(Config, Message),
+    ?assertEqual(ok, Result),
+    ?assertMatch(
+        {ok, [<<"payload">>], [[Payload]]},
+        connect_and_get_payload(Config)
+    ).
+
+t_workload_fits_prepared_statement_limit(Config) ->
+    N = 50,
+    ?assertMatch(
+        {ok, _},
+        create_bridge(Config)
+    ),
+    Results = lists:append(
+        emqx_misc:pmap(
+            fun(_) ->
+                [
+                    begin
+                        Payload = integer_to_binary(erlang:unique_integer()),
+                        Timestamp = erlang:system_time(millisecond),
+                        send_message(Config, #{payload => Payload, timestamp => Timestamp})
+                    end
+                 || _ <- lists:seq(1, N)
+                ]
+            end,
+            lists:seq(1, ?WORKER_POOL_SIZE * ?MYSQL_POOL_SIZE),
+            _Timeout = 10_000
+        )
+    ),
+    ?assertEqual(
+        [],
+        [R || R <- Results, R /= ok]
+    ),
+    {ok, _, [[_Var, Count]]} =
+        query_direct_mysql(Config, "SHOW GLOBAL STATUS LIKE 'Prepared_stmt_count'"),
+    ?assertEqual(
+        ?MYSQL_POOL_SIZE,
+        binary_to_integer(Count)
+    ).
+
 t_unprepared_statement_query(Config) ->
     ?assertMatch(
         {ok, _},

+ 7 - 0
lib-ee/emqx_ee_bridge/test/emqx_ee_bridge_pgsql_SUITE.erl

@@ -510,3 +510,10 @@ t_bad_sql_parameter(Config) ->
             )
     end,
     ok.
+
+t_nasty_sql_string(Config) ->
+    ?assertMatch({ok, _}, create_bridge(Config)),
+    Payload = list_to_binary(lists:seq(1, 127)),
+    Message = #{payload => Payload, timestamp => erlang:system_time(millisecond)},
+    ?assertEqual({ok, 1}, send_message(Config, Message)),
+    ?assertEqual(Payload, connect_and_get_payload(Config)).

+ 26 - 0
lib-ee/emqx_ee_bridge/test/emqx_ee_bridge_tdengine_SUITE.erl

@@ -426,6 +426,32 @@ t_bad_sql_parameter(Config) ->
     end,
     ok.
 
+t_nasty_sql_string(Config) ->
+    ?assertMatch(
+        {ok, _},
+        create_bridge(Config)
+    ),
+    % NOTE
+    % Column `payload` has BINARY type, so we would certainly like to test it
+    % with `lists:seq(1, 127)`, but:
+    % 1. There's no way to insert zero byte in an SQL string, seems that TDengine's
+    %    parser[1] has no escaping sequence for it so a zero byte probably confuses
+    %    interpreter somewhere down the line.
+    % 2. Bytes > 127 come back as U+FFFDs (i.e. replacement characters) in UTF-8 for
+    %    some reason.
+    %
+    % [1]: https://github.com/taosdata/TDengine/blob/066cb34a/source/libs/parser/src/parUtil.c#L279-L301
+    Payload = list_to_binary(lists:seq(1, 127)),
+    Message = #{payload => Payload, timestamp => erlang:system_time(millisecond)},
+    ?assertMatch(
+        {ok, #{<<"code">> := 0, <<"rows">> := 1}},
+        send_message(Config, Message)
+    ),
+    ?assertEqual(
+        Payload,
+        connect_and_get_payload(Config)
+    ).
+
 to_bin(List) when is_list(List) ->
     unicode:characters_to_binary(List, utf8);
 to_bin(Bin) when is_binary(Bin) ->