Browse Source

Merge branch 'master' into emqx_config

Shawn 4 years ago
parent
commit
1d2cb6cb81

+ 10 - 0
apps/emqx_authz/etc/emqx_authz.conf

@@ -1,5 +1,15 @@
 emqx_authz:{
 emqx_authz:{
     rules: [
     rules: [
+       # {
+       #      type: http
+       #      config: {
+       #          url: "https://emqx.com"
+       #          headers: {
+       #              Accept: "application/json"
+       #              Content-Type: "application/json"
+       #          }
+       #      }
+       # },
        # {
        # {
        #     type: mysql
        #     type: mysql
        #     config: {
        #     config: {

+ 10 - 2
apps/emqx_authz/src/emqx_authz.erl

@@ -111,6 +111,14 @@ init_rule(#{topics := Topics,
           topics => NTopics
           topics => NTopics
          };
          };
 
 
+init_rule(#{principal := Principal,
+            type := http,
+            config := #{url := Url} = Config
+           } = Rule) ->
+    NConfig = maps:merge(Config, #{base_url => maps:remove(query, Url)}),
+    NRule = create_resource(Rule#{config := NConfig}),
+    NRule#{principal => compile_principal(Principal)};
+
 init_rule(#{principal := Principal,
 init_rule(#{principal := Principal,
             type := DB
             type := DB
          } = Rule) when DB =:= redis;
          } = Rule) when DB =:= redis;
@@ -173,8 +181,8 @@ b2l(B) when is_binary(B) -> binary_to_list(B).
 -spec(authorize(emqx_types:clientinfo(), emqx_types:all(), emqx_topic:topic(), emqx_permission_rule:acl_result(), rules())
 -spec(authorize(emqx_types:clientinfo(), emqx_types:all(), emqx_topic:topic(), emqx_permission_rule:acl_result(), rules())
       -> {stop, allow} | {ok, deny}).
       -> {stop, allow} | {ok, deny}).
 authorize(#{username := Username,
 authorize(#{username := Username,
-              peerhost := IpAddress
-             } = Client, PubSub, Topic, _DefaultResult, Rules) ->
+            peerhost := IpAddress
+           } = Client, PubSub, Topic, _DefaultResult, Rules) ->
     case do_authorize(Client, PubSub, Topic, Rules) of
     case do_authorize(Client, PubSub, Topic, Rules) of
         {matched, allow} ->
         {matched, allow} ->
             ?LOG(info, "Client succeeded authorization: Username: ~p, IP: ~p, Topic: ~p, Permission: allow", [Username, IpAddress, Topic]),
             ?LOG(info, "Client succeeded authorization: Username: ~p, IP: ~p, Topic: ~p, Permission: allow", [Username, IpAddress, Topic]),

+ 99 - 0
apps/emqx_authz/src/emqx_authz_http.erl

@@ -0,0 +1,99 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_authz_http).
+
+-include("emqx_authz.hrl").
+-include_lib("emqx/include/emqx.hrl").
+-include_lib("emqx/include/logger.hrl").
+
+%% AuthZ Callbacks
+-export([ authorize/4
+        , description/0
+        ]).
+
+-ifdef(TEST).
+-compile(export_all).
+-compile(nowarn_export_all).
+-endif.
+
+description() ->
+    "AuthZ with http".
+
+authorize(Client, PubSub, Topic,
+            #{resource_id := ResourceID,
+              type := http,
+              config := #{url := #{path := Path} = Url,
+                          headers := Headers,
+                          method := Method,
+                          request_timeout := RequestTimeout} = Config
+             }) ->
+    Request = case Method of
+                  get  -> 
+                      Query = maps:get(query, Url, ""),
+                      Path1 = replvar(Path ++ "?" ++ Query, PubSub, Topic, Client),
+                      {Path1, maps:to_list(Headers)};
+                  _ ->
+                      Body0 = serialize_body(
+                                maps:get('Accept', Headers, <<"application/json">>),
+                                maps:get(body, Config, #{})
+                              ),
+                      Body1 = replvar(Body0, PubSub, Topic, Client),
+                      Path1 = replvar(Path, PubSub, Topic, Client),
+                      {Path1, maps:to_list(Headers), Body1}
+              end,
+    case emqx_resource:query(ResourceID,  {Method, Request, RequestTimeout}) of
+        {ok, 204, _Headers} -> {matched, allow};
+        {ok, 200, _Headers, _Body} -> {matched, allow};
+        _ -> nomatch
+    end.
+
+query_string(Body) ->
+    query_string(maps:to_list(Body), []).
+
+query_string([], Acc) ->
+    <<$&, Str/binary>> = iolist_to_binary(lists:reverse(Acc)),
+    Str;
+query_string([{K, V} | More], Acc) ->
+    query_string(More, [["&", emqx_http_lib:uri_encode(K), "=", emqx_http_lib:uri_encode(V)] | Acc]).
+
+serialize_body(<<"application/json">>, Body) ->
+    jsx:encode(Body);
+serialize_body(<<"application/x-www-form-urlencoded">>, Body) ->
+    query_string(Body).
+
+replvar(Str0, PubSub, Topic,
+        #{username := Username,
+          clientid := Clientid,
+          peerhost := IpAddress,
+          protocol := Protocol,
+          mountpoint := Mountpoint
+         }) when is_list(Str0);
+                 is_binary(Str0) ->
+    NTopic = emqx_http_lib:uri_encode(Topic),
+    Str1 = re:replace(Str0, "%c", Clientid, [global, {return, binary}]),
+    Str2 = re:replace(Str1, "%u", Username, [global, {return, binary}]),
+    Str3 = re:replace(Str2, "%a", inet_parse:ntoa(IpAddress), [global, {return, binary}]),
+    Str4 = re:replace(Str3, "%r", bin(Protocol), [global, {return, binary}]),
+    Str5 = re:replace(Str4, "%m", Mountpoint, [global, {return, binary}]),
+    Str6 = re:replace(Str5, "%t", NTopic, [global, {return, binary}]),
+    Str7 = re:replace(Str6, "%A", bin(PubSub), [global, {return, binary}]),
+    Str7.
+
+bin(A) when is_atom(A) -> atom_to_binary(A, utf8);
+bin(B) when is_binary(B) -> B;
+bin(L) when is_list(L) -> list_to_binary(L);
+bin(X) -> X.

+ 81 - 3
apps/emqx_authz/src/emqx_authz_schema.erl

@@ -4,18 +4,87 @@
 
 
 -type action() :: publish | subscribe | all.
 -type action() :: publish | subscribe | all.
 -type permission() :: allow | deny.
 -type permission() :: allow | deny.
+-type url() :: emqx_http_lib:uri_map().
 
 
 -reflect_type([ permission/0
 -reflect_type([ permission/0
               , action/0
               , action/0
+              , url/0
               ]).
               ]).
 
 
--export([structs/0, fields/1]).
+-typerefl_from_string({url/0, emqx_http_lib, uri_parse}).
+
+-export([ structs/0
+        , fields/1
+        ]).
 
 
 structs() -> ["emqx_authz"].
 structs() -> ["emqx_authz"].
 
 
 fields("emqx_authz") ->
 fields("emqx_authz") ->
     [ {rules, rules()}
     [ {rules, rules()}
     ];
     ];
