Browse Source

Merge pull request #13264 from thalesmg/fix-postgres-disabled-prepared-r57-20240614

fix(postgres): authn/authz/batch requests when prepared statements are disabled
Thales Macedo Garitezi 1 year ago
parent
commit
fb492e3dc5

+ 46 - 5
apps/emqx_auth_postgresql/test/emqx_authn_postgresql_SUITE.erl

@@ -30,15 +30,28 @@
 
 
 -define(PATH, [authentication]).
 -define(PATH, [authentication]).
 
 
+-import(emqx_common_test_helpers, [on_exit/1]).
+
 all() ->
 all() ->
+    AllTCs = emqx_common_test_helpers:all(?MODULE),
+    TCs = AllTCs -- require_seeds_tests(),
     [
     [
-        {group, require_seeds},
-        t_update_with_invalid_config,
-        t_update_with_bad_config_value
+        {group, require_seeds}
+        | TCs
     ].
     ].
 
 
 groups() ->
 groups() ->
-    [{require_seeds, [], [t_create, t_authenticate, t_update, t_destroy, t_is_superuser]}].
+    [{require_seeds, [], require_seeds_tests()}].
+
+require_seeds_tests() ->
+    [
+        t_create,
+        t_authenticate,
+        t_authenticate_disabled_prepared_statements,
+        t_update,
+        t_destroy,
+        t_is_superuser
+    ].
 
 
 init_per_testcase(_, Config) ->
 init_per_testcase(_, Config) ->
     emqx_authn_test_lib:delete_authenticators(
     emqx_authn_test_lib:delete_authenticators(
@@ -47,6 +60,10 @@ init_per_testcase(_, Config) ->
     ),
     ),
     Config.
     Config.
 
 
+end_per_testcase(_TestCase, _Config) ->
+    emqx_common_test_helpers:call_janitor(),
+    ok.
+
 init_per_group(require_seeds, Config) ->
 init_per_group(require_seeds, Config) ->
     ok = init_seeds(),
     ok = init_seeds(),
     Config.
     Config.
@@ -70,7 +87,12 @@ init_per_suite(Config) ->
             ),
             ),
             [{apps, Apps} | Config];
             [{apps, Apps} | Config];
         false ->
         false ->
-            {skip, no_pgsql}
+            case os:getenv("IS_CI") of
+                "yes" ->
+                    throw(no_postgres);
+                _ ->
+                    {skip, no_postgres}
+            end
     end.
     end.
 
 
 end_per_suite(Config) ->
 end_per_suite(Config) ->
@@ -174,6 +196,25 @@ test_user_auth(#{
         ?GLOBAL
         ?GLOBAL
     ).
     ).
 
 
+t_authenticate_disabled_prepared_statements(Config) ->
+    ResConfig = maps:merge(pgsql_config(), #{disable_prepared_statements => true}),
+    {ok, _} = emqx_resource:recreate_local(?PGSQL_RESOURCE, emqx_postgresql, ResConfig),
+    on_exit(fun() ->
+        emqx_resource:recreate_local(?PGSQL_RESOURCE, emqx_postgresql, pgsql_config())
+    end),
+    ok = lists:foreach(
+        fun(Sample0) ->
+            Sample = maps:update_with(
+                config_params,
+                fun(Cfg) -> Cfg#{<<"disable_prepared_statements">> => true} end,
+                Sample0
+            ),
+            ct:pal("test_user_auth sample: ~p", [Sample]),
+            test_user_auth(Sample)
+        end,
+        user_seeds()
+    ).
+
 t_destroy(_Config) ->
 t_destroy(_Config) ->
     AuthConfig = raw_pgsql_auth_config(),
     AuthConfig = raw_pgsql_auth_config(),
 
 

+ 2 - 2
apps/emqx_bridge_pgsql/test/emqx_bridge_pgsql_SUITE.erl

@@ -601,7 +601,7 @@ t_simple_sql_query(Config) ->
         {ok, _},
         {ok, _},
         create_bridge(Config)
         create_bridge(Config)
     ),
     ),
