Преглед на файлове

Merge pull request #10944 from thalesmg/pubsub-jwt-on-demand-v50

feat(gcp_pubsub): generate jwt tokens on demand without workers (5.1)
Thales Macedo Garitezi преди 2 години
родител
ревизия
f9ff1007a0

+ 45 - 62
apps/emqx_bridge_gcp_pubsub/src/emqx_bridge_gcp_pubsub_connector.erl

@@ -6,6 +6,7 @@
 
 -behaviour(emqx_resource).
 
+-include_lib("jose/include/jose_jwk.hrl").
 -include_lib("emqx_connector/include/emqx_connector_tables.hrl").
 -include_lib("emqx_resource/include/emqx_resource.hrl").
 -include_lib("typerefl/include/types.hrl").
@@ -26,7 +27,6 @@
 ]).
 -export([reply_delegator/3]).
 
--type jwt_worker() :: binary().
 -type service_account_json() :: emqx_bridge_gcp_pubsub:service_account_json().
 -type config() :: #{
     connect_timeout := emqx_schema:duration_ms(),
@@ -38,7 +38,7 @@
 }.
 -type state() :: #{
     connect_timeout := timer:time(),
-    jwt_worker_id := jwt_worker(),
+    jwt_config := emqx_connector_jwt:jwt_config(),
     max_retries := non_neg_integer(),
     payload_template := emqx_plugin_libs_rule:tmpl_token(),
     pool_name := binary(),
@@ -97,12 +97,12 @@ on_start(
         {enable_pipelining, maps:get(enable_pipelining, Config, ?DEFAULT_PIPELINE_SIZE)}
     ],
     #{
-        jwt_worker_id := JWTWorkerId,
+        jwt_config := JWTConfig,
         project_id := ProjectId
-    } = ensure_jwt_worker(ResourceId, Config),
+    } = parse_jwt_config(ResourceId, Config),
     State = #{
         connect_timeout => ConnectTimeout,
-        jwt_worker_id => JWTWorkerId,
+        jwt_config => JWTConfig,
         max_retries => MaxRetries,
         payload_template => emqx_plugin_libs_rule:preproc_tmpl(PayloadTemplate),
         pool_name => ResourceId,
