Просмотр исходного кода

feat(snowflake): attempt to provide extra details when JWT is considered invalid

Thales Macedo Garitezi 1 год назад
Родитель
Сommit
b6f791ef59

+ 101 - 0
apps/emqx_bridge_snowflake/docs/dev-quick-ref.md

@@ -39,6 +39,85 @@ openssl genrsa 2048 | openssl pkcs8 -topk8 -inform PEM -out snowflake_rsa_key.pr
 openssl rsa -in snowflake_rsa_key.private.pem -pubout -out snowflake_rsa_key.public.pem
 ```
 
+## SQL setup cheat sheet
+
+```sql
+CREATE USER IF NOT EXISTS testuser
+    PASSWORD = 'TestUser99'
+    MUST_CHANGE_PASSWORD = FALSE;
+
+-- Set the RSA public key for 'testuser'
+-- Note: Remove the '-----BEGIN PUBLIC KEY-----' and '-----END PUBLIC KEY-----' lines from your PEM file,
+-- and include the remaining content below, preserving line breaks.
+
+ALTER USER testuser SET RSA_PUBLIC_KEY = '
+<YOUR_PUBLIC_KEY_CONTENTS_LINE_1>
+<YOUR_PUBLIC_KEY_CONTENTS_LINE_2>
+<YOUR_PUBLIC_KEY_CONTENTS_LINE_3>
+<YOUR_PUBLIC_KEY_CONTENTS_LINE_4>
+';
+
+
+create or replace role testrole;
+
+create warehouse testwarehouse;
+
+CREATE OR REPLACE TABLE testdatabase.public.test0 (
+    clientid STRING,
+    topic STRING,
+    payload BINARY,
+    publish_received_at TIMESTAMP_LTZ
+);
+
+CREATE STAGE IF NOT EXISTS testdatabase.public.teststage0
+FILE_FORMAT = (TYPE = CSV PARSE_HEADER = TRUE FIELD_OPTIONALLY_ENCLOSED_BY = '"')
+COPY_OPTIONS = (ON_ERROR = CONTINUE PURGE = TRUE);
+
+CREATE PIPE IF NOT EXISTS testdatabase.public.testpipe0 AS
+COPY INTO testdatabase.public.test0
+FROM @testdatabase.public.teststage0
+MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE;
+
+
+-- Grant the USAGE privilege on the database and schema that contain the pipe object.
+grant usage on database testdatabase to role testrole;
+grant usage on schema testdatabase.public to role testrole;
+-- Grant the USAGE privilege on the warehouse (only needed for test account)
+grant usage on warehouse testwarehouse to role testrole;
+-- Grant the INSERT, SELECT, TRUNCATE and DELETE privileges on the target table
+-- for cleaning up after tests
+grant insert, select, truncate, delete on testdatabase.public.test0 to role testrole;
+-- Grant the READ and WRITE privilege on the internal stage.
+grant read, write on stage testdatabase.public.teststage0 to role testrole;
+-- Grant the OPERATE and MONITOR privileges on the pipe object.
+grant operate, monitor on pipe testdatabase.public.testpipe0 to role testrole;
+-- Grant the role to a user
+grant role testrole to user testuser;
+-- Set the role as the default role for the user
+alter user testuser set default_role = testrole;
+
+-- Create a role for the Snowpipe privileges.
+create or replace role snowpipe;
+-- Grant the USAGE privilege on the database and schema that contain the pipe object.
+grant usage on database testdatabase to role snowpipe;
+grant usage on schema testdatabase.public to role snowpipe;
+-- Grant the INSERT and SELECT privileges on the target table.
+grant insert, select on testdatabase.public.test0 to role snowpipe;
+-- Grant the READ and WRITE privilege on the internal stage.
+grant read, write on stage testdatabase.public.teststage0 to role snowpipe;
+-- Grant the OPERATE and MONITOR privileges on the pipe object.
+grant operate, monitor on pipe testdatabase.public.testpipe0 to role snowpipe;
+-- Grant the role to a user
+grant role snowpipe to user snowpipeuser;
+-- Set the role as the default role for the user
+alter user snowpipeuser set default_role = snowpipe;
+
+---- OPTIONAL
+-- not required, but helps gather JWT failure reasons like skewed time
+grant monitor on account to role snowpipe;
+
+```
+
 ## Basic helper functions
 
 ### Elixir
@@ -60,6 +139,13 @@ dsn = "snowflake"
 query = fn conn, sql -> :odbc.sql_query(conn, sql |> to_charlist()) end
 ```
 
