Prechádzať zdrojové kódy

fix(mysql): ensure proper escaping in batch inserts

Also hexencode non-utf8 binaries. This is essentially an heuristic.
We don't know column types in runtime, and there's no simple way
to find them out. Since we're already doing full binary scan during
escaping it should be cheap to bail out on non-utf8 strings and
hexencode them instead.

Also introduce separate function to highlight that this escaping
is MySQL-specific.
Andrew Mayorov 3 rokov pred
rodič
commit
0a7f6c7d03

+ 6 - 2
apps/emqx_connector/src/emqx_connector_mysql.erl

@@ -391,8 +391,12 @@ proc_sql_params(TypeOrKey, SQLOrData, Params, #{params_tokens := ParamsTokens})
     end.
     end.
 
 
 on_batch_insert(InstId, BatchReqs, InsertPart, Tokens, State) ->
 on_batch_insert(InstId, BatchReqs, InsertPart, Tokens, State) ->
-    SQL = emqx_plugin_libs_rule:proc_batch_sql(BatchReqs, InsertPart, Tokens),
-    on_sql_query(InstId, query, SQL, no_params, default_timeout, State).
+    ValuesPart = lists:join($,, [
+        emqx_placeholder:proc_param_str(Tokens, Msg, fun emqx_placeholder:quote_mysql/1)
+     || {_, Msg} <- BatchReqs
+    ]),
+    Query = [InsertPart, <<" values ">> | ValuesPart],
+    on_sql_query(InstId, query, Query, no_params, default_timeout, State).
 
 
 on_sql_query(
 on_sql_query(
     InstId,
     InstId,

+ 89 - 21
apps/emqx_plugin_libs/src/emqx_placeholder.erl

@@ -30,6 +30,7 @@
     proc_sql/2,
     proc_sql/2,
     proc_sql_param_str/2,
     proc_sql_param_str/2,
     proc_cql_param_str/2,
     proc_cql_param_str/2,
+    proc_param_str/3,
     preproc_tmpl_deep/1,
     preproc_tmpl_deep/1,
     preproc_tmpl_deep/2,
     preproc_tmpl_deep/2,
     proc_tmpl_deep/2,
     proc_tmpl_deep/2,
@@ -39,6 +40,12 @@
     sql_data/1
     sql_data/1
 ]).
 ]).
 
 
+-export([
+    quote_sql/1,
+    quote_cql/1,
+    quote_mysql/1
+]).
+
 -include_lib("emqx/include/emqx_placeholder.hrl").
 -include_lib("emqx/include/emqx_placeholder.hrl").
 
 
 -define(EX_PLACE_HOLDER, "(\\$\\{[a-zA-Z0-9\\._]+\\})").
 -define(EX_PLACE_HOLDER, "(\\$\\{[a-zA-Z0-9\\._]+\\})").
@@ -83,6 +90,8 @@
     | {tmpl, tmpl_token()}
     | {tmpl, tmpl_token()}
     | {value, term()}.
     | {value, term()}.
 
 
+-dialyzer({no_improper_lists, [quote_mysql/1, escape_mysql/4, escape_prepend/4]}).
+
 %%------------------------------------------------------------------------------
 %%------------------------------------------------------------------------------
 %% APIs
 %% APIs
 %%------------------------------------------------------------------------------
 %%------------------------------------------------------------------------------
@@ -162,12 +171,22 @@ proc_sql(Tokens, Data) ->
 
 
 -spec proc_sql_param_str(tmpl_token(), map()) -> binary().
 -spec proc_sql_param_str(tmpl_token(), map()) -> binary().
 proc_sql_param_str(Tokens, Data) ->
 proc_sql_param_str(Tokens, Data) ->
+    % NOTE
+    % This is a bit misleading: currently, escaping logic in `quote_sql/1` likely
+    % won't work with pgsql since it does not support C-style escapes by default.
+    % https://www.postgresql.org/docs/14/sql-syntax-lexical.html#SQL-SYNTAX-CONSTANTS
     proc_param_str(Tokens, Data, fun quote_sql/1).
     proc_param_str(Tokens, Data, fun quote_sql/1).
 
 
 -spec proc_cql_param_str(tmpl_token(), map()) -> binary().
 -spec proc_cql_param_str(tmpl_token(), map()) -> binary().
 proc_cql_param_str(Tokens, Data) ->
 proc_cql_param_str(Tokens, Data) ->
     proc_param_str(Tokens, Data, fun quote_cql/1).
     proc_param_str(Tokens, Data, fun quote_cql/1).
 
 
