Просмотр исходного кода

test(gw): more testcases for emqx_gateway_cm_registry

JianBo He 4 лет назад
Родитель
Сommit
056e284bc2

+ 30 - 25
apps/emqx_gateway/src/emqx_gateway_cm_registry.erl

@@ -17,6 +17,8 @@
 %% @doc The gateway connection registry
 -module(emqx_gateway_cm_registry).
 
+-include("include/emqx_gateway.hrl").
+
 -behaviour(gen_server).
 
 -export([start_link/1]).
@@ -27,6 +29,8 @@
 
 -export([lookup_channels/2]).
 
+-export([tabname/1]).
+
 %% gen_server callbacks
 -export([ init/1
         , handle_call/3
@@ -41,39 +45,40 @@
 
 -record(channel, {chid, pid}).
 
-%% @doc Start the global channel registry.
--spec(start_link(atom()) -> gen_server:startlink_ret()).
-start_link(Type) ->
-    gen_server:start_link(?MODULE, [Type], []).
+%% @doc Start the global channel registry for the gived gateway name.
+-spec(start_link(gateway_name()) -> gen_server:startlink_ret()).
+start_link(Name) ->
+    gen_server:start_link(?MODULE, [Name], []).
 
--spec tabname(atom()) -> atom().
-tabname(Type) ->
-    list_to_atom(lists:concat([emqx_gateway_, Type, '_channel_registry'])).
+-spec tabname(gateway_name()) -> atom().
+tabname(Name) ->
+    %% XXX: unsafe ??
+    list_to_atom(lists:concat([emqx_gateway_, Name, '_channel_registry'])).
 
 %%--------------------------------------------------------------------
 %% APIs
 %%--------------------------------------------------------------------
 
 %% @doc Register a global channel.
--spec register_channel(atom(), binary() | {binary(), pid()}) -> ok.
-register_channel(Type, ClientId) when is_binary(ClientId) ->
-    register_channel(Type, {ClientId, self()});
+-spec register_channel(gateway_name(), binary() | {binary(), pid()}) -> ok.
+register_channel(Name, ClientId) when is_binary(ClientId) ->
+    register_channel(Name, {ClientId, self()});
 
-register_channel(Type, {ClientId, ChanPid}) when is_binary(ClientId), is_pid(ChanPid) ->
-    mria:dirty_write(tabname(Type), record(ClientId, ChanPid)).
+register_channel(Name, {ClientId, ChanPid}) when is_binary(ClientId), is_pid(ChanPid) ->
+    mria:dirty_write(tabname(Name), record(ClientId, ChanPid)).
 
 %% @doc Unregister a global channel.
--spec unregister_channel(atom(), binary() | {binary(), pid()}) -> ok.
-unregister_channel(Type, ClientId) when is_binary(ClientId) ->
-    unregister_channel(Type, {ClientId, self()});
+-spec unregister_channel(gateway_name(), binary() | {binary(), pid()}) -> ok.
+unregister_channel(Name, ClientId) when is_binary(ClientId) ->
+    unregister_channel(Name, {ClientId, self()});
 
-unregister_channel(Type, {ClientId, ChanPid}) when is_binary(ClientId), is_pid(ChanPid) ->
-    mria:dirty_delete_object(tabname(Type), record(ClientId, ChanPid)).
+unregister_channel(Name, {ClientId, ChanPid}) when is_binary(ClientId), is_pid(ChanPid) ->
+    mria:dirty_delete_object(tabname(Name), record(ClientId, ChanPid)).
 
 %% @doc Lookup the global channels.
--spec lookup_channels(atom(), binary()) -> list(pid()).
-lookup_channels(Type, ClientId) ->
-    [ChanPid || #channel{pid = ChanPid} <- mnesia:dirty_read(tabname(Type), ClientId)].
+-spec lookup_channels(gateway_name(), binary()) -> list(pid()).
+lookup_channels(Name, ClientId) ->
+    [ChanPid || #channel{pid = ChanPid} <- mnesia:dirty_read(tabname(Name), ClientId)].
 
 record(ClientId, ChanPid) ->
     #channel{chid = ClientId, pid = ChanPid}.
@@ -82,8 +87,8 @@ record(ClientId, ChanPid) ->
 %% gen_server callbacks
 %%--------------------------------------------------------------------
 
-init([Type]) ->
-    Tab = tabname(Type),
+init([Name]) ->
+    Tab = tabname(Name),
     ok = mria:create_table(Tab, [
                 {type, bag},
                 {rlog_shard, ?CM_SHARD},
@@ -94,7 +99,7 @@ init([Type]) ->
                                              {write_concurrency, true}]}]}]),
     ok = mria:wait_for_tables([Tab]),
     ok = ekka:monitor(membership),