-    Request = {sql, <<"SELECT count(1) AS T">>},
+    Request = {query, <<"SELECT count(1) AS T">>},
     Result =
     Result =
         case QueryMode of
         case QueryMode of
             sync ->
             sync ->
@@ -651,7 +651,7 @@ t_bad_sql_parameter(Config) ->
         {ok, _},
         {ok, _},
         create_bridge(Config)
         create_bridge(Config)
     ),
     ),
-    Request = {sql, <<"">>, [bad_parameter]},
+    Request = {query, <<"">>, [bad_parameter]},
     Result =
     Result =
         case QueryMode of
         case QueryMode of
             sync ->
             sync ->

+ 70 - 7
apps/emqx_bridge_pgsql/test/emqx_bridge_v2_pgsql_SUITE.erl

@@ -102,6 +102,18 @@ init_per_group(Group, Config) when
         {connector_type, group_to_type(Group)}
         {connector_type, group_to_type(Group)}
         | Config
         | Config
     ];
     ];
+init_per_group(batch_enabled, Config) ->
+    [
+        {batch_size, 10},
+        {batch_time, <<"10ms">>}
+        | Config
+    ];
+init_per_group(batch_disabled, Config) ->
+    [
+        {batch_size, 1},
+        {batch_time, <<"0ms">>}
+        | Config
+    ];
 init_per_group(_Group, Config) ->
 init_per_group(_Group, Config) ->
     Config.
     Config.
 
 
@@ -262,17 +274,68 @@ t_start_action_or_source_with_disabled_connector(Config) ->
     ok.
     ok.
 
 
 t_disable_prepared_statements(matrix) ->
 t_disable_prepared_statements(matrix) ->
-    [[postgres], [timescale], [matrix]];
+    [
+        [postgres, batch_disabled],
+        [postgres, batch_enabled],
+        [timescale, batch_disabled],
+        [timescale, batch_enabled],
+        [matrix, batch_disabled],
+        [matrix, batch_enabled]
+    ];
 t_disable_prepared_statements(Config0) ->
 t_disable_prepared_statements(Config0) ->
+    BatchSize = ?config(batch_size, Config0),
+    BatchTime = ?config(batch_time, Config0),
     ConnectorConfig0 = ?config(connector_config, Config0),
     ConnectorConfig0 = ?config(connector_config, Config0),
     ConnectorConfig = maps:merge(ConnectorConfig0, #{<<"disable_prepared_statements">> => true}),
     ConnectorConfig = maps:merge(ConnectorConfig0, #{<<"disable_prepared_statements">> => true}),
-    Config = lists:keyreplace(connector_config, 1, Config0, {connector_config, ConnectorConfig}),
-    ok = emqx_bridge_v2_testlib:t_sync_query(
-        Config,
-        fun make_message/0,
-        fun(Res) -> ?assertMatch({ok, _}, Res) end,
-        postgres_bridge_connector_on_query_return
+    BridgeConfig0 = ?config(bridge_config, Config0),
+    BridgeConfig = emqx_utils_maps:deep_merge(
+        BridgeConfig0,
+        #{
+            <<"resource_opts">> => #{
+                <<"batch_size">> => BatchSize,
+                <<"batch_time">> => BatchTime,
+                <<"query_mode">> => <<"async">>
+            }
+        }
     ),
     ),