+-spec proc_param_str(tmpl_token(), map(), fun((_Value) -> iodata())) -> binary().
+proc_param_str(Tokens, Data, Quote) ->
+    iolist_to_binary(
+        proc_tmpl(Tokens, Data, #{return => rawlist, var_trans => Quote})
+    ).
+
 -spec preproc_tmpl_deep(term()) -> deep_template().
 -spec preproc_tmpl_deep(term()) -> deep_template().
 preproc_tmpl_deep(Data) ->
 preproc_tmpl_deep(Data) ->
     preproc_tmpl_deep(Data, #{process_keys => true}).
     preproc_tmpl_deep(Data, #{process_keys => true}).
@@ -226,15 +245,29 @@ sql_data(Map) when is_map(Map) -> emqx_json:encode(Map).
 -spec bin(term()) -> binary().
 -spec bin(term()) -> binary().
 bin(Val) -> emqx_plugin_libs_rule:bin(Val).
 bin(Val) -> emqx_plugin_libs_rule:bin(Val).
 
 
+-spec quote_sql(_Value) -> iolist().
+quote_sql(Str) ->
+    quote_escape(Str, fun escape_sql/1).
+
+-spec quote_cql(_Value) -> iolist().
+quote_cql(Str) ->
+    quote_escape(Str, fun escape_cql/1).
+
+-spec quote_mysql(_Value) -> iolist().
+quote_mysql(Str) when is_binary(Str) ->
+    try
+        escape_mysql(Str)
+    catch
+        throw:invalid_utf8 ->
+            [<<"0x">> | binary:encode_hex(Str)]
+    end;
+quote_mysql(Str) ->
+    quote_escape(Str, fun escape_mysql/1).
+
 %%------------------------------------------------------------------------------
 %%------------------------------------------------------------------------------
 %% Internal functions
 %% Internal functions
 %%------------------------------------------------------------------------------
 %%------------------------------------------------------------------------------
 
 
-proc_param_str(Tokens, Data, Quote) ->
-    iolist_to_binary(
-        proc_tmpl(Tokens, Data, #{return => rawlist, var_trans => Quote})
-    ).
-
 get_phld_var(Phld, Data) ->
 get_phld_var(Phld, Data) ->
     emqx_rule_maps:nested_get(Phld, Data).
     emqx_rule_maps:nested_get(Phld, Data).
 
 
@@ -312,21 +345,56 @@ unwrap(<<"\"${", Val/binary>>, _StripDoubleQuote = true) ->
 unwrap(<<"${", Val/binary>>, _StripDoubleQuote) ->
 unwrap(<<"${", Val/binary>>, _StripDoubleQuote) ->
     binary:part(Val, {0, byte_size(Val) - 1}).
     binary:part(Val, {0, byte_size(Val) - 1}).
 
 
-quote_sql(Str) ->
-    quote(Str, <<"\\\\'">>).
-
-quote_cql(Str) ->
-    quote(Str, <<"''">>).
-
-quote(Str, ReplaceWith) when
-    is_list(Str);
-    is_binary(Str);
-    is_atom(Str);
-    is_map(Str)
-->
-    [$', escape_apo(bin(Str), ReplaceWith), $'];
-quote(Val, _) ->
+-spec quote_escape(_Value, fun((binary()) -> iodata())) -> iodata().
+quote_escape(Str, EscapeFun) when is_binary(Str) ->
+    EscapeFun(Str);
+quote_escape(Str, EscapeFun) when is_list(Str) ->
+    case unicode:characters_to_binary(Str) of
+        Bin when is_binary(Bin) ->
+            EscapeFun(Bin);
+        Otherwise ->
+            error(Otherwise)
+    end;
+quote_escape(Str, EscapeFun) when is_atom(Str) orelse is_map(Str) ->
+    EscapeFun(bin(Str));
+quote_escape(Val, _EscapeFun) ->
     bin(Val).
     bin(Val).
 
 
-escape_apo(Str, ReplaceWith) ->
-    re:replace(Str, <<"'">>, ReplaceWith, [{return, binary}, global]).
+-spec escape_sql(binary()) -> iolist().
+escape_sql(S) ->
+    ES = binary:replace(S, [<<"\\">>, <<"'">>], <<"\\">>, [global, {insert_replaced, 1}]),
+    [$', ES, $'].
+
+-spec escape_cql(binary()) -> iolist().
+escape_cql(S) ->
+    ES = binary:replace(S, <<"'">>, <<"'">>, [global, {insert_replaced, 1}]),
+    [$', ES, $'].
+
+-spec escape_mysql(binary()) -> iolist().
+escape_mysql(S0) ->
+    % https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
+    [$', escape_mysql(S0, 0, 0, S0), $'].
+
+%% NOTE
+%% This thing looks more complicated than needed because it's optimized for as few
+%% intermediate memory (re)allocations as possible.
+escape_mysql(<<$', Rest/binary>>, I, Run, Src) ->
+    escape_prepend(I, Run, Src, [<<"\\'">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
+escape_mysql(<<$\\, Rest/binary>>, I, Run, Src) ->
+    escape_prepend(I, Run, Src, [<<"\\\\">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
+escape_mysql(<<0, Rest/binary>>, I, Run, Src) ->
+    escape_prepend(I, Run, Src, [<<"\\0">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
+escape_mysql(<<_/utf8, Rest/binary>> = S, I, Run, Src) ->
+    CWidth = byte_size(S) - byte_size(Rest),
+    escape_mysql(Rest, I, Run + CWidth, Src);
+escape_mysql(<<>>, 0, _, Src) ->
+    Src;
+escape_mysql(<<>>, I, Run, Src) ->
+    binary:part(Src, I, Run);
+escape_mysql(_, _I, _Run, _Src) ->
+    throw(invalid_utf8).
+
+escape_prepend(_RunI, 0, _Src, Tail) ->
+    Tail;
+escape_prepend(I, Run, Src, Tail) ->
+    [binary:part(Src, I, Run) | Tail].

+ 2 - 2
apps/emqx_plugin_libs/src/emqx_plugin_libs_rule.erl

@@ -172,8 +172,8 @@ detect_sql_type(SQL) ->
 ) -> InsertSQL :: binary().
 ) -> InsertSQL :: binary().
 proc_batch_sql(BatchReqs, InsertPart, Tokens) ->
 proc_batch_sql(BatchReqs, InsertPart, Tokens) ->
     ValuesPart = erlang:iolist_to_binary(
     ValuesPart = erlang:iolist_to_binary(
-        lists:join(", ", [
-            emqx_plugin_libs_rule:proc_sql_param_str(Tokens, Msg)
+        lists:join($,, [
+            proc_sql_param_str(Tokens, Msg)
          || {_, Msg} <- BatchReqs
          || {_, Msg} <- BatchReqs
         ])
         ])
     ),
     ),

+ 13 - 5
apps/emqx_plugin_libs/test/emqx_placeholder_SUITE.erl

@@ -105,19 +105,27 @@ t_preproc_sql3(_) ->
         emqx_placeholder:proc_sql_param_str(ParamsTokens, Selected)
         emqx_placeholder:proc_sql_param_str(ParamsTokens, Selected)
     ).
     ).
 
 
-t_preproc_sql4(_) ->
+t_preproc_mysql1(_) ->
     %% with apostrophes
     %% with apostrophes
     %% https://github.com/emqx/emqx/issues/4135
     %% https://github.com/emqx/emqx/issues/4135
     Selected = #{
     Selected = #{
         a => <<"1''2">>,
         a => <<"1''2">>,
         b => 1,
         b => 1,
         c => 1.0,
         c => 1.0,
-        d => #{d1 => <<"someone's phone">>}
+        d => #{d1 => <<"someone's phone">>},
+        e => <<$\\, 0, "💩"/utf8>>,
+        f => <<"non-utf8", 16#DCC900:24>>,
+        g => "utf8's cool 🐸"
     },
     },
-    ParamsTokens = emqx_placeholder:preproc_tmpl(<<"a:${a},b:${b},c:${c},d:${d}">>),
+    ParamsTokens = emqx_placeholder:preproc_tmpl(
+        <<"a:${a},b:${b},c:${c},d:${d},e:${e},f:${f},g:${g}">>
+    ),
     ?assertEqual(
     ?assertEqual(
-        <<"a:'1\\'\\'2',b:1,c:1.0,d:'{\"d1\":\"someone\\'s phone\"}'">>,
-        emqx_placeholder:proc_sql_param_str(ParamsTokens, Selected)
+        <<
+            "a:'1\\'\\'2',b:1,c:1.0,d:'{\"d1\":\"someone\\'s phone\"}',"
+            "e:'\\\\\\0💩',f:0x6E6F6E2D75746638DCC900,g:'utf8\\'s cool 🐸'"/utf8
+        >>,
+        emqx_placeholder:proc_param_str(ParamsTokens, Selected, fun emqx_placeholder:quote_mysql/1)
     ).
     ).
 
 
 t_preproc_sql5(_) ->
 t_preproc_sql5(_) ->

+ 11 - 0
lib-ee/emqx_ee_bridge/test/emqx_ee_bridge_mysql_SUITE.erl

@@ -511,6 +511,17 @@ t_bad_sql_parameter(Config) ->
     end,
     end,
     ok.
     ok.
 
 
+t_nasty_sql_string(Config) ->
+    ?assertMatch({ok, _}, create_bridge(Config)),
+    Payload = list_to_binary(lists:seq(0, 255)),
+    Message = #{payload => Payload, timestamp => erlang:system_time(millisecond)},
+    Result = send_message(Config, Message),
+    ?assertEqual(ok, Result),
+    ?assertMatch(
+        {ok, [<<"payload">>], [[Payload]]},
+        connect_and_get_payload(Config)
+    ).
+
 t_workload_fits_prepared_statement_limit(Config) ->
 t_workload_fits_prepared_statement_limit(Config) ->
     N = 50,
     N = 50,
     ?assertMatch(
     ?assertMatch(

+ 7 - 0
lib-ee/emqx_ee_bridge/test/emqx_ee_bridge_pgsql_SUITE.erl

@@ -510,3 +510,10 @@ t_bad_sql_parameter(Config) ->
             )
             )
     end,
     end,
     ok.
     ok.
+
+t_nasty_sql_string(Config) ->
+    ?assertMatch({ok, _}, create_bridge(Config)),
+    Payload = list_to_binary(lists:seq(1, 127)),
+    Message = #{payload => Payload, timestamp => erlang:system_time(millisecond)},
+    ?assertEqual({ok, 1}, send_message(Config, Message)),
+    ?assertEqual(Payload, connect_and_get_payload(Config)).