+fields(http) ->
+    [ {principal, principal()}
+    , {type, #{type => http}}
+    , {config, #{type => hoconsc:union([ hoconsc:ref(?MODULE, http_get)
+                                       , hoconsc:ref(?MODULE, http_post)
+                                       ])}
+      }
+    ];
+fields(http_get) ->
+    [ {url, #{type => url()}}
+    , {headers, #{type => map(),
+                  default => #{ <<"accept">> => <<"application/json">>
+                              , <<"cache-control">> => <<"no-cache">>
+                              , <<"connection">> => <<"keep-alive">>
+                              , <<"keep-alive">> => <<"timeout=5">>
+                              },
+                  converter => fun (Headers0) ->
+                                    Headers1 = maps:fold(fun(K0, V, AccIn) ->
+                                                           K1 = iolist_to_binary(string:to_lower(binary_to_list(K0))),
+                                                           maps:put(K1, V, AccIn)
+                                                        end, #{}, Headers0),
+                                    maps:merge(#{ <<"accept">> => <<"application/json">>
+                                                , <<"cache-control">> => <<"no-cache">>
+                                                , <<"connection">> => <<"keep-alive">>
+                                                , <<"keep-alive">> => <<"timeout=5">>
+                                                }, Headers1)
+                               end
+                 }
+      }
+    , {method,  #{type => get,
+                  default => get
+                 }}
+    ]  ++ proplists:delete(base_url, emqx_connector_http:fields(config));
+fields(http_post) ->
+    [ {url, #{type => url()}}
+    , {headers, #{type => map(),
+                  default => #{ <<"accept">> => <<"application/json">>
+                              , <<"cache-control">> => <<"no-cache">>
+                              , <<"connection">> => <<"keep-alive">>
+                              , <<"content-type">> => <<"application/json">>
+                              , <<"keep-alive">> => <<"timeout=5">>
+                              },
+                  converter => fun (Headers0) ->
+                                    Headers1 = maps:fold(fun(K0, V, AccIn) ->
+                                                           K1 = iolist_to_binary(string:to_lower(binary_to_list(K0))),
+                                                           maps:put(K1, V, AccIn)
+                                                        end, #{}, Headers0),
+                                    maps:merge(#{ <<"accept">> => <<"application/json">>
+                                                , <<"cache-control">> => <<"no-cache">>
+                                                , <<"connection">> => <<"keep-alive">>
+                                                , <<"content-type">> => <<"application/json">>
+                                                , <<"keep-alive">> => <<"timeout=5">>
+                                                }, Headers1)
+                               end
+                 }
+      }
+    , {method,  #{type => hoconsc:enum([post, put]),
+                  default => get}}
+    , {body, #{type => map(),
+               nullable => true
+              }
+      }
+    ]  ++ proplists:delete(base_url, emqx_connector_http:fields(config));
 fields(mongo) ->
 fields(mongo) ->
     connector_fields(mongo) ++
     connector_fields(mongo) ++
     [ {collection, #{type => atom()}}
     [ {collection, #{type => atom()}}
@@ -75,9 +144,10 @@ fields(eq_topic) ->
 union_array(Item) when is_list(Item) ->
 union_array(Item) when is_list(Item) ->
     hoconsc:array(hoconsc:union(Item)).
     hoconsc:array(hoconsc:union(Item)).
 
 
-rules() -> 
+rules() ->
     #{type => union_array(
     #{type => union_array(
                 [ hoconsc:ref(?MODULE, simple_rule)
                 [ hoconsc:ref(?MODULE, simple_rule)
+                , hoconsc:ref(?MODULE, http)
                 , hoconsc:ref(?MODULE, mysql)
                 , hoconsc:ref(?MODULE, mysql)
                 , hoconsc:ref(?MODULE, pgsql)
                 , hoconsc:ref(?MODULE, pgsql)
                 , hoconsc:ref(?MODULE, redis)
                 , hoconsc:ref(?MODULE, redis)
@@ -108,7 +178,15 @@ query() ->
      }.
      }.
 
 
 connector_fields(DB) ->
 connector_fields(DB) ->
-    Mod = list_to_existing_atom(io_lib:format("~s_~s",[emqx_connector, DB])),
+    Mod0 = io_lib:format("~s_~s",[emqx_connector, DB]),
+    Mod = try
+              list_to_existing_atom(Mod0)
+          catch
+              error:badarg ->
+                  list_to_atom(Mod0);
+              Error ->
+                  erlang:error(Error)
+          end,
     [ {principal, principal()}
     [ {principal, principal()}
     , {type, #{type => DB}}
     , {type, #{type => DB}}
     ] ++ Mod:fields("").
     ] ++ Mod:fields("").

+ 94 - 0
apps/emqx_authz/test/emqx_authz_http_SUITE.erl

@@ -0,0 +1,94 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%% http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_authz_http_SUITE).
+
+-compile(nowarn_export_all).
+-compile(export_all).
+
+-include("emqx_authz.hrl").
+-include_lib("eunit/include/eunit.hrl").
+-include_lib("common_test/include/ct.hrl").
+
+all() ->
+    emqx_ct:all(?MODULE).
+
+groups() ->
+    [].
+
+init_per_suite(Config) ->
+    meck:new(emqx_resource, [non_strict, passthrough, no_history, no_link]),
+    meck:expect(emqx_resource, create, fun(_, _, _) -> {ok, meck_data} end ),
+    ok = emqx_ct_helpers:start_apps([emqx_authz], fun set_special_configs/1),
+    Config.
+
+end_per_suite(_Config) ->
+    file:delete(filename:join(emqx:get_env(plugins_etc_dir), 'authz.conf')),
+    emqx_ct_helpers:stop_apps([emqx_authz, emqx_resource]),
+    meck:unload(emqx_resource).
+
+set_special_configs(emqx) ->
+    application:set_env(emqx, allow_anonymous, true),
+    application:set_env(emqx, enable_acl_cache, false),
+    application:set_env(emqx, acl_nomatch, deny),
+    application:set_env(emqx, plugins_loaded_file,
+                        emqx_ct_helpers:deps_path(emqx, "test/loaded_plguins")),
+    ok;
+set_special_configs(emqx_authz) ->
+    Rules = [#{config =>#{
+                 url => #{host => "fake.com",
+                          path => "/",
+                          port => 443,
+                          scheme => https},
+                 headers => #{},
+                 method => get,
+                 request_timeout => 5000
+                },
+               principal => all,
+               type => http}
+            ],
+    emqx_config:put([emqx_authz], #{rules => Rules}),
+    ok;
+set_special_configs(_App) ->
+    ok.
+
+%%------------------------------------------------------------------------------
+%% Testcases
+%%------------------------------------------------------------------------------
+
+t_authz(_) ->
+    ClientInfo = #{clientid => <<"clientid">>,
+                   username => <<"username">>,
+                   peerhost => {127,0,0,1},
+                   protocol => mqtt,
+                   mountpoint => <<"fake">>
+                   },
+
+    meck:expect(emqx_resource, query, fun(_, _) -> {ok, 204, fake_headers} end),
+    ?assertEqual(allow,
+                 emqx_access_control:authorize(ClientInfo, subscribe, <<"#">>)),
+
+    meck:expect(emqx_resource, query, fun(_, _) -> {ok, 200, fake_headers, fake_body} end),
+    ?assertEqual(allow,
+                 emqx_access_control:authorize(ClientInfo, publish, <<"#">>)),
+
+
+    meck:expect(emqx_resource, query, fun(_, _) -> {error, other} end),
+    ?assertEqual(deny,
+        emqx_access_control:authorize(ClientInfo, subscribe, <<"+">>)),
+    ?assertEqual(deny,
+        emqx_access_control:authorize(ClientInfo, publish, <<"+">>)),
+    ok.
+

+ 16 - 24
apps/emqx_connector/src/emqx_connector_http.erl

@@ -28,6 +28,10 @@
         , on_health_check/2
         , on_health_check/2
         ]).
         ]).
 
 
+-type url() :: emqx_http_lib:uri_map().
+-reflect_type([url/0]).
+-typerefl_from_string({url/0, emqx_http_lib, uri_parse}).
+
 -export([ structs/0
 -export([ structs/0
         , fields/1
         , fields/1
         , validations/0]).
         , validations/0]).
@@ -53,7 +57,6 @@ fields(config) ->
     , {connect_timeout, fun connect_timeout/1}
     , {connect_timeout, fun connect_timeout/1}
     , {max_retries,     fun max_retries/1}
     , {max_retries,     fun max_retries/1}
     , {retry_interval,  fun retry_interval/1}
     , {retry_interval,  fun retry_interval/1}
-    , {keepalive,       fun keepalive/1}
     , {pool_type,       fun pool_type/1}
     , {pool_type,       fun pool_type/1}
     , {pool_size,       fun pool_size/1}
     , {pool_size,       fun pool_size/1}
     , {ssl_opts,        #{type => hoconsc:ref(?MODULE, ssl_opts),
     , {ssl_opts,        #{type => hoconsc:ref(?MODULE, ssl_opts),
@@ -70,9 +73,12 @@ fields(ssl_opts) ->
 validations() ->
 validations() ->
     [ {check_ssl_opts, fun check_ssl_opts/1} ].
     [ {check_ssl_opts, fun check_ssl_opts/1} ].
 
 
-base_url(type) -> binary();
+base_url(type) -> url();
 base_url(nullable) -> false;
 base_url(nullable) -> false;
-base_url(validate) -> [fun check_base_url/1];
+base_url(validate) -> fun (#{query := _Query}) ->
+                              {error, "There must be no query in the base_url"};
+                          (_) -> ok
+                      end;
 base_url(_) -> undefined.
 base_url(_) -> undefined.
 
 
 connect_timeout(type) -> connect_timeout();
 connect_timeout(type) -> connect_timeout();
@@ -87,10 +93,6 @@ retry_interval(type) -> non_neg_integer();
 retry_interval(default) -> 1000;
 retry_interval(default) -> 1000;
 retry_interval(_) -> undefined.
 retry_interval(_) -> undefined.
 
 
-keepalive(type) -> non_neg_integer();
-keepalive(default) -> 5000;
-keepalive(_) -> undefined.
-
 pool_type(type) -> pool_type();
 pool_type(type) -> pool_type();
 pool_type(default) -> random;
 pool_type(default) -> random;
 pool_type(_) -> undefined.
 pool_type(_) -> undefined.
@@ -117,18 +119,16 @@ verify(default) -> false;
 verify(_) -> undefined.
 verify(_) -> undefined.
 
 
 %% ===================================================================
 %% ===================================================================
-on_start(InstId, #{url := URL,
+on_start(InstId, #{base_url := #{scheme := Scheme,
+                                 host := Host,
+                                 port := Port,
+                                 path := BasePath},
                    connect_timeout := ConnectTimeout,
                    connect_timeout := ConnectTimeout,
                    max_retries := MaxRetries,
                    max_retries := MaxRetries,
                    retry_interval := RetryInterval,
                    retry_interval := RetryInterval,
-                   keepalive := Keepalive,
                    pool_type := PoolType,
                    pool_type := PoolType,
                    pool_size := PoolSize} = Config) ->
                    pool_size := PoolSize} = Config) ->
     logger:info("starting http connector: ~p, config: ~p", [InstId, Config]),
     logger:info("starting http connector: ~p, config: ~p", [InstId, Config]),
-    {ok, #{scheme := Scheme,
-           host := Host,
-           port := Port,
-           path := BasePath}} = emqx_http_lib:uri_parse(URL),
     {Transport, TransportOpts} = case Scheme of
     {Transport, TransportOpts} = case Scheme of
                                      http ->
                                      http ->
                                          {tcp, []};
                                          {tcp, []};