-    {ok, #{type => Type}}.
+    {ok, #{name => Name}}.
 
 handle_call(Req, _From, State) ->
     logger:error("Unexpected call: ~p", [Req]),
@@ -104,8 +109,8 @@ handle_cast(Msg, State) ->
     logger:error("Unexpected cast: ~p", [Msg]),
     {noreply, State}.
 
-handle_info({membership, {mnesia, down, Node}}, State = #{type := Type}) ->
-    Tab = tabname(Type),
+handle_info({membership, {mnesia, down, Node}}, State = #{name := Name}) ->
+    Tab = tabname(Name),
     global:trans({?LOCK, self()},
                  fun() ->
                      mria:transaction(?CM_SHARD, fun cleanup_channels/2, [Node, Tab])

+ 97 - 0
apps/emqx_gateway/test/emqx_gateway_cm_registry_SUITE.erl

@@ -0,0 +1,97 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2022 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_gateway_cm_registry_SUITE).
+
+-include_lib("eunit/include/eunit.hrl").
+
+-compile(export_all).
+-compile(nowarn_export_all).
+
+-define(GWNAME, mqttsn).
+-define(CLIENTID, <<"client1">>).
+
+-define(CONF_DEFAULT, <<"gateway {}">>).
+
+%%--------------------------------------------------------------------
+%% setups
+%%--------------------------------------------------------------------
+
+all() -> emqx_common_test_helpers:all(?MODULE).
+
+init_per_suite(Conf) ->
+    emqx_config:erase(gateway),
+    emqx_config:init_load(emqx_gateway_schema, ?CONF_DEFAULT),
+    emqx_common_test_helpers:start_apps([]),
+    Conf.
+
+end_per_suite(_Conf) ->
+    emqx_common_test_helpers:stop_apps([]).
+
+init_per_testcase(_TestCase, Conf) ->
+    {ok, Pid} = emqx_gateway_cm_registry:start_link(?GWNAME),
+    [{registry, Pid} | Conf].
+
+end_per_testcase(_TestCase, Conf) ->
+    Pid = proplists:get_value(registry, Conf),
+    gen_server:stop(Pid),
+    Conf.
+
+%%--------------------------------------------------------------------
+%% cases
+%%--------------------------------------------------------------------
+
+t_tabname(_) ->
+    ?assertEqual(
+       emqx_gateway_gw_name_channel_registry,
+       emqx_gateway_cm_registry:tabname(gw_name)).
+
+t_register_unregister_channel(_) ->
+    ok = emqx_gateway_cm_registry:register_channel(?GWNAME, ?CLIENTID),
+    ?assertEqual(
+       [{channel, ?CLIENTID, self()}],
+       ets:tab2list(emqx_gateway_cm_registry:tabname(?GWNAME))),
+
+    ?assertEqual(
+       [self()],
+       emqx_gateway_cm_registry:lookup_channels(?GWNAME, ?CLIENTID)),
+
+    ok = emqx_gateway_cm_registry:unregister_channel(?GWNAME, ?CLIENTID),
+
+    ?assertEqual(
+       [], 
+       ets:tab2list(emqx_gateway_cm_registry:tabname(?GWNAME))),
+    ?assertEqual(
+       [],
+       emqx_gateway_cm_registry:lookup_channels(?GWNAME, ?CLIENTID)).
+
+t_cleanup_channels(Conf) ->
+    Pid = proplists:get_value(registry, Conf),
+    emqx_gateway_cm_registry:register_channel(?GWNAME, ?CLIENTID),
+    ?assertEqual(
+       [self()],
+       emqx_gateway_cm_registry:lookup_channels(?GWNAME, ?CLIENTID)),
+    Pid ! {membership, {mnesia, down, node()}},
+    ct:sleep(100),
+    ?assertEqual(
+       [],
+       emqx_gateway_cm_registry:lookup_channels(?GWNAME, ?CLIENTID)).
+
+t_unexpected_msg_handling(Conf) ->
+    Pid = proplists:get_value(registry, Conf),
+    _ = Pid ! unexpected_info,
+    ok = gen_server:cast(Pid, unexpected_cast),
+    ignored = gen_server:call(Pid, unexpected_call).