+    Config1 = lists:keyreplace(connector_config, 1, Config0, {connector_config, ConnectorConfig}),
+    Config = lists:keyreplace(bridge_config, 1, Config1, {bridge_config, BridgeConfig}),
+    ?check_trace(
+        #{timetrap => 5_000},
+        begin
+            ?assertMatch({ok, _}, emqx_bridge_v2_testlib:create_bridge_api(Config)),
+            RuleTopic = <<"t/postgres">>,
+            Type = ?config(bridge_type, Config),
+            {ok, _} = emqx_bridge_v2_testlib:create_rule_and_action_http(Type, RuleTopic, Config),
+            ResourceId = emqx_bridge_v2_testlib:resource_id(Config),
+            ?retry(
+                _Sleep = 1_000,
+                _Attempts = 20,
+                ?assertEqual({ok, connected}, emqx_resource_manager:health_check(ResourceId))
+            ),
+            {ok, C} = emqtt:start_link(),
+            {ok, _} = emqtt:connect(C),
+            lists:foreach(
+                fun(N) ->
+                    emqtt:publish(C, RuleTopic, integer_to_binary(N))
+                end,
+                lists:seq(1, BatchSize)
+            ),
+            case BatchSize > 1 of
+                true ->
+                    ?block_until(#{
+                        ?snk_kind := "postgres_success_batch_result",
+                        row_count := BatchSize
+                    }),
+                    ok;
+                false ->
+                    ok
+            end,
+            ok
+        end,
+        []
+    ),
+    emqx_bridge_v2_testlib:delete_all_bridges_and_connectors(),
     ok = emqx_bridge_v2_testlib:t_on_get_status(Config, #{failure_status => connecting}),
     ok = emqx_bridge_v2_testlib:t_on_get_status(Config, #{failure_status => connecting}),
     emqx_bridge_v2_testlib:delete_all_bridges_and_connectors(),
     emqx_bridge_v2_testlib:delete_all_bridges_and_connectors(),
     ok = emqx_bridge_v2_testlib:t_create_via_http(Config),
     ok = emqx_bridge_v2_testlib:t_create_via_http(Config),

+ 73 - 50
apps/emqx_postgresql/src/emqx_postgresql.erl

@@ -59,6 +59,9 @@
     default_port => ?PGSQL_DEFAULT_PORT
     default_port => ?PGSQL_DEFAULT_PORT
 }).
 }).
 
 
+-type connector_resource_id() :: binary().
+-type action_resource_id() :: binary().
+
 -type template() :: {unicode:chardata(), emqx_template_sql:row_template()}.
 -type template() :: {unicode:chardata(), emqx_template_sql:row_template()}.
 -type state() ::
 -type state() ::
     #{
     #{
@@ -319,38 +322,40 @@ do_check_channel_sql(
 on_get_channels(ResId) ->
 on_get_channels(ResId) ->
     emqx_bridge_v2:get_channels_for_connector(ResId).
     emqx_bridge_v2:get_channels_for_connector(ResId).
 
 
-on_query(InstId, {TypeOrKey, NameOrSQL}, State) ->
-    on_query(InstId, {TypeOrKey, NameOrSQL, []}, State);
+-spec on_query
+    %% Called from authn and authz modules
+    (connector_resource_id(), {prepared_query, binary(), [term()]}, state()) ->
+        {ok, _} | {error, term()};
+    %% Called from bridges
+    (connector_resource_id(), {action_resource_id(), map()}, state()) ->
+        {ok, _} | {error, term()}.
+on_query(InstId, {TypeOrKey, NameOrMap}, State) ->
+    on_query(InstId, {TypeOrKey, NameOrMap, []}, State);
 on_query(
 on_query(
     InstId,
     InstId,
-    {TypeOrKey, NameOrSQL, Params},
+    {TypeOrKey, NameOrMap, Params},
     #{pool_name := PoolName} = State
     #{pool_name := PoolName} = State
 ) ->
 ) ->
     ?SLOG(debug, #{
     ?SLOG(debug, #{
         msg => "postgresql_connector_received_sql_query",
         msg => "postgresql_connector_received_sql_query",
         connector => InstId,
         connector => InstId,
         type => TypeOrKey,
         type => TypeOrKey,
-        sql => NameOrSQL,
+        sql => NameOrMap,
         state => State
         state => State
     }),
     }),