@@ -143,7 +143,7 @@ on_start(InstId, #{url := URL,
                , {connect_timeout, ConnectTimeout}
                , {connect_timeout, ConnectTimeout}
                , {retry, MaxRetries}
                , {retry, MaxRetries}
                , {retry_timeout, RetryInterval}
                , {retry_timeout, RetryInterval}
-               , {keepalive, Keepalive}
+               , {keepalive, 5000}
                , {pool_type, PoolType}
                , {pool_type, PoolType}
                , {pool_size, PoolSize}
                , {pool_size, PoolSize}
                , {transport, Transport}
                , {transport, Transport}
@@ -192,19 +192,11 @@ on_health_check(_InstId, #{host := Host, port := Port} = State) ->
 %% Internal functions
 %% Internal functions
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
-check_base_url(URL) ->
-    case emqx_http_lib:uri_parse(URL) of
-        {error, _} -> false;
-        {ok, #{query := _}} -> false;
-        _ -> true
-    end.
-
 check_ssl_opts(Conf) ->
 check_ssl_opts(Conf) ->
     check_ssl_opts("base_url", Conf).
     check_ssl_opts("base_url", Conf).
 
 
 check_ssl_opts(URLFrom, Conf) ->
 check_ssl_opts(URLFrom, Conf) ->
-    URL = hocon_schema:get_value(URLFrom, Conf),
-    {ok, #{scheme := Scheme}} = emqx_http_lib:uri_parse(URL),
+    #{schema := Scheme} = hocon_schema:get_value(URLFrom, Conf),
     SSLOpts = hocon_schema:get_value("ssl_opts", Conf),
     SSLOpts = hocon_schema:get_value("ssl_opts", Conf),
     case {Scheme, maps:size(SSLOpts)} of
     case {Scheme, maps:size(SSLOpts)} of
         {http, 0} -> true;
         {http, 0} -> true;
@@ -216,4 +208,4 @@ check_ssl_opts(URLFrom, Conf) ->
 update_path(BasePath, {Path, Headers}) ->
 update_path(BasePath, {Path, Headers}) ->
     {filename:join(BasePath, Path), Headers};
     {filename:join(BasePath, Path), Headers};
 update_path(BasePath, {Path, Headers, Body}) ->
 update_path(BasePath, {Path, Headers, Body}) ->
-    {filename:join(BasePath, Path), Headers, Body}.
+    {filename:join(BasePath, Path), Headers, Body}.

+ 72 - 36
apps/emqx_retainer/etc/emqx_retainer.conf

@@ -6,40 +6,76 @@
 ##
 ##
 ## Notice that all nodes in the same cluster have to be configured to
 ## Notice that all nodes in the same cluster have to be configured to
 emqx_retainer: {
 emqx_retainer: {
-    ## enable/disable emqx_retainer
-    enable: true
-	## use the same storage_type.
-	##
-	## Value: ram | disc | disc_only
-	##  - ram: memory only
-	##  - disc: both memory and disc
-	##  - disc_only: disc only
-	##
-	## Default: ram
-	storage_type: ram
-
-	## Maximum number of retained messages. 0 means no limit.
-	##
-	## Value: Number >= 0
-	max_retained_messages: 0
-
-	## Maximum retained message size.
-	##
-	## Value: Bytes
-	max_payload_size: 1MB
-
-	## Expiry interval of the retained messages. Never expire if the value is 0.
-	##
-	## Value: Duration
-	##  - h: hour
-	##  - m: minute
-	##  - s: second
-	##
-	## Examples:
-	##  - 2h:  2 hours
-	##  - 30m: 30 minutes
-	##  - 20s: 20 seconds
-	##
-	## Default: 0s
-	expiry_interval: 0s
+  ## enable/disable emqx_retainer
+  enable: true
+
+  ## Periodic interval for cleaning up expired messages. Never clear if the value is 0.
+  ##
+  ## Value: Duration
+  ##  - h: hour
+  ##  - m: minute
+  ##  - s: second
+  ##
+  ## Examples:
+  ##  - 2h:  2 hours
+  ##  - 30m: 30 minutes
+  ##  - 20s: 20 seconds
+  ##
+  ## Default: 0s
+  msg_clear_interval: 0s
+
+  ## Message retention time. 0 means message will never be expired.
+  ##
+  ## Default: 0s
+  msg_expiry_interval: 0s
+
+  ## The message read and deliver flow rate control
+  ## When a client subscribe to a wildcard topic, may many retained messages will be loaded.
+  ## If you don't want these data loaded to the memory all at once, you can use this to control.
+  ## The processing flow:
+  ##   load max_read_number retained message from storage ->
+  ##    deliver ->
+  ##    repeat this, until all retianed messages are delivered
+  ##
+  flow_control: {
+    ## The max messages number per read from storage. 0 means no limit
+    ##
+    ## Default: 0
+    max_read_number: 0
+
+    ## The max number of retained message can be delivered in emqx per quota_release_interval.0 means no limit
+    ##
+    ## Default: 0
+    msg_deliver_quota: 0
+
+    ## deliver quota reset interval
+    ##
+    ## Default: 0s
+    quota_release_interval: 0s
+  }
+
+  ## Maximum retained message size.
+  ##
+  ## Value: Bytes
+  max_payload_size: 1MB
+
+  ## Storage connect parameters
+  ##
+  ## Value: mnesia
+  ##
+  connector:
+    [
+      {
+       type: mnesia
+       config: {
+            ## storage_type: ram | disc | disc_only
+            storage_type: ram
+
+            ## Maximum number of retained messages. 0 means no limit.
+            ##
+            ## Value: Number >= 0
+            max_retained_messages: 0
+       }
+      }
+    ]
 }
 }

+ 20 - 1
apps/emqx_retainer/include/emqx_retainer.hrl

@@ -14,7 +14,26 @@
 %% limitations under the License.
 %% limitations under the License.
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
+-include_lib("emqx/include/emqx.hrl").
+
 -define(APP, emqx_retainer).
 -define(APP, emqx_retainer).
 -define(TAB, ?APP).
 -define(TAB, ?APP).
--record(retained, {topic, msg, expiry_time}).
 -define(RETAINER_SHARD, emqx_retainer_shard).
 -define(RETAINER_SHARD, emqx_retainer_shard).
+
+-type topic() :: binary().
+-type payload() :: binary().
+-type message() :: #message{}.
+
+-type context() :: #{context_id := pos_integer(),
+                     atom() => term()}.
+
+-define(DELIVER_SEMAPHORE, deliver_remained_quota).
+-type semaphore() :: ?DELIVER_SEMAPHORE.
+-type cursor() :: undefined | term().
+-type result() :: term().
+
+-define(SHARED_CONTEXT_TAB, emqx_retainer_ctx).
+-record(shared_context, {key :: atom(), value :: term()}).
+-type shared_context_key() :: ?DELIVER_SEMAPHORE.
+
+-type backend() :: emqx_retainer_storage_mnesia.

+ 332 - 231
apps/emqx_retainer/src/emqx_retainer.erl

@@ -21,24 +21,24 @@
 -include("emqx_retainer.hrl").
 -include("emqx_retainer.hrl").
 -include_lib("emqx/include/emqx.hrl").
 -include_lib("emqx/include/emqx.hrl").
 -include_lib("emqx/include/logger.hrl").
 -include_lib("emqx/include/logger.hrl").
--include_lib("stdlib/include/ms_transform.hrl").
 
 
 -logger_header("[Retainer]").
 -logger_header("[Retainer]").
 
 
 -export([start_link/0]).
 -export([start_link/0]).
 
 
--export([unload/0
+-export([ on_session_subscribed/4
+        , on_message_publish/2
         ]).
         ]).
 
 
--export([ on_session_subscribed/3
-        , on_message_publish/1
-        ]).
-
--export([ clean/1
-        , update_config/1]).
+-export([ dispatch/4
+        , delete_message/2
+        , store_retained/2
+        , deliver/5]).
 
 
-%% for emqx_pool task func
--export([dispatch/2]).
+-export([ get_expiry_time/1
+        , update_config/1
+        , clean/0
+        , delete/1]).
 
 
 %% gen_server callbacks
 %% gen_server callbacks
 -export([ init/1
 -export([ init/1
@@ -49,62 +49,52 @@
         , code_change/3
         , code_change/3
         ]).
         ]).
 
 
--record(state, {stats_fun, stats_timer, expiry_timer}).
+-type state() :: #{ enable := boolean()
+                  , context_id := non_neg_integer()
+                  , context := undefined | context()
+                  , clear_timer := undefined | reference()
+                  , release_quota_timer := undefined | reference()
+                  , wait_quotas := list()
+                  }.
+
+-rlog_shard({?RETAINER_SHARD, ?TAB}).
 
 
--define(STATS_INTERVAL, timer:seconds(1)).
--define(DEF_STORAGE_TYPE, ram).
--define(DEF_MAX_RETAINED_MESSAGES, 0).
 -define(DEF_MAX_PAYLOAD_SIZE, (1024 * 1024)).
 -define(DEF_MAX_PAYLOAD_SIZE, (1024 * 1024)).
 -define(DEF_EXPIRY_INTERVAL, 0).
 -define(DEF_EXPIRY_INTERVAL, 0).
--define(DEF_ENABLE_VAL, false).
 
 
-%% convenient to generate stats_timer/expiry_timer
--define(MAKE_TIMER(State, Timer, Interval, Msg),
-        State#state{Timer = erlang:send_after(Interval, self(), Msg)}).
+-define(CAST(Msg), gen_server:cast(?MODULE, Msg)).
 
 
--rlog_shard({?RETAINER_SHARD, ?TAB}).
+-callback delete_message(context(), topic()) -> ok.
+-callback store_retained(context(), message()) -> ok.
+-callback read_message(context(), topic()) -> {ok, list()}.
+-callback match_messages(context(), topic(), cursor()) -> {ok, list(), cursor()}.
+-callback clear_expired(context()) -> ok.
+-callback clean(context()) -> ok.
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
-%% Load/Unload
+%% Hook API
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
-
-load() ->
-    _ = emqx:hook('session.subscribed', {?MODULE, on_session_subscribed, []}),
-    _ = emqx:hook('message.publish', {?MODULE, on_message_publish, []}),
-    ok.
-
-unload() ->
-    emqx:unhook('message.publish', {?MODULE, on_message_publish}),
-    emqx:unhook('session.subscribed', {?MODULE, on_session_subscribed}).
-
-on_session_subscribed(_, _, #{share := ShareName}) when ShareName =/= undefined ->
+on_session_subscribed(_, _, #{share := ShareName}, _) when ShareName =/= undefined ->
     ok;
     ok;
-on_session_subscribed(_, Topic, #{rh := Rh, is_new := IsNew}) ->
+on_session_subscribed(_, Topic, #{rh := Rh, is_new := IsNew}, Context) ->
     case Rh =:= 0 orelse (Rh =:= 1 andalso IsNew) of
     case Rh =:= 0 orelse (Rh =:= 1 andalso IsNew) of
-        true -> emqx_pool:async_submit(fun ?MODULE:dispatch/2, [self(), Topic]);
+        true -> dispatch(Context, Topic);
         _ -> ok
         _ -> ok
     end.
     end.
 
 
-%% @private
-dispatch(Pid, Topic) ->
-    Msgs = case emqx_topic:wildcard(Topic) of
-               false -> read_messages(Topic);
-               true  -> match_messages(Topic)
-           end,
-    [Pid ! {deliver, Topic, Msg} || Msg  <- sort_retained(Msgs)].
-
 %% RETAIN flag set to 1 and payload containing zero bytes
 %% RETAIN flag set to 1 and payload containing zero bytes
 on_message_publish(Msg = #message{flags   = #{retain := true},
 on_message_publish(Msg = #message{flags   = #{retain := true},
                                   topic   = Topic,
                                   topic   = Topic,
-                                  payload = <<>>}) ->
-    ekka_mnesia:dirty_delete(?TAB, topic2tokens(Topic)),
+                                  payload = <<>>},
+                   Context) ->
+    delete_message(Context, Topic),
     {ok, Msg};
     {ok, Msg};
 
 
-on_message_publish(Msg = #message{flags = #{retain := true}}) ->
+on_message_publish(Msg = #message{flags = #{retain := true}}, Context) ->
     Msg1 = emqx_message:set_header(retained, true, Msg),
     Msg1 = emqx_message:set_header(retained, true, Msg),
-    store_retained(Msg1),
+    store_retained(Context, Msg1),
     {ok, Msg};
     {ok, Msg};
-on_message_publish(Msg) ->
+on_message_publish(Msg, _) ->
     {ok, Msg}.
     {ok, Msg}.
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
@@ -116,71 +106,98 @@ on_message_publish(Msg) ->
 start_link() ->
 start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
 
--spec(clean(emqx_types:topic()) -> non_neg_integer()).
-clean(Topic) when is_binary(Topic) ->
-    case emqx_topic:wildcard(Topic) of
-        true -> match_delete_messages(Topic);
+-spec dispatch(context(), pid(), topic(), cursor()) -> ok.
+dispatch(Context, Pid, Topic, Cursor) ->
+    Mod = get_backend_module(),
+    case Cursor =/= undefined orelse emqx_topic:wildcard(Topic) of
         false ->
         false ->
-            Tokens = topic2tokens(Topic),
-            Fun = fun() ->
-                      case mnesia:read({?TAB, Tokens}) of
-                          [] -> 0;
-                          [_M] -> mnesia:delete({?TAB, Tokens}), 1
-                      end
-                  end,
-            {atomic, N} = ekka_mnesia:transaction(?RETAINER_SHARD, Fun), N
+            {ok, Result} = Mod:read_message(Context, Topic),
+            deliver(Result, Context, Pid, Topic, undefiend);
+        true  ->
+            {ok, Result, NewCursor} =  Mod:match_messages(Context, Topic, Cursor),
+            deliver(Result, Context, Pid, Topic, NewCursor)
+    end.
+
+deliver([], Context, Pid, Topic, Cursor) ->
+    case Cursor of
+        undefined ->
+            ok;
+        _ ->
+            dispatch(Context, Pid, Topic, Cursor)
+    end;
+deliver(Result, #{context_id := Id} = Context, Pid, Topic, Cursor) ->
+    case erlang:is_process_alive(Pid) of
+        false ->
+            ok;
+        _ ->
+            #{msg_deliver_quota := MaxDeliverNum} = emqx_config:get([?APP, flow_control]),
+            case MaxDeliverNum of
+                0 ->
+                    _ = [Pid ! {deliver, Topic, Msg} || Msg <- Result],
+                    ok;
+                _ ->
+                    case do_deliver(Result, Id, Pid, Topic) of
+                        ok ->
+                            deliver([], Context, Pid, Topic, Cursor);
+                        abort ->
+                            ok
+                    end
+            end
+    end.
+
+get_expiry_time(#message{headers = #{properties := #{'Message-Expiry-Interval' := 0}}}) ->
+    0;
+get_expiry_time(#message{headers = #{properties := #{'Message-Expiry-Interval' := Interval}},
+                         timestamp = Ts}) ->
+    Ts + Interval * 1000;
+get_expiry_time(#message{timestamp = Ts}) ->
+    Interval = emqx_config:get([?APP, msg_expiry_interval], ?DEF_EXPIRY_INTERVAL),
+    case Interval of
+        0 -> 0;
+        _ -> Ts + Interval
     end.
     end.
 
 
-%%--------------------------------------------------------------------
-%% Update Config
-%%--------------------------------------------------------------------
 -spec update_config(hocon:config()) -> ok.
 -spec update_config(hocon:config()) -> ok.
 update_config(Conf) ->
 update_config(Conf) ->
-    OldCfg = emqx_config:get([?APP]),
-    emqx_config:put([?APP], Conf),
-    check_enable_when_update(OldCfg).
+    gen_server:call(?MODULE, {?FUNCTION_NAME, Conf}).
+
+clean() ->
+    gen_server:call(?MODULE, ?FUNCTION_NAME).
+
+delete(Topic) ->
+    gen_server:call(?MODULE, {?FUNCTION_NAME, Topic}).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% gen_server callbacks
 %% gen_server callbacks
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 init([]) ->
 init([]) ->
-    StorageType = emqx_config:get([?MODULE, storage_type], ?DEF_STORAGE_TYPE),
-    ExpiryInterval = emqx_config:get([?MODULE, expiry_interval], ?DEF_EXPIRY_INTERVAL),
-    Copies = case StorageType of
-                 ram       -> ram_copies;
-                 disc      -> disc_copies;
-                 disc_only -> disc_only_copies
-             end,
-    StoreProps = [{ets, [compressed,
-                         {read_concurrency, true},
-                         {write_concurrency, true}]},
-                  {dets, [{auto_save, 1000}]}],
-    ok = ekka_mnesia:create_table(?TAB, [
-                {type, set},
-                {Copies, [node()]},
-                {record_name, retained},
-                {attributes, record_info(fields, retained)},
-                {storage_properties, StoreProps}]),
-    ok = ekka_mnesia:copy_table(?TAB, Copies),
-    ok = ekka_rlog:wait_for_shards([?RETAINER_SHARD], infinity),
-    case mnesia:table_info(?TAB, storage_type) of
-        Copies -> ok;
-        _Other ->
-            {atomic, ok} = mnesia:change_table_copy_type(?TAB, node(), Copies),
-            ok
-    end,
-    StatsFun = emqx_stats:statsfun('retained.count', 'retained.max'),
-    State = ?MAKE_TIMER(#state{stats_fun = StatsFun}, stats_timer, ?STATS_INTERVAL, stats),
-    check_enable_when_init(),
-    {ok, start_expire_timer(ExpiryInterval, State)}.
-
-start_expire_timer(0, State) ->
-    State;
-start_expire_timer(undefined, State) ->
-    State;
-start_expire_timer(Ms, State) ->
-    ?MAKE_TIMER(State, expiry_timer, Ms, expire).
+    init_shared_context(),
+    State = new_state(),
+    #{enable := Enable} = Cfg = emqx_config:get([?APP]),
+    {ok,
+     case Enable of
+         true ->
+             enable_retainer(State, Cfg);
+         _ ->
+             State
+     end}.
+
+handle_call({update_config, Conf}, _, State) ->
+    State2 = update_config(State, Conf),
+    emqx_config:put([?APP], Conf),
+    {reply, ok, State2};
+
+handle_call({wait_semaphore, Id}, From, #{wait_quotas := Waits} = State) ->
+    {noreply, State#{wait_quotas := [{Id, From} | Waits]}};
+
+handle_call(clean, _, #{context := Context} = State) ->
+    clean(Context),
+    {reply, ok, State};
+
+handle_call({delete, Topic}, _, #{context := Context} = State) ->
+    delete_message(Context, Topic),
+    {reply, ok, State};
 
 
 handle_call(Req, _From, State) ->
 handle_call(Req, _From, State) ->
     ?LOG(error, "Unexpected call: ~p", [Req]),
     ?LOG(error, "Unexpected call: ~p", [Req]),
@@ -190,22 +207,36 @@ handle_cast(Msg, State) ->
     ?LOG(error, "Unexpected cast: ~p", [Msg]),
     ?LOG(error, "Unexpected cast: ~p", [Msg]),
     {noreply, State}.
     {noreply, State}.
 
 
-handle_info(stats, State = #state{stats_fun = StatsFun}) ->
-    StatsFun(retained_count()),
-    {noreply, ?MAKE_TIMER(State, stats_timer, ?STATS_INTERVAL, stats), hibernate};
+handle_info(clear_expired, #{context := Context} = State) ->
+    Mod = get_backend_module(),
+    Mod:clear_expired(Context),
+    Interval = emqx_config:get([?APP, msg_clear_interval], ?DEF_EXPIRY_INTERVAL),
+    {noreply, State#{clear_timer := add_timer(Interval, clear_expired)}, hibernate};
 
 
-handle_info(expire, State) ->
-    ok = expire_messages(),
-    Interval = emqx_config:get([?MODULE, expiry_interval], ?DEF_EXPIRY_INTERVAL),
-    {noreply, start_expire_timer(Interval, State), hibernate};
+handle_info(release_deliver_quota, #{context := Context, wait_quotas := Waits} = State) ->
+    insert_shared_context(?DELIVER_SEMAPHORE, get_msg_deliver_quota()),
+    case Waits of
+        [] ->
+            ok;
+        _ ->
+            #{context_id := NowId} = Context,
+            Waits2 = lists:reverse(Waits),
+            lists:foreach(fun({Id, From}) ->
+                                  gen_server:reply(From, Id =:= NowId)
+                          end,
+                          Waits2)
+    end,
+    Interval = emqx_config:get([?APP, flow_control, quota_release_interval]),
+    {noreply, State#{release_quota_timer := add_timer(Interval, release_deliver_quota),
+                     wait_quotas := []}};
 
 
 handle_info(Info, State) ->
 handle_info(Info, State) ->
     ?LOG(error, "Unexpected info: ~p", [Info]),
     ?LOG(error, "Unexpected info: ~p", [Info]),
     {noreply, State}.
     {noreply, State}.
 
 
-terminate(_Reason, #state{stats_timer = TRef1, expiry_timer = TRef2}) ->
-    _ = erlang:cancel_timer(TRef1),
-    _ = erlang:cancel_timer(TRef2),
+terminate(_Reason, #{clear_timer := TRef1, release_quota_timer := TRef2}) ->
+    _ = stop_timer(TRef1),
+    _ = stop_timer(TRef2),
     ok.
     ok.
 
 
 code_change(_OldVsn, State, _Extra) ->
 code_change(_OldVsn, State, _Extra) ->
@@ -214,141 +245,211 @@ code_change(_OldVsn, State, _Extra) ->
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Internal functions
 %% Internal functions
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
-sort_retained([]) -> [];
-sort_retained([Msg]) -> [Msg];
-sort_retained(Msgs)  ->
-    lists:sort(fun(#message{timestamp = Ts1}, #message{timestamp = Ts2}) ->
-                       Ts1 =< Ts2 end,
-               Msgs).
-
-store_retained(Msg = #message{topic = Topic, payload = Payload}) ->
-    case {is_table_full(), is_too_big(size(Payload))} of
-        {false, false} ->
-            ok = emqx_metrics:inc('messages.retained'),
-            ekka_mnesia:dirty_write(?TAB, #retained{topic = topic2tokens(Topic),
-                                                    msg = Msg,
-                                                    expiry_time = get_expiry_time(Msg)});
-        {true, false} ->
-            {atomic, _} = ekka_mnesia:transaction(?RETAINER_SHARD,
-                fun() ->
-                        case mnesia:read(?TAB, Topic) of
-                            [_] ->
-                                mnesia:write(?TAB,
-                                             #retained{topic = topic2tokens(Topic),
-                                                       msg = Msg,
-                                                       expiry_time = get_expiry_time(Msg)},
-                                             write);
-                            [] ->
-                                ?LOG(error,
-                                     "Cannot retain message(topic=~s) for table is full!", [Topic])
-                    end
-                end),
-            ok;
-        {true, _} ->
-            ?LOG(error, "Cannot retain message(topic=~s) for table is full!", [Topic]);
-        {_, true} ->
-            ?LOG(error, "Cannot retain message(topic=~s, payload_size=~p) "
-                        "for payload is too big!", [Topic, iolist_size(Payload)])
-    end.
-
-is_table_full() ->
-    Limit = emqx_config:get([?MODULE, max_retained_messages], ?DEF_MAX_RETAINED_MESSAGES),
-    Limit > 0 andalso (retained_count() > Limit).
+-spec new_state() -> state().
+new_state() ->
+    #{enable => false,
+      context_id => 0,
+      context => undefined,
+      clear_timer => undefined,
+      release_quota_timer => undefined,
+      wait_quotas => []}.
+
+-spec new_context(pos_integer()) -> context().
+new_context(Id) ->
+    #{context_id => Id}.
 
 
 is_too_big(Size) ->
 is_too_big(Size) ->
-    Limit = emqx_config:get([?MODULE, max_payload_size], ?DEF_MAX_PAYLOAD_SIZE),
+    Limit = emqx_config:get([?APP, max_payload_size], ?DEF_MAX_PAYLOAD_SIZE),
     Limit > 0 andalso (Size > Limit).
     Limit > 0 andalso (Size > Limit).
 
 
-get_expiry_time(#message{headers = #{properties := #{'Message-Expiry-Interval' := 0}}}) ->
-    0;
-get_expiry_time(#message{headers = #{properties := #{'Message-Expiry-Interval' := Interval}},
-                         timestamp = Ts}) ->
-    Ts + Interval * 1000;
-get_expiry_time(#message{timestamp = Ts}) ->
-    Interval = emqx_config:get([?MODULE, expiry_interval], ?DEF_EXPIRY_INTERVAL),
-    case Interval of
-        0 -> 0;
-        _ -> Ts + Interval
+%% @private
+dispatch(Context, Topic) ->
+    emqx_retainer_pool:async_submit(fun ?MODULE:dispatch/4,
+                                    [Context, self(), Topic, undefined]).
+
+-spec delete_message(context(), topic()) -> ok.
+delete_message(Context, Topic) ->
+    Mod = get_backend_module(),
+    Mod:delete_message(Context, Topic).
+
+-spec store_retained(context(), message()) -> ok.
+store_retained(Context, #message{topic = Topic, payload = Payload} = Msg) ->
+    case is_too_big(erlang:byte_size(Payload)) of
+        false ->
+            Mod = get_backend_module(),
+            Mod:store_retained(Context, Msg);
+        _ ->
+            ?ERROR("Cannot retain message(topic=~s, payload_size=~p) for payload is too big!",
+                   [Topic, iolist_size(Payload)])
     end.
     end.
 
 
-%%--------------------------------------------------------------------
-%% Internal funcs
-%%--------------------------------------------------------------------
+-spec clean(context()) -> ok.
+clean(Context) ->
+    Mod = get_backend_module(),
+    Mod:clean(Context).
 
 
--spec(retained_count() -> non_neg_integer()).
-retained_count() -> mnesia:table_info(?TAB, size).
-
-topic2tokens(Topic) ->
-    emqx_topic:words(Topic).
-
-expire_messages() ->
-    NowMs = erlang:system_time(millisecond),
-    MsHd = #retained{topic = '$1', msg = '_', expiry_time = '$3'},
-    Ms = [{MsHd, [{'=/=','$3',0}, {'<','$3',NowMs}], ['$1']}],
-    {atomic, _} = ekka_mnesia:transaction(?RETAINER_SHARD,
-        fun() ->
-            Keys = mnesia:select(?TAB, Ms, write),
-            lists:foreach(fun(Key) -> mnesia:delete({?TAB, Key}) end, Keys)
-        end),
+-spec do_deliver(list(term()), pos_integer(), pid(), topic()) -> ok | abort.
+do_deliver([Msg | T], Id, Pid, Topic) ->
+    case require_semaphore(?DELIVER_SEMAPHORE, Id) of
+        true ->
+            Pid ! {deliver, Topic, Msg},
+            do_deliver(T, Id, Pid, Topic);
+        _ ->
+            abort
+    end;
+do_deliver([], _, _, _) ->
     ok.
     ok.
 
 
--spec(read_messages(emqx_types:topic())
-      -> [emqx_types:message()]).
-read_messages(Topic) ->
-    Tokens = topic2tokens(Topic),
-    case mnesia:dirty_read(?TAB, Tokens) of
-        [] -> [];
-        [#retained{msg = Msg, expiry_time = Et}] ->
-            case Et =:= 0 orelse Et >= erlang:system_time(millisecond) of
-                true -> [Msg];
-                false -> []
-            end
-    end.
+-spec require_semaphore(semaphore(), pos_integer()) -> boolean().
+require_semaphore(Semaphore, Id) ->
+    Remained = ets:update_counter(?SHARED_CONTEXT_TAB,
+                                  Semaphore,
+                                  {#shared_context.value, -1, 0, 0}),
+    wait_semaphore(Remained, Id).
+
+-spec wait_semaphore(non_neg_integer(), pos_integer()) -> boolean().
+wait_semaphore(0, Id) ->
+    gen_server:call(?MODULE, {?FUNCTION_NAME, Id}, infinity);
+wait_semaphore(_, _) ->
+    true.
+
+-spec init_shared_context() -> ok.
+init_shared_context() ->
+    ?SHARED_CONTEXT_TAB = ets:new(?SHARED_CONTEXT_TAB,
+                                  [ set, named_table, public
+                                  , {keypos, #shared_context.key}
+                                  , {write_concurrency, true}
+                                  , {read_concurrency, true}]),
+    lists:foreach(fun({K, V}) ->
+                          insert_shared_context(K, V)
+                  end,
+                  [{?DELIVER_SEMAPHORE, get_msg_deliver_quota()}]).
 
 
--spec(match_messages(emqx_types:topic())
-      -> [emqx_types:message()]).
-match_messages(Filter) ->
-    NowMs = erlang:system_time(millisecond),
-    Cond = condition(emqx_topic:words(Filter)),
-    MsHd = #retained{topic = Cond, msg = '$2', expiry_time = '$3'},
-    Ms = [{MsHd, [{'=:=','$3',0}], ['$2']},
-          {MsHd, [{'>','$3',NowMs}], ['$2']}],
-    mnesia:dirty_select(?TAB, Ms).
-
--spec(match_delete_messages(emqx_types:topic())
-      -> DeletedCnt :: non_neg_integer()).
-match_delete_messages(Filter) ->
-    Cond = condition(emqx_topic:words(Filter)),
-    MsHd = #retained{topic = Cond, msg = '_', expiry_time = '_'},
-    Ms = [{MsHd, [], ['$_']}],
-    Rs = mnesia:dirty_select(?TAB, Ms),
-    lists:foreach(fun(R) -> ekka_mnesia:dirty_delete_object(?TAB, R) end, Rs),
-    length(Rs).
 
 
-%% @private
-condition(Ws) ->
-    Ws1 = [case W =:= '+' of true -> '_'; _ -> W end || W <- Ws],
-    case lists:last(Ws1) =:= '#' of
-        false -> Ws1;
-        _ -> (Ws1 -- ['#']) ++ '_'
-    end.
+-spec insert_shared_context(shared_context_key(), term()) -> ok.
+insert_shared_context(Key, Term) ->
+    ets:insert(?SHARED_CONTEXT_TAB, #shared_context{key = Key, value = Term}),
+    ok.
 
 
--spec check_enable_when_init() -> ok.
-check_enable_when_init() ->
-    case emqx_config:get([?APP, enable], ?DEF_ENABLE_VAL) of
-        true -> load();
-        _  -> ok
-    end.
+-spec get_msg_deliver_quota() -> non_neg_integer().
+get_msg_deliver_quota() ->
+    emqx_config:get([?APP, flow_control, msg_deliver_quota]).
 
 
--spec check_enable_when_update(hocon:config()) -> ok.
-check_enable_when_update(OldCfg) ->
-    OldVal = maps:get(enable, OldCfg, undefined),
-    case emqx_config:get([?APP, enable], ?DEF_ENABLE_VAL) of
-        OldVal ->
-            ok;
+-spec update_config(state(), hocons:config()) -> state().
+update_config(#{clear_timer := ClearTimer,
+                release_quota_timer := QuotaTimer} = State, Conf) ->
+    #{enable := Enable,
+      connector := [Connector | _],
+      flow_control := #{quota_release_interval := QuotaInterval},
+      msg_clear_interval := ClearInterval} = Conf,
+
+    #{connector := [OldConnector | _]} = emqx_config:get([?APP]),
+
+    case Enable of
         true ->
         true ->
-            load();
+            StorageType = maps:get(type, Connector),
+            OldStrorageType = maps:get(type, OldConnector),
+            case OldStrorageType of
+                StorageType ->
+                    State#{clear_timer := check_timer(ClearTimer,
+                                                      ClearInterval,
+                                                      clear_expired),
+                           release_quota_timer := check_timer(QuotaTimer,
+                                                              QuotaInterval,
+                                                              release_deliver_quota)};
+                _ ->
+                    State2 = disable_retainer(State),
+                    enable_retainer(State2, Conf)
+            end;
         _ ->
         _ ->
-            unload()
+            disable_retainer(State)
+    end.
+
+-spec enable_retainer(state(), hocon:config()) -> state().
+enable_retainer(#{context_id := ContextId} = State,
+                #{msg_clear_interval := ClearInterval,
+                  flow_control := #{quota_release_interval := ReleaseInterval},
+                  connector := [Connector | _]}) ->
+    NewContextId = ContextId + 1,
+    Context = create_resource(new_context(NewContextId), Connector),
+    load(Context),
+    State#{enable := true,
+           context_id := NewContextId,
+           context := Context,
+           clear_timer := add_timer(ClearInterval, clear_expired),
+           release_quota_timer := add_timer(ReleaseInterval, release_deliver_quota)}.
+
+-spec disable_retainer(state()) -> state().
+disable_retainer(#{clear_timer := TRef1,
+                   release_quota_timer := TRef2,
+                   context := Context,
+                   wait_quotas := Waits} = State) ->
+    unload(),
+    ok = lists:foreach(fun(E) -> gen_server:reply(E, false) end, Waits),
+    ok = close_resource(Context),
+    State#{enable := false,
+           clear_timer := stop_timer(TRef1),
+           release_quota_timer := stop_timer(TRef2),
+           wait_quotas := []}.
+
+-spec stop_timer(undefined | reference()) -> undefined.
+stop_timer(undefined) ->
+    undefined;
+stop_timer(TimerRef) ->
+    _ = erlang:cancel_timer(TimerRef),
+    undefined.
+
+add_timer(0, _) ->
+    undefined;
+add_timer(undefined, _) ->
+    undefined;
+add_timer(Ms, Content) ->
+    erlang:send_after(Ms, self(), Content).
+
+check_timer(undefined, Ms, Context) ->
+    add_timer(Ms, Context);
+check_timer(Timer, 0, _) ->
+    stop_timer(Timer);
+check_timer(Timer, undefined, _) ->
+    stop_timer(Timer);
+check_timer(Timer, _, _) ->
+    Timer.
+
+-spec get_backend_module() -> backend().
+get_backend_module() ->
+    [#{type := Backend} | _] = emqx_config:get([?APP, connector]),
+    erlang:list_to_existing_atom(io_lib:format("~s_~s", [?APP, Backend])).
+
+create_resource(Context, #{type := mnesia, config := Cfg}) ->
+    emqx_retainer_mnesia:create_resource(Cfg),
+    Context;
+
+create_resource(Context, #{type := DB, config := Config}) ->
+    ResourceID = erlang:iolist_to_binary([io_lib:format("~s_~s", [?APP, DB])]),
+    case emqx_resource:create(
+           ResourceID,
+           list_to_existing_atom(io_lib:format("~s_~s", [emqx_connector, DB])),
+           Config) of
+        {ok, _} ->
+            Context#{resource_id => ResourceID};
+        {error, already_created} ->
+            Context#{resource_id => ResourceID};
+        {error, Reason} ->
+            error({load_config_error, Reason})
     end.
     end.
 
 
+-spec close_resource(context()) -> ok | {error, term()}.
+close_resource(#{resource_id := ResourceId}) ->
+    emqx_resource:stop(ResourceId);
+close_resource(_) ->
+    ok.
+
+-spec load(context()) -> ok.
+load(Context) ->
+    _ = emqx:hook('session.subscribed', {?MODULE, on_session_subscribed, [Context]}),
+    _ = emqx:hook('message.publish', {?MODULE, on_message_publish, [Context]}),
+    ok.
+
+unload() ->
+    emqx:unhook('message.publish', {?MODULE, on_message_publish}),
+    emqx:unhook('session.subscribed', {?MODULE, on_session_subscribed}).

+ 1 - 2
apps/emqx_retainer/src/emqx_retainer_app.erl

@@ -30,6 +30,5 @@ start(_Type, _Args) ->
     {ok, Sup}.
     {ok, Sup}.
 
 
 stop(_State) ->
 stop(_State) ->
-    emqx_retainer_cli:unload(),
-    emqx_retainer:unload().
+    emqx_retainer_cli:unload().
 
 

+ 0 - 20
apps/emqx_retainer/src/emqx_retainer_cli.erl

@@ -27,26 +27,6 @@
 load() ->
 load() ->
     emqx_ctl:register_command(retainer, {?MODULE, cmd}, []).
     emqx_ctl:register_command(retainer, {?MODULE, cmd}, []).
 
 
-cmd(["info"]) ->
-    emqx_ctl:print("retained/total: ~w~n", [mnesia:table_info(?TAB, size)]);
-
-cmd(["topics"]) ->
-    case mnesia:dirty_all_keys(?TAB) of
-        []     -> ignore;
-        Topics -> lists:foreach(fun(Topic) -> emqx_ctl:print("~s~n", [Topic]) end, Topics)
-    end;
-
-cmd(["clean"]) ->
-    Size = mnesia:table_info(?TAB, size),
-    case ekka_mnesia:clear_table(?TAB) of
-        {atomic, ok} -> emqx_ctl:print("Cleaned ~p retained messages~n", [Size]);
-        {aborted, R} -> emqx_ctl:print("Aborted ~p~n", [R])
-    end;
-
-cmd(["clean", Topic]) ->
-    Lines = emqx_retainer:clean(list_to_binary(Topic)),
-    emqx_ctl:print("Cleaned ~p retained messages~n", [Lines]);
-
 cmd(_) ->
 cmd(_) ->
     emqx_ctl:usage([{"retainer info",   "Show the count of retained messages"},
     emqx_ctl:usage([{"retainer info",   "Show the count of retained messages"},
                     {"retainer topics", "Show all topics of retained messages"},
                     {"retainer topics", "Show all topics of retained messages"},

+ 241 - 0
apps/emqx_retainer/src/emqx_retainer_mnesia.erl

@@ -0,0 +1,241 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_retainer_mnesia).
+
+-behaviour(emqx_retainer).
+
+-include("emqx_retainer.hrl").
+-include_lib("emqx/include/logger.hrl").
+-include_lib("stdlib/include/ms_transform.hrl").
+-include_lib("stdlib/include/qlc.hrl").
+
+-logger_header("[Retainer]").
+
+-export([delete_message/2
+        , store_retained/2
+        , read_message/2
+        , match_messages/3
+        , clear_expired/1
+        , clean/1]).
+
+-export([create_resource/1]).
+
+-define(DEF_MAX_RETAINED_MESSAGES, 0).
+
+-rlog_shard({?RETAINER_SHARD, ?TAB}).
+
+-record(retained, {topic, msg, expiry_time}).
+
+-type batch_read_result() ::
+        {ok, list(emqx:message()), cursor()}.
+
+%%--------------------------------------------------------------------
+%% emqx_retainer_storage callbacks
+%%--------------------------------------------------------------------
+create_resource(#{storage_type := StorageType}) ->
+    Copies = case StorageType of
+                 ram       -> ram_copies;
+                 disc      -> disc_copies;
+                 disc_only -> disc_only_copies
+             end,
+    StoreProps = [{ets, [compressed,
+                         {read_concurrency, true},
+                         {write_concurrency, true}]},
+                  {dets, [{auto_save, 1000}]}],
+    ok = ekka_mnesia:create_table(?TAB, [
+                {type, set},
+                {Copies, [node()]},
+                {record_name, retained},
+                {attributes, record_info(fields, retained)},
+                {storage_properties, StoreProps}]),
+    ok = ekka_mnesia:copy_table(?TAB, Copies),
+    ok = ekka_rlog:wait_for_shards([?RETAINER_SHARD], infinity),
+    case mnesia:table_info(?TAB, storage_type) of
+        Copies -> ok;
+        _Other ->
+            {atomic, ok} = mnesia:change_table_copy_type(?TAB, node(), Copies),
+            ok
+    end.
+
+store_retained(_, Msg =#message{topic = Topic}) ->
+    ExpiryTime = emqx_retainer:get_expiry_time(Msg),
+    case is_table_full() of
+        false ->
+            ok = emqx_metrics:inc('messages.retained'),
+            ekka_mnesia:dirty_write(?TAB,
+                                    #retained{topic = topic2tokens(Topic),
+                                              msg = Msg,
+                                              expiry_time = ExpiryTime});
+        _ ->
+            Tokens = topic2tokens(Topic),
+            Fun = fun() ->
+                          case mnesia:read(?TAB, Tokens) of
+                              [_] ->
+                                  mnesia:write(?TAB,
+                                               #retained{topic = Tokens,
+                                                         msg = Msg,
+                                                         expiry_time = ExpiryTime},
+                                               write);
+                              [] ->
+                                  ?LOG(error,
+                                       "Cannot retain message(topic=~s) for table is full!",
+                                       [Topic]),
+                                  ok
+                          end
+            end,
+            {atomic, ok} = ekka_mnesia:transaction(?RETAINER_SHARD, Fun),
+            ok
+    end.
+
+clear_expired(_) ->
+    NowMs = erlang:system_time(millisecond),
+    MsHd = #retained{topic = '$1', msg = '_', expiry_time = '$3'},
+    Ms = [{MsHd, [{'=/=', '$3', 0}, {'<', '$3', NowMs}], ['$1']}],
+    Fun = fun() ->
+                  Keys = mnesia:select(?TAB, Ms, write),
+                  lists:foreach(fun(Key) -> mnesia:delete({?TAB, Key}) end, Keys)
+          end,
+    {atomic, _} = ekka_mnesia:transaction(?RETAINER_SHARD, Fun),
+    ok.
+
+delete_message(_, Topic) ->
+    case emqx_topic:wildcard(Topic) of
+        true -> match_delete_messages(Topic);
+        false ->
+            Tokens = topic2tokens(Topic),
+            Fun = fun() ->
+                       mnesia:delete({?TAB, Tokens})
+                  end,
+            case ekka_mnesia:transaction(?RETAINER_SHARD, Fun) of
+                {atomic, Result} ->
+                    Result;
+                ok ->
+                    ok
+                end
+    end,
+    ok.
+
+read_message(_, Topic) ->
+    {ok, read_messages(Topic)}.
+
+match_messages(_, Topic, Cursor) ->
+    MaxReadNum = emqx_config:get([?APP, flow_control, max_read_number]),
+    case Cursor of
+        undefined ->
+            case MaxReadNum of
+                0 ->
+                    {ok, sort_retained(match_messages(Topic)), undefined};
+                _ ->
+                    start_batch_read(Topic, MaxReadNum)
+            end;
+        _ ->
+            batch_read_messages(Cursor, MaxReadNum)
+    end.
+
+clean(_) ->
+    ekka_mnesia:clear_table(?TAB),
+    ok.
+%%--------------------------------------------------------------------
+%% Internal functions
+%%--------------------------------------------------------------------
+sort_retained([]) -> [];
+sort_retained([Msg]) -> [Msg];
+sort_retained(Msgs)  ->
+    lists:sort(fun(#message{timestamp = Ts1}, #message{timestamp = Ts2}) ->
+                       Ts1 =< Ts2 end,
+               Msgs).
+
+%%--------------------------------------------------------------------
+%% Internal funcs
+%%--------------------------------------------------------------------
+topic2tokens(Topic) ->
+    emqx_topic:words(Topic).
+
+-spec start_batch_read(topic(), pos_integer()) -> batch_read_result().
+start_batch_read(Topic, MaxReadNum) ->
+    Ms = make_match_spec(Topic),
+    TabQH = ets:table(?TAB, [{traverse, {select, Ms}}]),
+    QH = qlc:q([E || E <- TabQH]),
+    Cursor = qlc:cursor(QH),
+    batch_read_messages(Cursor, MaxReadNum).
+
+-spec batch_read_messages(emqx_retainer_storage:cursor(), pos_integer()) -> batch_read_result().
+batch_read_messages(Cursor, MaxReadNum) ->
+    Answers = qlc:next_answers(Cursor, MaxReadNum),
+    Orders = sort_retained(Answers),
+    case erlang:length(Orders) < MaxReadNum of
+        true ->
+            qlc:delete_cursor(Cursor),
+            {ok, Orders, undefined};
+        _ ->
+            {ok, Orders, Cursor}
+    end.
+
+-spec(read_messages(emqx_types:topic())
+      -> [emqx_types:message()]).
+read_messages(Topic) ->
+    Tokens = topic2tokens(Topic),
+    case mnesia:dirty_read(?TAB, Tokens) of
+        [] -> [];
+        [#retained{msg = Msg, expiry_time = Et}] ->
+            case Et =:= 0 orelse Et >= erlang:system_time(millisecond) of
+                true -> [Msg];
+                false -> []
+            end
+    end.
+
+-spec(match_messages(emqx_types:topic())
+      -> [emqx_types:message()]).
+match_messages(Filter) ->
+    Ms = make_match_spec(Filter),
+    mnesia:dirty_select(?TAB, Ms).
+
+-spec(match_delete_messages(emqx_types:topic()) -> ok).
+match_delete_messages(Filter) ->
+    Cond = condition(emqx_topic:words(Filter)),
+    MsHd = #retained{topic = Cond, msg = '_', expiry_time = '_'},
+    Ms = [{MsHd, [], ['$_']}],
+    Rs = mnesia:dirty_select(?TAB, Ms),
+    lists:foreach(fun(R) -> ekka_mnesia:dirty_delete_object(?TAB, R) end, Rs).
+
+%% @private
+condition(Ws) ->
+    Ws1 = [case W =:= '+' of true -> '_'; _ -> W end || W <- Ws],
+    case lists:last(Ws1) =:= '#' of
+        false -> Ws1;
+        _ -> (Ws1 -- ['#']) ++ '_'
+    end.
+
+-spec make_match_spec(topic()) -> ets:match_spec().
+make_match_spec(Filter) ->
+    NowMs = erlang:system_time(millisecond),
+    Cond = condition(emqx_topic:words(Filter)),
+    MsHd = #retained{topic = Cond, msg = '$2', expiry_time = '$3'},
+    [{MsHd, [{'=:=', '$3', 0}], ['$2']},
+     {MsHd, [{'>', '$3', NowMs}], ['$2']}].
+
+-spec is_table_full() -> boolean().
+is_table_full() ->
+    [#{config := Cfg} | _] = emqx_config:get([?APP, connector]),
+    Limit = maps:get(max_retained_messages,
+                     Cfg,
+                     ?DEF_MAX_RETAINED_MESSAGES),
+    Limit > 0 andalso (table_size() >= Limit).
+
+-spec table_size() -> non_neg_integer().
+table_size() ->
+    mnesia:table_info(?TAB, size).

+ 182 - 0
apps/emqx_retainer/src/emqx_retainer_pool.erl

@@ -0,0 +1,182 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_retainer_pool).
+
+-behaviour(gen_server).
+
+-include_lib("emqx/include/logger.hrl").
+
+%% API
+-export([start_link/2,
+         async_submit/2]).
+
+%% gen_server callbacks
+-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
+         terminate/2, code_change/3, format_status/2]).
+
+-define(POOL, ?MODULE).
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+async_submit(Fun, Args) ->
+    cast({async_submit, {Fun, Args}}).
+
+%%--------------------------------------------------------------------
+%% @doc
+%% Starts the server
+%% @end
+%%--------------------------------------------------------------------
+-spec start_link(atom(), pos_integer()) -> {ok, Pid :: pid()} |
+          {error, Error :: {already_started, pid()}} |
+          {error, Error :: term()} |
+          ignore.
+start_link(Pool, Id) ->
+    gen_server:start_link({local, emqx_misc:proc_name(?MODULE, Id)},
+                          ?MODULE, [Pool, Id], [{hibernate_after, 1000}]).
+
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+
+%%--------------------------------------------------------------------
+%% @private
+%% @doc
+%% Initializes the server
+%% @end
+%%--------------------------------------------------------------------
+-spec init(Args :: term()) -> {ok, State :: term()} |
+          {ok, State :: term(), Timeout :: timeout()} |
+          {ok, State :: term(), hibernate} |
+          {stop, Reason :: term()} |
+          ignore.
+init([Pool, Id]) ->
+    true = gproc_pool:connect_worker(Pool, {Pool, Id}),
+    {ok, #{pool => Pool, id => Id}}.
+
+%%--------------------------------------------------------------------
+%% @private
+%% @doc
+%% Handling call messages
+%% @end
+%%--------------------------------------------------------------------
+-spec handle_call(Request :: term(), From :: {pid(), term()}, State :: term()) ->
+          {reply, Reply :: term(), NewState :: term()} |
+          {reply, Reply :: term(), NewState :: term(), Timeout :: timeout()} |
+          {reply, Reply :: term(), NewState :: term(), hibernate} |
+          {noreply, NewState :: term()} |
+          {noreply, NewState :: term(), Timeout :: timeout()} |
+          {noreply, NewState :: term(), hibernate} |
+          {stop, Reason :: term(), Reply :: term(), NewState :: term()} |
+          {stop, Reason :: term(), NewState :: term()}.
+handle_call(Req, _From, State) ->
+    ?LOG(error, "Unexpected call: ~p", [Req]),
+    {reply, ignored, State}.
+
+%%--------------------------------------------------------------------
+%% @private
+%% @doc
+%% Handling cast messages
+%% @end
+%%--------------------------------------------------------------------
+-spec handle_cast(Request :: term(), State :: term()) ->
+          {noreply, NewState :: term()} |
+          {noreply, NewState :: term(), Timeout :: timeout()} |
+          {noreply, NewState :: term(), hibernate} |
+          {stop, Reason :: term(), NewState :: term()}.
+handle_cast({async_submit, Task}, State) ->
+    try run(Task)
+    catch _:Error:Stacktrace ->
+            ?LOG(error, "Error: ~0p, ~0p", [Error, Stacktrace])
+    end,
+    {noreply, State};
+
+handle_cast(Msg, State) ->
+    ?LOG(error, "Unexpected cast: ~p", [Msg]),
+    {noreply, State}.
+
+%%--------------------------------------------------------------------
+%% @private
+%% @doc
+%% Handling all non call/cast messages
+%% @end
+%%--------------------------------------------------------------------
+-spec handle_info(Info :: timeout() | term(), State :: term()) ->
+          {noreply, NewState :: term()} |
+          {noreply, NewState :: term(), Timeout :: timeout()} |
+          {noreply, NewState :: term(), hibernate} |
+          {stop, Reason :: normal | term(), NewState :: term()}.
+handle_info(Info, State) ->
+    ?LOG(error, "Unexpected info: ~p", [Info]),
+    {noreply, State}.
+
+%%--------------------------------------------------------------------
+%% @private
+%% @doc
+%% This function is called by a gen_server when it is about to
+%% terminate. It should be the opposite of Module:init/1 and do any
+%% necessary cleaning up. When it returns, the gen_server terminates
+%% with Reason. The return value is ignored.
+%% @end
+%%--------------------------------------------------------------------
+-spec terminate(Reason :: normal | shutdown | {shutdown, term()} | term(),
+                State :: term()) -> any().
+terminate(_Reason, #{pool := Pool, id := Id}) ->
+    gproc_pool:disconnect_worker(Pool, {Pool, Id}).
+%%--------------------------------------------------------------------
+%% @private
+%% @doc
+%% Convert process state when code is changed
+%% @end
+%%--------------------------------------------------------------------
+-spec code_change(OldVsn :: term() | {down, term()},
+                  State :: term(),
+                  Extra :: term()) -> {ok, NewState :: term()} |
+          {error, Reason :: term()}.
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
+
+%%--------------------------------------------------------------------
+%% @private
+%% @doc
+%% This function is called for changing the form and appearance
+%% of gen_server status when it is returned from sys:get_status/1,2
+%% or when it appears in termination error logs.
+%% @end
+%%--------------------------------------------------------------------
+-spec format_status(Opt :: normal | terminate,
+                    Status :: list()) -> Status :: term().
+format_status(_Opt, Status) ->
+    Status.
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+%% @private
+cast(Msg) ->
+    gen_server:cast(worker(), Msg).
+
+%% @private
+worker() ->
+    gproc_pool:pick_worker(?POOL).
+
+run({M, F, A}) ->
+    erlang:apply(M, F, A);
+run({F, A}) when is_function(F), is_list(A) ->
+    erlang:apply(F, A);
+run(Fun) when is_function(Fun) ->
+    Fun().

+ 28 - 7
apps/emqx_retainer/src/emqx_retainer_schema.erl

@@ -2,20 +2,35 @@
 
 
 -include_lib("typerefl/include/types.hrl").
 -include_lib("typerefl/include/types.hrl").
 
 
--type storage_type() :: ram | disc | disc_only.
-
--reflect_type([storage_type/0]).
-
 -export([structs/0, fields/1]).
 -export([structs/0, fields/1]).
 
 
+-define(TYPE(Type), hoconsc:t(Type)).
+
 structs() -> ["emqx_retainer"].
 structs() -> ["emqx_retainer"].
 
 
 fields("emqx_retainer") ->
 fields("emqx_retainer") ->
     [ {enable, t(boolean(), false)}
     [ {enable, t(boolean(), false)}
-    , {storage_type, t(storage_type(), ram)}
-    , {max_retained_messages, t(integer(), 0, fun is_pos_integer/1)}
+    , {msg_expiry_interval, t(emqx_schema:duration_ms(), "0s")}
+    , {msg_clear_interval, t(emqx_schema:duration_ms(), "0s")}
+    , {connector, connector()}
+    , {flow_control, ?TYPE(hoconsc:ref(?MODULE, flow_control))}
     , {max_payload_size, t(emqx_schema:bytesize(), "1MB")}
     , {max_payload_size, t(emqx_schema:bytesize(), "1MB")}
-    , {expiry_interval, t(emqx_schema:duration_ms(), "0s")}
+    ];
+
+fields(mnesia_connector) ->
+    [ {type, ?TYPE(hoconsc:union([mnesia]))}
+    , {config, ?TYPE(hoconsc:ref(?MODULE, mnesia_connector_cfg))}
+    ];
+
+fields(mnesia_connector_cfg) ->
+    [ {storage_type, t(hoconsc:union([ram, disc, disc_only]), ram)}
+    , {max_retained_messages, t(integer(), 0, fun is_pos_integer/1)}
+    ];
+
+fields(flow_control) ->
+    [ {max_read_number, t(integer(), 0, fun is_pos_integer/1)}
+    , {msg_deliver_quota, t(integer(), 0, fun is_pos_integer/1)}
+    , {quota_release_interval, t(emqx_schema:duration_ms(), "0ms")}
     ].
     ].
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
@@ -28,5 +43,11 @@ t(Type, Default, Validator) ->
     hoconsc:t(Type, #{default => Default,
     hoconsc:t(Type, #{default => Default,
                       validator => Validator}).
                       validator => Validator}).
 
 
+union_array(Item) when is_list(Item) ->
+    hoconsc:array(hoconsc:union(Item)).
+
 is_pos_integer(V) ->
 is_pos_integer(V) ->
     V >= 0.
     V >= 0.
+
+connector() ->
+    #{type => union_array([hoconsc:ref(?MODULE, mnesia_connector)])}.

+ 4 - 2
apps/emqx_retainer/src/emqx_retainer_sup.erl

@@ -26,11 +26,13 @@ start_link() ->
     supervisor:start_link({local, ?MODULE}, ?MODULE, []).
     supervisor:start_link({local, ?MODULE}, ?MODULE, []).
 
 
 init([]) ->
 init([]) ->
+    PoolSpec = emqx_pool_sup:spec([emqx_retainer_pool, random, emqx_vm:schedulers(),
+                                   {emqx_retainer_pool, start_link, []}]),
     {ok, {{one_for_one, 10, 3600},
     {ok, {{one_for_one, 10, 3600},
           [#{id       => retainer,
           [#{id       => retainer,
              start    => {emqx_retainer, start_link, []},
              start    => {emqx_retainer, start_link, []},
              restart  => permanent,
              restart  => permanent,
              shutdown => 5000,
              shutdown => 5000,
              type     => worker,
              type     => worker,
-             modules  => [emqx_retainer]}]}}.
-
+             modules  => [emqx_retainer]},
+           PoolSpec]}}.

+ 19 - 13
apps/emqx_retainer/test/emqx_retainer_SUITE.erl

@@ -19,7 +19,7 @@
 -compile(export_all).
 -compile(export_all).
 -compile(nowarn_export_all).
 -compile(nowarn_export_all).
 
 
--define(APP, emqx).
+-define(APP, emqx_retainer).
 
 
 -include_lib("eunit/include/eunit.hrl").
 -include_lib("eunit/include/eunit.hrl").
 -include_lib("common_test/include/ct.hrl").
 -include_lib("common_test/include/ct.hrl").
@@ -39,27 +39,34 @@ end_per_suite(_Config) ->
     emqx_ct_helpers:stop_apps([emqx_retainer]).
     emqx_ct_helpers:stop_apps([emqx_retainer]).
 
 
 init_per_testcase(TestCase, Config) ->
 init_per_testcase(TestCase, Config) ->
-    emqx_retainer:clean(<<"#">>),
+    emqx_retainer:clean(),
     Interval = case TestCase of
     Interval = case TestCase of
                    t_message_expiry_2 -> 2000;
                    t_message_expiry_2 -> 2000;
                    _ -> 0
                    _ -> 0
                end,
                end,
-    init_emqx_retainer_conf(Interval),
+    OldCfg = emqx_config:get([?APP]),
+    emqx_config:put([?APP], OldCfg#{msg_expiry_interval := Interval}),
     application:ensure_all_started(emqx_retainer),
     application:ensure_all_started(emqx_retainer),
     Config.
     Config.
 
 
 set_special_configs(emqx_retainer) ->
 set_special_configs(emqx_retainer) ->
-    init_emqx_retainer_conf(0);
+    init_emqx_retainer_conf();
 set_special_configs(_) ->
 set_special_configs(_) ->
     ok.
     ok.
 
 
-init_emqx_retainer_conf(Expiry) ->
-    emqx_config:put([emqx_retainer],
+init_emqx_retainer_conf() ->
+    emqx_config:put([?APP],
                     #{enable => true,
                     #{enable => true,
-                      storage_type => ram,
-                      max_retained_messages => 0,
-                      max_payload_size => 1024 * 1024,
-                      expiry_interval => Expiry}).
+                      msg_expiry_interval => 0,
+                      msg_clear_interval => 0,
+                      connector => [#{type => mnesia,
+                                      config =>
+                                          #{max_retained_messages => 0,
+                                            storage_type => ram}}],
+                      flow_control => #{max_read_number => 0,
+                                        msg_deliver_quota => 0,
+                                        quota_release_interval => 0},
+                      max_payload_size => 1024 * 1024}).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Test Cases
 %% Test Cases
@@ -177,8 +184,8 @@ t_clean(_) ->
     {ok, #{}, [0]} = emqtt:subscribe(C1, <<"retained/#">>, [{qos, 0}, {rh, 0}]),
     {ok, #{}, [0]} = emqtt:subscribe(C1, <<"retained/#">>, [{qos, 0}, {rh, 0}]),
     ?assertEqual(3, length(receive_messages(3))),
     ?assertEqual(3, length(receive_messages(3))),
 
 
-    1 = emqx_retainer:clean(<<"retained/test/0">>),
-    2 = emqx_retainer:clean(<<"retained/+">>),
+    ok = emqx_retainer:delete(<<"retained/test/0">>),
+    ok = emqx_retainer:delete(<<"retained/+">>),
     {ok, #{}, [0]} = emqtt:subscribe(C1, <<"retained/#">>, [{qos, 0}, {rh, 0}]),
     {ok, #{}, [0]} = emqtt:subscribe(C1, <<"retained/#">>, [{qos, 0}, {rh, 0}]),
     ?assertEqual(0, length(receive_messages(3))),
     ?assertEqual(0, length(receive_messages(3))),
 
 
@@ -203,4 +210,3 @@ receive_messages(Count, Msgs) ->
     after 2000 ->
     after 2000 ->
             Msgs
             Msgs
     end.
     end.
-

+ 3 - 10
apps/emqx_retainer/test/emqx_retainer_api_SUITE.erl

@@ -56,21 +56,14 @@ init_per_testcase(_, Config) ->
     Config.
     Config.
 
 
 set_special_configs(emqx_retainer) ->
 set_special_configs(emqx_retainer) ->
-    init_emqx_retainer_conf(0);
+    emqx_retainer_SUITE:init_emqx_retainer_conf();
 set_special_configs(emqx_management) ->
 set_special_configs(emqx_management) ->
     emqx_config:put([emqx_management], #{listeners => [#{protocol => http, port => 8081}],
     emqx_config:put([emqx_management], #{listeners => [#{protocol => http, port => 8081}],
-        applications =>[#{id => "admin", secret => "public"}]}),
+                                         applications =>[#{id => "admin", secret => "public"}]}),
     ok;
     ok;
 set_special_configs(_) ->
 set_special_configs(_) ->
     ok.
     ok.
 
 
-init_emqx_retainer_conf(Expiry) ->
-    emqx_config:put([emqx_retainer],
-                    #{enable => true,
-                      storage_type => ram,
-                      max_retained_messages => 0,
-                      max_payload_size => 1024 * 1024,
-                      expiry_interval => Expiry}).
 %%------------------------------------------------------------------------------
 %%------------------------------------------------------------------------------
 %% Test Cases
 %% Test Cases
 %%------------------------------------------------------------------------------
 %%------------------------------------------------------------------------------
@@ -78,7 +71,7 @@ init_emqx_retainer_conf(Expiry) ->
 t_config(_Config) ->
 t_config(_Config) ->
     {ok, Return} = request_http_rest_lookup(["retainer"]),
     {ok, Return} = request_http_rest_lookup(["retainer"]),
     NowCfg = get_http_data(Return),
     NowCfg = get_http_data(Return),
-    NewCfg = NowCfg#{<<"expiry_interval">> => timer:seconds(60)},
+    NewCfg = NowCfg#{<<"msg_expiry_interval">> => timer:seconds(60)},
     RetainerConf = #{<<"emqx_retainer">> => NewCfg},
     RetainerConf = #{<<"emqx_retainer">> => NewCfg},
 
 
     {ok, _} = request_http_rest_update(["retainer?action=test"], RetainerConf),
     {ok, _} = request_http_rest_update(["retainer?action=test"], RetainerConf),

+ 1 - 7
apps/emqx_retainer/test/mqtt_protocol_v5_SUITE.erl

@@ -38,13 +38,7 @@ end_per_suite(_Config) ->
 %% Helpers
 %% Helpers
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 set_special_configs(emqx_retainer) ->
 set_special_configs(emqx_retainer) ->
-    emqx_config:put([emqx_retainer],
-                    #{enable => true,
-                      storage_type => ram,
-                      max_retained_messages => 0,
-                      max_payload_size => 1024 * 1024,
-                      expiry_interval => 0});
-
+    emqx_retainer_SUITE:init_emqx_retainer_conf();
 set_special_configs(_) ->
 set_special_configs(_) ->
     ok.
     ok.
 
 

+ 1 - 1
rebar.config

@@ -61,7 +61,7 @@
     , {getopt, "1.0.1"}
     , {getopt, "1.0.1"}
     , {snabbkaffe, {git, "https://github.com/kafka4beam/snabbkaffe.git", {tag, "0.13.0"}}}
     , {snabbkaffe, {git, "https://github.com/kafka4beam/snabbkaffe.git", {tag, "0.13.0"}}}
     , {hocon, {git, "https://github.com/emqx/hocon.git", {tag, "0.10.3"}}}
     , {hocon, {git, "https://github.com/emqx/hocon.git", {tag, "0.10.3"}}}
-    , {emqx_http_lib, {git, "https://github.com/emqx/emqx_http_lib.git", {tag, "0.2.1"}}}
+    , {emqx_http_lib, {git, "https://github.com/emqx/emqx_http_lib.git", {tag, "0.3.0"}}}
     , {esasl, {git, "https://github.com/emqx/esasl", {tag, "0.1.0"}}}
     , {esasl, {git, "https://github.com/emqx/esasl", {tag, "0.1.0"}}}
     ]}.
     ]}.