+Or, if you have already set up a connector:
+
+```elixir
+conn_res_id = "connector:snowflake:name"
+query = fn sql -> conn_res_id |> :ecpool.pick_and_do(fn conn -> :odbc.sql_query(conn, sql |> to_charlist()) end, :handover) end
+```
+
 ### Erlang
 
 ```erlang
@@ -74,6 +160,13 @@ DSN = "snowflake".
 Query = fun(Conn, Sql) -> odbc:sql_query(Conn, Sql) end.
 ```
 
+Or, if you have already set up a connector:
+
+```Erlang
+ConnResId = <<"connector:snowflake:name">>.
+Query = fun(Sql) -> ecpool:pick_and_do(ConnResId, fun(Conn) -> odbc:sql_query(Conn, Sql) end, handover) end.
+```
+
 ## Initialize Database and user accounts
 
 ### Elixir
@@ -148,6 +241,10 @@ query.(conn, "grant operate, monitor on pipe #{fqn_pipe} to role #{snowpipe_role
 query.(conn, "grant role #{snowpipe_role} to user #{snowpipe_user}")
 # Set the role as the default role for the user
 query.(conn, "alter user #{snowpipe_user} set default_role = #{snowpipe_role}")
+
+## OPTIONAL
+# not required, but helps gather JWT failure reasons like skewed time
+query.(conn, "grant monitor on account to role #{snowpipe_role}")
 ```
 
 ### Erlang
@@ -236,4 +333,8 @@ Query(Conn, ["grant role ", SnowpipeRole, " to user ", SnowpipeUser]).
 
 % Set the role as the default role for the user
 Query(Conn, ["alter user ", SnowpipeUser, " set default_role = ", SnowpipeRole]).
+
+%%  OPTIONAL
+% not required, but helps gather JWT failure reasons like skewed time
+Query(Conn, ["grant monitor on account to role ", SnowpipeRole]).
 ```

+ 21 - 0
apps/emqx_bridge_snowflake/docs/user-guide.md

@@ -115,3 +115,24 @@ SELECT
 FROM
   "t/#"
 ```
+
+## Debugging invalid JWT failures
+
+In case the following error appears in the logs:
+
+```
+JWT token is invalid. [eaa17004-5830-4b84-b357-2a981d28606f]
+```
+
+Copy the UUID in that message (`eaa17004-5830-4b84-b357-2a981d28606f` in this example) and on a Snowflake worksheet with an user that has admin privileges on the account (at least `MONITOR` on account):
+
+```sql
+select SYSTEM$GET_LOGIN_FAILURE_DETAILS('eaa17004-5830-4b84-b357-2a981d28606f');
+```
+
+Which can output more hints on why the JWT is considered invalid by Snowflake:
+
+Ex:
+```json
+{"clientIP":"xxx","clientType":"OTHER","clientVersion":"","username":null,"errorCode":"JWT_TOKEN_INVALID_ISSUE_TIME","timestamp":1728418411}
+```

+ 116 - 14
apps/emqx_bridge_snowflake/src/emqx_bridge_snowflake_connector.erl

@@ -42,7 +42,8 @@
     connect/1,
     disconnect/1,
     do_health_check_connector/1,
-    do_stage_file/6
+    do_stage_file/6,
+    do_get_login_failure_details/2
 ]).
 
 %% `emqx_connector_aggreg_delivery' API
@@ -267,7 +268,8 @@ on_remove_channel(
     destroy_action(ActionResId, ActionState),
     ConnState = ConnState0#{installed_actions := InstalledActions},
     {ok, ConnState};
-on_remove_channel(_ConnResId, ConnState, _ActionResId) ->
+on_remove_channel(_ConnResId, ConnState, ActionResId) ->
+    ensure_common_action_destroyed(ActionResId),
     {ok, ConnState}.
 
 -spec on_get_channels(connector_resource_id()) ->
@@ -282,12 +284,12 @@ on_get_channels(ConnResId) ->
 ) ->
     ?status_connected | ?status_disconnected.
 on_get_channel_status(
-    _ConnResId,
+    ConnResId,
     ActionResId,
     _ConnState = #{installed_actions := InstalledActions}
 ) when is_map_key(ActionResId, InstalledActions) ->
     ActionState = maps:get(ActionResId, InstalledActions),
-    action_status(ActionResId, ActionState);
+    action_status(ConnResId, ActionResId, ActionState);
 on_get_channel_status(_ConnResId, _ActionResId, _ConnState) ->
     ?status_disconnected.
 
@@ -636,7 +638,7 @@ create_action(
 ) ->
     maybe
         {ok, ActionState0} ?= start_http_pool(ActionResId, ActionConfig, ConnState),
-        _ = check_snowpipe_user_permission(ActionResId, ActionState0),
+        _ = check_snowpipe_user_permission(ActionResId, ConnResId, ActionState0),
         start_aggregator(ConnResId, ActionResId, ActionConfig, ActionState0)
     end.
 
@@ -798,6 +800,10 @@ destroy_action(ActionResId, ActionState) ->
         _ ->
             ok
     end,
+    ok = ensure_common_action_destroyed(ActionResId),
+    ok.
+
+ensure_common_action_destroyed(ActionResId) ->
     ok = ehttpc_sup:stop_pool(ActionResId),
     ok = emqx_connector_jwt:delete_jwt(?JWT_TABLE, ActionResId),
     ok.
@@ -907,13 +913,19 @@ insert_report_request(HTTPPool, Opts, HTTPClientConfig) ->
     JWTToken = emqx_connector_jwt:ensure_jwt(JWTConfig),
     AuthnHeader = [<<"BEARER ">>, JWTToken],
     Headers = http_headers(AuthnHeader),
+    QString = insert_report_query_string(Opts),
     InsertReportPath =
-        case Opts of
-            #{begin_mark := BeginMark} when is_binary(BeginMark) ->
-                <<InsertReportPath0/binary, "?beginMark=", BeginMark/binary>>;
+        case QString of
+            <<>> ->
+                InsertReportPath0;
             _ ->
-                InsertReportPath0
+                <<InsertReportPath0/binary, "?", QString/binary>>
         end,
+    ?SLOG(debug, #{
+        msg => "snowflake_insert_report_request",
+        path => InsertReportPath,
+        pool => HTTPPool
+    }),
     Req = {InsertReportPath, Headers},
     Response = ?MODULE:do_insert_report_request(HTTPPool, Req, RequestTTL, MaxRetries),
     case Response of