-    Type = pgsql_query_type(TypeOrKey, State),
-    {NameOrSQL2, Data} = proc_sql_params(TypeOrKey, NameOrSQL, Params, State),
-    Res = on_sql_query(TypeOrKey, InstId, PoolName, Type, NameOrSQL2, Data),
+    {QueryType, NameOrSQL2, Data} = proc_sql_params(TypeOrKey, NameOrMap, Params, State),
+    emqx_trace:rendered_action_template(
+        TypeOrKey,
+        #{
+            statement_type => QueryType,
+            statement_or_name => NameOrSQL2,
+            data => Data
+        }
+    ),
+    Res = on_sql_query(InstId, PoolName, QueryType, NameOrSQL2, Data),
     ?tp(postgres_bridge_connector_on_query_return, #{instance_id => InstId, result => Res}),
     ?tp(postgres_bridge_connector_on_query_return, #{instance_id => InstId, result => Res}),
     handle_result(Res).
     handle_result(Res).
 
 
-pgsql_query_type(_TypeOrTag, #{prepares := disabled}) ->
-    query;
-pgsql_query_type(sql, _ConnectorState) ->
-    query;
-pgsql_query_type(query, _ConnectorState) ->
-    query;
-pgsql_query_type(prepared_query, _ConnectorState) ->
-    prepared_query;
-%% for bridge
-pgsql_query_type(_, ConnectorState) ->
-    pgsql_query_type(prepared_query, ConnectorState).
-
 on_batch_query(
 on_batch_query(
     InstId,
     InstId,
     [{Key, _} = Request | _] = BatchReq,
     [{Key, _} = Request | _] = BatchReq,
@@ -370,7 +375,15 @@ on_batch_query(
         {_Statement, RowTemplate} ->
         {_Statement, RowTemplate} ->
             StatementTemplate = get_templated_statement(BinKey, State),
             StatementTemplate = get_templated_statement(BinKey, State),
             Rows = [render_prepare_sql_row(RowTemplate, Data) || {_Key, Data} <- BatchReq],
             Rows = [render_prepare_sql_row(RowTemplate, Data) || {_Key, Data} <- BatchReq],
-            case on_sql_query(Key, InstId, PoolName, execute_batch, StatementTemplate, Rows) of
+            emqx_trace:rendered_action_template(
+                Key,
+                #{
+                    statement_type => execute_batch,
+                    statement_or_name => StatementTemplate,
+                    data => Rows
+                }
+            ),
+            case on_sql_query(InstId, PoolName, execute_batch, StatementTemplate, Rows) of
                 {error, _Error} = Result ->
                 {error, _Error} = Result ->
                     handle_result(Result);
                     handle_result(Result);
                 {_Column, Results} ->
                 {_Column, Results} ->
@@ -386,25 +399,38 @@ on_batch_query(InstId, BatchReq, State) ->
     }),
     }),
     {error, {unrecoverable_error, invalid_request}}.
     {error, {unrecoverable_error, invalid_request}}.
 
 
-proc_sql_params(query, SQLOrKey, Params, _State) ->
-    {SQLOrKey, Params};
-proc_sql_params(prepared_query, SQLOrKey, Params, _State) ->
-    {SQLOrKey, Params};
-proc_sql_params(TypeOrKey, SQLOrData, Params, State) ->
-    DisablePreparedStatements = maps:get(prepares, State, #{}) =:= disabled,
-    BinKey = to_bin(TypeOrKey),
-    case get_template(BinKey, State) of
-        undefined ->
-            {SQLOrData, Params};
-        {Statement, RowTemplate} ->
-            Rendered = render_prepare_sql_row(RowTemplate, SQLOrData),
-            case DisablePreparedStatements of
-                true ->
-                    {Statement, Rendered};
-                false ->
-                    {BinKey, Rendered}
-            end
-    end.
+proc_sql_params(ActionResId, #{} = Map, [], State) when is_binary(ActionResId) ->
+    %% When this connector is called from actions/bridges.
+    DisablePreparedStatements = prepared_statements_disabled(State),
+    {ExprTemplate, RowTemplate} = get_template(ActionResId, State),
+    Rendered = render_prepare_sql_row(RowTemplate, Map),
+    case DisablePreparedStatements of
+        true ->
+            {query, ExprTemplate, Rendered};
+        false ->
+            {prepared_query, ActionResId, Rendered}
+    end;
+proc_sql_params(prepared_query, ConnResId, Params, State) ->
+    %% When this connector is called from authn/authz modules
+    DisablePreparedStatements = prepared_statements_disabled(State),
+    case DisablePreparedStatements of
+        true ->
+            #{query_templates := #{ConnResId := {ExprTemplate, _VarsTemplate}}} = State,
+            {query, ExprTemplate, Params};
+        false ->
+            %% Connector resource id itself is the prepared statement name
+            {prepared_query, ConnResId, Params}
+    end;
+proc_sql_params(QueryType, SQL, Params, _State) when
+    is_atom(QueryType) andalso
+        (is_binary(SQL) orelse is_list(SQL)) andalso
+        is_list(Params)
+->
+    %% When called to do ad-hoc commands/queries.
+    {QueryType, SQL, Params}.
+
+prepared_statements_disabled(State) ->
+    maps:get(prepares, State, #{}) =:= disabled.
 
 
 get_template(Key, #{installed_channels := Channels} = _State) when is_map_key(Key, Channels) ->
 get_template(Key, #{installed_channels := Channels} = _State) when is_map_key(Key, Channels) ->
     BinKey = to_bin(Key),
     BinKey = to_bin(Key),
@@ -420,21 +446,17 @@ get_templated_statement(Key, #{installed_channels := Channels} = _State) when
 ->
 ->
     BinKey = to_bin(Key),
     BinKey = to_bin(Key),
     ChannelState = maps:get(BinKey, Channels),
     ChannelState = maps:get(BinKey, Channels),
-    ChannelPreparedStatements = maps:get(prepares, ChannelState),
-    maps:get(BinKey, ChannelPreparedStatements);
+    case ChannelState of
+        #{prepares := disabled, query_templates := #{BinKey := {ExprTemplate, _}}} ->
+            ExprTemplate;
+        #{prepares := #{BinKey := ExprTemplate}} ->
+            ExprTemplate
+    end;
 get_templated_statement(Key, #{prepares := PrepStatements}) ->
 get_templated_statement(Key, #{prepares := PrepStatements}) ->
     BinKey = to_bin(Key),
     BinKey = to_bin(Key),
     maps:get(BinKey, PrepStatements).
     maps:get(BinKey, PrepStatements).
 
 
-on_sql_query(Key, InstId, PoolName, Type, NameOrSQL, Data) ->
-    emqx_trace:rendered_action_template(
-        Key,
-        #{
-            statement_type => Type,
-            statement_or_name => NameOrSQL,
-            data => Data
-        }
-    ),
+on_sql_query(InstId, PoolName, Type, NameOrSQL, Data) ->
     try ecpool:pick_and_do(PoolName, {?MODULE, Type, [NameOrSQL, Data]}, no_handover) of
     try ecpool:pick_and_do(PoolName, {?MODULE, Type, [NameOrSQL, Data]}, no_handover) of
         {error, Reason} ->
         {error, Reason} ->
             ?tp(
             ?tp(
@@ -785,6 +807,7 @@ handle_batch_result([{error, Error} | _Rest], _Acc) ->
     TranslatedError = translate_to_log_context(Error),
     TranslatedError = translate_to_log_context(Error),
     {error, {unrecoverable_error, export_error(TranslatedError)}};
     {error, {unrecoverable_error, export_error(TranslatedError)}};
 handle_batch_result([], Acc) ->
 handle_batch_result([], Acc) ->
+    ?tp("postgres_success_batch_result", #{row_count => Acc}),
     {ok, Acc}.
     {ok, Acc}.
 
 
 translate_to_log_context({error, Reason}) ->
 translate_to_log_context({error, Reason}) ->