@@ -136,14 +136,13 @@ on_start(
 -spec on_stop(resource_id(), state()) -> ok | {error, term()}.
 on_stop(
     ResourceId,
-    _State = #{jwt_worker_id := JWTWorkerId}
+    _State = #{jwt_config := JWTConfig}
 ) ->
-    ?tp(gcp_pubsub_stop, #{resource_id => ResourceId, jwt_worker_id => JWTWorkerId}),
+    ?tp(gcp_pubsub_stop, #{resource_id => ResourceId, jwt_config => JWTConfig}),
     ?SLOG(info, #{
         msg => "stopping_gcp_pubsub_bridge",
         connector => ResourceId
     }),
-    emqx_connector_jwt_sup:ensure_worker_deleted(JWTWorkerId),
     emqx_connector_jwt:delete_jwt(?JWT_TABLE, ResourceId),
     ehttpc_sup:stop_pool(ResourceId).
 
@@ -228,12 +227,12 @@ on_get_status(ResourceId, #{connect_timeout := Timeout} = State) ->
 %% Helper fns
 %%-------------------------------------------------------------------------------------------------
 
--spec ensure_jwt_worker(resource_id(), config()) ->
+-spec parse_jwt_config(resource_id(), config()) ->
     #{
-        jwt_worker_id := jwt_worker(),
+        jwt_config := emqx_connector_jwt:jwt_config(),
         project_id := binary()
     }.
-ensure_jwt_worker(ResourceId, #{
+parse_jwt_config(ResourceId, #{
     service_account_json := ServiceAccountJSON
 }) ->
     #{
@@ -246,8 +245,32 @@ ensure_jwt_worker(ResourceId, #{
     Aud = <<"https://pubsub.googleapis.com/">>,
     ExpirationMS = timer:hours(1),
     Alg = <<"RS256">>,
-    Config = #{
-        private_key => PrivateKeyPEM,
+    JWK =
+        try jose_jwk:from_pem(PrivateKeyPEM) of
+            JWK0 = #jose_jwk{} ->
+                %% Don't wrap the JWK with `emqx_secret:wrap' here;
+                %% this is stored in mnesia and synchronized among the
+                %% nodes, and will easily become a bad fun.
+                JWK0;
+            [] ->
+                ?tp(error, gcp_pubsub_connector_startup_error, #{error => empty_key}),
+                throw("empty private in service account json");
+            {error, Reason} ->
+                Error = {invalid_private_key, Reason},
+                ?tp(error, gcp_pubsub_connector_startup_error, #{error => Error}),
+                throw("invalid private key in service account json");
+            Error0 ->
+                Error = {invalid_private_key, Error0},
+                ?tp(error, gcp_pubsub_connector_startup_error, #{error => Error}),
+                throw("invalid private key in service account json")
+        catch
+            Kind:Reason ->
+                Error = {Kind, Reason},
+                ?tp(error, gcp_pubsub_connector_startup_error, #{error => Error}),
+                throw("invalid private key in service account json")
+        end,
+    JWTConfig = #{
+        jwk => emqx_secret:wrap(JWK),
         resource_id => ResourceId,
         expiration => ExpirationMS,
         table => ?JWT_TABLE,
@@ -257,46 +280,8 @@ ensure_jwt_worker(ResourceId, #{
         kid => KId,
         alg => Alg
     },
-
-    JWTWorkerId = <<"gcp_pubsub_jwt_worker:", ResourceId/binary>>,
-    Worker =
-        case emqx_connector_jwt_sup:ensure_worker_present(JWTWorkerId, Config) of
-            {ok, Worker0} ->
-                Worker0;
-            Error ->
-                ?tp(error, "gcp_pubsub_bridge_jwt_worker_failed_to_start", #{
-                    connector => ResourceId,
-                    reason => Error
-                }),
-                _ = emqx_connector_jwt_sup:ensure_worker_deleted(JWTWorkerId),
-                throw(failed_to_start_jwt_worker)
-        end,
-    MRef = monitor(process, Worker),
-    Ref = emqx_connector_jwt_worker:ensure_jwt(Worker),
-
-    %% to ensure that this resource and its actions will be ready to
-    %% serve when started, we must ensure that the first JWT has been
-    %% produced by the worker.
-    receive
-        {Ref, token_created} ->
-            ?tp(gcp_pubsub_bridge_jwt_created, #{resource_id => ResourceId}),
-            demonitor(MRef, [flush]),
-            ok;
-        {'DOWN', MRef, process, Worker, Reason} ->
-            ?tp(error, "gcp_pubsub_bridge_jwt_worker_failed_to_start", #{
-                connector => ResourceId,
-                reason => Reason
-            }),
-            _ = emqx_connector_jwt_sup:ensure_worker_deleted(JWTWorkerId),
-            throw(failed_to_start_jwt_worker)
-    after 10_000 ->
-        ?tp(warning, "gcp_pubsub_bridge_jwt_timeout", #{connector => ResourceId}),
-        demonitor(MRef, [flush]),
-        _ = emqx_connector_jwt_sup:ensure_worker_deleted(JWTWorkerId),
-        throw(timeout_creating_jwt)
-    end,
     #{
-        jwt_worker_id => JWTWorkerId,
+        jwt_config => JWTConfig,
         project_id => ProjectId
     }.
 
@@ -322,14 +307,10 @@ publish_path(
 ) ->
     <<"/v1/projects/", ProjectId/binary, "/topics/", PubSubTopic/binary, ":publish">>.
 
--spec get_jwt_authorization_header(resource_id()) -> [{binary(), binary()}].
-get_jwt_authorization_header(ResourceId) ->
-    case emqx_connector_jwt:lookup_jwt(?JWT_TABLE, ResourceId) of
-        %% Since we synchronize the JWT creation during resource start
-        %% (see `on_start/2'), this will be always be populated.
-        {ok, JWT} ->
-            [{<<"Authorization">>, <<"Bearer ", JWT/binary>>}]
-    end.
+-spec get_jwt_authorization_header(emqx_connector_jwt:jwt_config()) -> [{binary(), binary()}].
+get_jwt_authorization_header(JWTConfig) ->
+    JWT = emqx_connector_jwt:ensure_jwt(JWTConfig),
+    [{<<"Authorization">>, <<"Bearer ", JWT/binary>>}].
 
 -spec do_send_requests_sync(
     state(),
@@ -342,6 +323,7 @@ get_jwt_authorization_header(ResourceId) ->
     | {error, term()}.
 do_send_requests_sync(State, Requests, ResourceId) ->
     #{
+        jwt_config := JWTConfig,
         pool_name := PoolName,
         max_retries := MaxRetries,
         request_ttl := RequestTTL
@@ -354,7 +336,7 @@ do_send_requests_sync(State, Requests, ResourceId) ->
             requests => Requests
         }
     ),
-    Headers = get_jwt_authorization_header(ResourceId),
+    Headers = get_jwt_authorization_header(JWTConfig),
     Payloads =
         lists:map(
             fun({send_message, Selected}) ->
@@ -466,6 +448,7 @@ do_send_requests_sync(State, Requests, ResourceId) ->
 ) -> {ok, pid()}.
 do_send_requests_async(State, Requests, ReplyFunAndArgs, ResourceId) ->
     #{
+        jwt_config := JWTConfig,
         pool_name := PoolName,
         request_ttl := RequestTTL
     } = State,
@@ -477,7 +460,7 @@ do_send_requests_async(State, Requests, ReplyFunAndArgs, ResourceId) ->
             requests => Requests
         }
     ),
-    Headers = get_jwt_authorization_header(ResourceId),
+    Headers = get_jwt_authorization_header(JWTConfig),
     Payloads =
         lists:map(
             fun({send_message, Selected}) ->

+ 71 - 61
apps/emqx_bridge_gcp_pubsub/test/emqx_bridge_gcp_pubsub_SUITE.erl

@@ -55,8 +55,9 @@ single_config_tests() ->
         t_not_of_service_account_type,
         t_json_missing_fields,
         t_invalid_private_key,
-        t_jwt_worker_start_timeout,
-        t_failed_to_start_jwt_worker,
+        t_truncated_private_key,
+        t_jose_error_tuple,
+        t_jose_other_error,
         t_stop,
         t_get_status_ok,
         t_get_status_down,
@@ -580,14 +581,7 @@ t_publish_success(Config) ->
     ServiceAccountJSON = ?config(service_account_json, Config),
     TelemetryTable = ?config(telemetry_table, Config),
     Topic = <<"t/topic">>,
-    ?check_trace(
-        create_bridge(Config),
-        fun(Res, Trace) ->
-            ?assertMatch({ok, _}, Res),
-            ?assertMatch([_], ?of_kind(gcp_pubsub_bridge_jwt_created, Trace)),
-            ok
-        end
-    ),
+    ?assertMatch({ok, _}, create_bridge(Config)),
     {ok, #{<<"id">> := RuleId}} = create_rule_and_action_http(Config),
     on_exit(fun() -> ok = emqx_rule_engine:delete_rule(RuleId) end),
     assert_empty_metrics(ResourceId),
@@ -686,14 +680,7 @@ t_publish_success_local_topic(Config) ->
     ok.
 
 t_create_via_http(Config) ->
-    ?check_trace(
-        create_bridge_http(Config),
-        fun(Res, Trace) ->
-            ?assertMatch({ok, _}, Res),
-            ?assertMatch([_, _], ?of_kind(gcp_pubsub_bridge_jwt_created, Trace)),
-            ok
-        end
-    ),
+    ?assertMatch({ok, _}, create_bridge_http(Config)),
     ok.
 
 t_publish_templated(Config) ->
@@ -705,16 +692,12 @@ t_publish_templated(Config) ->
         "{\"payload\": \"${payload}\","
         " \"pub_props\": ${pub_props}}"
     >>,
-    ?check_trace(
+    ?assertMatch(
+        {ok, _},
         create_bridge(
             Config,
             #{<<"payload_template">> => PayloadTemplate}
-        ),
-        fun(Res, Trace) ->
-            ?assertMatch({ok, _}, Res),
-            ?assertMatch([_], ?of_kind(gcp_pubsub_bridge_jwt_created, Trace)),
-            ok
-        end
+        )
     ),
     {ok, #{<<"id">> := RuleId}} = create_rule_and_action_http(Config),
     on_exit(fun() -> ok = emqx_rule_engine:delete_rule(RuleId) end),
@@ -908,36 +891,26 @@ t_invalid_private_key(Config) ->
                                 #{<<"private_key">> => InvalidPrivateKeyPEM}
                         }
                     ),
-                    #{?snk_kind := "gcp_pubsub_bridge_jwt_worker_failed_to_start"},
+                    #{?snk_kind := gcp_pubsub_connector_startup_error},
                     20_000
                 ),
             Res
         end,
         fun(Res, Trace) ->
             ?assertMatch({ok, _}, Res),
-            ?assertMatch(
-                [#{reason := Reason}] when
-                    Reason =:= noproc orelse
-                        Reason =:= {shutdown, {error, empty_key}},
-                ?of_kind("gcp_pubsub_bridge_jwt_worker_failed_to_start", Trace)
-            ),
             ?assertMatch(
                 [#{error := empty_key}],
-                ?of_kind(connector_jwt_worker_startup_error, Trace)
+                ?of_kind(gcp_pubsub_connector_startup_error, Trace)
             ),
             ok
         end
     ),
     ok.
 
-t_jwt_worker_start_timeout(Config) ->
-    InvalidPrivateKeyPEM = <<"xxxxxx">>,
+t_truncated_private_key(Config) ->
+    InvalidPrivateKeyPEM = <<"-----BEGIN PRIVATE KEY-----\nMIIEvQI...">>,
     ?check_trace(
         begin
-            ?force_ordering(
-                #{?snk_kind := will_never_happen},
-                #{?snk_kind := connector_jwt_worker_make_key}
-            ),
             {Res, {ok, _Event}} =
                 ?wait_async_action(
                     create_bridge(
@@ -947,14 +920,71 @@ t_jwt_worker_start_timeout(Config) ->
                                 #{<<"private_key">> => InvalidPrivateKeyPEM}
                         }
                     ),
-                    #{?snk_kind := "gcp_pubsub_bridge_jwt_timeout"},
+                    #{?snk_kind := gcp_pubsub_connector_startup_error},
                     20_000
                 ),
             Res
         end,
         fun(Res, Trace) ->
             ?assertMatch({ok, _}, Res),
-            ?assertMatch([_], ?of_kind("gcp_pubsub_bridge_jwt_timeout", Trace)),
+            ?assertMatch(
+                [#{error := {error, function_clause}}],
+                ?of_kind(gcp_pubsub_connector_startup_error, Trace)
+            ),
+            ok
+        end
+    ),
+    ok.
+
+t_jose_error_tuple(Config) ->
+    ?check_trace(
+        begin
+            {Res, {ok, _Event}} =
+                ?wait_async_action(
+                    emqx_common_test_helpers:with_mock(
+                        jose_jwk,
+                        from_pem,
+                        fun(_PrivateKeyPEM) -> {error, some_error} end,
+                        fun() -> create_bridge(Config) end
+                    ),
+                    #{?snk_kind := gcp_pubsub_connector_startup_error},
+                    20_000
+                ),
+            Res
+        end,
+        fun(Res, Trace) ->
+            ?assertMatch({ok, _}, Res),
+            ?assertMatch(
+                [#{error := {invalid_private_key, some_error}}],
+                ?of_kind(gcp_pubsub_connector_startup_error, Trace)
+            ),
+            ok
+        end
+    ),
+    ok.
+
+t_jose_other_error(Config) ->
+    ?check_trace(
+        begin
+            {Res, {ok, _Event}} =
+                ?wait_async_action(
+                    emqx_common_test_helpers:with_mock(
+                        jose_jwk,
+                        from_pem,
+                        fun(_PrivateKeyPEM) -> {unknown, error} end,
+                        fun() -> create_bridge(Config) end
+                    ),
+                    #{?snk_kind := gcp_pubsub_connector_startup_error},
+                    20_000
+                ),
+            Res
+        end,
+        fun(Res, Trace) ->
+            ?assertMatch({ok, _}, Res),
+            ?assertMatch(
+                [#{error := {invalid_private_key, {unknown, error}}}],
+                ?of_kind(gcp_pubsub_connector_startup_error, Trace)
+            ),
             ok
         end
     ),
@@ -1309,26 +1339,6 @@ t_unrecoverable_error(Config) ->
     ),
     ok.
 
-t_failed_to_start_jwt_worker(Config) ->
-    ?check_trace(
-        emqx_common_test_helpers:with_mock(
-            emqx_connector_jwt_sup,
-            ensure_worker_present,
-            fun(_JWTWorkerId, _Config) -> {error, restarting} end,
-            fun() ->
-                ?assertMatch({ok, _}, create_bridge(Config))
-            end
-        ),
-        fun(Trace) ->
-            ?assertMatch(
-                [#{reason := {error, restarting}}],
-                ?of_kind("gcp_pubsub_bridge_jwt_worker_failed_to_start", Trace)
-            ),
-            ok
-        end
-    ),
-    ok.
-
 t_stop(Config) ->
     Name = ?config(gcp_pubsub_name, Config),
     {ok, _} = create_bridge(Config),

+ 86 - 1
apps/emqx_connector/src/emqx_connector_jwt.erl

@@ -19,15 +19,33 @@
 -include_lib("emqx_connector/include/emqx_connector_tables.hrl").
 -include_lib("emqx_resource/include/emqx_resource.hrl").
 -include_lib("snabbkaffe/include/snabbkaffe.hrl").
+-include_lib("jose/include/jose_jwt.hrl").
+-include_lib("jose/include/jose_jws.hrl").
 
 %% API
 -export([
     lookup_jwt/1,
     lookup_jwt/2,
-    delete_jwt/2
+    delete_jwt/2,
+    ensure_jwt/1
 ]).
 
 -type jwt() :: binary().
+-type wrapped_jwk() :: fun(() -> jose_jwk:key()).
+-type jwk() :: jose_jwk:key().
+-type jwt_config() :: #{
+    expiration := timer:time(),
+    resource_id := resource_id(),
+    table := ets:table(),
+    jwk := wrapped_jwk() | jwk(),
+    iss := binary(),
+    sub := binary(),
+    aud := binary(),
+    kid := binary(),
+    alg := binary()
+}.
+
+-export_type([jwt_config/0, jwt/0]).
 
 -spec lookup_jwt(resource_id()) -> {ok, jwt()} | {error, not_found}.
 lookup_jwt(ResourceId) ->
@@ -57,3 +75,70 @@ delete_jwt(TId, ResourceId) ->
         error:badarg ->
             ok
     end.
+
+%% @doc Attempts to retrieve a valid JWT from the cache.  If there is
+%% none or if the cached token is expired, generates an caches a fresh
+%% one.
+-spec ensure_jwt(jwt_config()) -> jwt().
+ensure_jwt(JWTConfig) ->
+    #{resource_id := ResourceId, table := Table} = JWTConfig,
+    case lookup_jwt(Table, ResourceId) of
+        {error, not_found} ->
+            JWT = do_generate_jwt(JWTConfig),
+            store_jwt(JWTConfig, JWT),
+            JWT;
+        {ok, JWT0} ->
+            case is_about_to_expire(JWT0) of
+                true ->
+                    JWT = do_generate_jwt(JWTConfig),
+                    store_jwt(JWTConfig, JWT),
+                    JWT;
+                false ->
+                    JWT0
+            end
+    end.
+
+%%-----------------------------------------------------------------------------------------
+%% Helper fns
+%%-----------------------------------------------------------------------------------------
+
+-spec do_generate_jwt(jwt_config()) -> jwt().
+do_generate_jwt(#{
+    expiration := ExpirationMS,
+    iss := Iss,
+    sub := Sub,
+    aud := Aud,
+    kid := KId,
+    alg := Alg,
+    jwk := WrappedJWK
+}) ->
+    JWK = emqx_secret:unwrap(WrappedJWK),
+    Headers = #{
+        <<"alg">> => Alg,
+        <<"kid">> => KId
+    },
+    Now = erlang:system_time(seconds),
+    ExpirationS = erlang:convert_time_unit(ExpirationMS, millisecond, second),
+    Claims = #{
+        <<"iss">> => Iss,
+        <<"sub">> => Sub,
+        <<"aud">> => Aud,
+        <<"iat">> => Now,
+        <<"exp">> => Now + ExpirationS
+    },
+    JWT0 = jose_jwt:sign(JWK, Headers, Claims),
+    {_, JWT} = jose_jws:compact(JWT0),
+    JWT.
+
+-spec store_jwt(jwt_config(), jwt()) -> ok.
+store_jwt(#{resource_id := ResourceId, table := TId}, JWT) ->
+    true = ets:insert(TId, {{ResourceId, jwt}, JWT}),
+    ?tp(emqx_connector_jwt_token_stored, #{resource_id => ResourceId}),
+    ok.
+
+-spec is_about_to_expire(jwt()) -> boolean().
+is_about_to_expire(JWT) ->
+    #jose_jwt{fields = #{<<"exp">> := Exp}} = jose_jwt:peek(JWT),
+    Now = erlang:system_time(seconds),
+    GraceExp = Exp - timer:seconds(5),
+    Now >= GraceExp.

+ 2 - 37
apps/emqx_connector/src/emqx_connector_jwt_worker.erl

@@ -189,49 +189,14 @@ terminate(_Reason, State) ->
 %% Helper fns
 %%-----------------------------------------------------------------------------------------
 
--spec do_generate_jwt(state()) -> jwt().
-do_generate_jwt(
-    #{
-        expiration := ExpirationMS,
-        iss := Iss,
-        sub := Sub,
-        aud := Aud,
-        kid := KId,
-        alg := Alg,
-        jwk := JWK
-    } = _State
-) ->
-    Headers = #{
-        <<"alg">> => Alg,
-        <<"kid">> => KId
-    },
-    Now = erlang:system_time(seconds),
-    ExpirationS = erlang:convert_time_unit(ExpirationMS, millisecond, second),
-    Claims = #{
-        <<"iss">> => Iss,
-        <<"sub">> => Sub,
-        <<"aud">> => Aud,
-        <<"iat">> => Now,
-        <<"exp">> => Now + ExpirationS
-    },
-    JWT0 = jose_jwt:sign(JWK, Headers, Claims),
-    {_, JWT} = jose_jws:compact(JWT0),
-    JWT.
-
 -spec generate_and_store_jwt(state()) -> state().
 generate_and_store_jwt(State0) ->
-    JWT = do_generate_jwt(State0),
-    store_jwt(State0, JWT),
+    JWTConfig = maps:without([jwt, refresh_timer], State0),
+    JWT = emqx_connector_jwt:ensure_jwt(JWTConfig),
     ?tp(connector_jwt_worker_refresh, #{jwt => JWT}),
     State1 = State0#{jwt := JWT},
     ensure_timer(State1).
 
--spec store_jwt(state(), jwt()) -> ok.
-store_jwt(#{resource_id := ResourceId, table := TId}, JWT) ->
-    true = ets:insert(TId, {{ResourceId, jwt}, JWT}),
-    ?tp(connector_jwt_worker_token_stored, #{resource_id => ResourceId}),
-    ok.
-
 -spec ensure_timer(state()) -> state().
 ensure_timer(
     State = #{

+ 66 - 0
apps/emqx_connector/test/emqx_connector_jwt_SUITE.erl

@@ -18,7 +18,10 @@
 
 -include_lib("eunit/include/eunit.hrl").
 -include_lib("common_test/include/ct.hrl").
+-include_lib("jose/include/jose_jwt.hrl").
+-include_lib("jose/include/jose_jws.hrl").
 -include("emqx_connector_tables.hrl").
+-include_lib("snabbkaffe/include/snabbkaffe.hrl").
 
 -compile([export_all, nowarn_export_all]).
 
@@ -51,6 +54,33 @@ end_per_testcase(_TestCase, _Config) ->
 insert_jwt(TId, ResourceId, JWT) ->
     ets:insert(TId, {{ResourceId, jwt}, JWT}).
 
+generate_private_key_pem() ->
+    PublicExponent = 65537,
+    Size = 2048,
+    Key = public_key:generate_key({rsa, Size, PublicExponent}),
+    DERKey = public_key:der_encode('PrivateKeyInfo', Key),
+    public_key:pem_encode([{'PrivateKeyInfo', DERKey, not_encrypted}]).
+
+generate_config() ->
+    PrivateKeyPEM = generate_private_key_pem(),
+    ResourceID = emqx_guid:gen(),
+    #{
+        private_key => PrivateKeyPEM,
+        expiration => timer:hours(1),
+        resource_id => ResourceID,
+        table => ets:new(test_jwt_table, [ordered_set, public]),
+        iss => <<"issuer">>,
+        sub => <<"subject">>,
+        aud => <<"audience">>,
+        kid => <<"key id">>,
+        alg => <<"RS256">>
+    }.
+
+is_expired(JWT) ->
+    #jose_jwt{fields = #{<<"exp">> := Exp}} = jose_jwt:peek(JWT),
+    Now = erlang:system_time(seconds),
+    Now >= Exp.
+
 %%-----------------------------------------------------------------------------
 %% Test cases
 %%-----------------------------------------------------------------------------
@@ -77,3 +107,39 @@ t_delete_jwt(_Config) ->
     ?assertEqual(ok, emqx_connector_jwt:delete_jwt(TId, ResourceId)),
     ?assertEqual({error, not_found}, emqx_connector_jwt:lookup_jwt(TId, ResourceId)),
     ok.
+
+t_ensure_jwt(_Config) ->
+    Config0 =
+        #{
+            table := Table,
+            resource_id := ResourceId,
+            private_key := PrivateKeyPEM
+        } = generate_config(),
+    JWK = jose_jwk:from_pem(PrivateKeyPEM),
+    Config1 = maps:without([private_key], Config0),
+    Expiration = timer:seconds(10),
+    JWTConfig = Config1#{jwk => JWK, expiration := Expiration},
+    ?assertEqual({error, not_found}, emqx_connector_jwt:lookup_jwt(Table, ResourceId)),
+    ?check_trace(
+        begin
+            JWT0 = emqx_connector_jwt:ensure_jwt(JWTConfig),
+            ?assertNot(is_expired(JWT0)),
+            %% should refresh 5 s before expiration
+            ct:sleep(Expiration - 5500),
+            JWT1 = emqx_connector_jwt:ensure_jwt(JWTConfig),
+            ?assertNot(is_expired(JWT1)),
+            %% fully expired
+            ct:sleep(2 * Expiration),
+            JWT2 = emqx_connector_jwt:ensure_jwt(JWTConfig),
+            ?assertNot(is_expired(JWT2)),
+            {JWT0, JWT1, JWT2}
+        end,
+        fun({JWT0, JWT1, JWT2}, Trace) ->
+            ?assertNotEqual(JWT0, JWT1),
+            ?assertNotEqual(JWT1, JWT2),
+            ?assertNotEqual(JWT2, JWT0),
+            ?assertMatch([_, _, _], ?of_kind(emqx_connector_jwt_token_stored, Trace)),
+            ok
+        end
+    ),
+    ok.

+ 2 - 2
apps/emqx_connector/test/emqx_connector_jwt_worker_SUITE.erl

@@ -176,7 +176,7 @@ t_refresh(_Config) ->
             {{ok, _Pid}, {ok, _Event}} =
                 ?wait_async_action(
                     emqx_connector_jwt_worker:start_link(Config),
-                    #{?snk_kind := connector_jwt_worker_token_stored},
+                    #{?snk_kind := emqx_connector_jwt_token_stored},
                     5_000
                 ),
             {ok, FirstJWT} = emqx_connector_jwt:lookup_jwt(Table, ResourceId),
@@ -209,7 +209,7 @@ t_refresh(_Config) ->
         fun({FirstJWT, SecondJWT, ThirdJWT}, Trace) ->
             ?assertMatch(
                 [_, _, _ | _],
-                ?of_kind(connector_jwt_worker_token_stored, Trace)
+                ?of_kind(emqx_connector_jwt_token_stored, Trace)
             ),
             ?assertNotEqual(FirstJWT, SecondJWT),
             ?assertNotEqual(SecondJWT, ThirdJWT),

+ 1 - 0
changes/ee/feat-10944.en.md

@@ -0,0 +1 @@
+Improved the GCP PubSub bridge to avoid a potential issue where messages could fail to be sent when restarting a node.