فهرست منبع

Merge pull request #13806 from thalesmg/20240913-m-snowflake-fixes

snowflake action fixes
Thales Macedo Garitezi 1 سال پیش
والد
کامیت
0d2c1b0013

+ 1 - 1
apps/emqx_bridge_snowflake/src/emqx_bridge_snowflake_action_schema.erl

@@ -84,7 +84,7 @@ fields(aggreg_parameters) ->
             })},
         {pipelining, mk(pos_integer(), #{default => 100, desc => ?DESC("pipelining")})},
         {pool_size, mk(pos_integer(), #{default => 8, desc => ?DESC("pool_size")})},
-        {max_retries, mk(non_neg_integer(), #{required => false, desc => ?DESC("max_retries")})},
+        {max_retries, mk(non_neg_integer(), #{default => 3, desc => ?DESC("max_retries")})},
         {max_block_size,
             mk(
                 emqx_schema:bytesize(),

+ 92 - 9
apps/emqx_bridge_snowflake/src/emqx_bridge_snowflake_connector.erl

@@ -58,7 +58,7 @@
 ]).
 
 %% Internal exports only for mocking
--export([do_insert_files_request/4]).
+-export([do_insert_files_request/4, do_insert_report_request/4]).
 
 %%------------------------------------------------------------------------------
 %% Type declarations
@@ -392,6 +392,7 @@ stage_file(ODBCPool, Filename, Database, Schema, Stage, ActionName) ->
     {ok, file:filename()} | {error, term()}.
 do_stage_file(ConnPid, Filename, Database, Schema, Stage, ActionName) ->
     SQL = stage_file_sql(Filename, Database, Schema, Stage, ActionName),
+    ?tp(debug, "snowflake_stage_file", #{sql => SQL, action => ActionName}),
     %% Should we also check if it actually succeeded by inspecting reportFiles?
     odbc:sql_query(ConnPid, SQL).
 
@@ -607,6 +608,7 @@ process_complete(TransferState0) ->
             {ok, 200, _, Body} ->
                 {ok, emqx_utils_json:decode(Body, [return_maps])};
             Res ->
+                ?tp("snowflake_insert_files_request_failed", #{response => Res}),
                 %% TODO: retry?
                 exit({insert_failed, Res})
         end
@@ -625,6 +627,7 @@ create_action(
 ) ->
     maybe
         {ok, ActionState0} ?= start_http_pool(ActionResId, ActionConfig, ConnState),
+        _ = check_snowpipe_user_permission(ActionResId, ActionState0),
         start_aggregator(ConnResId, ActionResId, ActionConfig, ActionState0)
     end.
 
@@ -879,13 +882,7 @@ insert_report_request(HTTPPool, Opts, HTTPClientConfig) ->
                 InsertReportPath0
         end,
     Req = {InsertReportPath, Headers},
-    Response = ehttpc:request(
-        HTTPPool,
-        get,
-        Req,
-        RequestTTL,
-        MaxRetries
-    ),
+    Response = ?MODULE:do_insert_report_request(HTTPPool, Req, RequestTTL, MaxRetries),
     case Response of
         {ok, 200, _Headers, Body0} ->
             Body = emqx_utils_json:decode(Body0, [return_maps]),
@@ -894,6 +891,10 @@ insert_report_request(HTTPPool, Opts, HTTPClientConfig) ->
             {error, Response}
     end.
 
+%% Internal export only for mocking
+do_insert_report_request(HTTPPool, Req, RequestTTL, MaxRetries) ->
+    ehttpc:request(HTTPPool, get, Req, RequestTTL, MaxRetries).
+
 http_headers(AuthnHeader) ->
     [
         {<<"X-Snowflake-Authorization-Token-Type">>, <<"KEYPAIR_JWT">>},
@@ -915,6 +916,8 @@ action_status(ActionResId, #{mode := aggregated} = ActionState) ->
     %% NOTE: This will effectively trigger uploads of buffers yet to be uploaded.
     Timestamp = erlang:system_time(second),
     ok = emqx_connector_aggregator:tick(AggregId, Timestamp),
+    ok = check_aggreg_upload_errors(AggregId),
+    ok = check_snowpipe_user_permission(ActionResId, ActionState),
     case http_pool_workers_healthy(ActionResId, ConnectTimeout) of
         true ->
             ?status_connected;
@@ -968,7 +971,7 @@ http_pool_workers_healthy(HTTPPool, Timeout) ->
 
 %% https://docs.snowflake.com/en/sql-reference/identifiers-syntax
 needs_quoting(Identifier) ->
-    nomatch =:= re:run(Identifier, <<"^[A-Za-z_][A-Za-z_0-9$]$">>, [{capture, none}]).
+    nomatch =:= re:run(Identifier, <<"^[A-Za-z_][A-Za-z_0-9$]*$">>, [{capture, none}]).
 
 maybe_quote(Identifier) ->
     case needs_quoting(Identifier) of
@@ -977,3 +980,83 @@ maybe_quote(Identifier) ->
         false ->
             Identifier
     end.
+
+check_aggreg_upload_errors(AggregId) ->
+    case emqx_connector_aggregator:take_error(AggregId) of
+        [Error] ->
+            ?tp("snowflake_check_aggreg_upload_error_found", #{error => Error}),
+            %% TODO
+            %% This approach means that, for example, 3 upload failures will cause
+            %% the channel to be marked as unhealthy for 3 consecutive health checks.
+            ErrorMessage = emqx_utils:format(Error),
+            throw({unhealthy_target, ErrorMessage});
+        [] ->
+            ok
+    end.
+
+check_snowpipe_user_permission(HTTPPool, ActionState) ->
+    #{http := HTTPClientConfig} = ActionState,
+    Opts = #{},
+    case insert_report_request(HTTPPool, Opts, HTTPClientConfig) of
+        {ok, _} ->
+            ok;
+        {error, {ok, 401, _, Body0}} ->
+            Body =
+                case emqx_utils_json:safe_decode(Body0, [return_maps]) of
+                    {ok, JSON} -> JSON;
+                    {error, _} -> Body0
+                end,
+            ?SLOG(debug, #{
+                msg => "snowflake_check_snowpipe_user_permission_error",
+                body => Body
+            }),
+            Msg = <<
+                "Configured pipe user does not have permissions to operate on pipe,"
+                " or does not exist. Please check your configuration."
+            >>,
+            throw({unhealthy_target, Msg});
+        {error, {ok, StatusCode, _}} ->
+            Msg = iolist_to_binary([
+                <<"Error checking if configured snowpipe user has permissions.">>,
+                <<" HTTP Status Code:">>,
+                integer_to_binary(StatusCode)
+            ]),
+            %% Not marking it as unhealthy because it could be spurious
+            throw(Msg);
+        {error, {ok, StatusCode, _, Body}} ->
+            Msg = iolist_to_binary([
+                <<"Error checking if configured snowpipe user has permissions.">>,
+                <<" HTTP Status Code:">>,
+                integer_to_binary(StatusCode),
+                <<"; Body: ">>,
+                Body
+            ]),
+            %% Not marking it as unhealthy because it could be spurious
+            throw(Msg)
+    end.
+
+%%------------------------------------------------------------------------------
+%% Tests
+%%------------------------------------------------------------------------------
+-ifdef(TEST).
+-include_lib("eunit/include/eunit.hrl").
+
+needs_quoting_test_() ->
+    PositiveCases = [
+        <<"with spaece">>,
+        <<"1_number_in_beginning">>,
+        <<"contains_açéntõ">>,
+        <<"with-hyphen">>,
+        <<"">>
+    ],
+    NegativeCases = [
+        <<"testdatabase">>,
+        <<"TESTDATABASE">>,
+        <<"TestDatabase">>,
+        <<"with_underscore">>,
+        <<"with_underscore_10">>
+    ],
+    Positive = lists:map(fun(Id) -> {Id, ?_assert(needs_quoting(Id))} end, PositiveCases),
+    Negative = lists:map(fun(Id) -> {Id, ?_assertNot(needs_quoting(Id))} end, NegativeCases),
+    Positive ++ Negative.
+-endif.

+ 14 - 1
apps/emqx_bridge_snowflake/src/emqx_bridge_snowflake_connector_schema.erl

@@ -68,7 +68,12 @@ fields(connector_config) ->
                 #{required => true, desc => ?DESC("server")},
                 ?SERVER_OPTS
             )},
-        {account, mk(binary(), #{required => true, desc => ?DESC("account")})},
+        {account,
+            mk(binary(), #{
+                required => true,
+                desc => ?DESC("account"),
+                validator => fun account_id_validator/1
+            })},
         {dsn, mk(binary(), #{required => true, desc => ?DESC("dsn")})}
         | Fields
     ] ++
@@ -143,3 +148,11 @@ connector_example(put) ->
 %%------------------------------------------------------------------------------
 
 mk(Type, Meta) -> hoconsc:mk(Type, Meta).
+
+account_id_validator(AccountId) ->
+    case binary:split(AccountId, <<"-">>) of
+        [_, _] ->
+            ok;
+        _ ->
+            {error, <<"Account identifier must be of form ORGID-ACCOUNTNAME">>}
+    end.

+ 140 - 11
apps/emqx_bridge_snowflake/test/emqx_bridge_snowflake_SUITE.erl

@@ -53,8 +53,8 @@ init_per_suite(Config) ->
     case os:getenv("SNOWFLAKE_ACCOUNT_ID", "") of
         "" ->
             Mock = true,
-            AccountId = "mocked_account_id",
-            Server = <<"mocked.snowflakecomputing.com">>,
+            AccountId = "mocked_orgid-mocked_account_id",
+            Server = <<"mocked_orgid-mocked_account_id.snowflakecomputing.com">>,
             Username = <<"mock_username">>,
             Password = <<"mock_password">>;
         AccountId ->
@@ -206,6 +206,12 @@ mock_snowflake() ->
         {selected, Headers, Rows}
     end),
     meck:expect(Mod, do_health_check_connector, fun(_ConnPid) -> true end),
+    %% Used in health checks
+    meck:expect(Mod, do_insert_report_request, fun(_HTTPPool, _Req, _RequestTTL, _MaxRetries) ->
+        Headers = [],
+        Body = emqx_utils_json:encode(#{}),
+        {ok, 200, Headers, Body}
+    end),
     meck:expect(Mod, do_insert_files_request, fun(_HTTPPool, _Req, _RequestTTL, _MaxRetries) ->
         Headers = [],
         Body = emqx_utils_json:encode(#{}),
@@ -462,26 +468,36 @@ get_begin_mark(#{mock := false}, ActionResId) ->
         emqx_bridge_snowflake_connector:insert_report(ActionResId, #{}),
     BeginMark.
 
-wait_until_processed(Config, ActionResId, BeginMark) when is_list(Config) ->
-    wait_until_processed(maps:from_list(Config), ActionResId, BeginMark);
-wait_until_processed(#{mock := true} = Config, _ActionResId, _BeginMark) ->
-    {ok, _} = ?block_until(#{?snk_kind := "mock_snowflake_insert_file_request"}),
+wait_until_processed(Config, ActionResId, BeginMark) ->
+    wait_until_processed(Config, ActionResId, BeginMark, _ExpectedNumFiles = 1).
+
+wait_until_processed(Config, ActionResId, BeginMark, ExpectedNumFiles) when is_list(Config) ->
+    wait_until_processed(maps:from_list(Config), ActionResId, BeginMark, ExpectedNumFiles);
+wait_until_processed(#{mock := true} = Config, _ActionResId, _BeginMark, ExpectedNumFiles) ->
+    snabbkaffe:block_until(
+        ?match_n_events(
+            ExpectedNumFiles,
+            #{?snk_kind := "mock_snowflake_insert_file_request"}
+        ),
+        _Timeout = infinity,
+        _BackInTIme = infinity
+    ),
     InsertRes = maps:get(mocked_insert_report, Config, #{}),
     {ok, InsertRes};
-wait_until_processed(#{mock := false} = Config, ActionResId, BeginMark) ->
+wait_until_processed(#{mock := false} = Config, ActionResId, BeginMark, ExpectedNumFiles) ->
     {ok, Res} =
         emqx_bridge_snowflake_connector:insert_report(ActionResId, #{begin_mark => BeginMark}),
     ct:pal("insert report (begin mark ~s):\n  ~p", [BeginMark, Res]),
     case Res of
         #{
-            <<"files">> := [_ | _],
+            <<"files">> := Files,
             <<"statistics">> := #{<<"activeFilesCount">> := 0}
-        } ->
+        } when length(Files) >= ExpectedNumFiles ->
             ct:pal("insertReport response:\n  ~p", [Res]),
             {ok, Res};
         _ ->
             ct:sleep(2_000),
-            wait_until_processed(Config, ActionResId, BeginMark)
+            wait_until_processed(Config, ActionResId, BeginMark, ExpectedNumFiles)
     end.
 
 bin2hex(Bin) ->
@@ -564,7 +580,8 @@ t_aggreg_upload(Config) ->
                     #{?snk_kind := connector_aggreg_delivery_completed, action := AggregId}
                 ),
             %% Check the uploaded objects.
-            wait_until_processed(Config, ActionResId, BeginMark),
+            ExpectedNumFiles = 2,
+            wait_until_processed(Config, ActionResId, BeginMark, ExpectedNumFiles),
             Rows = get_all_rows(Config),
             [
                 P1Hex,
@@ -1032,6 +1049,118 @@ t_aggreg_invalid_column_values(Config0) ->
     ),
     ok.
 
+t_aggreg_inexistent_database(init, Config) when is_list(Config) ->
+    t_aggreg_inexistent_database(init, maps:from_list(Config));
+t_aggreg_inexistent_database(init, #{mock := true} = Config) ->
+    Mod = ?CONN_MOD,
+    meck:expect(Mod, do_stage_file, fun(
+        _ConnPid, _Filename, _Database, _Schema, _Stage, _ActionName
+    ) ->
+        Msg =
+            "SQL compilation error:, Database 'INEXISTENT' does not"
+            " exist or not authorized. SQLSTATE IS: 02000",
+        {error, Msg}
+    end),
+    maps:to_list(Config);
+t_aggreg_inexistent_database(init, #{} = Config) ->
+    maps:to_list(Config).
+t_aggreg_inexistent_database(Config) ->
+    ?check_trace(
+        emqx_bridge_v2_testlib:snk_timetrap(),
+        begin
+            {ok, _} = emqx_bridge_v2_testlib:create_bridge_api(
+                Config,
+                #{<<"parameters">> => #{<<"database">> => <<"inexistent">>}}
+            ),
+            ActionResId = emqx_bridge_v2_testlib:bridge_id(Config),
+            %% BeginMark = get_begin_mark(Config, ActionResId),
+            {ok, _Rule} =
+                emqx_bridge_v2_testlib:create_rule_and_action_http(
+                    ?ACTION_TYPE_BIN, <<"">>, Config, #{
+                        sql => sql1()
+                    }
+                ),
+            Messages1 = lists:map(fun mk_message/1, [
+                {<<"C1">>, <<"sf/a/b/c">>, <<"{\"hello\":\"world\"}">>},
+                {<<"C2">>, <<"sf/foo/bar">>, <<"baz">>},
+                {<<"C3">>, <<"sf/t/42">>, <<"">>}
+            ]),
+            ok = publish_messages(Messages1),
+            %% Wait until the insert files request fails
+            ct:pal("waiting for delivery to fail..."),
+            ?block_until(#{?snk_kind := "aggregated_buffer_delivery_failed"}),
+            %% When channel health check happens, we check aggregator for errors.
+            %% Current implementation will mark the action as unhealthy.
+            ct:pal("waiting for delivery failure to be noticed by health check..."),
+            ?block_until(#{?snk_kind := "snowflake_check_aggreg_upload_error_found"}),
+
+            ?retry(
+                _Sleep = 500,
+                _Retries = 10,
+                ?assertMatch(
+                    {200, #{
+                        <<"error">> :=
+                            <<"{unhealthy_target,", _/binary>>
+                    }},
+                    emqx_bridge_v2_testlib:simplify_result(
+                        emqx_bridge_v2_testlib:get_action_api(Config)
+                    )
+                )
+            ),
+
+            ?assertEqual(3, emqx_resource_metrics:matched_get(ActionResId)),
+            %% Currently, failure metrics are not bumped when aggregated uploads fail
+            ?assertEqual(0, emqx_resource_metrics:failed_get(ActionResId)),
+
+            ok
+        end,
+        []
+    ),
+    ok.
+
+%% Checks that we detect early that the configured snowpipe user does not have the proper
+%% credentials (or does not exist) for accessing Snowpipe's REST API.
+t_wrong_snowpipe_user(init, Config) when is_list(Config) ->
+    t_wrong_snowpipe_user(init, maps:from_list(Config));
+t_wrong_snowpipe_user(init, #{mock := true} = Config) ->
+    Mod = ?CONN_MOD,
+    InsertReportResponse = #{
+        <<"code">> => <<"390144">>,
+        <<"data">> => null,
+        <<"headers">> => null,
+        <<"message">> => <<"JWT token is invalid. [92d86b2e-d652-4d2d-9780-a6ed28b38356]">>,
+        <<"success">> => false
+    },
+    meck:expect(Mod, do_insert_report_request, fun(_HTTPPool, _Req, _RequestTTL, _MaxRetries) ->
+        Headers = [],
+        Body = emqx_utils_json:encode(InsertReportResponse),
+        {ok, 401, Headers, Body}
+    end),
+    maps:to_list(Config);
+t_wrong_snowpipe_user(init, #{} = Config) ->
+    maps:to_list(Config).
+t_wrong_snowpipe_user(Config) ->
+    ?check_trace(
+        emqx_bridge_v2_testlib:snk_timetrap(),
+        begin
+            {ok, _} = emqx_bridge_v2_testlib:create_connector_api(Config),
+            ?assertMatch(
+                {ok,
+                    {{_, 201, _}, _, #{
+                        <<"status">> := <<"disconnected">>,
+                        <<"error">> := <<"{unhealthy_target,", _/binary>>
+                    }}},
+                emqx_bridge_v2_testlib:create_kind_api(
+                    Config,
+                    #{<<"parameters">> => #{<<"pipe_user">> => <<"idontexist">>}}
+                )
+            ),
+            ok
+        end,
+        []
+    ),
+    ok.
+
 %% Todo: test scenarios
 %% * User error in rule definition; e.g.:
 %%    - forgot to use `bin2hexstr' to encode the payload

+ 53 - 0
apps/emqx_bridge_snowflake/test/emqx_bridge_snowflake_tests.erl

@@ -0,0 +1,53 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%--------------------------------------------------------------------
+-module(emqx_bridge_snowflake_tests).
+
+-include_lib("eunit/include/eunit.hrl").
+-include("src/emqx_bridge_snowflake.hrl").
+
+%%------------------------------------------------------------------------------
+%% Helper fns
+%%------------------------------------------------------------------------------
+
+-define(CONNECTOR_NAME, <<"my_connector">>).
+
+parse_and_check_connector(InnerConfig) ->
+    emqx_bridge_v2_testlib:parse_and_check_connector(
+        ?CONNECTOR_TYPE_BIN,
+        ?CONNECTOR_NAME,
+        InnerConfig
+    ).
+
+connector_config(Overrides) ->
+    Base = emqx_bridge_snowflake_SUITE:connector_config(
+        ?CONNECTOR_NAME,
+        <<"orgid-accountid">>,
+        <<"orgid-accountid.snowflakecomputing.com">>,
+        <<"odbcuser">>,
+        <<"odbcpass">>
+    ),
+    emqx_utils_maps:deep_merge(Base, Overrides).
+
+%%------------------------------------------------------------------------------
+%% Test cases
+%%------------------------------------------------------------------------------
+
+validation_test_() ->
+    [
+        {"good config", ?_assertMatch(#{}, parse_and_check_connector(connector_config(#{})))},
+        {"account must contain org id and account id",
+            ?_assertThrow(
+                {_SchemaMod, [
+                    #{
+                        reason := <<"Account identifier must be of form ORGID-ACCOUNTNAME">>,
+                        kind := validation_error
+                    }
+                ]},
+                parse_and_check_connector(
+                    connector_config(#{
+                        <<"account">> => <<"onlyaccid">>
+                    })
+                )
+            )}
+    ].

+ 8 - 4
apps/emqx_connector_aggregator/src/emqx_connector_aggreg_delivery.erl

@@ -158,7 +158,7 @@ process_write(Delivery = #delivery{callback_module = Mod, transfer = Transfer0})
             Delivery#delivery{transfer = Transfer};
         {error, Reason} ->
             %% Todo: handle more gracefully?  Retry?
-            error({transfer_failed, Reason})
+            exit({upload_failed, Reason})
     end.
 
 process_complete(#delivery{id = Id, empty = true}) ->
@@ -169,9 +169,13 @@ process_complete(#delivery{
 }) ->
     Trailer = emqx_connector_aggreg_csv:close(Container),
     Transfer = Mod:process_append(Trailer, Transfer0),
-    {ok, Completed} = Mod:process_complete(Transfer),
-    ?tp(connector_aggreg_delivery_completed, #{action => Id, transfer => Completed}),
-    ok.
+    case Mod:process_complete(Transfer) of
+        {ok, Completed} ->
+            ?tp(connector_aggreg_delivery_completed, #{action => Id, transfer => Completed}),
+            ok;
+        {error, Error} ->
+            exit({upload_failed, Error})
+    end.
 
 %%
 

+ 1 - 2
apps/emqx_connector_aggregator/src/emqx_connector_aggregator.erl

@@ -415,8 +415,7 @@ handle_delivery_exit(Buffer, {shutdown, {skipped, Reason}}, St = #st{name = Name
     ok = discard_buffer(Buffer),
     St;
 handle_delivery_exit(Buffer, Error, St = #st{name = Name}) ->
-    ?SLOG(error, #{
-        msg => "aggregated_buffer_delivery_failed",
+    ?tp(error, "aggregated_buffer_delivery_failed", #{
         action => Name,
         buffer => {Buffer#buffer.since, Buffer#buffer.seq},
         filename => Buffer#buffer.filename,

+ 8 - 0
apps/emqx_utils/src/emqx_utils_redact.erl

@@ -23,12 +23,18 @@
 -define(IS_KEY_HEADERS(K), (K == headers orelse K == <<"headers">> orelse K == "headers")).
 
 %% NOTE: keep alphabetical order
+is_sensitive_key(account_key) -> true;
+is_sensitive_key("account_key") -> true;
+is_sensitive_key(<<"account_key">>) -> true;
 is_sensitive_key(aws_secret_access_key) -> true;
 is_sensitive_key("aws_secret_access_key") -> true;
 is_sensitive_key(<<"aws_secret_access_key">>) -> true;
 is_sensitive_key(password) -> true;
 is_sensitive_key("password") -> true;
 is_sensitive_key(<<"password">>) -> true;
+is_sensitive_key(private_key) -> true;
+is_sensitive_key("private_key") -> true;
+is_sensitive_key(<<"private_key">>) -> true;
 is_sensitive_key(secret) -> true;
 is_sensitive_key("secret") -> true;
 is_sensitive_key(<<"secret">>) -> true;
@@ -236,8 +242,10 @@ redact_test_() ->
 
     Types = [atom, string, binary],
     Keys = [
+        account_key,
         aws_secret_access_key,
         password,
+        private_key,
         secret,
         secret_key,
         secret_access_key,