@@ -924,6 +936,13 @@ insert_report_request(HTTPPool, Opts, HTTPClientConfig) ->
             {error, Response}
     end.
 
+insert_report_query_string(Opts0) ->
+    Opts1 = maps:with([begin_mark, request_id], Opts0),
+    Opts2 = maps:filter(fun(_K, V) -> is_binary(V) end, Opts1),
+    Opts3 = emqx_utils_maps:rename(begin_mark, <<"beginMark">>, Opts2),
+    Opts = emqx_utils_maps:rename(request_id, <<"requestId">>, Opts3),
+    emqx_utils_conv:bin(uri_string:compose_query(maps:to_list(Opts))).
+
 %% Internal export only for mocking
 do_insert_report_request(HTTPPool, Req, RequestTTL, MaxRetries) ->
     ehttpc:request(HTTPPool, get, Req, RequestTTL, MaxRetries).
@@ -941,7 +960,7 @@ row_to_map(Row0, Headers) ->
     Row = lists:zip(Headers, Row2),
     maps:from_list(Row).
 
-action_status(ActionResId, #{mode := aggregated} = ActionState) ->
+action_status(ConnResId, ActionResId, #{mode := aggregated} = ActionState) ->
     #{
         aggreg_id := AggregId,
         http := #{connect_timeout := ConnectTimeout}
@@ -950,7 +969,7 @@ action_status(ActionResId, #{mode := aggregated} = ActionState) ->
     Timestamp = erlang:system_time(second),
     ok = emqx_connector_aggregator:tick(AggregId, Timestamp),
     ok = check_aggreg_upload_errors(AggregId),
-    ok = check_snowpipe_user_permission(ActionResId, ActionState),
+    ok = check_snowpipe_user_permission(ActionResId, ConnResId, ActionState),
     case http_pool_workers_healthy(ActionResId, ConnectTimeout) of
         true ->
             ?status_connected;
@@ -1027,9 +1046,10 @@ check_aggreg_upload_errors(AggregId) ->
             ok
     end.
 
-check_snowpipe_user_permission(HTTPPool, ActionState) ->
+check_snowpipe_user_permission(HTTPPool, ODBCPool, ActionState) ->
     #{http := HTTPClientConfig} = ActionState,
-    Opts = #{},
+    RequestId = list_to_binary(uuid:uuid_to_string(uuid:get_v4())),
+    Opts = #{request_id => RequestId},
     case insert_report_request(HTTPPool, Opts, HTTPClientConfig) of
         {ok, _} ->
             ok;
@@ -1039,7 +1059,10 @@ check_snowpipe_user_permission(HTTPPool, ActionState) ->
                     {ok, JSON} -> JSON;
                     {error, _} -> Body0
                 end,
-            ?SLOG(debug, #{
+            FailureDetails = try_get_jwt_failure_details(ODBCPool, HTTPPool, Body),
+            ?SLOG(warning, FailureDetails#{
+                pool => HTTPPool,
+                request_id => RequestId,
                 msg => "snowflake_check_snowpipe_user_permission_error",
                 body => Body
             }),
@@ -1068,6 +1091,85 @@ check_snowpipe_user_permission(HTTPPool, ActionState) ->
             throw(Msg)
     end.
 
+try_get_jwt_failure_details(ODBCPool, ActionResId, RespBody) ->
+    maybe
+        #{<<"message">> := Msg} ?= RespBody,
+        {ok, RequestId} ?= get_jwt_error_request_id(Msg),
+        {selected, [_ColHeader], [{Val}]} ?= get_login_failure_details(ODBCPool, RequestId),
+        true ?= is_list(Val) orelse {error, {not_string, Val}},
+        {ok, Data} ?= emqx_utils_json:safe_decode(Val, [return_maps]),
+        #{failure_details => Data}
+    else
+        Err ->
+            ?SLOG(debug, #{
+                msg => "snowflake_action_get_jwt_failure_details_err",
+                action_res_id => ActionResId,
+                reason => Err
+            }),
+            %% When role doesn't have MONITOR on account, the command returns:
+            %% SQL compilation error:\nUnknown function SYSTEM$GET_LOGIN_FAILURE_DETAILS
+            %% SQLSTATE IS: 42601
+            Hint = <<
+                "To get more details about the login failure, log into your",
+                " Snowflake account with an admin role that has the MONITOR privilege",
+                " on the account, and check the output of",
+                " SYSTEM$GET_LOGIN_FAILURE_DETAILS on logged request id."
+            >>,
+            #{failure_details => undefined, hint => Hint}
+    end.
+
+%% Even if we provide a request id for the HTTP call, snowflake decides to use its own
+%% request id when returning JWT errors...
+get_jwt_error_request_id(Msg) when is_binary(Msg) ->
+    %% ece3379e-6715-4d48-adeb-d5507d05e3e2
+    HexChar = <<"[0-9a-fA-F]">>,
+    UUIDRE = iolist_to_binary([
+        HexChar,
+        <<"{8}-">>,
+        HexChar,
+        <<"{4}-">>,
+        HexChar,
+        <<"{4}-">>,
+        HexChar,
+        <<"{4}-">>,
+        HexChar,
+        <<"{12}">>
+    ]),
+    RE = <<"\\[(", UUIDRE/binary, ")\\]">>,
+    case re:run(Msg, RE, [{capture, all_but_first, binary}]) of
+        {match, [UUID]} ->
+            {ok, UUID};
+        _ ->
+            {error, <<"couldn't obtain jwt request id from error message">>}
+    end;
+get_jwt_error_request_id(_) ->
+    {error, <<"couldn't obtain jwt request id from error message">>}.
+
+get_login_failure_details(ODBCPool, RequestId) ->
+    try
+        ecpool:pick_and_do(
+            ODBCPool,
+            fun(ConnPid) ->
+                ?MODULE:do_get_login_failure_details(ConnPid, RequestId)
+            end,
+            %% Must be executed by the ecpool worker, which owns the ODBC connection.
+            handover
+        )
+    catch
+        K:E:Stacktrace ->
+            {error, #{kind => K, reason => E, stacktrace => Stacktrace}}
+    end.
+
+do_get_login_failure_details(ConnPid, RequestId) ->
+    SQL0 = iolist_to_binary([
+        <<"select SYSTEM$GET_LOGIN_FAILURE_DETAILS('">>,
+        RequestId,
+        <<"')">>
+    ]),
+    SQL = binary_to_list(SQL0),
+    Timeout = 5_000,
+    odbc:sql_query(ConnPid, SQL, Timeout).
+
 %%------------------------------------------------------------------------------
 %% Tests
 %%------------------------------------------------------------------------------

+ 16 - 1
apps/emqx_bridge_snowflake/test/emqx_bridge_snowflake_SUITE.erl

@@ -31,6 +31,7 @@
 -define(STAGE, <<"teststage0">>).
 -define(TABLE, <<"test0">>).
 -define(WAREHOUSE, <<"testwarehouse">>).
+-define(PIPE, <<"testpipe0">>).
 -define(PIPE_USER, <<"snowpipeuser">>).
 
 -define(CONF_COLUMN_ORDER, ?CONF_COLUMN_ORDER([])).
@@ -294,7 +295,7 @@ aggregated_action_config(Overrides0) ->
                     <<"private_key">> => private_key(),
                     <<"database">> => ?DATABASE,
                     <<"schema">> => ?SCHEMA,
-                    <<"pipe">> => <<"testpipe0">>,
+                    <<"pipe">> => ?PIPE,
                     <<"stage">> => ?STAGE,
                     <<"pipe_user">> => ?PIPE_USER,
                     <<"connect_timeout">> => <<"5s">>,
@@ -1136,6 +1137,20 @@ t_wrong_snowpipe_user(init, #{mock := true} = Config) ->
         Body = emqx_utils_json:encode(InsertReportResponse),
         {ok, 401, Headers, Body}
     end),
+    meck:expect(Mod, do_get_login_failure_details, fun(_Connpid, _RequestId) ->
+        Details = #{
+            <<"clientIP">> => <<"127.0.0.1">>,
+            <<"clientType">> => <<"OTHER">>,
+            <<"clientVersion">> => <<"">>,
+            <<"errorCode">> => <<"JWT_TOKEN_INVALID_ISSUE_TIME">>,
+            <<"timestamp">> => 1728418411,
+            <<"username">> => null
+        },
+        Col = binary_to_list(emqx_utils_json:encode(Details)),
+        {selected, ["SYSTEM$GET_LOGIN_FAILURE_DETAILS('92D86B2E-D652-4D2D-9780-A6ED28B38356')"], [
+            {Col}
+        ]}
+    end),
     maps:to_list(Config);
 t_wrong_snowpipe_user(init, #{} = Config) ->
     maps:to_list(Config).